Initial commit

This commit is contained in:
sjk
2025-11-17 13:32:54 +08:00
commit e788eab6eb
1659 changed files with 171560 additions and 0 deletions

View File

@@ -0,0 +1,102 @@
package repository
import (
"dianshang/internal/model"
"gorm.io/gorm"
)
// AdminRepository 管理员仓库
type AdminRepository struct {
db *gorm.DB
}
// NewAdminRepository 创建管理员仓库
func NewAdminRepository(db *gorm.DB) *AdminRepository {
return &AdminRepository{db: db}
}
// Create 创建管理员
func (r *AdminRepository) Create(admin *model.AdminUser) error {
return r.db.Create(admin).Error
}
// GetByID 根据ID获取管理员
func (r *AdminRepository) GetByID(id uint) (*model.AdminUser, error) {
var admin model.AdminUser
err := r.db.First(&admin, id).Error
return &admin, err
}
// GetByIDWithRole 根据ID获取管理员包含角色信息
func (r *AdminRepository) GetByIDWithRole(id uint) (*model.AdminUser, error) {
var admin model.AdminUser
err := r.db.Preload("Role").First(&admin, id).Error
if err != nil {
return nil, err
}
// 不返回密码
admin.Password = ""
return &admin, nil
}
// GetByUsername 根据用户名获取管理员
func (r *AdminRepository) GetByUsername(username string) (*model.AdminUser, error) {
var admin model.AdminUser
err := r.db.Where("username = ?", username).First(&admin).Error
return &admin, err
}
// GetList 获取管理员列表
func (r *AdminRepository) GetList(page, pageSize int, keyword string) ([]model.AdminUser, int64, error) {
var admins []model.AdminUser
var total int64
query := r.db.Model(&model.AdminUser{}).Preload("Role")
// 关键词搜索
if keyword != "" {
query = query.Where("username LIKE ? OR nickname LIKE ? OR email LIKE ?",
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 分页查询
offset := (page - 1) * pageSize
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&admins).Error
// 清除密码字段
for i := range admins {
admins[i].Password = ""
}
return admins, total, err
}
// Update 更新管理员
func (r *AdminRepository) Update(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.AdminUser{}).Where("id = ?", id).Updates(updates).Error
}
// Delete 删除管理员(软删除)
func (r *AdminRepository) Delete(id uint) error {
return r.db.Delete(&model.AdminUser{}, id).Error
}
// GetAdminCount 获取管理员总数
func (r *AdminRepository) GetAdminCount() (int64, error) {
var count int64
err := r.db.Model(&model.AdminUser{}).Count(&count).Error
return count, err
}
// GetActiveAdminCount 获取活跃管理员总数
func (r *AdminRepository) GetActiveAdminCount() (int64, error) {
var count int64
err := r.db.Model(&model.AdminUser{}).Where("status = ?", 1).Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,178 @@
package repository
import (
"dianshang/internal/model"
"time"
"gorm.io/gorm"
)
// BannerRepository 轮播图仓储
type BannerRepository struct {
db *gorm.DB
}
// NewBannerRepository 创建轮播图仓储
func NewBannerRepository(db *gorm.DB) *BannerRepository {
return &BannerRepository{db: db}
}
// GetActiveBanners 获取有效的轮播图
func (r *BannerRepository) GetActiveBanners() ([]model.Banner, error) {
var banners []model.Banner
err := r.db.Where("status = ?", 1).Order("sort ASC").Find(&banners).Error
return banners, err
}
// GetBannerList 获取轮播图列表(分页)
func (r *BannerRepository) GetBannerList(page, pageSize int, status *int) ([]model.Banner, int64, error) {
var banners []model.Banner
var total int64
query := r.db.Model(&model.Banner{})
// 状态筛选
if status != nil {
query = query.Where("status = ?", *status)
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 分页查询
offset := (page - 1) * pageSize
err := query.Order("sort ASC, created_at DESC").
Offset(offset).
Limit(pageSize).
Find(&banners).Error
return banners, total, err
}
// GetBannerByID 根据ID获取轮播图
func (r *BannerRepository) GetBannerByID(id uint) (*model.Banner, error) {
var banner model.Banner
err := r.db.First(&banner, id).Error
if err != nil {
return nil, err
}
return &banner, nil
}
// CreateBanner 创建轮播图
func (r *BannerRepository) CreateBanner(banner *model.Banner) error {
return r.db.Create(banner).Error
}
// UpdateBanner 更新轮播图
func (r *BannerRepository) UpdateBanner(banner *model.Banner) error {
return r.db.Save(banner).Error
}
// DeleteBanner 删除轮播图
func (r *BannerRepository) DeleteBanner(id uint) error {
return r.db.Delete(&model.Banner{}, id).Error
}
// BatchDeleteBanners 批量删除轮播图
func (r *BannerRepository) BatchDeleteBanners(ids []uint) error {
return r.db.Delete(&model.Banner{}, ids).Error
}
// UpdateBannerStatus 更新轮播图状态
func (r *BannerRepository) UpdateBannerStatus(id uint, status int) error {
return r.db.Model(&model.Banner{}).Where("id = ?", id).Update("status", status).Error
}
// BatchUpdateBannerStatus 批量更新轮播图状态
func (r *BannerRepository) BatchUpdateBannerStatus(ids []uint, status int) error {
return r.db.Model(&model.Banner{}).Where("id IN ?", ids).Update("status", status).Error
}
// UpdateBannerSort 更新轮播图排序
func (r *BannerRepository) UpdateBannerSort(id uint, sort int) error {
return r.db.Model(&model.Banner{}).Where("id = ?", id).Update("sort", sort).Error
}
// BatchUpdateBannerSort 批量更新轮播图排序
func (r *BannerRepository) BatchUpdateBannerSort(sortData []map[string]interface{}) error {
return r.db.Transaction(func(tx *gorm.DB) error {
for _, data := range sortData {
if err := tx.Model(&model.Banner{}).
Where("id = ?", data["id"]).
Update("sort", data["sort"]).Error; err != nil {
return err
}
}
return nil
})
}
// GetBannersByDateRange 根据日期范围获取轮播图
func (r *BannerRepository) GetBannersByDateRange(startDate, endDate time.Time) ([]model.Banner, error) {
var banners []model.Banner
err := r.db.Where("created_at BETWEEN ? AND ?", startDate, endDate).
Order("sort ASC, created_at DESC").
Find(&banners).Error
return banners, err
}
// GetBannersByStatus 根据状态获取轮播图
func (r *BannerRepository) GetBannersByStatus(status int) ([]model.Banner, error) {
var banners []model.Banner
err := r.db.Where("status = ?", status).
Order("sort ASC, created_at DESC").
Find(&banners).Error
return banners, err
}
// GetBannerCount 获取轮播图总数
func (r *BannerRepository) GetBannerCount() (int64, error) {
var count int64
err := r.db.Model(&model.Banner{}).Count(&count).Error
return count, err
}
// GetBannerCountByStatus 根据状态获取轮播图数量
func (r *BannerRepository) GetBannerCountByStatus(status int) (int64, error) {
var count int64
err := r.db.Model(&model.Banner{}).Where("status = ?", status).Count(&count).Error
return count, err
}
// GetMaxSort 获取最大排序值
func (r *BannerRepository) GetMaxSort() (int, error) {
var maxSort int
err := r.db.Model(&model.Banner{}).Select("COALESCE(MAX(sort), 0)").Scan(&maxSort).Error
return maxSort, err
}
// CheckBannerExists 检查轮播图是否存在
func (r *BannerRepository) CheckBannerExists(id uint) (bool, error) {
var count int64
err := r.db.Model(&model.Banner{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// GetExpiredBanners 获取过期的轮播图
func (r *BannerRepository) GetExpiredBanners() ([]model.Banner, error) {
var banners []model.Banner
now := time.Now()
err := r.db.Where("end_time IS NOT NULL AND end_time < ? AND status = 1", now).
Find(&banners).Error
return banners, err
}
// GetActiveBannersWithTimeRange 获取在时间范围内有效的轮播图
func (r *BannerRepository) GetActiveBannersWithTimeRange() ([]model.Banner, error) {
var banners []model.Banner
now := time.Now()
err := r.db.Where("status = 1").
Where("(start_time IS NULL OR start_time <= ?)"+
" AND (end_time IS NULL OR end_time >= ?)", now, now).
Order("sort ASC").
Find(&banners).Error
return banners, err
}

View File

@@ -0,0 +1,289 @@
package repository
import (
"fmt"
"gorm.io/gorm"
"dianshang/internal/model"
)
type CommentRepository struct {
db *gorm.DB
}
func NewCommentRepository(db *gorm.DB) *CommentRepository {
return &CommentRepository{db: db}
}
// Create 创建评论
func (r *CommentRepository) Create(comment *model.Comment) error {
return r.db.Create(comment).Error
}
// GetByID 根据ID获取评论
func (r *CommentRepository) GetByID(id uint) (*model.Comment, error) {
var comment model.Comment
err := r.db.Preload("User").Preload("Product").Preload("Order").Preload("Replies.User").First(&comment, id).Error
return &comment, err
}
// GetByProductID 根据商品ID获取评论列表
func (r *CommentRepository) GetByProductID(productID uint, offset, limit int, rating int) ([]model.Comment, int64, error) {
var comments []model.Comment
var total int64
query := r.db.Model(&model.Comment{}).Where("product_id = ? AND status = ?", productID, 1)
// 按评分筛选
if rating > 0 && rating <= 5 {
query = query.Where("rating = ?", rating)
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取评论列表
err := query.Preload("User").Preload("Replies", "status = ?", 1).Preload("Replies.User").
Order("created_at DESC").Offset(offset).Limit(limit).Find(&comments).Error
return comments, total, err
}
// GetByUserID 根据用户ID获取评论列表
func (r *CommentRepository) GetByUserID(userID uint, offset, limit int) ([]model.Comment, int64, error) {
var comments []model.Comment
var total int64
query := r.db.Model(&model.Comment{}).Where("user_id = ? AND status = ?", userID, 1)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取评论列表
err := query.Preload("Product").Preload("Replies", "status = ?", 1).Preload("Replies.User").
Order("created_at DESC").Offset(offset).Limit(limit).Find(&comments).Error
return comments, total, err
}
// GetByOrderItemID 根据订单项ID获取评论
func (r *CommentRepository) GetByOrderItemID(orderItemID uint) (*model.Comment, error) {
var comment model.Comment
err := r.db.Where("order_item_id = ? AND status = ?", orderItemID, 1).First(&comment).Error
if err != nil {
return nil, err
}
return &comment, nil
}
// Update 更新评论
func (r *CommentRepository) Update(comment *model.Comment) error {
return r.db.Save(comment).Error
}
// Delete 删除评论(软删除)
func (r *CommentRepository) Delete(id uint) error {
return r.db.Model(&model.Comment{}).Where("id = ?", id).Update("status", 3).Error
}
// GetStats 获取商品评论统计
func (r *CommentRepository) GetStats(productID uint) (*model.CommentStats, error) {
var stats model.CommentStats
stats.ProductID = productID
// 获取总评论数和平均评分
err := r.db.Model(&model.Comment{}).
Select("COUNT(*) as total_count, AVG(rating) as average_rating").
Where("product_id = ? AND status = ?", productID, 1).
Scan(&stats).Error
if err != nil {
return nil, err
}
// 获取各星级评论数
var ratingCounts []struct {
Rating int `json:"rating"`
Count int `json:"count"`
}
err = r.db.Model(&model.Comment{}).
Select("rating, COUNT(*) as count").
Where("product_id = ? AND status = ?", productID, 1).
Group("rating").
Scan(&ratingCounts).Error
if err != nil {
return nil, err
}
// 填充各星级评论数
for _, rc := range ratingCounts {
switch rc.Rating {
case 1:
stats.Rating1Count = rc.Count
case 2:
stats.Rating2Count = rc.Count
case 3:
stats.Rating3Count = rc.Count
case 4:
stats.Rating4Count = rc.Count
case 5:
stats.Rating5Count = rc.Count
}
}
// 获取带图评论数
var hasImagesCount int64
err = r.db.Model(&model.Comment{}).
Where("product_id = ? AND status = ? AND images != '' AND images != '[]'", productID, 1).
Count(&hasImagesCount).Error
if err != nil {
return nil, err
}
stats.HasImagesCount = int(hasImagesCount)
return &stats, nil
}
// GetList 获取评论列表(管理端使用)
func (r *CommentRepository) GetList(offset, limit int, conditions map[string]interface{}) ([]model.Comment, int64, error) {
var comments []model.Comment
var total int64
query := r.db.Model(&model.Comment{})
// 添加查询条件
for key, value := range conditions {
switch key {
case "product_id":
query = query.Where("product_id = ?", value)
case "user_id":
query = query.Where("user_id = ?", value)
case "rating":
query = query.Where("rating = ?", value)
case "status":
query = query.Where("status = ?", value)
case "keyword":
query = query.Where("content LIKE ?", "%"+value.(string)+"%")
}
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取评论列表
err := query.Preload("User").Preload("Product").Preload("Order").
Order("created_at DESC").Offset(offset).Limit(limit).Find(&comments).Error
return comments, total, err
}
// CreateReply 创建评论回复
func (r *CommentRepository) CreateReply(reply *model.CommentReply) error {
tx := r.db.Begin()
// 创建回复
if err := tx.Create(reply).Error; err != nil {
tx.Rollback()
return err
}
// 更新评论回复数量
if err := tx.Model(&model.Comment{}).Where("id = ?", reply.CommentID).
UpdateColumn("reply_count", gorm.Expr("reply_count + ?", 1)).Error; err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
// GetReplies 获取评论回复列表
func (r *CommentRepository) GetReplies(commentID uint) ([]model.CommentReply, error) {
var replies []model.CommentReply
err := r.db.Where("comment_id = ? AND status = ?", commentID, 1).
Preload("User").Order("created_at ASC").Find(&replies).Error
return replies, err
}
// LikeComment 点赞评论
func (r *CommentRepository) LikeComment(commentID, userID uint) error {
tx := r.db.Begin()
// 检查是否已点赞
var count int64
if err := tx.Model(&model.CommentLike{}).Where("comment_id = ? AND user_id = ?", commentID, userID).Count(&count).Error; err != nil {
tx.Rollback()
return err
}
if count > 0 {
tx.Rollback()
return fmt.Errorf("已经点赞过了")
}
// 创建点赞记录
like := &model.CommentLike{
CommentID: commentID,
UserID: userID,
}
if err := tx.Create(like).Error; err != nil {
tx.Rollback()
return err
}
// 更新评论点赞数量
if err := tx.Model(&model.Comment{}).Where("id = ?", commentID).
UpdateColumn("like_count", gorm.Expr("like_count + ?", 1)).Error; err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
// UnlikeComment 取消点赞评论
func (r *CommentRepository) UnlikeComment(commentID, userID uint) error {
tx := r.db.Begin()
// 删除点赞记录
if err := tx.Where("comment_id = ? AND user_id = ?", commentID, userID).Delete(&model.CommentLike{}).Error; err != nil {
tx.Rollback()
return err
}
// 更新评论点赞数量
if err := tx.Model(&model.Comment{}).Where("id = ?", commentID).
UpdateColumn("like_count", gorm.Expr("like_count - ?", 1)).Error; err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
// UpdateProductStats 更新商品评论统计
func (r *CommentRepository) UpdateProductStats(productID uint) error {
var stats struct {
Count int `json:"count"`
Rating float64 `json:"rating"`
}
err := r.db.Model(&model.Comment{}).
Select("COUNT(*) as count, AVG(rating) as rating").
Where("product_id = ? AND status = ?", productID, 1).
Scan(&stats).Error
if err != nil {
return err
}
return r.db.Model(&model.Product{}).Where("id = ?", productID).
Updates(map[string]interface{}{
"comment_count": stats.Count,
"average_rating": stats.Rating,
}).Error
}

View File

@@ -0,0 +1,123 @@
package repository
import (
"dianshang/internal/model"
"time"
"gorm.io/gorm"
)
// CouponRepository 优惠券仓储
type CouponRepository struct {
db *gorm.DB
}
// NewCouponRepository 创建优惠券仓储
func NewCouponRepository(db *gorm.DB) *CouponRepository {
return &CouponRepository{db: db}
}
// GetAvailableCoupons 获取可用优惠券列表
func (r *CouponRepository) GetAvailableCoupons() ([]model.Coupon, error) {
var coupons []model.Coupon
now := time.Now()
err := r.db.Where("status = ? AND start_time <= ? AND end_time >= ?", 1, now, now).
Where("total_count = 0 OR used_count < total_count").
Order("created_at DESC").
Find(&coupons).Error
return coupons, err
}
// GetUserCoupons 获取用户优惠券
func (r *CouponRepository) GetUserCoupons(userID uint, status int) ([]model.UserCoupon, error) {
var userCoupons []model.UserCoupon
query := r.db.Preload("Coupon").Where("user_id = ?", userID)
if status > 0 {
// API状态值到数据库状态值的映射
// API: 1=未使用,2=已使用,3=已过期
// DB: 0=未使用,1=已使用,2=已过期
dbStatus := status - 1
query = query.Where("status = ?", dbStatus)
}
err := query.Order("created_at DESC").Find(&userCoupons).Error
return userCoupons, err
}
// GetByID 根据ID获取优惠券
func (r *CouponRepository) GetByID(id uint) (*model.Coupon, error) {
var coupon model.Coupon
err := r.db.First(&coupon, id).Error
return &coupon, err
}
// CheckUserCouponExists 检查用户是否已领取优惠券
func (r *CouponRepository) CheckUserCouponExists(userID, couponID uint) (bool, error) {
var count int64
err := r.db.Model(&model.UserCoupon{}).
Where("user_id = ? AND coupon_id = ?", userID, couponID).
Count(&count).Error
return count > 0, err
}
// GetUserCouponCount 获取用户领取某优惠券的数量
func (r *CouponRepository) GetUserCouponCount(userID, couponID uint) (int, error) {
var count int64
err := r.db.Model(&model.UserCoupon{}).
Where("user_id = ? AND coupon_id = ?", userID, couponID).
Count(&count).Error
return int(count), err
}
// CreateUserCoupon 创建用户优惠券记录
func (r *CouponRepository) CreateUserCoupon(userCoupon *model.UserCoupon) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 创建用户优惠券记录
if err := tx.Create(userCoupon).Error; err != nil {
return err
}
// 更新优惠券使用数量
return tx.Model(&model.Coupon{}).
Where("id = ?", userCoupon.CouponID).
UpdateColumn("used_count", gorm.Expr("used_count + ?", 1)).Error
})
}
// GetUserCouponByID 根据ID获取用户优惠券
func (r *CouponRepository) GetUserCouponByID(id uint) (*model.UserCoupon, error) {
var userCoupon model.UserCoupon
err := r.db.Preload("Coupon").First(&userCoupon, id).Error
return &userCoupon, err
}
// GetUserCouponByOrderID 根据订单ID获取用户优惠券
func (r *CouponRepository) GetUserCouponByOrderID(orderID uint) (*model.UserCoupon, error) {
var userCoupon model.UserCoupon
err := r.db.Preload("Coupon").Where("order_id = ?", orderID).First(&userCoupon).Error
return &userCoupon, err
}
// UseCoupon 使用优惠券
func (r *CouponRepository) UseCoupon(userCouponID, orderID uint) error {
now := time.Now()
return r.db.Model(&model.UserCoupon{}).
Where("id = ?", userCouponID).
Updates(map[string]interface{}{
"status": 1, // 已使用
"order_id": orderID,
"used_time": &now,
}).Error
}
// RestoreCoupon 恢复优惠券(取消订单时使用)
func (r *CouponRepository) RestoreCoupon(userCouponID uint) error {
return r.db.Model(&model.UserCoupon{}).
Where("id = ?", userCouponID).
Updates(map[string]interface{}{
"status": 0, // 恢复为未使用
"order_id": nil, // 清除订单关联
"used_time": nil, // 清除使用时间
}).Error
}

View File

@@ -0,0 +1,621 @@
package repository
import (
"dianshang/internal/model"
"fmt"
"time"
"gorm.io/gorm"
)
// OrderRepository 订单仓储
type OrderRepository struct {
db *gorm.DB
}
// NewOrderRepository 创建订单仓储
func NewOrderRepository(db *gorm.DB) *OrderRepository {
return &OrderRepository{db: db}
}
// Create 创建订单
func (r *OrderRepository) Create(order *model.Order) error {
return r.db.Create(order).Error
}
// GetByID 根据ID获取订单
func (r *OrderRepository) GetByID(id uint) (*model.Order, error) {
var order model.Order
fmt.Printf("🔍 [OrderRepository.GetByID] 开始查询订单 - 查询ID: %d\n", id)
// 启用 SQL 调试
db := r.db.Debug()
err := db.Preload("OrderItems").Preload("OrderItems.Product").Preload("OrderItems.SKU").Preload("Store").Where("id = ?", id).First(&order).Error
if err == nil {
fmt.Printf("✅ [OrderRepository.GetByID] 查询成功 - ID: %d, 状态: %d, 订单号: %s\n", order.ID, order.Status, order.OrderNo)
// 直接查询验证
var directOrder model.Order
r.db.Raw("SELECT id, order_no, status FROM ai_orders WHERE id = ?", id).Scan(&directOrder)
fmt.Printf("🔍 [直接SQL查询] ID: %d, 状态: %d, 订单号: %s\n", directOrder.ID, directOrder.Status, directOrder.OrderNo)
} else {
fmt.Printf("❌ [OrderRepository.GetByID] 查询失败 - ID: %d, 错误: %v\n", id, err)
}
return &order, err
}
// GetByOrderNo 根据订单号获取订单
func (r *OrderRepository) GetByOrderNo(orderNo string) (*model.Order, error) {
var order model.Order
err := r.db.Preload("OrderItems").Preload("OrderItems.Product").Preload("OrderItems.SKU").Preload("Store").Where("order_no = ?", orderNo).First(&order).Error
return &order, err
}
// GetUserOrders 获取用户订单列表
func (r *OrderRepository) GetUserOrders(userID uint, status int, offset, limit int) ([]model.Order, int64, error) {
var orders []model.Order
var total int64
query := r.db.Model(&model.Order{}).Where("user_id = ?", userID)
if status > 0 {
// 前端状态映射:
// 前端status=3表示"待发货"对应数据库status=2已付款/待发货和status=3待发货
if status == 3 {
query = query.Where("status IN ?", []int{2, 3})
} else {
query = query.Where("status = ?", status)
}
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表,预加载订单项和店铺信息
err := query.Preload("OrderItems").Preload("OrderItems.Product").Preload("OrderItems.Product.SKUs", "status = ?", 1).Preload("OrderItems.SKU").Preload("Store").
Offset(offset).Limit(limit).Order("created_at DESC").Find(&orders).Error
return orders, total, err
}
// GetUserOrderStatistics 获取用户订单统计
func (r *OrderRepository) GetUserOrderStatistics(userID uint) (map[string]interface{}, error) {
var result struct {
OrderCount int64 `json:"order_count"`
TotalAmount float64 `json:"total_amount"`
}
// 查询用户的订单数量和总消费金额
err := r.db.Model(&model.Order{}).
Select("COUNT(*) as order_count, COALESCE(SUM(total_amount), 0) as total_amount").
Where("user_id = ? AND status != ?", userID, 0). // 排除已取消的订单
Scan(&result).Error
if err != nil {
return nil, err
}
return map[string]interface{}{
"order_count": result.OrderCount,
"total_amount": result.TotalAmount,
}, nil
}
// Update 更新订单
func (r *OrderRepository) Update(id uint, updates map[string]interface{}) error {
fmt.Printf("🔄 [OrderRepository.Update] 执行数据库更新订单ID: %d\n", id)
fmt.Printf("🔄 [OrderRepository.Update] 更新字段: %+v\n", updates)
result := r.db.Model(&model.Order{}).Where("id = ?", id).Updates(updates)
fmt.Printf("🔄 [OrderRepository.Update] 影响行数: %d\n", result.RowsAffected)
fmt.Printf("🔄 [OrderRepository.Update] 错误信息: %v\n", result.Error)
return result.Error
}
// UpdateByID 根据ID更新订单记录
func (r *OrderRepository) UpdateByID(orderID uint, updates map[string]interface{}) error {
return r.Update(orderID, updates)
}
// UpdateByOrderNo 根据订单号更新订单
func (r *OrderRepository) UpdateByOrderNo(orderNo string, updates map[string]interface{}) error {
fmt.Printf("🔄 [UpdateByOrderNo] 执行数据库更新,订单号: %s\n", orderNo)
fmt.Printf("🔄 [UpdateByOrderNo] 更新字段: %+v\n", updates)
result := r.db.Model(&model.Order{}).Where("order_no = ?", orderNo).Updates(updates)
fmt.Printf("🔄 [UpdateByOrderNo] 影响行数: %d\n", result.RowsAffected)
fmt.Printf("🔄 [UpdateByOrderNo] 错误信息: %v\n", result.Error)
return result.Error
}
// Delete 删除订单(软删除)
func (r *OrderRepository) Delete(id uint) error {
return r.db.Where("id = ?", id).Delete(&model.Order{}).Error
}
// GetList 获取订单列表(管理员)
func (r *OrderRepository) GetList(offset, limit int, conditions map[string]interface{}) ([]model.Order, int64, error) {
var orders []model.Order
var total int64
query := r.db.Model(&model.Order{})
// 添加查询条件
for key, value := range conditions {
switch key {
case "status":
query = query.Where("status = ?", value)
case "user_id":
query = query.Where("user_id = ?", value)
case "order_no":
query = query.Where("order_no LIKE ?", "%"+value.(string)+"%")
case "start_date":
query = query.Where("created_at >= ?", value)
case "end_date":
query = query.Where("created_at <= ?", value)
}
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表预加载订单项、商品信息、SKU信息、用户信息和店铺信息
err := query.Preload("OrderItems").Preload("OrderItems.Product").Preload("OrderItems.SKU").Preload("User").Preload("Store").
Offset(offset).Limit(limit).Order("created_at DESC").Find(&orders).Error
return orders, total, err
}
// CreateOrderItem 创建订单项
func (r *OrderRepository) CreateOrderItem(item *model.OrderItem) error {
return r.db.Create(item).Error
}
// GetOrderItems 获取订单项列表
func (r *OrderRepository) GetOrderItems(orderID uint) ([]model.OrderItem, error) {
var items []model.OrderItem
err := r.db.Preload("Product").Where("order_id = ?", orderID).Find(&items).Error
return items, err
}
// UpdateOrderItem 更新订单项
func (r *OrderRepository) UpdateOrderItem(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.OrderItem{}).Where("id = ?", id).Updates(updates).Error
}
// GetCart 获取购物车
func (r *OrderRepository) GetCart(userID uint) ([]model.Cart, error) {
var cart []model.Cart
err := r.db.Preload("Product").Preload("Product.SKUs", "status = ?", 1).Preload("SKU").Where("user_id = ?", userID).Find(&cart).Error
return cart, err
}
// GetCartItem 获取购物车项
func (r *OrderRepository) GetCartItem(userID, productID uint) (*model.Cart, error) {
var cart model.Cart
err := r.db.Where("user_id = ? AND product_id = ?", userID, productID).First(&cart).Error
return &cart, err
}
// GetCartItemBySKU 根据SKU获取购物车项
func (r *OrderRepository) GetCartItemBySKU(userID, productID, skuID uint) (*model.Cart, error) {
var cart model.Cart
query := r.db.Where("user_id = ? AND product_id = ?", userID, productID)
fmt.Printf("🔍 [GetCartItemBySKU] 查询条件 - 用户ID: %d, 产品ID: %d, SKU ID: %d\n",
userID, productID, skuID)
if skuID > 0 {
// 查找指定的SKU ID
query = query.Where("sk_uid = ?", skuID)
fmt.Printf("🔍 [GetCartItemBySKU] 查找指定SKU: %d\n", skuID)
} else {
// 查找没有SKU的商品sk_uid为NULL或0
query = query.Where("sk_uid IS NULL OR sk_uid = 0")
fmt.Printf("🔍 [GetCartItemBySKU] 查找无SKU商品 (sk_uid IS NULL OR sk_uid = 0)\n")
}
err := query.First(&cart).Error
if err != nil {
fmt.Printf("❌ [GetCartItemBySKU] 查询失败: %v\n", err)
return nil, err
}
fmt.Printf("✅ [GetCartItemBySKU] 找到购物车项 - ID: %d, SKU ID: %v\n",
cart.ID, cart.SKUID)
return &cart, nil
}
// AddToCart 添加到购物车
func (r *OrderRepository) AddToCart(cart *model.Cart) error {
return r.db.Create(cart).Error
}
// UpdateCartItem 更新购物车项
func (r *OrderRepository) UpdateCartItem(id uint, quantity int) error {
return r.db.Model(&model.Cart{}).Where("id = ?", id).Update("quantity", quantity).Error
}
// RemoveFromCart 从购物车移除
func (r *OrderRepository) RemoveFromCart(userID, productID uint) error {
return r.db.Where("user_id = ? AND product_id = ?", userID, productID).Delete(&model.Cart{}).Error
}
// RemoveFromCartBySKU 根据SKU从购物车移除
func (r *OrderRepository) RemoveFromCartBySKU(userID, productID, skuID uint) error {
query := r.db.Where("user_id = ? AND product_id = ?", userID, productID)
if skuID > 0 {
// 删除指定的SKU ID使用sk_uid字段
query = query.Where("sk_uid = ?", skuID)
} else {
// 删除没有SKU的商品sk_uid为NULL或0
query = query.Where("sk_uid IS NULL OR sk_uid = 0")
}
return query.Delete(&model.Cart{}).Error
}
// ClearCart 清空购物车
func (r *OrderRepository) ClearCart(userID uint) error {
return r.db.Where("user_id = ?", userID).Delete(&model.Cart{}).Error
}
// GetOrderStatistics 获取订单统计
func (r *OrderRepository) GetOrderStatistics() (map[string]interface{}, error) {
var result map[string]interface{} = make(map[string]interface{})
// 总订单数
var totalOrders int64
r.db.Model(&model.Order{}).Count(&totalOrders)
result["total_orders"] = totalOrders
// 待付款订单数
var pendingOrders int64
r.db.Model(&model.Order{}).Where("status = ?", 1).Count(&pendingOrders)
result["pending_orders"] = pendingOrders
// 待发货订单数状态2和3都是待发货
var toShipOrders int64
r.db.Model(&model.Order{}).Where("status IN (?)", []int{2, 3}).Count(&toShipOrders)
result["to_ship_orders"] = toShipOrders
// 已发货订单数
var shippedOrders int64
r.db.Model(&model.Order{}).Where("status = ?", 4).Count(&shippedOrders)
result["shipped_orders"] = shippedOrders
// 已完成订单数
var completedOrders int64
r.db.Model(&model.Order{}).Where("status = ?", 6).Count(&completedOrders)
result["completed_orders"] = completedOrders
// 总销售额(包含待发货、已发货、待收货、已完成状态)
var totalAmount float64
r.db.Model(&model.Order{}).Where("status IN (?)", []int{2, 3, 4, 5, 6}).Select("COALESCE(SUM(total_amount), 0)").Scan(&totalAmount)
result["total_amount"] = totalAmount
return result, nil
}
// GetDailyOrderStatistics 获取每日订单统计
func (r *OrderRepository) GetDailyOrderStatistics(days int) ([]map[string]interface{}, error) {
var results []map[string]interface{}
query := `
SELECT
DATE(created_at) as date,
COUNT(*) as order_count,
COALESCE(SUM(total_amount), 0) as total_amount
FROM ai_orders
WHERE created_at >= DATE_SUB(CURDATE(), INTERVAL ? DAY)
GROUP BY DATE(created_at)
ORDER BY date DESC
`
err := r.db.Raw(query, days).Scan(&results).Error
return results, err
}
// SelectCartItem 选择/取消选择购物车项
func (r *OrderRepository) SelectCartItem(userID, cartID uint, selected bool) error {
return r.db.Model(&model.Cart{}).
Where("id = ? AND user_id = ?", cartID, userID).
Update("selected", selected).Error
}
// SelectAllCartItems 全选/取消全选购物车
func (r *OrderRepository) SelectAllCartItems(userID uint, selected bool) error {
return r.db.Model(&model.Cart{}).
Where("user_id = ?", userID).
Update("selected", selected).Error
}
// UpdateOrderStatus 更新订单状态
func (r *OrderRepository) UpdateOrderStatus(orderID uint, status int) error {
updates := map[string]interface{}{
"status": status,
}
// 根据状态设置相应的时间字段
now := time.Now()
switch status {
case 2: // 已支付
updates["paid_at"] = now
case 3: // 已发货
updates["shipped_at"] = now
case 4: // 已完成
updates["completed_at"] = now
case 5: // 已取消
updates["cancelled_at"] = now
case 6: // 已退款
updates["refunded_at"] = now
}
return r.db.Model(&model.Order{}).Where("id = ?", orderID).Updates(updates).Error
}
// BatchUpdateOrderStatus 批量更新订单状态
func (r *OrderRepository) BatchUpdateOrderStatus(orderIDs []uint, status int) error {
updates := map[string]interface{}{
"status": status,
}
// 根据状态设置相应的时间字段
now := time.Now()
switch status {
case 2: // 已支付
updates["paid_at"] = now
case 3: // 已发货
updates["shipped_at"] = now
case 4: // 已完成
updates["completed_at"] = now
case 5: // 已取消
updates["cancelled_at"] = now
case 6: // 已退款
updates["refunded_at"] = now
}
return r.db.Model(&model.Order{}).Where("id IN ?", orderIDs).Updates(updates).Error
}
// GetOrdersByDateRange 根据日期范围获取订单
func (r *OrderRepository) GetOrdersByDateRange(startDate, endDate string, status, offset, limit int) ([]*model.Order, int64, error) {
var orders []*model.Order
var total int64
query := r.db.Model(&model.Order{})
if startDate != "" {
query = query.Where("created_at >= ?", startDate+" 00:00:00")
}
if endDate != "" {
query = query.Where("created_at <= ?", endDate+" 23:59:59")
}
if status > 0 {
query = query.Where("status = ?", status)
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取分页数据
err := query.Offset(offset).Limit(limit).Order("created_at DESC").Find(&orders).Error
return orders, total, err
}
// GetOrdersForExport 获取用于导出的订单数据
func (r *OrderRepository) GetOrdersForExport(conditions map[string]interface{}) ([]*model.Order, error) {
var orders []*model.Order
query := r.db.Model(&model.Order{})
if startDate, ok := conditions["start_date"]; ok {
query = query.Where("created_at >= ?", startDate.(string)+" 00:00:00")
}
if endDate, ok := conditions["end_date"]; ok {
query = query.Where("created_at <= ?", endDate.(string)+" 23:59:59")
}
if status, ok := conditions["status"]; ok {
query = query.Where("status = ?", status)
}
if orderNo, ok := conditions["order_no"]; ok {
query = query.Where("order_no LIKE ?", "%"+orderNo.(string)+"%")
}
err := query.Order("created_at DESC").Find(&orders).Error
return orders, err
}
// GetOrderTrendData 获取订单趋势数据
func (r *OrderRepository) GetOrderTrendData(days int) ([]map[string]interface{}, error) {
var results []map[string]interface{}
query := `
SELECT
DATE(created_at) as date,
COUNT(*) as order_count,
SUM(total_amount) as total_amount,
COUNT(CASE WHEN status = 6 THEN 1 END) as completed_count
FROM ai_orders
WHERE created_at >= DATE_SUB(NOW(), INTERVAL ? DAY)
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.db.Raw(query, days).Scan(&results).Error
return results, err
}
// GetPendingOrdersCount 获取待处理订单数量
func (r *OrderRepository) GetPendingOrdersCount() (int64, error) {
var count int64
// 待处理订单包括:待支付(1)、待发货(2,3)、已发货(4)、待收货(5)
err := r.db.Model(&model.Order{}).Where("status IN ?", []int{1, 2, 3, 4, 5}).Count(&count).Error
return count, err
}
// UpdateOrderLogistics 更新订单物流信息
func (r *OrderRepository) UpdateOrderLogistics(orderID uint, logisticsCompany, trackingNumber string) error {
updates := map[string]interface{}{
"logistics_company": logisticsCompany,
"tracking_number": trackingNumber,
"shipped_at": time.Now(),
"status": 3, // 已发货
}
return r.db.Model(&model.Order{}).Where("id = ?", orderID).Updates(updates).Error
}
// GetOrderByOrderNo 根据订单号获取订单
func (r *OrderRepository) GetOrderByOrderNo(orderNo string) (*model.Order, error) {
var order model.Order
err := r.db.Where("order_no = ?", orderNo).First(&order).Error
if err != nil {
return nil, err
}
return &order, nil
}
// GetOrderByWechatOutTradeNo 根据微信支付订单号获取订单
func (r *OrderRepository) GetOrderByWechatOutTradeNo(wechatOutTradeNo string) (*model.Order, error) {
var order model.Order
err := r.db.Where("wechat_out_trade_no = ?", wechatOutTradeNo).First(&order).Error
if err != nil {
return nil, err
}
return &order, nil
}
// GetOrderItemByID 根据ID获取订单项
func (r *OrderRepository) GetOrderItemByID(id uint) (*model.OrderItem, error) {
var orderItem model.OrderItem
err := r.db.Preload("Product").Preload("Order").First(&orderItem, id).Error
return &orderItem, err
}
// SaveOrderItem 保存订单项
func (r *OrderRepository) SaveOrderItem(orderItem *model.OrderItem) error {
return r.db.Save(orderItem).Error
}
// GetUncommentedOrderItems 获取用户未评论的订单项
func (r *OrderRepository) GetUncommentedOrderItems(userID uint) ([]model.OrderItem, error) {
var orderItems []model.OrderItem
// 查询已完成订单中未评论的订单项
err := r.db.Joins("JOIN ai_orders ON order_items.order_id = ai_orders.id").
Where("ai_orders.user_id = ? AND ai_orders.status = ? AND order_items.is_commented = ?",
userID, model.OrderStatusCompleted, false).
Preload("Product").Preload("Order").
Find(&orderItems).Error
return orderItems, err
}
// UpdateOrderRefund 更新订单退款信息
func (r *OrderRepository) UpdateOrderRefund(orderID uint, refundAmount float64, refundReason string) error {
updates := map[string]interface{}{
"refund_amount": refundAmount,
"refund_reason": refundReason,
"refunded_at": time.Now(),
"status": 6, // 已退款
}
return r.db.Model(&model.Order{}).Where("id = ?", orderID).Updates(updates).Error
}
// GetOrderStatisticsByStatus 根据状态获取订单统计
func (r *OrderRepository) GetOrderStatisticsByStatus() (map[string]interface{}, error) {
var results []struct {
Status int `json:"status"`
Count int64 `json:"count"`
Amount float64 `json:"amount"`
}
query := `
SELECT
status,
COUNT(*) as count,
COALESCE(SUM(total_amount), 0) as amount
FROM ai_orders
GROUP BY status
`
if err := r.db.Raw(query).Scan(&results).Error; err != nil {
return nil, err
}
statistics := make(map[string]interface{})
for _, result := range results {
statusKey := fmt.Sprintf("status_%d", result.Status)
statistics[statusKey] = map[string]interface{}{
"count": result.Count,
"amount": result.Amount,
}
}
return statistics, nil
}
// GetTotalOrderStatistics 获取总订单统计
func (r *OrderRepository) GetTotalOrderStatistics() (map[string]interface{}, error) {
var result struct {
TotalCount int64 `json:"total_count"`
TotalAmount float64 `json:"total_amount"`
}
query := `
SELECT
COUNT(*) as total_count,
COALESCE(SUM(total_amount), 0) as total_amount
FROM ai_orders
`
if err := r.db.Raw(query).Scan(&result).Error; err != nil {
return nil, err
}
return map[string]interface{}{
"total_count": result.TotalCount,
"total_amount": result.TotalAmount,
}, nil
}
// BatchRemoveFromCart 批量从购物车移除
func (r *OrderRepository) BatchRemoveFromCart(userID uint, cartIDs []uint) error {
return r.db.Where("user_id = ? AND id IN ?", userID, cartIDs).Delete(&model.Cart{}).Error
}
// GetCartItemByID 根据ID获取购物车项
func (r *OrderRepository) GetCartItemByID(userID, cartID uint) (*model.Cart, error) {
var cart model.Cart
err := r.db.Where("user_id = ? AND id = ?", userID, cartID).First(&cart).Error
if err != nil {
return nil, err
}
return &cart, nil
}
// RemoveCartItem 移除购物车项
func (r *OrderRepository) RemoveCartItem(cartID uint) error {
return r.db.Delete(&model.Cart{}, cartID).Error
}
// GetSelectedCartItems 获取选中的购物车项
func (r *OrderRepository) GetSelectedCartItems(userID uint) ([]model.Cart, error) {
var cart []model.Cart
err := r.db.Where("user_id = ? AND selected = ?", userID, true).
Preload("Product").
Preload("ProductSKU").
Find(&cart).Error
return cart, err
}

View File

@@ -0,0 +1,118 @@
package repository
import (
"dianshang/internal/model"
"gorm.io/gorm"
)
// PointsRepository 积分数据访问层
type PointsRepository struct {
db *gorm.DB
}
// NewPointsRepository 创建积分数据访问层
func NewPointsRepository(db *gorm.DB) *PointsRepository {
return &PointsRepository{db: db}
}
// GetUserPoints 获取用户积分
func (r *PointsRepository) GetUserPoints(userID uint) (int, error) {
var user model.User
err := r.db.Select("points").Where("id = ?", userID).First(&user).Error
if err != nil {
return 0, err
}
return user.Points, nil
}
// UpdateUserPoints 更新用户积分
func (r *PointsRepository) UpdateUserPoints(userID uint, points int) error {
return r.db.Model(&model.User{}).Where("id = ?", userID).Update("points", points).Error
}
// CreatePointsHistory 创建积分历史记录
func (r *PointsRepository) CreatePointsHistory(history *model.PointsHistory) error {
return r.db.Create(history).Error
}
// GetPointsHistory 获取积分历史记录
func (r *PointsRepository) GetPointsHistory(userID uint, page, pageSize int) ([]model.PointsHistory, int64, error) {
var histories []model.PointsHistory
var total int64
offset := (page - 1) * pageSize
// 获取总数
err := r.db.Model(&model.PointsHistory{}).Where("user_id = ?", userID).Count(&total).Error
if err != nil {
return nil, 0, err
}
// 获取分页数据
err = r.db.Where("user_id = ?", userID).
Order("created_at DESC").
Offset(offset).
Limit(pageSize).
Find(&histories).Error
return histories, total, err
}
// GetPointsRules 获取积分规则列表
func (r *PointsRepository) GetPointsRules() ([]model.PointsRule, error) {
var rules []model.PointsRule
err := r.db.Where("status = ?", 1).Order("sort ASC, id ASC").Find(&rules).Error
return rules, err
}
// GetPointsExchangeList 获取积分兑换商品列表
func (r *PointsRepository) GetPointsExchangeList() ([]model.PointsExchange, error) {
var exchanges []model.PointsExchange
err := r.db.Where("status = ?", 1).Order("sort ASC, id ASC").Find(&exchanges).Error
return exchanges, err
}
// GetPointsExchangeByID 根据ID获取积分兑换商品
func (r *PointsRepository) GetPointsExchangeByID(id uint) (*model.PointsExchange, error) {
var exchange model.PointsExchange
err := r.db.Where("id = ? AND status = ?", id, 1).First(&exchange).Error
if err != nil {
return nil, err
}
return &exchange, nil
}
// CreatePointsExchangeRecord 创建积分兑换记录
func (r *PointsRepository) CreatePointsExchangeRecord(record *model.PointsExchangeRecord) error {
return r.db.Create(record).Error
}
// UpdatePointsExchangeCount 更新兑换商品的兑换次数
func (r *PointsRepository) UpdatePointsExchangeCount(id uint) error {
return r.db.Model(&model.PointsExchange{}).Where("id = ?", id).
UpdateColumn("exchange_count", gorm.Expr("exchange_count + ?", 1)).Error
}
// GetUserExchangeRecords 获取用户兑换记录
func (r *PointsRepository) GetUserExchangeRecords(userID uint, page, pageSize int) ([]model.PointsExchangeRecord, int64, error) {
var records []model.PointsExchangeRecord
var total int64
offset := (page - 1) * pageSize
// 获取总数
err := r.db.Model(&model.PointsExchangeRecord{}).Where("user_id = ?", userID).Count(&total).Error
if err != nil {
return nil, 0, err
}
// 获取分页数据,预加载关联数据
err = r.db.Preload("PointsExchange").
Where("user_id = ?", userID).
Order("created_at DESC").
Offset(offset).
Limit(pageSize).
Find(&records).Error
return records, total, err
}

View File

@@ -0,0 +1,905 @@
package repository
import (
"dianshang/internal/model"
"gorm.io/gorm"
)
// ProductRepository 产品仓储
type ProductRepository struct {
db *gorm.DB
}
// NewProductRepository 创建产品仓储
func NewProductRepository(db *gorm.DB) *ProductRepository {
return &ProductRepository{db: db}
}
// GetDB 获取数据库连接
func (r *ProductRepository) GetDB() *gorm.DB {
return r.db
}
// GetList 获取产品列表
func (r *ProductRepository) GetList(offset, limit int, conditions map[string]interface{}) ([]model.Product, int64, error) {
var products []model.Product
var total int64
query := r.db.Model(&model.Product{})
// 默认只查询上架商品,除非明确指定状态
if statusValue, exists := conditions["status"]; exists {
if statusValue == "all" {
// 管理系统可以传递 "all" 来获取所有状态的商品
} else {
query = query.Where("status = ?", statusValue)
}
} else {
// 默认只查询上架商品(兼容原有逻辑)
query = query.Where("status = ?", 1)
}
// 添加查询条件
var sortField string
var sortType string
for key, value := range conditions {
switch key {
case "category_id":
// 支持包含子分类的筛选
var catID uint
switch v := value.(type) {
case uint:
catID = v
case int:
catID = uint(v)
case float64:
catID = uint(v)
}
if catID > 0 {
categoryIDs, err := r.getCategoryIDsIncludingChildren(catID)
if err == nil && len(categoryIDs) > 0 {
query = query.Where("category_id IN (?)", categoryIDs)
} else {
// 兜底:如果获取子分类失败,退化为当前分类
query = query.Where("category_id = ?", catID)
}
}
case "keyword":
query = query.Where("name LIKE ? OR description LIKE ?", "%"+value.(string)+"%", "%"+value.(string)+"%")
case "min_price":
query = query.Where("price >= ?", value)
case "max_price":
query = query.Where("price <= ?", value)
case "is_hot":
if value.(string) == "true" {
query = query.Where("is_hot = ?", true)
} else if value.(string) == "false" {
query = query.Where("is_hot = ?", false)
}
case "is_new":
if value.(string) == "true" {
query = query.Where("is_new = ?", true)
} else if value.(string) == "false" {
query = query.Where("is_new = ?", false)
}
case "is_recommend":
if value.(string) == "true" {
query = query.Where("is_recommend = ?", true)
} else if value.(string) == "false" {
query = query.Where("is_recommend = ?", false)
}
case "sort":
sortField = value.(string)
case "sort_type":
sortType = value.(string)
case "status":
// 状态条件已在上面处理,这里跳过
continue
}
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 构建排序条件
orderBy := "sort DESC, created_at DESC" // 默认排序
if sortField != "" {
switch sortField {
case "price":
if sortType == "asc" {
orderBy = "price ASC"
} else {
orderBy = "price DESC"
}
case "sales":
if sortType == "asc" {
orderBy = "sales ASC"
} else {
orderBy = "sales DESC"
}
case "created_at":
if sortType == "asc" {
orderBy = "created_at ASC"
} else {
orderBy = "created_at DESC"
}
}
}
// 获取列表,预加载分类
err := query.Preload("Category").
Offset(offset).Limit(limit).Order(orderBy).Find(&products).Error
return products, total, err
}
// GetByID 根据ID获取产品详情
func (r *ProductRepository) GetByID(id uint) (*model.Product, error) {
var product model.Product
err := r.db.Preload("Category").Preload("Specs").Preload("SKUs", "status = ?", 1).
Where("id = ?", id).First(&product).Error
return &product, err
}
// Create 创建产品
func (r *ProductRepository) Create(product *model.Product) error {
return r.db.Create(product).Error
}
// Update 更新产品
func (r *ProductRepository) Update(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.Product{}).Where("id = ?", id).Updates(updates).Error
}
// Delete 删除产品(软删除)
func (r *ProductRepository) Delete(id uint) error {
// 使用GORM软删除只删除商品主记录
// 关联记录保留通过商品的deleted_at字段来判断商品是否被删除
return r.db.Delete(&model.Product{}, id).Error
}
// UpdateStock 更新库存
func (r *ProductRepository) UpdateStock(id uint, quantity int) error {
return r.db.Model(&model.Product{}).Where("id = ?", id).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
}
// UpdateSales 更新销量
func (r *ProductRepository) UpdateSales(id uint, quantity int) error {
return r.db.Model(&model.Product{}).Where("id = ?", id).
Update("sales", gorm.Expr("sales + ?", quantity)).Error
}
// DeductStock 扣减库存(支付成功时使用)
func (r *ProductRepository) DeductStock(id uint, quantity int) error {
// 检查库存是否充足并扣减
result := r.db.Model(&model.Product{}).
Where("id = ? AND stock >= ?", id, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 库存不足或产品不存在
}
return nil
}
// ReserveStock 预扣库存(创建订单时使用)
func (r *ProductRepository) ReserveStock(id uint, quantity int) error {
// 检查库存是否充足并预扣
result := r.db.Model(&model.Product{}).
Where("id = ? AND stock >= ?", id, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 库存不足或产品不存在
}
return nil
}
// RestoreStock 恢复库存(取消订单或支付失败时使用)
func (r *ProductRepository) RestoreStock(id uint, quantity int) error {
return r.db.Model(&model.Product{}).Where("id = ?", id).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
}
// GetCategories 获取分类列表
func (r *ProductRepository) GetCategories() ([]model.Category, error) {
var allCategories []model.Category
err := r.db.Where("status = ?", 1).Order("level ASC, sort DESC, created_at ASC").Find(&allCategories).Error
if err != nil {
return nil, err
}
// 构建层级关系
categoryMap := make(map[uint]*model.Category)
var rootCategories []model.Category
// 先将所有分类放入map
for i := range allCategories {
categoryMap[allCategories[i].ID] = &allCategories[i]
}
// 构建父子关系
for i := range allCategories {
category := &allCategories[i]
if category.ParentID == nil {
// 一级分类
rootCategories = append(rootCategories, *category)
} else {
// 二级分类添加到父分类的Children中
if parent, exists := categoryMap[*category.ParentID]; exists {
parent.Children = append(parent.Children, *category)
}
}
}
// 更新根分类的Children并设置HasChildren字段
for i := range rootCategories {
if parent, exists := categoryMap[rootCategories[i].ID]; exists {
rootCategories[i].Children = parent.Children
rootCategories[i].HasChildren = len(parent.Children) > 0
}
// 为子分类设置HasChildren字段
for j := range rootCategories[i].Children {
rootCategories[i].Children[j].HasChildren = false // 二级分类没有子分类
}
}
return rootCategories, nil
}
// GetCategoryByID 根据ID获取分类
func (r *ProductRepository) GetCategoryByID(id uint) (*model.Category, error) {
var category model.Category
err := r.db.Where("id = ? AND status = ?", id, 1).First(&category).Error
return &category, err
}
// CreateCategory 创建分类
func (r *ProductRepository) CreateCategory(category *model.Category) error {
return r.db.Create(category).Error
}
// UpdateCategory 更新分类
func (r *ProductRepository) UpdateCategory(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.Category{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteCategory 删除分类
func (r *ProductRepository) DeleteCategory(id uint) error {
return r.db.Delete(&model.Category{}, id).Error
}
// GetReviews 获取产品评价列表
func (r *ProductRepository) GetReviews(productID uint, offset, limit int) ([]model.ProductReview, int64, error) {
var reviews []model.ProductReview
var total int64
query := r.db.Model(&model.ProductReview{}).Where("product_id = ?", productID)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
err := query.Offset(offset).Limit(limit).
Order("created_at DESC").Find(&reviews).Error
return reviews, total, err
}
// CreateReview 创建评价
func (r *ProductRepository) CreateReview(review *model.ProductReview) error {
return r.db.Create(review).Error
}
// GetReviewByOrderID 根据订单ID获取评价
func (r *ProductRepository) GetReviewByOrderID(userID uint, orderID uint) (*model.ProductReview, error) {
var review model.ProductReview
err := r.db.Where("user_id = ? AND order_id = ?", userID, orderID).First(&review).Error
return &review, err
}
// GetReviewCount 获取产品评价统计
func (r *ProductRepository) GetReviewCount(productID uint) (map[string]interface{}, error) {
var total int64
var goodCount int64 // 好评数 (4-5星)
var middleCount int64 // 中评数 (3星)
var badCount int64 // 差评数 (1-2星)
var hasImageCount int64 // 有图评价数
// 获取总评价数
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ?", productID).Count(&total).Error; err != nil {
return nil, err
}
// 获取好评数 (4-5星)
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating >= ?", productID, 4).
Count(&goodCount).Error; err != nil {
return nil, err
}
// 获取中评数 (3星)
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating = ?", productID, 3).
Count(&middleCount).Error; err != nil {
return nil, err
}
// 获取差评数 (1-2星)
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating <= ?", productID, 2).
Count(&badCount).Error; err != nil {
return nil, err
}
// 获取有图评价数
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND images IS NOT NULL AND images != ''", productID).
Count(&hasImageCount).Error; err != nil {
return nil, err
}
// 计算好评率
var goodRate float64 = 100.0
if total > 0 {
goodRate = float64(goodCount) / float64(total) * 100
}
result := map[string]interface{}{
"total": total,
"good_rate": goodRate,
"good_count": goodCount,
"middle_count": middleCount,
"bad_count": badCount,
"has_image_count": hasImageCount,
}
return result, nil
}
// GetProductImages 获取产品图片
func (r *ProductRepository) GetProductImages(productID uint) ([]model.ProductImage, error) {
var images []model.ProductImage
err := r.db.Where("product_id = ?", productID).Order("sort ASC").Find(&images).Error
return images, err
}
// CreateProductImage 创建产品图片
func (r *ProductRepository) CreateProductImage(image *model.ProductImage) error {
return r.db.Create(image).Error
}
// DeleteProductImage 删除产品图片
func (r *ProductRepository) DeleteProductImage(id uint) error {
return r.db.Delete(&model.ProductImage{}, id).Error
}
// GetProductSpecs 获取产品规格
func (r *ProductRepository) GetProductSpecs(productID uint) ([]model.ProductSpec, error) {
var specs []model.ProductSpec
err := r.db.Where("product_id = ?", productID).Order("sort ASC").Find(&specs).Error
return specs, err
}
// CreateProductSpec 创建产品规格
func (r *ProductRepository) CreateProductSpec(spec *model.ProductSpec) error {
return r.db.Create(spec).Error
}
// UpdateProductSpec 更新产品规格
func (r *ProductRepository) UpdateProductSpec(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.ProductSpec{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteProductSpec 删除产品规格
func (r *ProductRepository) DeleteProductSpec(id uint) error {
return r.db.Delete(&model.ProductSpec{}, id).Error
}
// GetHotProducts 获取热门产品
func (r *ProductRepository) GetHotProducts(limit int) ([]model.Product, error) {
var products []model.Product
err := r.db.Preload("Category").
Where("status = ? AND is_hot = ?", 1, 1).
Order("sales DESC, created_at DESC").Limit(limit).Find(&products).Error
return products, err
}
// GetRecommendProducts 获取推荐产品
func (r *ProductRepository) GetRecommendProducts(limit int) ([]model.Product, error) {
var products []model.Product
err := r.db.Preload("Category").
Where("status = ? AND is_recommend = ?", 1, 1).
Order("sort DESC, created_at DESC").Limit(limit).Find(&products).Error
return products, err
}
// GetProductSKUs 获取产品SKU列表
func (r *ProductRepository) GetProductSKUs(productID uint) ([]model.ProductSKU, error) {
var skus []model.ProductSKU
err := r.db.Where("product_id = ? AND status = ?", productID, 1).
Order("weight ASC").Find(&skus).Error
return skus, err
}
// GetSKUByID 根据SKU ID获取SKU详情
func (r *ProductRepository) GetSKUByID(skuID uint) (*model.ProductSKU, error) {
var sku model.ProductSKU
err := r.db.Where("id = ? AND status = ?", skuID, 1).First(&sku).Error
return &sku, err
}
// GetProductTags 获取产品标签列表
func (r *ProductRepository) GetProductTags() ([]model.ProductTag, error) {
var tags []model.ProductTag
err := r.db.Where("status = ?", 1).Order("sort ASC").Find(&tags).Error
return tags, err
}
// GetStores 获取店铺列表
func (r *ProductRepository) GetStores() ([]model.Store, error) {
var stores []model.Store
err := r.db.Where("status = ?", 1).Order("sort ASC").Find(&stores).Error
return stores, err
}
// GetStoreByID 根据ID获取店铺信息
func (r *ProductRepository) GetStoreByID(id uint) (*model.Store, error) {
var store model.Store
err := r.db.Where("id = ? AND status = ?", id, 1).First(&store).Error
return &store, err
}
// GetProductStatistics 获取产品统计
func (r *ProductRepository) GetProductStatistics() (map[string]interface{}, error) {
result := make(map[string]interface{})
// 总商品数
var totalProducts int64
if err := r.db.Model(&model.Product{}).Count(&totalProducts).Error; err != nil {
return nil, err
}
result["total_products"] = totalProducts
// 上架商品数
var onlineProducts int64
if err := r.db.Model(&model.Product{}).Where("status = ?", 1).Count(&onlineProducts).Error; err != nil {
return nil, err
}
result["online_products"] = onlineProducts
// 下架商品数
var offlineProducts int64
if err := r.db.Model(&model.Product{}).Where("status = ?", 0).Count(&offlineProducts).Error; err != nil {
return nil, err
}
result["offline_products"] = offlineProducts
// 库存不足商品数库存小于10
var lowStockProducts int64
if err := r.db.Model(&model.Product{}).Where("stock < ?", 10).Count(&lowStockProducts).Error; err != nil {
return nil, err
}
result["low_stock_products"] = lowStockProducts
return result, nil
}
// GetCategorySalesStatistics 获取分类销售统计
func (r *ProductRepository) GetCategorySalesStatistics(startDate, endDate string, limit int) ([]map[string]interface{}, error) {
var results []map[string]interface{}
query := `
SELECT
c.id as category_id,
c.name as category_name,
COALESCE(SUM(oi.quantity), 0) as sales_count,
COALESCE(SUM(oi.total_price), 0) as sales_amount
FROM ai_categories c
LEFT JOIN ai_products p ON c.id = p.category_id
LEFT JOIN order_items oi ON p.id = oi.product_id
LEFT JOIN ai_orders o ON oi.order_id = o.id
WHERE c.status = 1
AND (o.created_at IS NULL OR (o.created_at >= ? AND o.created_at <= ?))
AND (o.status IS NULL OR o.status IN (2, 3, 4, 5, 6))
GROUP BY c.id, c.name
ORDER BY sales_amount DESC
LIMIT ?
`
err := r.db.Raw(query, startDate+" 00:00:00", endDate+" 23:59:59", limit).Scan(&results).Error
return results, err
}
// BatchUpdateStatus 批量更新商品状态
func (r *ProductRepository) BatchUpdateStatus(ids []uint, status int) error {
return r.db.Model(&model.Product{}).Where("id IN ?", ids).Update("status", status).Error
}
// BatchUpdatePrice 批量更新商品价格
func (r *ProductRepository) BatchUpdatePrice(updates []map[string]interface{}) error {
tx := r.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
for _, update := range updates {
if id, ok := update["id"]; ok {
delete(update, "id")
if err := tx.Model(&model.Product{}).Where("id = ?", id).Updates(update).Error; err != nil {
tx.Rollback()
return err
}
}
}
return tx.Commit().Error
}
// CountProductsByCategory 统计指定分类下的商品数量(包括子分类)
func (r *ProductRepository) CountProductsByCategory(categoryID uint) (int64, error) {
var count int64
// 获取所有子分类ID
var categoryIDs []uint
categoryIDs = append(categoryIDs, categoryID)
// 查找所有子分类
var childCategories []model.Category
err := r.db.Where("parent_id = ?", categoryID).Find(&childCategories).Error
if err != nil {
return 0, err
}
for _, child := range childCategories {
categoryIDs = append(categoryIDs, child.ID)
}
// 统计这些分类下的商品数量(只统计未删除的商品)
err = r.db.Model(&model.Product{}).
Where("category_id IN (?)", categoryIDs).
Count(&count).Error
return count, err
}
// BatchDelete 批量删除商品(软删除)
func (r *ProductRepository) BatchDelete(ids []uint) error {
if len(ids) == 0 {
return nil
}
// 使用GORM软删除只删除商品主记录
// 关联记录保留通过商品的deleted_at字段来判断商品是否被删除
return r.db.Delete(&model.Product{}, ids).Error
}
// CreateSKU 创建商品SKU
func (r *ProductRepository) CreateSKU(sku *model.ProductSKU) error {
return r.db.Create(sku).Error
}
// UpdateSKU 更新商品SKU
func (r *ProductRepository) UpdateSKU(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.ProductSKU{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteSKU 删除商品SKU
func (r *ProductRepository) DeleteSKU(id uint) error {
// 检查SKU是否被订单引用
var count int64
err := r.db.Table("order_items").Where("sk_uid = ?", id).Count(&count).Error
if err != nil {
return err
}
if count > 0 {
// 如果被订单引用使用软删除设置status=0
return r.db.Model(&model.ProductSKU{}).Where("id = ?", id).Update("status", 0).Error
} else {
// 如果没有被引用,可以进行硬删除
return r.db.Delete(&model.ProductSKU{}, id).Error
}
}
// UpdateProductImageSort 更新商品图片排序
func (r *ProductRepository) UpdateProductImageSort(id uint, sort int) error {
return r.db.Model(&model.ProductImage{}).Where("id = ?", id).Update("sort", sort).Error
}
// CreateProductTag 创建商品标签
func (r *ProductRepository) CreateProductTag(tag *model.ProductTag) error {
return r.db.Create(tag).Error
}
// UpdateProductTag 更新商品标签
func (r *ProductRepository) UpdateProductTag(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.ProductTag{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteProductTag 删除商品标签
func (r *ProductRepository) DeleteProductTag(id uint) error {
return r.db.Delete(&model.ProductTag{}, id).Error
}
// AssignTagsToProduct 为商品分配标签
func (r *ProductRepository) AssignTagsToProduct(productID uint, tagIDs []uint) error {
// 先清除现有关联
if err := r.db.Exec("DELETE FROM ai_product_tag_relations WHERE product_id = ?", productID).Error; err != nil {
return err
}
// 添加新关联
if len(tagIDs) > 0 {
var relations []map[string]interface{}
for _, tagID := range tagIDs {
relations = append(relations, map[string]interface{}{
"product_id": productID,
"tag_id": tagID,
})
}
return r.db.Table("ai_product_tag_relations").Create(relations).Error
}
return nil
}
// GetLowStockProducts 获取低库存商品
func (r *ProductRepository) GetLowStockProducts(threshold int) ([]model.Product, error) {
var products []model.Product
err := r.db.Where("stock <= ? AND status = ?", threshold, 1).
Preload("Category").Find(&products).Error
return products, err
}
// GetInventoryStatistics 获取库存统计
func (r *ProductRepository) GetInventoryStatistics() (map[string]interface{}, error) {
var result map[string]interface{}
// 总商品数
var totalProducts int64
r.db.Model(&model.Product{}).Where("status = ?", 1).Count(&totalProducts)
// 低库存商品数库存小于等于10
var lowStockProducts int64
r.db.Model(&model.Product{}).Where("stock <= ? AND status = ?", 10, 1).Count(&lowStockProducts)
// 缺货商品数
var outOfStockProducts int64
r.db.Model(&model.Product{}).Where("stock = ? AND status = ?", 0, 1).Count(&outOfStockProducts)
// 总库存值
var totalStockValue float64
r.db.Model(&model.Product{}).Where("status = ?", 1).
Select("SUM(stock * price)").Scan(&totalStockValue)
result = map[string]interface{}{
"total_products": totalProducts,
"low_stock_products": lowStockProducts,
"out_of_stock_products": outOfStockProducts,
"total_stock_value": totalStockValue,
}
return result, nil
}
// GetProductsForExport 获取用于导出的商品数据
func (r *ProductRepository) GetProductsForExport(conditions map[string]interface{}) ([]model.Product, error) {
var products []model.Product
query := r.db.Model(&model.Product{}).Preload("Category")
// 添加查询条件
for key, value := range conditions {
switch key {
case "category_id":
// 导出也支持子分类
var catID uint
switch v := value.(type) {
case uint:
catID = v
case int:
catID = uint(v)
case float64:
catID = uint(v)
}
if catID > 0 {
categoryIDs, err := r.getCategoryIDsIncludingChildren(catID)
if err == nil && len(categoryIDs) > 0 {
query = query.Where("category_id IN (?)", categoryIDs)
} else {
query = query.Where("category_id = ?", catID)
}
}
case "status":
query = query.Where("status = ?", value)
case "keyword":
query = query.Where("name LIKE ? OR description LIKE ?", "%"+value.(string)+"%", "%"+value.(string)+"%")
}
}
err := query.Find(&products).Error
return products, err
}
// getCategoryIDsIncludingChildren 获取给定分类ID及其子分类ID集合仅支持两级分类
func (r *ProductRepository) getCategoryIDsIncludingChildren(categoryID uint) ([]uint, error) {
// 将自身分类加入集合
ids := []uint{categoryID}
// 查找直接子分类(系统设计为最多二级)
var childCategories []model.Category
if err := r.db.Where("parent_id = ? AND status = ?", categoryID, 1).Find(&childCategories).Error; err != nil {
return ids, err
}
for _, child := range childCategories {
ids = append(ids, child.ID)
}
return ids, nil
}
// DeductSKUStock 扣减SKU库存
func (r *ProductRepository) DeductSKUStock(skuID uint, quantity int) error {
// 检查SKU库存是否充足并扣减
result := r.db.Model(&model.ProductSKU{}).
Where("id = ? AND stock >= ?", skuID, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // SKU库存不足或SKU不存在
}
return nil
}
// RestoreSKUStock 恢复SKU库存
func (r *ProductRepository) RestoreSKUStock(skuID uint, quantity int) error {
return r.db.Model(&model.ProductSKU{}).Where("id = ?", skuID).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
}
// SyncProductStockFromSKUs 根据SKU库存同步商品总库存
func (r *ProductRepository) SyncProductStockFromSKUs(productID uint) error {
// 计算所有SKU的库存总和
var totalStock int
err := r.db.Model(&model.ProductSKU{}).
Where("product_id = ? AND status = ?", productID, 1).
Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error
if err != nil {
return err
}
// 更新商品总库存
return r.db.Model(&model.Product{}).Where("id = ?", productID).
Update("stock", totalStock).Error
}
// DeductStockWithSKU 扣减商品和SKU库存支持SKU商品的库存扣减
func (r *ProductRepository) DeductStockWithSKU(productID uint, skuID *uint, quantity int) error {
tx := r.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
if skuID != nil && *skuID > 0 {
// 有SKU的商品先扣减SKU库存
result := tx.Model(&model.ProductSKU{}).
Where("id = ? AND stock >= ?", *skuID, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
tx.Rollback()
return result.Error
}
if result.RowsAffected == 0 {
tx.Rollback()
return gorm.ErrRecordNotFound // SKU库存不足
}
// 同步更新商品总库存
var totalStock int
err := tx.Model(&model.ProductSKU{}).
Where("product_id = ? AND status = ?", productID, 1).
Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error
if err != nil {
tx.Rollback()
return err
}
err = tx.Model(&model.Product{}).Where("id = ?", productID).
Update("stock", totalStock).Error
if err != nil {
tx.Rollback()
return err
}
} else {
// 没有SKU的商品直接扣减商品库存
result := tx.Model(&model.Product{}).
Where("id = ? AND stock >= ?", productID, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
tx.Rollback()
return result.Error
}
if result.RowsAffected == 0 {
tx.Rollback()
return gorm.ErrRecordNotFound // 商品库存不足
}
}
return tx.Commit().Error
}
// RestoreStockWithSKU 恢复商品和SKU库存支持SKU商品的库存恢复
func (r *ProductRepository) RestoreStockWithSKU(productID uint, skuID *uint, quantity int) error {
tx := r.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
if skuID != nil && *skuID > 0 {
// 有SKU的商品先恢复SKU库存
err := tx.Model(&model.ProductSKU{}).Where("id = ?", *skuID).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
if err != nil {
tx.Rollback()
return err
}
// 同步更新商品总库存
var totalStock int
err = tx.Model(&model.ProductSKU{}).
Where("product_id = ? AND status = ?", productID, 1).
Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error
if err != nil {
tx.Rollback()
return err
}
err = tx.Model(&model.Product{}).Where("id = ?", productID).
Update("stock", totalStock).Error
if err != nil {
tx.Rollback()
return err
}
} else {
// 没有SKU的商品直接恢复商品库存
err := tx.Model(&model.Product{}).Where("id = ?", productID).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
if err != nil {
tx.Rollback()
return err
}
}
return tx.Commit().Error
}

View File

@@ -0,0 +1,406 @@
package repository
import (
"dianshang/internal/model"
"fmt"
"time"
"gorm.io/gorm"
)
type RefundRepository struct {
db *gorm.DB
}
func NewRefundRepository(db *gorm.DB) *RefundRepository {
return &RefundRepository{db: db}
}
// Create 创建退款记录(别名方法)
func (r *RefundRepository) Create(refund *model.Refund) error {
return r.db.Create(refund).Error
}
// CreateRefund 创建退款记录
func (r *RefundRepository) CreateRefund(refund *model.Refund) error {
return r.db.Create(refund).Error
}
// GetByID 根据ID获取退款记录别名方法
func (r *RefundRepository) GetByID(id uint) (*model.Refund, error) {
var refund model.Refund
err := r.db.Preload("Order").Preload("User").Preload("RefundItems").
First(&refund, id).Error
if err != nil {
return nil, err
}
return &refund, nil
}
// GetRefundByID 根据ID获取退款记录
func (r *RefundRepository) GetRefundByID(id uint) (*model.Refund, error) {
var refund model.Refund
err := r.db.Preload("Order").Preload("User").Preload("RefundItems").
First(&refund, id).Error
if err != nil {
return nil, err
}
return &refund, nil
}
// GetRefundByRefundNo 根据退款单号获取退款记录
func (r *RefundRepository) GetRefundByRefundNo(refundNo string) (*model.Refund, error) {
var refund model.Refund
err := r.db.Preload("Order").Preload("User").Preload("RefundItems").
Where("refund_no = ?", refundNo).First(&refund).Error
if err != nil {
return nil, err
}
return &refund, nil
}
// GetByWechatOutRefundNo 根据微信退款单号获取退款记录(别名方法)
func (r *RefundRepository) GetByWechatOutRefundNo(wechatOutRefundNo string) (*model.Refund, error) {
var refund model.Refund
err := r.db.Preload("Order").Preload("User").Preload("RefundItems").
Where("wechat_out_refund_no = ?", wechatOutRefundNo).First(&refund).Error
if err != nil {
return nil, err
}
return &refund, nil
}
// GetRefundByWechatOutRefundNo 根据微信退款单号获取退款记录
func (r *RefundRepository) GetRefundByWechatOutRefundNo(wechatOutRefundNo string) (*model.Refund, error) {
var refund model.Refund
err := r.db.Preload("Order").Preload("User").Preload("RefundItems").
Where("wechat_out_refund_no = ?", wechatOutRefundNo).First(&refund).Error
if err != nil {
return nil, err
}
return &refund, nil
}
// GetTotalRefundedByOrderID 根据订单ID获取已退款总金额
func (r *RefundRepository) GetTotalRefundedByOrderID(orderID uint) (float64, error) {
var totalRefunded float64
err := r.db.Model(&model.Refund{}).
Where("order_id = ? AND status IN (?)", orderID, []int{model.RefundStatusSuccess}).
Select("COALESCE(SUM(actual_refund_amount), 0)").
Scan(&totalRefunded).Error
return totalRefunded, err
}
// GetByOrderID 根据订单ID获取退款记录列表别名方法
func (r *RefundRepository) GetByOrderID(orderID uint) ([]model.Refund, error) {
var refunds []model.Refund
err := r.db.Preload("RefundItems").
Where("order_id = ?", orderID).
Order("created_at DESC").
Find(&refunds).Error
return refunds, err
}
// GetRefundsByOrderID 根据订单ID获取退款记录列表
func (r *RefundRepository) GetRefundsByOrderID(orderID uint) ([]model.Refund, error) {
var refunds []model.Refund
err := r.db.Preload("RefundItems").
Where("order_id = ?", orderID).
Order("created_at DESC").
Find(&refunds).Error
return refunds, err
}
// GetByUserID 根据用户ID获取退款记录列表别名方法
func (r *RefundRepository) GetByUserID(userID uint, page, pageSize int) ([]model.Refund, int64, error) {
var refunds []model.Refund
var total int64
// 计算总数
err := r.db.Model(&model.Refund{}).Where("user_id = ?", userID).Count(&total).Error
if err != nil {
return nil, 0, err
}
// 分页查询
offset := (page - 1) * pageSize
err = r.db.Preload("Order").Preload("RefundItems").
Where("user_id = ?", userID).
Order("created_at DESC").
Offset(offset).Limit(pageSize).
Find(&refunds).Error
return refunds, total, err
}
// GetRefundCountByOrderID 获取订单的退款数量
func (r *RefundRepository) GetRefundCountByOrderID(orderID uint) (int64, error) {
var count int64
err := r.db.Model(&model.Refund{}).Where("order_id = ?", orderID).Count(&count).Error
return count, err
}
// CreateLog 创建退款日志
func (r *RefundRepository) CreateLog(log *model.RefundLog) error {
return r.db.Create(log).Error
}
// GetRefundsByUserID 根据用户ID获取退款记录列表
func (r *RefundRepository) GetRefundsByUserID(userID uint, page, pageSize int) ([]model.Refund, int64, error) {
var refunds []model.Refund
var total int64
// 计算总数
err := r.db.Model(&model.Refund{}).Where("user_id = ?", userID).Count(&total).Error
if err != nil {
return nil, 0, err
}
// 分页查询
offset := (page - 1) * pageSize
err = r.db.Preload("Order").Preload("RefundItems").
Where("user_id = ?", userID).
Order("created_at DESC").
Offset(offset).Limit(pageSize).
Find(&refunds).Error
return refunds, total, err
}
// GetPendingRefunds 获取待审核的退款记录
func (r *RefundRepository) GetPendingRefunds(page, pageSize int) ([]model.Refund, int64, error) {
var refunds []model.Refund
var total int64
// 计算总数
err := r.db.Model(&model.Refund{}).Where("status = ?", model.RefundStatusPending).Count(&total).Error
if err != nil {
return nil, 0, err
}
// 分页查询
offset := (page - 1) * pageSize
err = r.db.Preload("Order").Preload("User").Preload("RefundItems").
Where("status = ?", model.RefundStatusPending).
Order("apply_time ASC").
Offset(offset).Limit(pageSize).
Find(&refunds).Error
return refunds, total, err
}
// GetAllRefunds 获取所有退款记录(管理员)
func (r *RefundRepository) GetAllRefunds(page, pageSize int, conditions map[string]interface{}) ([]model.Refund, int64, error) {
var refunds []model.Refund
var total int64
query := r.db.Model(&model.Refund{})
// 添加查询条件
for key, value := range conditions {
switch key {
case "status":
query = query.Where("status = ?", value)
case "user_id":
query = query.Where("user_id = ?", value)
}
}
// 计算总数
err := query.Count(&total).Error
if err != nil {
return nil, 0, err
}
// 分页查询
offset := (page - 1) * pageSize
err = query.Preload("Order").Preload("User").Preload("RefundItems").
Order("apply_time DESC").
Offset(offset).Limit(pageSize).
Find(&refunds).Error
return refunds, total, err
}
// UpdateStatus 更新退款状态
func (r *RefundRepository) UpdateStatus(refundID uint, status int) error {
return r.db.Model(&model.Refund{}).Where("id = ?", refundID).Update("status", status).Error
}
// UpdateByID 根据ID更新退款记录
func (r *RefundRepository) UpdateByID(refundID uint, updates map[string]interface{}) error {
return r.db.Model(&model.Refund{}).Where("id = ?", refundID).Updates(updates).Error
}
// UpdateRefund 更新退款记录
func (r *RefundRepository) UpdateRefund(refundID uint, updates map[string]interface{}) error {
return r.db.Model(&model.Refund{}).Where("id = ?", refundID).Updates(updates).Error
}
// UpdateRefundByRefundNo 根据退款单号更新退款记录
func (r *RefundRepository) UpdateRefundByRefundNo(refundNo string, updates map[string]interface{}) error {
return r.db.Model(&model.Refund{}).Where("refund_no = ?", refundNo).Updates(updates).Error
}
// UpdateRefundByWechatOutRefundNo 根据微信退款单号更新退款记录
func (r *RefundRepository) UpdateRefundByWechatOutRefundNo(wechatOutRefundNo string, updates map[string]interface{}) error {
return r.db.Model(&model.Refund{}).Where("wechat_out_refund_no = ?", wechatOutRefundNo).Updates(updates).Error
}
// CreateRefundItem 创建退款项目
func (r *RefundRepository) CreateRefundItem(refundItem *model.RefundItem) error {
return r.db.Create(refundItem).Error
}
// CreateRefundItems 批量创建退款项目
func (r *RefundRepository) CreateRefundItems(refundItems []model.RefundItem) error {
return r.db.CreateInBatches(refundItems, 100).Error
}
// GetRefundItemsByRefundID 根据退款ID获取退款项目列表
func (r *RefundRepository) GetRefundItemsByRefundID(refundID uint) ([]model.RefundItem, error) {
var refundItems []model.RefundItem
err := r.db.Preload("Product").Preload("SKU").
Where("refund_id = ?", refundID).
Find(&refundItems).Error
return refundItems, err
}
// CreateRefundLog 创建退款日志
func (r *RefundRepository) CreateRefundLog(refundLog *model.RefundLog) error {
return r.db.Create(refundLog).Error
}
// GetRefundLogsByRefundID 根据退款ID获取操作日志
func (r *RefundRepository) GetRefundLogsByRefundID(refundID uint) ([]model.RefundLog, error) {
var refundLogs []model.RefundLog
err := r.db.Preload("Operator").
Where("refund_id = ?", refundID).
Order("created_at ASC").
Find(&refundLogs).Error
return refundLogs, err
}
// GetRefundStatistics 获取退款统计数据
func (r *RefundRepository) GetRefundStatistics(startTime, endTime time.Time) (map[string]interface{}, error) {
var result struct {
TotalCount int64 `json:"total_count"`
TotalAmount float64 `json:"total_amount"`
PendingCount int64 `json:"pending_count"`
ApprovedCount int64 `json:"approved_count"`
RejectedCount int64 `json:"rejected_count"`
ProcessingCount int64 `json:"processing_count"`
SuccessCount int64 `json:"success_count"`
FailedCount int64 `json:"failed_count"`
SuccessAmount float64 `json:"success_amount"`
}
// 总退款申请数和金额
err := r.db.Model(&model.Refund{}).
Where("apply_time BETWEEN ? AND ?", startTime, endTime).
Select("COUNT(*) as total_count, COALESCE(SUM(refund_amount), 0) as total_amount").
Scan(&result).Error
if err != nil {
return nil, err
}
// 各状态统计
statusCounts := []struct {
Status int `json:"status"`
Count int64 `json:"count"`
}{}
err = r.db.Model(&model.Refund{}).
Where("apply_time BETWEEN ? AND ?", startTime, endTime).
Select("status, COUNT(*) as count").
Group("status").
Scan(&statusCounts).Error
if err != nil {
return nil, err
}
// 分配到对应字段
for _, sc := range statusCounts {
switch sc.Status {
case model.RefundStatusPending:
result.PendingCount = sc.Count
case model.RefundStatusApproved:
result.ApprovedCount = sc.Count
case model.RefundStatusRejected:
result.RejectedCount = sc.Count
case model.RefundStatusProcessing:
result.ProcessingCount = sc.Count
case model.RefundStatusSuccess:
result.SuccessCount = sc.Count
case model.RefundStatusFailed:
result.FailedCount = sc.Count
}
}
// 成功退款金额
err = r.db.Model(&model.Refund{}).
Where("apply_time BETWEEN ? AND ? AND status = ?", startTime, endTime, model.RefundStatusSuccess).
Select("COALESCE(SUM(actual_refund_amount), 0) as success_amount").
Scan(&result).Error
if err != nil {
return nil, err
}
return map[string]interface{}{
"total_count": result.TotalCount,
"total_amount": result.TotalAmount,
"pending_count": result.PendingCount,
"approved_count": result.ApprovedCount,
"rejected_count": result.RejectedCount,
"processing_count": result.ProcessingCount,
"success_count": result.SuccessCount,
"failed_count": result.FailedCount,
"success_amount": result.SuccessAmount,
}, nil
}
// GetRefundsByStatus 根据状态获取退款记录列表
func (r *RefundRepository) GetRefundsByStatus(status int) ([]model.Refund, error) {
var refunds []model.Refund
err := r.db.Preload("Order").Preload("RefundItems").
Where("status = ?", status).
Order("created_at DESC").
Find(&refunds).Error
return refunds, err
}
// GetRefundTrends 获取退款趋势数据
func (r *RefundRepository) GetRefundTrends(days int) ([]map[string]interface{}, error) {
var trends []struct {
Date string `json:"date"`
Count int64 `json:"count"`
Amount float64 `json:"amount"`
}
query := fmt.Sprintf(`
SELECT
DATE(apply_time) as date,
COUNT(*) as count,
COALESCE(SUM(refund_amount), 0) as amount
FROM ai_refunds
WHERE apply_time >= DATE_SUB(CURDATE(), INTERVAL %d DAY)
GROUP BY DATE(apply_time)
ORDER BY date ASC
`, days)
err := r.db.Raw(query).Scan(&trends).Error
if err != nil {
return nil, err
}
result := make([]map[string]interface{}, len(trends))
for i, trend := range trends {
result[i] = map[string]interface{}{
"date": trend.Date,
"count": trend.Count,
"amount": trend.Amount,
}
}
return result, nil
}

View File

@@ -0,0 +1,151 @@
package repository
import (
"dianshang/internal/model"
"gorm.io/gorm"
)
// UserRepository 用户仓储
type UserRepository struct {
db *gorm.DB
}
// NewUserRepository 创建用户仓储
func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
}
// Create 创建用户
func (r *UserRepository) Create(user *model.User) error {
return r.db.Create(user).Error
}
// GetByID 根据ID获取用户
func (r *UserRepository) GetByID(id uint) (*model.User, error) {
var user model.User
err := r.db.First(&user, id).Error
return &user, err
}
// GetByOpenID 根据OpenID获取用户
func (r *UserRepository) GetByOpenID(openID string) (*model.User, error) {
var user model.User
err := r.db.Where("open_id = ?", openID).First(&user).Error
return &user, err
}
// Update 更新用户
func (r *UserRepository) Update(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.User{}).Where("id = ?", id).Updates(updates).Error
}
// Delete 删除用户(软删除)
func (r *UserRepository) Delete(id uint) error {
return r.db.Delete(&model.User{}, id).Error
}
// BatchDelete 批量删除用户(软删除)
func (r *UserRepository) BatchDelete(ids []uint) error {
if len(ids) == 0 {
return nil
}
return r.db.Delete(&model.User{}, ids).Error
}
// GetList 获取用户列表
func (r *UserRepository) GetList(offset, limit int, conditions map[string]interface{}) ([]model.User, int64, error) {
var users []model.User
var total int64
query := r.db.Model(&model.User{})
// 添加查询条件
for key, value := range conditions {
query = query.Where(key+" = ?", value)
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
err := query.Offset(offset).Limit(limit).Find(&users).Error
return users, total, err
}
// GetAddresses 获取用户地址列表
func (r *UserRepository) GetAddresses(userID uint) ([]model.UserAddress, error) {
var addresses []model.UserAddress
err := r.db.Where("user_id = ?", userID).Order("is_default DESC, created_at DESC").Find(&addresses).Error
return addresses, err
}
// CreateAddress 创建用户地址
func (r *UserRepository) CreateAddress(address *model.UserAddress) error {
return r.db.Create(address).Error
}
// GetAddressByID 根据ID获取地址
func (r *UserRepository) GetAddressByID(id uint) (*model.UserAddress, error) {
var address model.UserAddress
err := r.db.First(&address, id).Error
return &address, err
}
// UpdateAddress 更新用户地址
func (r *UserRepository) UpdateAddress(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.UserAddress{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteAddress 删除用户地址
func (r *UserRepository) DeleteAddress(id uint) error {
return r.db.Delete(&model.UserAddress{}, id).Error
}
// ClearDefaultAddress 清除用户的默认地址
func (r *UserRepository) ClearDefaultAddress(userID uint) error {
return r.db.Model(&model.UserAddress{}).Where("user_id = ?", userID).Update("is_default", 0).Error
}
// GetDefaultAddress 获取用户默认地址
func (r *UserRepository) GetDefaultAddress(userID uint) (*model.UserAddress, error) {
var address model.UserAddress
err := r.db.Where("user_id = ? AND is_default = 1", userID).First(&address).Error
return &address, err
}
// GetFavorites 获取用户收藏列表
func (r *UserRepository) GetFavorites(userID uint, offset, limit int) ([]model.UserFavorite, int64, error) {
var favorites []model.UserFavorite
var total int64
query := r.db.Model(&model.UserFavorite{}).Where("user_id = ?", userID)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表,预加载商品信息
err := query.Preload("Product").Offset(offset).Limit(limit).Order("created_at DESC").Find(&favorites).Error
return favorites, total, err
}
// CreateFavorite 创建收藏
func (r *UserRepository) CreateFavorite(favorite *model.UserFavorite) error {
return r.db.Create(favorite).Error
}
// DeleteFavorite 删除收藏
func (r *UserRepository) DeleteFavorite(userID, productID uint) error {
return r.db.Where("user_id = ? AND product_id = ?", userID, productID).Delete(&model.UserFavorite{}).Error
}
// IsFavorite 检查是否已收藏
func (r *UserRepository) IsFavorite(userID, productID uint) bool {
var count int64
r.db.Model(&model.UserFavorite{}).Where("user_id = ? AND product_id = ?", userID, productID).Count(&count)
return count > 0
}