package repository import ( "dianshang/internal/model" "fmt" "time" "gorm.io/gorm" ) type RefundRepository struct { db *gorm.DB } func NewRefundRepository(db *gorm.DB) *RefundRepository { return &RefundRepository{db: db} } // Create 创建退款记录(别名方法) func (r *RefundRepository) Create(refund *model.Refund) error { return r.db.Create(refund).Error } // CreateRefund 创建退款记录 func (r *RefundRepository) CreateRefund(refund *model.Refund) error { return r.db.Create(refund).Error } // GetByID 根据ID获取退款记录(别名方法) func (r *RefundRepository) GetByID(id uint) (*model.Refund, error) { var refund model.Refund err := r.db.Preload("Order").Preload("User").Preload("RefundItems"). First(&refund, id).Error if err != nil { return nil, err } return &refund, nil } // GetRefundByID 根据ID获取退款记录 func (r *RefundRepository) GetRefundByID(id uint) (*model.Refund, error) { var refund model.Refund err := r.db.Preload("Order").Preload("User").Preload("RefundItems"). First(&refund, id).Error if err != nil { return nil, err } return &refund, nil } // GetRefundByRefundNo 根据退款单号获取退款记录 func (r *RefundRepository) GetRefundByRefundNo(refundNo string) (*model.Refund, error) { var refund model.Refund err := r.db.Preload("Order").Preload("User").Preload("RefundItems"). Where("refund_no = ?", refundNo).First(&refund).Error if err != nil { return nil, err } return &refund, nil } // GetByWechatOutRefundNo 根据微信退款单号获取退款记录(别名方法) func (r *RefundRepository) GetByWechatOutRefundNo(wechatOutRefundNo string) (*model.Refund, error) { var refund model.Refund err := r.db.Preload("Order").Preload("User").Preload("RefundItems"). Where("wechat_out_refund_no = ?", wechatOutRefundNo).First(&refund).Error if err != nil { return nil, err } return &refund, nil } // GetRefundByWechatOutRefundNo 根据微信退款单号获取退款记录 func (r *RefundRepository) GetRefundByWechatOutRefundNo(wechatOutRefundNo string) (*model.Refund, error) { var refund model.Refund err := r.db.Preload("Order").Preload("User").Preload("RefundItems"). Where("wechat_out_refund_no = ?", wechatOutRefundNo).First(&refund).Error if err != nil { return nil, err } return &refund, nil } // GetTotalRefundedByOrderID 根据订单ID获取已退款总金额 func (r *RefundRepository) GetTotalRefundedByOrderID(orderID uint) (float64, error) { var totalRefunded float64 err := r.db.Model(&model.Refund{}). Where("order_id = ? AND status IN (?)", orderID, []int{model.RefundStatusSuccess}). Select("COALESCE(SUM(actual_refund_amount), 0)"). Scan(&totalRefunded).Error return totalRefunded, err } // GetByOrderID 根据订单ID获取退款记录列表(别名方法) func (r *RefundRepository) GetByOrderID(orderID uint) ([]model.Refund, error) { var refunds []model.Refund err := r.db.Preload("RefundItems"). Where("order_id = ?", orderID). Order("created_at DESC"). Find(&refunds).Error return refunds, err } // GetRefundsByOrderID 根据订单ID获取退款记录列表 func (r *RefundRepository) GetRefundsByOrderID(orderID uint) ([]model.Refund, error) { var refunds []model.Refund err := r.db.Preload("RefundItems"). Where("order_id = ?", orderID). Order("created_at DESC"). Find(&refunds).Error return refunds, err } // GetByUserID 根据用户ID获取退款记录列表(别名方法) func (r *RefundRepository) GetByUserID(userID uint, page, pageSize int) ([]model.Refund, int64, error) { var refunds []model.Refund var total int64 // 计算总数 err := r.db.Model(&model.Refund{}).Where("user_id = ?", userID).Count(&total).Error if err != nil { return nil, 0, err } // 分页查询 offset := (page - 1) * pageSize err = r.db.Preload("Order").Preload("RefundItems"). Where("user_id = ?", userID). Order("created_at DESC"). Offset(offset).Limit(pageSize). Find(&refunds).Error return refunds, total, err } // GetRefundCountByOrderID 获取订单的退款数量 func (r *RefundRepository) GetRefundCountByOrderID(orderID uint) (int64, error) { var count int64 err := r.db.Model(&model.Refund{}).Where("order_id = ?", orderID).Count(&count).Error return count, err } // CreateLog 创建退款日志 func (r *RefundRepository) CreateLog(log *model.RefundLog) error { return r.db.Create(log).Error } // GetRefundsByUserID 根据用户ID获取退款记录列表 func (r *RefundRepository) GetRefundsByUserID(userID uint, page, pageSize int) ([]model.Refund, int64, error) { var refunds []model.Refund var total int64 // 计算总数 err := r.db.Model(&model.Refund{}).Where("user_id = ?", userID).Count(&total).Error if err != nil { return nil, 0, err } // 分页查询 offset := (page - 1) * pageSize err = r.db.Preload("Order").Preload("RefundItems"). Where("user_id = ?", userID). Order("created_at DESC"). Offset(offset).Limit(pageSize). Find(&refunds).Error return refunds, total, err } // GetPendingRefunds 获取待审核的退款记录 func (r *RefundRepository) GetPendingRefunds(page, pageSize int) ([]model.Refund, int64, error) { var refunds []model.Refund var total int64 // 计算总数 err := r.db.Model(&model.Refund{}).Where("status = ?", model.RefundStatusPending).Count(&total).Error if err != nil { return nil, 0, err } // 分页查询 offset := (page - 1) * pageSize err = r.db.Preload("Order").Preload("User").Preload("RefundItems"). Where("status = ?", model.RefundStatusPending). Order("apply_time ASC"). Offset(offset).Limit(pageSize). Find(&refunds).Error return refunds, total, err } // GetAllRefunds 获取所有退款记录(管理员) func (r *RefundRepository) GetAllRefunds(page, pageSize int, conditions map[string]interface{}) ([]model.Refund, int64, error) { var refunds []model.Refund var total int64 query := r.db.Model(&model.Refund{}) // 添加查询条件 for key, value := range conditions { switch key { case "status": query = query.Where("status = ?", value) case "user_id": query = query.Where("user_id = ?", value) } } // 计算总数 err := query.Count(&total).Error if err != nil { return nil, 0, err } // 分页查询 offset := (page - 1) * pageSize err = query.Preload("Order").Preload("User").Preload("RefundItems"). Order("apply_time DESC"). Offset(offset).Limit(pageSize). Find(&refunds).Error return refunds, total, err } // UpdateStatus 更新退款状态 func (r *RefundRepository) UpdateStatus(refundID uint, status int) error { return r.db.Model(&model.Refund{}).Where("id = ?", refundID).Update("status", status).Error } // UpdateByID 根据ID更新退款记录 func (r *RefundRepository) UpdateByID(refundID uint, updates map[string]interface{}) error { return r.db.Model(&model.Refund{}).Where("id = ?", refundID).Updates(updates).Error } // UpdateRefund 更新退款记录 func (r *RefundRepository) UpdateRefund(refundID uint, updates map[string]interface{}) error { return r.db.Model(&model.Refund{}).Where("id = ?", refundID).Updates(updates).Error } // UpdateRefundByRefundNo 根据退款单号更新退款记录 func (r *RefundRepository) UpdateRefundByRefundNo(refundNo string, updates map[string]interface{}) error { return r.db.Model(&model.Refund{}).Where("refund_no = ?", refundNo).Updates(updates).Error } // UpdateRefundByWechatOutRefundNo 根据微信退款单号更新退款记录 func (r *RefundRepository) UpdateRefundByWechatOutRefundNo(wechatOutRefundNo string, updates map[string]interface{}) error { return r.db.Model(&model.Refund{}).Where("wechat_out_refund_no = ?", wechatOutRefundNo).Updates(updates).Error } // CreateRefundItem 创建退款项目 func (r *RefundRepository) CreateRefundItem(refundItem *model.RefundItem) error { return r.db.Create(refundItem).Error } // CreateRefundItems 批量创建退款项目 func (r *RefundRepository) CreateRefundItems(refundItems []model.RefundItem) error { return r.db.CreateInBatches(refundItems, 100).Error } // GetRefundItemsByRefundID 根据退款ID获取退款项目列表 func (r *RefundRepository) GetRefundItemsByRefundID(refundID uint) ([]model.RefundItem, error) { var refundItems []model.RefundItem err := r.db.Preload("Product").Preload("SKU"). Where("refund_id = ?", refundID). Find(&refundItems).Error return refundItems, err } // CreateRefundLog 创建退款日志 func (r *RefundRepository) CreateRefundLog(refundLog *model.RefundLog) error { return r.db.Create(refundLog).Error } // GetRefundLogsByRefundID 根据退款ID获取操作日志 func (r *RefundRepository) GetRefundLogsByRefundID(refundID uint) ([]model.RefundLog, error) { var refundLogs []model.RefundLog err := r.db.Preload("Operator"). Where("refund_id = ?", refundID). Order("created_at ASC"). Find(&refundLogs).Error return refundLogs, err } // GetRefundStatistics 获取退款统计数据 func (r *RefundRepository) GetRefundStatistics(startTime, endTime time.Time) (map[string]interface{}, error) { var result struct { TotalCount int64 `json:"total_count"` TotalAmount float64 `json:"total_amount"` PendingCount int64 `json:"pending_count"` ApprovedCount int64 `json:"approved_count"` RejectedCount int64 `json:"rejected_count"` ProcessingCount int64 `json:"processing_count"` SuccessCount int64 `json:"success_count"` FailedCount int64 `json:"failed_count"` SuccessAmount float64 `json:"success_amount"` } // 总退款申请数和金额 err := r.db.Model(&model.Refund{}). Where("apply_time BETWEEN ? AND ?", startTime, endTime). Select("COUNT(*) as total_count, COALESCE(SUM(refund_amount), 0) as total_amount"). Scan(&result).Error if err != nil { return nil, err } // 各状态统计 statusCounts := []struct { Status int `json:"status"` Count int64 `json:"count"` }{} err = r.db.Model(&model.Refund{}). Where("apply_time BETWEEN ? AND ?", startTime, endTime). Select("status, COUNT(*) as count"). Group("status"). Scan(&statusCounts).Error if err != nil { return nil, err } // 分配到对应字段 for _, sc := range statusCounts { switch sc.Status { case model.RefundStatusPending: result.PendingCount = sc.Count case model.RefundStatusApproved: result.ApprovedCount = sc.Count case model.RefundStatusRejected: result.RejectedCount = sc.Count case model.RefundStatusProcessing: result.ProcessingCount = sc.Count case model.RefundStatusSuccess: result.SuccessCount = sc.Count case model.RefundStatusFailed: result.FailedCount = sc.Count } } // 成功退款金额 err = r.db.Model(&model.Refund{}). Where("apply_time BETWEEN ? AND ? AND status = ?", startTime, endTime, model.RefundStatusSuccess). Select("COALESCE(SUM(actual_refund_amount), 0) as success_amount"). Scan(&result).Error if err != nil { return nil, err } return map[string]interface{}{ "total_count": result.TotalCount, "total_amount": result.TotalAmount, "pending_count": result.PendingCount, "approved_count": result.ApprovedCount, "rejected_count": result.RejectedCount, "processing_count": result.ProcessingCount, "success_count": result.SuccessCount, "failed_count": result.FailedCount, "success_amount": result.SuccessAmount, }, nil } // GetRefundsByStatus 根据状态获取退款记录列表 func (r *RefundRepository) GetRefundsByStatus(status int) ([]model.Refund, error) { var refunds []model.Refund err := r.db.Preload("Order").Preload("RefundItems"). Where("status = ?", status). Order("created_at DESC"). Find(&refunds).Error return refunds, err } // GetRefundTrends 获取退款趋势数据 func (r *RefundRepository) GetRefundTrends(days int) ([]map[string]interface{}, error) { var trends []struct { Date string `json:"date"` Count int64 `json:"count"` Amount float64 `json:"amount"` } query := fmt.Sprintf(` SELECT DATE(apply_time) as date, COUNT(*) as count, COALESCE(SUM(refund_amount), 0) as amount FROM ai_refunds WHERE apply_time >= DATE_SUB(CURDATE(), INTERVAL %d DAY) GROUP BY DATE(apply_time) ORDER BY date ASC `, days) err := r.db.Raw(query).Scan(&trends).Error if err != nil { return nil, err } result := make([]map[string]interface{}, len(trends)) for i, trend := range trends { result[i] = map[string]interface{}{ "date": trend.Date, "count": trend.Count, "amount": trend.Amount, } } return result, nil }