init
This commit is contained in:
102
server/internal/repository/admin.go
Normal file
102
server/internal/repository/admin.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AdminRepository 管理员仓库
|
||||
type AdminRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewAdminRepository 创建管理员仓库
|
||||
func NewAdminRepository(db *gorm.DB) *AdminRepository {
|
||||
return &AdminRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建管理员
|
||||
func (r *AdminRepository) Create(admin *model.AdminUser) error {
|
||||
return r.db.Create(admin).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取管理员
|
||||
func (r *AdminRepository) GetByID(id uint) (*model.AdminUser, error) {
|
||||
var admin model.AdminUser
|
||||
err := r.db.First(&admin, id).Error
|
||||
return &admin, err
|
||||
}
|
||||
|
||||
// GetByIDWithRole 根据ID获取管理员(包含角色信息)
|
||||
func (r *AdminRepository) GetByIDWithRole(id uint) (*model.AdminUser, error) {
|
||||
var admin model.AdminUser
|
||||
err := r.db.Preload("Role").First(&admin, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 不返回密码
|
||||
admin.Password = ""
|
||||
return &admin, nil
|
||||
}
|
||||
|
||||
// GetByUsername 根据用户名获取管理员
|
||||
func (r *AdminRepository) GetByUsername(username string) (*model.AdminUser, error) {
|
||||
var admin model.AdminUser
|
||||
err := r.db.Where("username = ?", username).First(&admin).Error
|
||||
return &admin, err
|
||||
}
|
||||
|
||||
// GetList 获取管理员列表
|
||||
func (r *AdminRepository) GetList(page, pageSize int, keyword string) ([]model.AdminUser, int64, error) {
|
||||
var admins []model.AdminUser
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.AdminUser{}).Preload("Role")
|
||||
|
||||
// 关键词搜索
|
||||
if keyword != "" {
|
||||
query = query.Where("username LIKE ? OR nickname LIKE ? OR email LIKE ?",
|
||||
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&admins).Error
|
||||
|
||||
// 清除密码字段
|
||||
for i := range admins {
|
||||
admins[i].Password = ""
|
||||
}
|
||||
|
||||
return admins, total, err
|
||||
}
|
||||
|
||||
// Update 更新管理员
|
||||
func (r *AdminRepository) Update(id uint, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.AdminUser{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// Delete 删除管理员(软删除)
|
||||
func (r *AdminRepository) Delete(id uint) error {
|
||||
return r.db.Delete(&model.AdminUser{}, id).Error
|
||||
}
|
||||
|
||||
// GetAdminCount 获取管理员总数
|
||||
func (r *AdminRepository) GetAdminCount() (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.AdminUser{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetActiveAdminCount 获取活跃管理员总数
|
||||
func (r *AdminRepository) GetActiveAdminCount() (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.AdminUser{}).Where("status = ?", 1).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
178
server/internal/repository/banner.go
Normal file
178
server/internal/repository/banner.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// BannerRepository 轮播图仓储
|
||||
type BannerRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewBannerRepository 创建轮播图仓储
|
||||
func NewBannerRepository(db *gorm.DB) *BannerRepository {
|
||||
return &BannerRepository{db: db}
|
||||
}
|
||||
|
||||
// GetActiveBanners 获取有效的轮播图
|
||||
func (r *BannerRepository) GetActiveBanners() ([]model.Banner, error) {
|
||||
var banners []model.Banner
|
||||
err := r.db.Where("status = ?", 1).Order("sort ASC").Find(&banners).Error
|
||||
return banners, err
|
||||
}
|
||||
|
||||
// GetBannerList 获取轮播图列表(分页)
|
||||
func (r *BannerRepository) GetBannerList(page, pageSize int, status *int) ([]model.Banner, int64, error) {
|
||||
var banners []model.Banner
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Banner{})
|
||||
|
||||
// 状态筛选
|
||||
if status != nil {
|
||||
query = query.Where("status = ?", *status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Order("sort ASC, created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&banners).Error
|
||||
|
||||
return banners, total, err
|
||||
}
|
||||
|
||||
// GetBannerByID 根据ID获取轮播图
|
||||
func (r *BannerRepository) GetBannerByID(id uint) (*model.Banner, error) {
|
||||
var banner model.Banner
|
||||
err := r.db.First(&banner, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &banner, nil
|
||||
}
|
||||
|
||||
// CreateBanner 创建轮播图
|
||||
func (r *BannerRepository) CreateBanner(banner *model.Banner) error {
|
||||
return r.db.Create(banner).Error
|
||||
}
|
||||
|
||||
// UpdateBanner 更新轮播图
|
||||
func (r *BannerRepository) UpdateBanner(banner *model.Banner) error {
|
||||
return r.db.Save(banner).Error
|
||||
}
|
||||
|
||||
// DeleteBanner 删除轮播图
|
||||
func (r *BannerRepository) DeleteBanner(id uint) error {
|
||||
return r.db.Delete(&model.Banner{}, id).Error
|
||||
}
|
||||
|
||||
// BatchDeleteBanners 批量删除轮播图
|
||||
func (r *BannerRepository) BatchDeleteBanners(ids []uint) error {
|
||||
return r.db.Delete(&model.Banner{}, ids).Error
|
||||
}
|
||||
|
||||
// UpdateBannerStatus 更新轮播图状态
|
||||
func (r *BannerRepository) UpdateBannerStatus(id uint, status int) error {
|
||||
return r.db.Model(&model.Banner{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
|
||||
// BatchUpdateBannerStatus 批量更新轮播图状态
|
||||
func (r *BannerRepository) BatchUpdateBannerStatus(ids []uint, status int) error {
|
||||
return r.db.Model(&model.Banner{}).Where("id IN ?", ids).Update("status", status).Error
|
||||
}
|
||||
|
||||
// UpdateBannerSort 更新轮播图排序
|
||||
func (r *BannerRepository) UpdateBannerSort(id uint, sort int) error {
|
||||
return r.db.Model(&model.Banner{}).Where("id = ?", id).Update("sort", sort).Error
|
||||
}
|
||||
|
||||
// BatchUpdateBannerSort 批量更新轮播图排序
|
||||
func (r *BannerRepository) BatchUpdateBannerSort(sortData []map[string]interface{}) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
for _, data := range sortData {
|
||||
if err := tx.Model(&model.Banner{}).
|
||||
Where("id = ?", data["id"]).
|
||||
Update("sort", data["sort"]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetBannersByDateRange 根据日期范围获取轮播图
|
||||
func (r *BannerRepository) GetBannersByDateRange(startDate, endDate time.Time) ([]model.Banner, error) {
|
||||
var banners []model.Banner
|
||||
err := r.db.Where("created_at BETWEEN ? AND ?", startDate, endDate).
|
||||
Order("sort ASC, created_at DESC").
|
||||
Find(&banners).Error
|
||||
return banners, err
|
||||
}
|
||||
|
||||
// GetBannersByStatus 根据状态获取轮播图
|
||||
func (r *BannerRepository) GetBannersByStatus(status int) ([]model.Banner, error) {
|
||||
var banners []model.Banner
|
||||
err := r.db.Where("status = ?", status).
|
||||
Order("sort ASC, created_at DESC").
|
||||
Find(&banners).Error
|
||||
return banners, err
|
||||
}
|
||||
|
||||
// GetBannerCount 获取轮播图总数
|
||||
func (r *BannerRepository) GetBannerCount() (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Banner{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetBannerCountByStatus 根据状态获取轮播图数量
|
||||
func (r *BannerRepository) GetBannerCountByStatus(status int) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Banner{}).Where("status = ?", status).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetMaxSort 获取最大排序值
|
||||
func (r *BannerRepository) GetMaxSort() (int, error) {
|
||||
var maxSort int
|
||||
err := r.db.Model(&model.Banner{}).Select("COALESCE(MAX(sort), 0)").Scan(&maxSort).Error
|
||||
return maxSort, err
|
||||
}
|
||||
|
||||
// CheckBannerExists 检查轮播图是否存在
|
||||
func (r *BannerRepository) CheckBannerExists(id uint) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Banner{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// GetExpiredBanners 获取过期的轮播图
|
||||
func (r *BannerRepository) GetExpiredBanners() ([]model.Banner, error) {
|
||||
var banners []model.Banner
|
||||
now := time.Now()
|
||||
err := r.db.Where("end_time IS NOT NULL AND end_time < ? AND status = 1", now).
|
||||
Find(&banners).Error
|
||||
return banners, err
|
||||
}
|
||||
|
||||
// GetActiveBannersWithTimeRange 获取在时间范围内有效的轮播图
|
||||
func (r *BannerRepository) GetActiveBannersWithTimeRange() ([]model.Banner, error) {
|
||||
var banners []model.Banner
|
||||
now := time.Now()
|
||||
err := r.db.Where("status = 1").
|
||||
Where("(start_time IS NULL OR start_time <= ?)"+
|
||||
" AND (end_time IS NULL OR end_time >= ?)", now, now).
|
||||
Order("sort ASC").
|
||||
Find(&banners).Error
|
||||
return banners, err
|
||||
}
|
||||
289
server/internal/repository/comment.go
Normal file
289
server/internal/repository/comment.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"dianshang/internal/model"
|
||||
)
|
||||
|
||||
type CommentRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewCommentRepository(db *gorm.DB) *CommentRepository {
|
||||
return &CommentRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建评论
|
||||
func (r *CommentRepository) Create(comment *model.Comment) error {
|
||||
return r.db.Create(comment).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取评论
|
||||
func (r *CommentRepository) GetByID(id uint) (*model.Comment, error) {
|
||||
var comment model.Comment
|
||||
err := r.db.Preload("User").Preload("Product").Preload("Order").Preload("Replies.User").First(&comment, id).Error
|
||||
return &comment, err
|
||||
}
|
||||
|
||||
// GetByProductID 根据商品ID获取评论列表
|
||||
func (r *CommentRepository) GetByProductID(productID uint, offset, limit int, rating int) ([]model.Comment, int64, error) {
|
||||
var comments []model.Comment
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Comment{}).Where("product_id = ? AND status = ?", productID, 1)
|
||||
|
||||
// 按评分筛选
|
||||
if rating > 0 && rating <= 5 {
|
||||
query = query.Where("rating = ?", rating)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取评论列表
|
||||
err := query.Preload("User").Preload("Replies", "status = ?", 1).Preload("Replies.User").
|
||||
Order("created_at DESC").Offset(offset).Limit(limit).Find(&comments).Error
|
||||
|
||||
return comments, total, err
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取评论列表
|
||||
func (r *CommentRepository) GetByUserID(userID uint, offset, limit int) ([]model.Comment, int64, error) {
|
||||
var comments []model.Comment
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Comment{}).Where("user_id = ? AND status = ?", userID, 1)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取评论列表
|
||||
err := query.Preload("Product").Preload("Replies", "status = ?", 1).Preload("Replies.User").
|
||||
Order("created_at DESC").Offset(offset).Limit(limit).Find(&comments).Error
|
||||
|
||||
return comments, total, err
|
||||
}
|
||||
|
||||
// GetByOrderItemID 根据订单项ID获取评论
|
||||
func (r *CommentRepository) GetByOrderItemID(orderItemID uint) (*model.Comment, error) {
|
||||
var comment model.Comment
|
||||
err := r.db.Where("order_item_id = ? AND status = ?", orderItemID, 1).First(&comment).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &comment, nil
|
||||
}
|
||||
|
||||
// Update 更新评论
|
||||
func (r *CommentRepository) Update(comment *model.Comment) error {
|
||||
return r.db.Save(comment).Error
|
||||
}
|
||||
|
||||
// Delete 删除评论(软删除)
|
||||
func (r *CommentRepository) Delete(id uint) error {
|
||||
return r.db.Model(&model.Comment{}).Where("id = ?", id).Update("status", 3).Error
|
||||
}
|
||||
|
||||
// GetStats 获取商品评论统计
|
||||
func (r *CommentRepository) GetStats(productID uint) (*model.CommentStats, error) {
|
||||
var stats model.CommentStats
|
||||
stats.ProductID = productID
|
||||
|
||||
// 获取总评论数和平均评分
|
||||
err := r.db.Model(&model.Comment{}).
|
||||
Select("COUNT(*) as total_count, AVG(rating) as average_rating").
|
||||
Where("product_id = ? AND status = ?", productID, 1).
|
||||
Scan(&stats).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取各星级评论数
|
||||
var ratingCounts []struct {
|
||||
Rating int `json:"rating"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
err = r.db.Model(&model.Comment{}).
|
||||
Select("rating, COUNT(*) as count").
|
||||
Where("product_id = ? AND status = ?", productID, 1).
|
||||
Group("rating").
|
||||
Scan(&ratingCounts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 填充各星级评论数
|
||||
for _, rc := range ratingCounts {
|
||||
switch rc.Rating {
|
||||
case 1:
|
||||
stats.Rating1Count = rc.Count
|
||||
case 2:
|
||||
stats.Rating2Count = rc.Count
|
||||
case 3:
|
||||
stats.Rating3Count = rc.Count
|
||||
case 4:
|
||||
stats.Rating4Count = rc.Count
|
||||
case 5:
|
||||
stats.Rating5Count = rc.Count
|
||||
}
|
||||
}
|
||||
|
||||
// 获取带图评论数
|
||||
var hasImagesCount int64
|
||||
err = r.db.Model(&model.Comment{}).
|
||||
Where("product_id = ? AND status = ? AND images != '' AND images != '[]'", productID, 1).
|
||||
Count(&hasImagesCount).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.HasImagesCount = int(hasImagesCount)
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetList 获取评论列表(管理端使用)
|
||||
func (r *CommentRepository) GetList(offset, limit int, conditions map[string]interface{}) ([]model.Comment, int64, error) {
|
||||
var comments []model.Comment
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Comment{})
|
||||
|
||||
// 添加查询条件
|
||||
for key, value := range conditions {
|
||||
switch key {
|
||||
case "product_id":
|
||||
query = query.Where("product_id = ?", value)
|
||||
case "user_id":
|
||||
query = query.Where("user_id = ?", value)
|
||||
case "rating":
|
||||
query = query.Where("rating = ?", value)
|
||||
case "status":
|
||||
query = query.Where("status = ?", value)
|
||||
case "keyword":
|
||||
query = query.Where("content LIKE ?", "%"+value.(string)+"%")
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取评论列表
|
||||
err := query.Preload("User").Preload("Product").Preload("Order").
|
||||
Order("created_at DESC").Offset(offset).Limit(limit).Find(&comments).Error
|
||||
|
||||
return comments, total, err
|
||||
}
|
||||
|
||||
// CreateReply 创建评论回复
|
||||
func (r *CommentRepository) CreateReply(reply *model.CommentReply) error {
|
||||
tx := r.db.Begin()
|
||||
|
||||
// 创建回复
|
||||
if err := tx.Create(reply).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新评论回复数量
|
||||
if err := tx.Model(&model.Comment{}).Where("id = ?", reply.CommentID).
|
||||
UpdateColumn("reply_count", gorm.Expr("reply_count + ?", 1)).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// GetReplies 获取评论回复列表
|
||||
func (r *CommentRepository) GetReplies(commentID uint) ([]model.CommentReply, error) {
|
||||
var replies []model.CommentReply
|
||||
err := r.db.Where("comment_id = ? AND status = ?", commentID, 1).
|
||||
Preload("User").Order("created_at ASC").Find(&replies).Error
|
||||
return replies, err
|
||||
}
|
||||
|
||||
// LikeComment 点赞评论
|
||||
func (r *CommentRepository) LikeComment(commentID, userID uint) error {
|
||||
tx := r.db.Begin()
|
||||
|
||||
// 检查是否已点赞
|
||||
var count int64
|
||||
if err := tx.Model(&model.CommentLike{}).Where("comment_id = ? AND user_id = ?", commentID, userID).Count(&count).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("已经点赞过了")
|
||||
}
|
||||
|
||||
// 创建点赞记录
|
||||
like := &model.CommentLike{
|
||||
CommentID: commentID,
|
||||
UserID: userID,
|
||||
}
|
||||
if err := tx.Create(like).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新评论点赞数量
|
||||
if err := tx.Model(&model.Comment{}).Where("id = ?", commentID).
|
||||
UpdateColumn("like_count", gorm.Expr("like_count + ?", 1)).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// UnlikeComment 取消点赞评论
|
||||
func (r *CommentRepository) UnlikeComment(commentID, userID uint) error {
|
||||
tx := r.db.Begin()
|
||||
|
||||
// 删除点赞记录
|
||||
if err := tx.Where("comment_id = ? AND user_id = ?", commentID, userID).Delete(&model.CommentLike{}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新评论点赞数量
|
||||
if err := tx.Model(&model.Comment{}).Where("id = ?", commentID).
|
||||
UpdateColumn("like_count", gorm.Expr("like_count - ?", 1)).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// UpdateProductStats 更新商品评论统计
|
||||
func (r *CommentRepository) UpdateProductStats(productID uint) error {
|
||||
var stats struct {
|
||||
Count int `json:"count"`
|
||||
Rating float64 `json:"rating"`
|
||||
}
|
||||
|
||||
err := r.db.Model(&model.Comment{}).
|
||||
Select("COUNT(*) as count, AVG(rating) as rating").
|
||||
Where("product_id = ? AND status = ?", productID, 1).
|
||||
Scan(&stats).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return r.db.Model(&model.Product{}).Where("id = ?", productID).
|
||||
Updates(map[string]interface{}{
|
||||
"comment_count": stats.Count,
|
||||
"average_rating": stats.Rating,
|
||||
}).Error
|
||||
}
|
||||
123
server/internal/repository/coupon.go
Normal file
123
server/internal/repository/coupon.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CouponRepository 优惠券仓储
|
||||
type CouponRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewCouponRepository 创建优惠券仓储
|
||||
func NewCouponRepository(db *gorm.DB) *CouponRepository {
|
||||
return &CouponRepository{db: db}
|
||||
}
|
||||
|
||||
// GetAvailableCoupons 获取可用优惠券列表
|
||||
func (r *CouponRepository) GetAvailableCoupons() ([]model.Coupon, error) {
|
||||
var coupons []model.Coupon
|
||||
now := time.Now()
|
||||
err := r.db.Where("status = ? AND start_time <= ? AND end_time >= ?", 1, now, now).
|
||||
Where("total_count = 0 OR used_count < total_count").
|
||||
Order("created_at DESC").
|
||||
Find(&coupons).Error
|
||||
return coupons, err
|
||||
}
|
||||
|
||||
// GetUserCoupons 获取用户优惠券
|
||||
func (r *CouponRepository) GetUserCoupons(userID uint, status int) ([]model.UserCoupon, error) {
|
||||
var userCoupons []model.UserCoupon
|
||||
query := r.db.Preload("Coupon").Where("user_id = ?", userID)
|
||||
|
||||
if status > 0 {
|
||||
// API状态值到数据库状态值的映射
|
||||
// API: 1=未使用,2=已使用,3=已过期
|
||||
// DB: 0=未使用,1=已使用,2=已过期
|
||||
dbStatus := status - 1
|
||||
query = query.Where("status = ?", dbStatus)
|
||||
}
|
||||
|
||||
err := query.Order("created_at DESC").Find(&userCoupons).Error
|
||||
return userCoupons, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取优惠券
|
||||
func (r *CouponRepository) GetByID(id uint) (*model.Coupon, error) {
|
||||
var coupon model.Coupon
|
||||
err := r.db.First(&coupon, id).Error
|
||||
return &coupon, err
|
||||
}
|
||||
|
||||
// CheckUserCouponExists 检查用户是否已领取优惠券
|
||||
func (r *CouponRepository) CheckUserCouponExists(userID, couponID uint) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.UserCoupon{}).
|
||||
Where("user_id = ? AND coupon_id = ?", userID, couponID).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// GetUserCouponCount 获取用户领取某优惠券的数量
|
||||
func (r *CouponRepository) GetUserCouponCount(userID, couponID uint) (int, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.UserCoupon{}).
|
||||
Where("user_id = ? AND coupon_id = ?", userID, couponID).
|
||||
Count(&count).Error
|
||||
return int(count), err
|
||||
}
|
||||
|
||||
// CreateUserCoupon 创建用户优惠券记录
|
||||
func (r *CouponRepository) CreateUserCoupon(userCoupon *model.UserCoupon) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 创建用户优惠券记录
|
||||
if err := tx.Create(userCoupon).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新优惠券使用数量
|
||||
return tx.Model(&model.Coupon{}).
|
||||
Where("id = ?", userCoupon.CouponID).
|
||||
UpdateColumn("used_count", gorm.Expr("used_count + ?", 1)).Error
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserCouponByID 根据ID获取用户优惠券
|
||||
func (r *CouponRepository) GetUserCouponByID(id uint) (*model.UserCoupon, error) {
|
||||
var userCoupon model.UserCoupon
|
||||
err := r.db.Preload("Coupon").First(&userCoupon, id).Error
|
||||
return &userCoupon, err
|
||||
}
|
||||
|
||||
// GetUserCouponByOrderID 根据订单ID获取用户优惠券
|
||||
func (r *CouponRepository) GetUserCouponByOrderID(orderID uint) (*model.UserCoupon, error) {
|
||||
var userCoupon model.UserCoupon
|
||||
err := r.db.Preload("Coupon").Where("order_id = ?", orderID).First(&userCoupon).Error
|
||||
return &userCoupon, err
|
||||
}
|
||||
|
||||
// UseCoupon 使用优惠券
|
||||
func (r *CouponRepository) UseCoupon(userCouponID, orderID uint) error {
|
||||
now := time.Now()
|
||||
return r.db.Model(&model.UserCoupon{}).
|
||||
Where("id = ?", userCouponID).
|
||||
Updates(map[string]interface{}{
|
||||
"status": 1, // 已使用
|
||||
"order_id": orderID,
|
||||
"used_time": &now,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// RestoreCoupon 恢复优惠券(取消订单时使用)
|
||||
func (r *CouponRepository) RestoreCoupon(userCouponID uint) error {
|
||||
return r.db.Model(&model.UserCoupon{}).
|
||||
Where("id = ?", userCouponID).
|
||||
Updates(map[string]interface{}{
|
||||
"status": 0, // 恢复为未使用
|
||||
"order_id": nil, // 清除订单关联
|
||||
"used_time": nil, // 清除使用时间
|
||||
}).Error
|
||||
}
|
||||
621
server/internal/repository/order.go
Normal file
621
server/internal/repository/order.go
Normal file
@@ -0,0 +1,621 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// OrderRepository 订单仓储
|
||||
type OrderRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewOrderRepository 创建订单仓储
|
||||
func NewOrderRepository(db *gorm.DB) *OrderRepository {
|
||||
return &OrderRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建订单
|
||||
func (r *OrderRepository) Create(order *model.Order) error {
|
||||
return r.db.Create(order).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取订单
|
||||
func (r *OrderRepository) GetByID(id uint) (*model.Order, error) {
|
||||
var order model.Order
|
||||
fmt.Printf("🔍 [OrderRepository.GetByID] 开始查询订单 - 查询ID: %d\n", id)
|
||||
|
||||
// 启用 SQL 调试
|
||||
db := r.db.Debug()
|
||||
err := db.Preload("OrderItems").Preload("OrderItems.Product").Preload("OrderItems.SKU").Preload("Store").Where("id = ?", id).First(&order).Error
|
||||
|
||||
if err == nil {
|
||||
fmt.Printf("✅ [OrderRepository.GetByID] 查询成功 - ID: %d, 状态: %d, 订单号: %s\n", order.ID, order.Status, order.OrderNo)
|
||||
|
||||
// 直接查询验证
|
||||
var directOrder model.Order
|
||||
r.db.Raw("SELECT id, order_no, status FROM ai_orders WHERE id = ?", id).Scan(&directOrder)
|
||||
fmt.Printf("🔍 [直接SQL查询] ID: %d, 状态: %d, 订单号: %s\n", directOrder.ID, directOrder.Status, directOrder.OrderNo)
|
||||
} else {
|
||||
fmt.Printf("❌ [OrderRepository.GetByID] 查询失败 - ID: %d, 错误: %v\n", id, err)
|
||||
}
|
||||
|
||||
return &order, err
|
||||
}
|
||||
|
||||
// GetByOrderNo 根据订单号获取订单
|
||||
func (r *OrderRepository) GetByOrderNo(orderNo string) (*model.Order, error) {
|
||||
var order model.Order
|
||||
err := r.db.Preload("OrderItems").Preload("OrderItems.Product").Preload("OrderItems.SKU").Preload("Store").Where("order_no = ?", orderNo).First(&order).Error
|
||||
return &order, err
|
||||
}
|
||||
|
||||
// GetUserOrders 获取用户订单列表
|
||||
func (r *OrderRepository) GetUserOrders(userID uint, status int, offset, limit int) ([]model.Order, int64, error) {
|
||||
var orders []model.Order
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Order{}).Where("user_id = ?", userID)
|
||||
|
||||
if status > 0 {
|
||||
// 前端状态映射:
|
||||
// 前端status=3表示"待发货",对应数据库status=2(已付款/待发货)和status=3(待发货)
|
||||
if status == 3 {
|
||||
query = query.Where("status IN ?", []int{2, 3})
|
||||
} else {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表,预加载订单项和店铺信息
|
||||
err := query.Preload("OrderItems").Preload("OrderItems.Product").Preload("OrderItems.Product.SKUs", "status = ?", 1).Preload("OrderItems.SKU").Preload("Store").
|
||||
Offset(offset).Limit(limit).Order("created_at DESC").Find(&orders).Error
|
||||
return orders, total, err
|
||||
}
|
||||
|
||||
// GetUserOrderStatistics 获取用户订单统计
|
||||
func (r *OrderRepository) GetUserOrderStatistics(userID uint) (map[string]interface{}, error) {
|
||||
var result struct {
|
||||
OrderCount int64 `json:"order_count"`
|
||||
TotalAmount float64 `json:"total_amount"`
|
||||
}
|
||||
|
||||
// 查询用户的订单数量和总消费金额
|
||||
err := r.db.Model(&model.Order{}).
|
||||
Select("COUNT(*) as order_count, COALESCE(SUM(total_amount), 0) as total_amount").
|
||||
Where("user_id = ? AND status != ?", userID, 0). // 排除已取消的订单
|
||||
Scan(&result).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"order_count": result.OrderCount,
|
||||
"total_amount": result.TotalAmount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Update 更新订单
|
||||
func (r *OrderRepository) Update(id uint, updates map[string]interface{}) error {
|
||||
fmt.Printf("🔄 [OrderRepository.Update] 执行数据库更新,订单ID: %d\n", id)
|
||||
fmt.Printf("🔄 [OrderRepository.Update] 更新字段: %+v\n", updates)
|
||||
|
||||
result := r.db.Model(&model.Order{}).Where("id = ?", id).Updates(updates)
|
||||
|
||||
fmt.Printf("🔄 [OrderRepository.Update] 影响行数: %d\n", result.RowsAffected)
|
||||
fmt.Printf("🔄 [OrderRepository.Update] 错误信息: %v\n", result.Error)
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// UpdateByID 根据ID更新订单记录
|
||||
func (r *OrderRepository) UpdateByID(orderID uint, updates map[string]interface{}) error {
|
||||
return r.Update(orderID, updates)
|
||||
}
|
||||
|
||||
// UpdateByOrderNo 根据订单号更新订单
|
||||
func (r *OrderRepository) UpdateByOrderNo(orderNo string, updates map[string]interface{}) error {
|
||||
fmt.Printf("🔄 [UpdateByOrderNo] 执行数据库更新,订单号: %s\n", orderNo)
|
||||
fmt.Printf("🔄 [UpdateByOrderNo] 更新字段: %+v\n", updates)
|
||||
|
||||
result := r.db.Model(&model.Order{}).Where("order_no = ?", orderNo).Updates(updates)
|
||||
|
||||
fmt.Printf("🔄 [UpdateByOrderNo] 影响行数: %d\n", result.RowsAffected)
|
||||
fmt.Printf("🔄 [UpdateByOrderNo] 错误信息: %v\n", result.Error)
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// Delete 删除订单(软删除)
|
||||
func (r *OrderRepository) Delete(id uint) error {
|
||||
return r.db.Where("id = ?", id).Delete(&model.Order{}).Error
|
||||
}
|
||||
|
||||
// GetList 获取订单列表(管理员)
|
||||
func (r *OrderRepository) GetList(offset, limit int, conditions map[string]interface{}) ([]model.Order, int64, error) {
|
||||
var orders []model.Order
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Order{})
|
||||
|
||||
// 添加查询条件
|
||||
for key, value := range conditions {
|
||||
switch key {
|
||||
case "status":
|
||||
query = query.Where("status = ?", value)
|
||||
case "user_id":
|
||||
query = query.Where("user_id = ?", value)
|
||||
case "order_no":
|
||||
query = query.Where("order_no LIKE ?", "%"+value.(string)+"%")
|
||||
case "start_date":
|
||||
query = query.Where("created_at >= ?", value)
|
||||
case "end_date":
|
||||
query = query.Where("created_at <= ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表,预加载订单项、商品信息、SKU信息、用户信息和店铺信息
|
||||
err := query.Preload("OrderItems").Preload("OrderItems.Product").Preload("OrderItems.SKU").Preload("User").Preload("Store").
|
||||
Offset(offset).Limit(limit).Order("created_at DESC").Find(&orders).Error
|
||||
return orders, total, err
|
||||
}
|
||||
|
||||
// CreateOrderItem 创建订单项
|
||||
func (r *OrderRepository) CreateOrderItem(item *model.OrderItem) error {
|
||||
return r.db.Create(item).Error
|
||||
}
|
||||
|
||||
// GetOrderItems 获取订单项列表
|
||||
func (r *OrderRepository) GetOrderItems(orderID uint) ([]model.OrderItem, error) {
|
||||
var items []model.OrderItem
|
||||
err := r.db.Preload("Product").Where("order_id = ?", orderID).Find(&items).Error
|
||||
return items, err
|
||||
}
|
||||
|
||||
// UpdateOrderItem 更新订单项
|
||||
func (r *OrderRepository) UpdateOrderItem(id uint, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.OrderItem{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// GetCart 获取购物车
|
||||
func (r *OrderRepository) GetCart(userID uint) ([]model.Cart, error) {
|
||||
var cart []model.Cart
|
||||
err := r.db.Preload("Product").Preload("Product.SKUs", "status = ?", 1).Preload("SKU").Where("user_id = ?", userID).Find(&cart).Error
|
||||
return cart, err
|
||||
}
|
||||
|
||||
// GetCartItem 获取购物车项
|
||||
func (r *OrderRepository) GetCartItem(userID, productID uint) (*model.Cart, error) {
|
||||
var cart model.Cart
|
||||
err := r.db.Where("user_id = ? AND product_id = ?", userID, productID).First(&cart).Error
|
||||
return &cart, err
|
||||
}
|
||||
|
||||
// GetCartItemBySKU 根据SKU获取购物车项
|
||||
func (r *OrderRepository) GetCartItemBySKU(userID, productID, skuID uint) (*model.Cart, error) {
|
||||
var cart model.Cart
|
||||
query := r.db.Where("user_id = ? AND product_id = ?", userID, productID)
|
||||
|
||||
fmt.Printf("🔍 [GetCartItemBySKU] 查询条件 - 用户ID: %d, 产品ID: %d, SKU ID: %d\n",
|
||||
userID, productID, skuID)
|
||||
|
||||
if skuID > 0 {
|
||||
// 查找指定的SKU ID
|
||||
query = query.Where("sk_uid = ?", skuID)
|
||||
fmt.Printf("🔍 [GetCartItemBySKU] 查找指定SKU: %d\n", skuID)
|
||||
} else {
|
||||
// 查找没有SKU的商品(sk_uid为NULL或0)
|
||||
query = query.Where("sk_uid IS NULL OR sk_uid = 0")
|
||||
fmt.Printf("🔍 [GetCartItemBySKU] 查找无SKU商品 (sk_uid IS NULL OR sk_uid = 0)\n")
|
||||
}
|
||||
|
||||
err := query.First(&cart).Error
|
||||
if err != nil {
|
||||
fmt.Printf("❌ [GetCartItemBySKU] 查询失败: %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Printf("✅ [GetCartItemBySKU] 找到购物车项 - ID: %d, SKU ID: %v\n",
|
||||
cart.ID, cart.SKUID)
|
||||
return &cart, nil
|
||||
}
|
||||
|
||||
// AddToCart 添加到购物车
|
||||
func (r *OrderRepository) AddToCart(cart *model.Cart) error {
|
||||
return r.db.Create(cart).Error
|
||||
}
|
||||
|
||||
// UpdateCartItem 更新购物车项
|
||||
func (r *OrderRepository) UpdateCartItem(id uint, quantity int) error {
|
||||
return r.db.Model(&model.Cart{}).Where("id = ?", id).Update("quantity", quantity).Error
|
||||
}
|
||||
|
||||
// RemoveFromCart 从购物车移除
|
||||
func (r *OrderRepository) RemoveFromCart(userID, productID uint) error {
|
||||
return r.db.Where("user_id = ? AND product_id = ?", userID, productID).Delete(&model.Cart{}).Error
|
||||
}
|
||||
|
||||
// RemoveFromCartBySKU 根据SKU从购物车移除
|
||||
func (r *OrderRepository) RemoveFromCartBySKU(userID, productID, skuID uint) error {
|
||||
query := r.db.Where("user_id = ? AND product_id = ?", userID, productID)
|
||||
if skuID > 0 {
|
||||
// 删除指定的SKU ID,使用sk_uid字段
|
||||
query = query.Where("sk_uid = ?", skuID)
|
||||
} else {
|
||||
// 删除没有SKU的商品(sk_uid为NULL或0)
|
||||
query = query.Where("sk_uid IS NULL OR sk_uid = 0")
|
||||
}
|
||||
return query.Delete(&model.Cart{}).Error
|
||||
}
|
||||
|
||||
// ClearCart 清空购物车
|
||||
func (r *OrderRepository) ClearCart(userID uint) error {
|
||||
return r.db.Where("user_id = ?", userID).Delete(&model.Cart{}).Error
|
||||
}
|
||||
|
||||
// GetOrderStatistics 获取订单统计
|
||||
func (r *OrderRepository) GetOrderStatistics() (map[string]interface{}, error) {
|
||||
var result map[string]interface{} = make(map[string]interface{})
|
||||
|
||||
// 总订单数
|
||||
var totalOrders int64
|
||||
r.db.Model(&model.Order{}).Count(&totalOrders)
|
||||
result["total_orders"] = totalOrders
|
||||
|
||||
// 待付款订单数
|
||||
var pendingOrders int64
|
||||
r.db.Model(&model.Order{}).Where("status = ?", 1).Count(&pendingOrders)
|
||||
result["pending_orders"] = pendingOrders
|
||||
|
||||
// 待发货订单数(状态2和3都是待发货)
|
||||
var toShipOrders int64
|
||||
r.db.Model(&model.Order{}).Where("status IN (?)", []int{2, 3}).Count(&toShipOrders)
|
||||
result["to_ship_orders"] = toShipOrders
|
||||
|
||||
// 已发货订单数
|
||||
var shippedOrders int64
|
||||
r.db.Model(&model.Order{}).Where("status = ?", 4).Count(&shippedOrders)
|
||||
result["shipped_orders"] = shippedOrders
|
||||
|
||||
// 已完成订单数
|
||||
var completedOrders int64
|
||||
r.db.Model(&model.Order{}).Where("status = ?", 6).Count(&completedOrders)
|
||||
result["completed_orders"] = completedOrders
|
||||
|
||||
// 总销售额(包含待发货、已发货、待收货、已完成状态)
|
||||
var totalAmount float64
|
||||
r.db.Model(&model.Order{}).Where("status IN (?)", []int{2, 3, 4, 5, 6}).Select("COALESCE(SUM(total_amount), 0)").Scan(&totalAmount)
|
||||
result["total_amount"] = totalAmount
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetDailyOrderStatistics 获取每日订单统计
|
||||
func (r *OrderRepository) GetDailyOrderStatistics(days int) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COUNT(*) as order_count,
|
||||
COALESCE(SUM(total_amount), 0) as total_amount
|
||||
FROM ai_orders
|
||||
WHERE created_at >= DATE_SUB(CURDATE(), INTERVAL ? DAY)
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date DESC
|
||||
`
|
||||
|
||||
err := r.db.Raw(query, days).Scan(&results).Error
|
||||
return results, err
|
||||
}
|
||||
|
||||
// SelectCartItem 选择/取消选择购物车项
|
||||
func (r *OrderRepository) SelectCartItem(userID, cartID uint, selected bool) error {
|
||||
return r.db.Model(&model.Cart{}).
|
||||
Where("id = ? AND user_id = ?", cartID, userID).
|
||||
Update("selected", selected).Error
|
||||
}
|
||||
|
||||
// SelectAllCartItems 全选/取消全选购物车
|
||||
func (r *OrderRepository) SelectAllCartItems(userID uint, selected bool) error {
|
||||
return r.db.Model(&model.Cart{}).
|
||||
Where("user_id = ?", userID).
|
||||
Update("selected", selected).Error
|
||||
}
|
||||
|
||||
// UpdateOrderStatus 更新订单状态
|
||||
func (r *OrderRepository) UpdateOrderStatus(orderID uint, status int) error {
|
||||
updates := map[string]interface{}{
|
||||
"status": status,
|
||||
}
|
||||
|
||||
// 根据状态设置相应的时间字段
|
||||
now := time.Now()
|
||||
switch status {
|
||||
case 2: // 已支付
|
||||
updates["paid_at"] = now
|
||||
case 3: // 已发货
|
||||
updates["shipped_at"] = now
|
||||
case 4: // 已完成
|
||||
updates["completed_at"] = now
|
||||
case 5: // 已取消
|
||||
updates["cancelled_at"] = now
|
||||
case 6: // 已退款
|
||||
updates["refunded_at"] = now
|
||||
}
|
||||
|
||||
return r.db.Model(&model.Order{}).Where("id = ?", orderID).Updates(updates).Error
|
||||
}
|
||||
|
||||
// BatchUpdateOrderStatus 批量更新订单状态
|
||||
func (r *OrderRepository) BatchUpdateOrderStatus(orderIDs []uint, status int) error {
|
||||
updates := map[string]interface{}{
|
||||
"status": status,
|
||||
}
|
||||
|
||||
// 根据状态设置相应的时间字段
|
||||
now := time.Now()
|
||||
switch status {
|
||||
case 2: // 已支付
|
||||
updates["paid_at"] = now
|
||||
case 3: // 已发货
|
||||
updates["shipped_at"] = now
|
||||
case 4: // 已完成
|
||||
updates["completed_at"] = now
|
||||
case 5: // 已取消
|
||||
updates["cancelled_at"] = now
|
||||
case 6: // 已退款
|
||||
updates["refunded_at"] = now
|
||||
}
|
||||
|
||||
return r.db.Model(&model.Order{}).Where("id IN ?", orderIDs).Updates(updates).Error
|
||||
}
|
||||
|
||||
// GetOrdersByDateRange 根据日期范围获取订单
|
||||
func (r *OrderRepository) GetOrdersByDateRange(startDate, endDate string, status, offset, limit int) ([]*model.Order, int64, error) {
|
||||
var orders []*model.Order
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Order{})
|
||||
|
||||
if startDate != "" {
|
||||
query = query.Where("created_at >= ?", startDate+" 00:00:00")
|
||||
}
|
||||
if endDate != "" {
|
||||
query = query.Where("created_at <= ?", endDate+" 23:59:59")
|
||||
}
|
||||
if status > 0 {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
err := query.Offset(offset).Limit(limit).Order("created_at DESC").Find(&orders).Error
|
||||
return orders, total, err
|
||||
}
|
||||
|
||||
// GetOrdersForExport 获取用于导出的订单数据
|
||||
func (r *OrderRepository) GetOrdersForExport(conditions map[string]interface{}) ([]*model.Order, error) {
|
||||
var orders []*model.Order
|
||||
|
||||
query := r.db.Model(&model.Order{})
|
||||
|
||||
if startDate, ok := conditions["start_date"]; ok {
|
||||
query = query.Where("created_at >= ?", startDate.(string)+" 00:00:00")
|
||||
}
|
||||
if endDate, ok := conditions["end_date"]; ok {
|
||||
query = query.Where("created_at <= ?", endDate.(string)+" 23:59:59")
|
||||
}
|
||||
if status, ok := conditions["status"]; ok {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
if orderNo, ok := conditions["order_no"]; ok {
|
||||
query = query.Where("order_no LIKE ?", "%"+orderNo.(string)+"%")
|
||||
}
|
||||
|
||||
err := query.Order("created_at DESC").Find(&orders).Error
|
||||
return orders, err
|
||||
}
|
||||
|
||||
// GetOrderTrendData 获取订单趋势数据
|
||||
func (r *OrderRepository) GetOrderTrendData(days int) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COUNT(*) as order_count,
|
||||
SUM(total_amount) as total_amount,
|
||||
COUNT(CASE WHEN status = 6 THEN 1 END) as completed_count
|
||||
FROM ai_orders
|
||||
WHERE created_at >= DATE_SUB(NOW(), INTERVAL ? DAY)
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.db.Raw(query, days).Scan(&results).Error
|
||||
return results, err
|
||||
}
|
||||
|
||||
// GetPendingOrdersCount 获取待处理订单数量
|
||||
func (r *OrderRepository) GetPendingOrdersCount() (int64, error) {
|
||||
var count int64
|
||||
// 待处理订单包括:待支付(1)、待发货(2,3)、已发货(4)、待收货(5)
|
||||
err := r.db.Model(&model.Order{}).Where("status IN ?", []int{1, 2, 3, 4, 5}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// UpdateOrderLogistics 更新订单物流信息
|
||||
func (r *OrderRepository) UpdateOrderLogistics(orderID uint, logisticsCompany, trackingNumber string) error {
|
||||
updates := map[string]interface{}{
|
||||
"logistics_company": logisticsCompany,
|
||||
"tracking_number": trackingNumber,
|
||||
"shipped_at": time.Now(),
|
||||
"status": 3, // 已发货
|
||||
}
|
||||
|
||||
return r.db.Model(&model.Order{}).Where("id = ?", orderID).Updates(updates).Error
|
||||
}
|
||||
|
||||
// GetOrderByOrderNo 根据订单号获取订单
|
||||
func (r *OrderRepository) GetOrderByOrderNo(orderNo string) (*model.Order, error) {
|
||||
var order model.Order
|
||||
err := r.db.Where("order_no = ?", orderNo).First(&order).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
// GetOrderByWechatOutTradeNo 根据微信支付订单号获取订单
|
||||
func (r *OrderRepository) GetOrderByWechatOutTradeNo(wechatOutTradeNo string) (*model.Order, error) {
|
||||
var order model.Order
|
||||
err := r.db.Where("wechat_out_trade_no = ?", wechatOutTradeNo).First(&order).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
// GetOrderItemByID 根据ID获取订单项
|
||||
func (r *OrderRepository) GetOrderItemByID(id uint) (*model.OrderItem, error) {
|
||||
var orderItem model.OrderItem
|
||||
err := r.db.Preload("Product").Preload("Order").First(&orderItem, id).Error
|
||||
return &orderItem, err
|
||||
}
|
||||
|
||||
// SaveOrderItem 保存订单项
|
||||
func (r *OrderRepository) SaveOrderItem(orderItem *model.OrderItem) error {
|
||||
return r.db.Save(orderItem).Error
|
||||
}
|
||||
|
||||
// GetUncommentedOrderItems 获取用户未评论的订单项
|
||||
func (r *OrderRepository) GetUncommentedOrderItems(userID uint) ([]model.OrderItem, error) {
|
||||
var orderItems []model.OrderItem
|
||||
|
||||
// 查询已完成订单中未评论的订单项
|
||||
err := r.db.Joins("JOIN ai_orders ON order_items.order_id = ai_orders.id").
|
||||
Where("ai_orders.user_id = ? AND ai_orders.status = ? AND order_items.is_commented = ?",
|
||||
userID, model.OrderStatusCompleted, false).
|
||||
Preload("Product").Preload("Order").
|
||||
Find(&orderItems).Error
|
||||
|
||||
return orderItems, err
|
||||
}
|
||||
|
||||
// UpdateOrderRefund 更新订单退款信息
|
||||
func (r *OrderRepository) UpdateOrderRefund(orderID uint, refundAmount float64, refundReason string) error {
|
||||
updates := map[string]interface{}{
|
||||
"refund_amount": refundAmount,
|
||||
"refund_reason": refundReason,
|
||||
"refunded_at": time.Now(),
|
||||
"status": 6, // 已退款
|
||||
}
|
||||
|
||||
return r.db.Model(&model.Order{}).Where("id = ?", orderID).Updates(updates).Error
|
||||
}
|
||||
|
||||
// GetOrderStatisticsByStatus 根据状态获取订单统计
|
||||
func (r *OrderRepository) GetOrderStatisticsByStatus() (map[string]interface{}, error) {
|
||||
var results []struct {
|
||||
Status int `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
Amount float64 `json:"amount"`
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
status,
|
||||
COUNT(*) as count,
|
||||
COALESCE(SUM(total_amount), 0) as amount
|
||||
FROM ai_orders
|
||||
GROUP BY status
|
||||
`
|
||||
|
||||
if err := r.db.Raw(query).Scan(&results).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
statistics := make(map[string]interface{})
|
||||
for _, result := range results {
|
||||
statusKey := fmt.Sprintf("status_%d", result.Status)
|
||||
statistics[statusKey] = map[string]interface{}{
|
||||
"count": result.Count,
|
||||
"amount": result.Amount,
|
||||
}
|
||||
}
|
||||
|
||||
return statistics, nil
|
||||
}
|
||||
|
||||
// GetTotalOrderStatistics 获取总订单统计
|
||||
func (r *OrderRepository) GetTotalOrderStatistics() (map[string]interface{}, error) {
|
||||
var result struct {
|
||||
TotalCount int64 `json:"total_count"`
|
||||
TotalAmount float64 `json:"total_amount"`
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as total_count,
|
||||
COALESCE(SUM(total_amount), 0) as total_amount
|
||||
FROM ai_orders
|
||||
`
|
||||
|
||||
if err := r.db.Raw(query).Scan(&result).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_count": result.TotalCount,
|
||||
"total_amount": result.TotalAmount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BatchRemoveFromCart 批量从购物车移除
|
||||
func (r *OrderRepository) BatchRemoveFromCart(userID uint, cartIDs []uint) error {
|
||||
return r.db.Where("user_id = ? AND id IN ?", userID, cartIDs).Delete(&model.Cart{}).Error
|
||||
}
|
||||
|
||||
// GetCartItemByID 根据ID获取购物车项
|
||||
func (r *OrderRepository) GetCartItemByID(userID, cartID uint) (*model.Cart, error) {
|
||||
var cart model.Cart
|
||||
err := r.db.Where("user_id = ? AND id = ?", userID, cartID).First(&cart).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cart, nil
|
||||
}
|
||||
|
||||
// RemoveCartItem 移除购物车项
|
||||
func (r *OrderRepository) RemoveCartItem(cartID uint) error {
|
||||
return r.db.Delete(&model.Cart{}, cartID).Error
|
||||
}
|
||||
|
||||
// GetSelectedCartItems 获取选中的购物车项
|
||||
func (r *OrderRepository) GetSelectedCartItems(userID uint) ([]model.Cart, error) {
|
||||
var cart []model.Cart
|
||||
err := r.db.Where("user_id = ? AND selected = ?", userID, true).
|
||||
Preload("Product").
|
||||
Preload("ProductSKU").
|
||||
Find(&cart).Error
|
||||
return cart, err
|
||||
}
|
||||
118
server/internal/repository/points.go
Normal file
118
server/internal/repository/points.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PointsRepository 积分数据访问层
|
||||
type PointsRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewPointsRepository 创建积分数据访问层
|
||||
func NewPointsRepository(db *gorm.DB) *PointsRepository {
|
||||
return &PointsRepository{db: db}
|
||||
}
|
||||
|
||||
// GetUserPoints 获取用户积分
|
||||
func (r *PointsRepository) GetUserPoints(userID uint) (int, error) {
|
||||
var user model.User
|
||||
err := r.db.Select("points").Where("id = ?", userID).First(&user).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return user.Points, nil
|
||||
}
|
||||
|
||||
// UpdateUserPoints 更新用户积分
|
||||
func (r *PointsRepository) UpdateUserPoints(userID uint, points int) error {
|
||||
return r.db.Model(&model.User{}).Where("id = ?", userID).Update("points", points).Error
|
||||
}
|
||||
|
||||
// CreatePointsHistory 创建积分历史记录
|
||||
func (r *PointsRepository) CreatePointsHistory(history *model.PointsHistory) error {
|
||||
return r.db.Create(history).Error
|
||||
}
|
||||
|
||||
// GetPointsHistory 获取积分历史记录
|
||||
func (r *PointsRepository) GetPointsHistory(userID uint, page, pageSize int) ([]model.PointsHistory, int64, error) {
|
||||
var histories []model.PointsHistory
|
||||
var total int64
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
// 获取总数
|
||||
err := r.db.Model(&model.PointsHistory{}).Where("user_id = ?", userID).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
err = r.db.Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&histories).Error
|
||||
|
||||
return histories, total, err
|
||||
}
|
||||
|
||||
// GetPointsRules 获取积分规则列表
|
||||
func (r *PointsRepository) GetPointsRules() ([]model.PointsRule, error) {
|
||||
var rules []model.PointsRule
|
||||
err := r.db.Where("status = ?", 1).Order("sort ASC, id ASC").Find(&rules).Error
|
||||
return rules, err
|
||||
}
|
||||
|
||||
// GetPointsExchangeList 获取积分兑换商品列表
|
||||
func (r *PointsRepository) GetPointsExchangeList() ([]model.PointsExchange, error) {
|
||||
var exchanges []model.PointsExchange
|
||||
err := r.db.Where("status = ?", 1).Order("sort ASC, id ASC").Find(&exchanges).Error
|
||||
return exchanges, err
|
||||
}
|
||||
|
||||
// GetPointsExchangeByID 根据ID获取积分兑换商品
|
||||
func (r *PointsRepository) GetPointsExchangeByID(id uint) (*model.PointsExchange, error) {
|
||||
var exchange model.PointsExchange
|
||||
err := r.db.Where("id = ? AND status = ?", id, 1).First(&exchange).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &exchange, nil
|
||||
}
|
||||
|
||||
// CreatePointsExchangeRecord 创建积分兑换记录
|
||||
func (r *PointsRepository) CreatePointsExchangeRecord(record *model.PointsExchangeRecord) error {
|
||||
return r.db.Create(record).Error
|
||||
}
|
||||
|
||||
// UpdatePointsExchangeCount 更新兑换商品的兑换次数
|
||||
func (r *PointsRepository) UpdatePointsExchangeCount(id uint) error {
|
||||
return r.db.Model(&model.PointsExchange{}).Where("id = ?", id).
|
||||
UpdateColumn("exchange_count", gorm.Expr("exchange_count + ?", 1)).Error
|
||||
}
|
||||
|
||||
// GetUserExchangeRecords 获取用户兑换记录
|
||||
func (r *PointsRepository) GetUserExchangeRecords(userID uint, page, pageSize int) ([]model.PointsExchangeRecord, int64, error) {
|
||||
var records []model.PointsExchangeRecord
|
||||
var total int64
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
// 获取总数
|
||||
err := r.db.Model(&model.PointsExchangeRecord{}).Where("user_id = ?", userID).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据,预加载关联数据
|
||||
err = r.db.Preload("PointsExchange").
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&records).Error
|
||||
|
||||
return records, total, err
|
||||
}
|
||||
905
server/internal/repository/product.go
Normal file
905
server/internal/repository/product.go
Normal file
@@ -0,0 +1,905 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ProductRepository 产品仓储
|
||||
type ProductRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewProductRepository 创建产品仓储
|
||||
func NewProductRepository(db *gorm.DB) *ProductRepository {
|
||||
return &ProductRepository{db: db}
|
||||
}
|
||||
|
||||
// GetDB 获取数据库连接
|
||||
func (r *ProductRepository) GetDB() *gorm.DB {
|
||||
return r.db
|
||||
}
|
||||
|
||||
// GetList 获取产品列表
|
||||
func (r *ProductRepository) GetList(offset, limit int, conditions map[string]interface{}) ([]model.Product, int64, error) {
|
||||
var products []model.Product
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Product{})
|
||||
|
||||
// 默认只查询上架商品,除非明确指定状态
|
||||
if statusValue, exists := conditions["status"]; exists {
|
||||
if statusValue == "all" {
|
||||
// 管理系统可以传递 "all" 来获取所有状态的商品
|
||||
} else {
|
||||
query = query.Where("status = ?", statusValue)
|
||||
}
|
||||
} else {
|
||||
// 默认只查询上架商品(兼容原有逻辑)
|
||||
query = query.Where("status = ?", 1)
|
||||
}
|
||||
|
||||
// 添加查询条件
|
||||
var sortField string
|
||||
var sortType string
|
||||
for key, value := range conditions {
|
||||
switch key {
|
||||
case "category_id":
|
||||
// 支持包含子分类的筛选
|
||||
var catID uint
|
||||
switch v := value.(type) {
|
||||
case uint:
|
||||
catID = v
|
||||
case int:
|
||||
catID = uint(v)
|
||||
case float64:
|
||||
catID = uint(v)
|
||||
}
|
||||
if catID > 0 {
|
||||
categoryIDs, err := r.getCategoryIDsIncludingChildren(catID)
|
||||
if err == nil && len(categoryIDs) > 0 {
|
||||
query = query.Where("category_id IN (?)", categoryIDs)
|
||||
} else {
|
||||
// 兜底:如果获取子分类失败,退化为当前分类
|
||||
query = query.Where("category_id = ?", catID)
|
||||
}
|
||||
}
|
||||
case "keyword":
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+value.(string)+"%", "%"+value.(string)+"%")
|
||||
case "min_price":
|
||||
query = query.Where("price >= ?", value)
|
||||
case "max_price":
|
||||
query = query.Where("price <= ?", value)
|
||||
case "is_hot":
|
||||
if value.(string) == "true" {
|
||||
query = query.Where("is_hot = ?", true)
|
||||
} else if value.(string) == "false" {
|
||||
query = query.Where("is_hot = ?", false)
|
||||
}
|
||||
case "is_new":
|
||||
if value.(string) == "true" {
|
||||
query = query.Where("is_new = ?", true)
|
||||
} else if value.(string) == "false" {
|
||||
query = query.Where("is_new = ?", false)
|
||||
}
|
||||
case "is_recommend":
|
||||
if value.(string) == "true" {
|
||||
query = query.Where("is_recommend = ?", true)
|
||||
} else if value.(string) == "false" {
|
||||
query = query.Where("is_recommend = ?", false)
|
||||
}
|
||||
case "sort":
|
||||
sortField = value.(string)
|
||||
case "sort_type":
|
||||
sortType = value.(string)
|
||||
case "status":
|
||||
// 状态条件已在上面处理,这里跳过
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 构建排序条件
|
||||
orderBy := "sort DESC, created_at DESC" // 默认排序
|
||||
if sortField != "" {
|
||||
switch sortField {
|
||||
case "price":
|
||||
if sortType == "asc" {
|
||||
orderBy = "price ASC"
|
||||
} else {
|
||||
orderBy = "price DESC"
|
||||
}
|
||||
case "sales":
|
||||
if sortType == "asc" {
|
||||
orderBy = "sales ASC"
|
||||
} else {
|
||||
orderBy = "sales DESC"
|
||||
}
|
||||
case "created_at":
|
||||
if sortType == "asc" {
|
||||
orderBy = "created_at ASC"
|
||||
} else {
|
||||
orderBy = "created_at DESC"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取列表,预加载分类
|
||||
err := query.Preload("Category").
|
||||
Offset(offset).Limit(limit).Order(orderBy).Find(&products).Error
|
||||
return products, total, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取产品详情
|
||||
func (r *ProductRepository) GetByID(id uint) (*model.Product, error) {
|
||||
var product model.Product
|
||||
err := r.db.Preload("Category").Preload("Specs").Preload("SKUs", "status = ?", 1).
|
||||
Where("id = ?", id).First(&product).Error
|
||||
return &product, err
|
||||
}
|
||||
|
||||
// Create 创建产品
|
||||
func (r *ProductRepository) Create(product *model.Product) error {
|
||||
return r.db.Create(product).Error
|
||||
}
|
||||
|
||||
// Update 更新产品
|
||||
func (r *ProductRepository) Update(id uint, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.Product{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// Delete 删除产品(软删除)
|
||||
func (r *ProductRepository) Delete(id uint) error {
|
||||
// 使用GORM软删除,只删除商品主记录
|
||||
// 关联记录保留,通过商品的deleted_at字段来判断商品是否被删除
|
||||
return r.db.Delete(&model.Product{}, id).Error
|
||||
}
|
||||
|
||||
// UpdateStock 更新库存
|
||||
func (r *ProductRepository) UpdateStock(id uint, quantity int) error {
|
||||
return r.db.Model(&model.Product{}).Where("id = ?", id).
|
||||
Update("stock", gorm.Expr("stock + ?", quantity)).Error
|
||||
}
|
||||
|
||||
// UpdateSales 更新销量
|
||||
func (r *ProductRepository) UpdateSales(id uint, quantity int) error {
|
||||
return r.db.Model(&model.Product{}).Where("id = ?", id).
|
||||
Update("sales", gorm.Expr("sales + ?", quantity)).Error
|
||||
}
|
||||
|
||||
// DeductStock 扣减库存(支付成功时使用)
|
||||
func (r *ProductRepository) DeductStock(id uint, quantity int) error {
|
||||
// 检查库存是否充足并扣减
|
||||
result := r.db.Model(&model.Product{}).
|
||||
Where("id = ? AND stock >= ?", id, quantity).
|
||||
Update("stock", gorm.Expr("stock - ?", quantity))
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound // 库存不足或产品不存在
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReserveStock 预扣库存(创建订单时使用)
|
||||
func (r *ProductRepository) ReserveStock(id uint, quantity int) error {
|
||||
// 检查库存是否充足并预扣
|
||||
result := r.db.Model(&model.Product{}).
|
||||
Where("id = ? AND stock >= ?", id, quantity).
|
||||
Update("stock", gorm.Expr("stock - ?", quantity))
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound // 库存不足或产品不存在
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreStock 恢复库存(取消订单或支付失败时使用)
|
||||
func (r *ProductRepository) RestoreStock(id uint, quantity int) error {
|
||||
return r.db.Model(&model.Product{}).Where("id = ?", id).
|
||||
Update("stock", gorm.Expr("stock + ?", quantity)).Error
|
||||
}
|
||||
|
||||
// GetCategories 获取分类列表
|
||||
func (r *ProductRepository) GetCategories() ([]model.Category, error) {
|
||||
var allCategories []model.Category
|
||||
err := r.db.Where("status = ?", 1).Order("level ASC, sort DESC, created_at ASC").Find(&allCategories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建层级关系
|
||||
categoryMap := make(map[uint]*model.Category)
|
||||
var rootCategories []model.Category
|
||||
|
||||
// 先将所有分类放入map
|
||||
for i := range allCategories {
|
||||
categoryMap[allCategories[i].ID] = &allCategories[i]
|
||||
}
|
||||
|
||||
// 构建父子关系
|
||||
for i := range allCategories {
|
||||
category := &allCategories[i]
|
||||
if category.ParentID == nil {
|
||||
// 一级分类
|
||||
rootCategories = append(rootCategories, *category)
|
||||
} else {
|
||||
// 二级分类,添加到父分类的Children中
|
||||
if parent, exists := categoryMap[*category.ParentID]; exists {
|
||||
parent.Children = append(parent.Children, *category)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新根分类的Children并设置HasChildren字段
|
||||
for i := range rootCategories {
|
||||
if parent, exists := categoryMap[rootCategories[i].ID]; exists {
|
||||
rootCategories[i].Children = parent.Children
|
||||
rootCategories[i].HasChildren = len(parent.Children) > 0
|
||||
}
|
||||
// 为子分类设置HasChildren字段
|
||||
for j := range rootCategories[i].Children {
|
||||
rootCategories[i].Children[j].HasChildren = false // 二级分类没有子分类
|
||||
}
|
||||
}
|
||||
|
||||
return rootCategories, nil
|
||||
}
|
||||
|
||||
// GetCategoryByID 根据ID获取分类
|
||||
func (r *ProductRepository) GetCategoryByID(id uint) (*model.Category, error) {
|
||||
var category model.Category
|
||||
err := r.db.Where("id = ? AND status = ?", id, 1).First(&category).Error
|
||||
return &category, err
|
||||
}
|
||||
|
||||
// CreateCategory 创建分类
|
||||
func (r *ProductRepository) CreateCategory(category *model.Category) error {
|
||||
return r.db.Create(category).Error
|
||||
}
|
||||
|
||||
// UpdateCategory 更新分类
|
||||
func (r *ProductRepository) UpdateCategory(id uint, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.Category{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteCategory 删除分类
|
||||
func (r *ProductRepository) DeleteCategory(id uint) error {
|
||||
return r.db.Delete(&model.Category{}, id).Error
|
||||
}
|
||||
|
||||
// GetReviews 获取产品评价列表
|
||||
func (r *ProductRepository) GetReviews(productID uint, offset, limit int) ([]model.ProductReview, int64, error) {
|
||||
var reviews []model.ProductReview
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.ProductReview{}).Where("product_id = ?", productID)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
err := query.Offset(offset).Limit(limit).
|
||||
Order("created_at DESC").Find(&reviews).Error
|
||||
return reviews, total, err
|
||||
}
|
||||
|
||||
// CreateReview 创建评价
|
||||
func (r *ProductRepository) CreateReview(review *model.ProductReview) error {
|
||||
return r.db.Create(review).Error
|
||||
}
|
||||
|
||||
// GetReviewByOrderID 根据订单ID获取评价
|
||||
func (r *ProductRepository) GetReviewByOrderID(userID uint, orderID uint) (*model.ProductReview, error) {
|
||||
var review model.ProductReview
|
||||
err := r.db.Where("user_id = ? AND order_id = ?", userID, orderID).First(&review).Error
|
||||
return &review, err
|
||||
}
|
||||
|
||||
// GetReviewCount 获取产品评价统计
|
||||
func (r *ProductRepository) GetReviewCount(productID uint) (map[string]interface{}, error) {
|
||||
var total int64
|
||||
var goodCount int64 // 好评数 (4-5星)
|
||||
var middleCount int64 // 中评数 (3星)
|
||||
var badCount int64 // 差评数 (1-2星)
|
||||
var hasImageCount int64 // 有图评价数
|
||||
|
||||
// 获取总评价数
|
||||
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ?", productID).Count(&total).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取好评数 (4-5星)
|
||||
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating >= ?", productID, 4).
|
||||
Count(&goodCount).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取中评数 (3星)
|
||||
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating = ?", productID, 3).
|
||||
Count(&middleCount).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取差评数 (1-2星)
|
||||
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating <= ?", productID, 2).
|
||||
Count(&badCount).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取有图评价数
|
||||
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND images IS NOT NULL AND images != ''", productID).
|
||||
Count(&hasImageCount).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 计算好评率
|
||||
var goodRate float64 = 100.0
|
||||
if total > 0 {
|
||||
goodRate = float64(goodCount) / float64(total) * 100
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"total": total,
|
||||
"good_rate": goodRate,
|
||||
"good_count": goodCount,
|
||||
"middle_count": middleCount,
|
||||
"bad_count": badCount,
|
||||
"has_image_count": hasImageCount,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetProductImages 获取产品图片
|
||||
func (r *ProductRepository) GetProductImages(productID uint) ([]model.ProductImage, error) {
|
||||
var images []model.ProductImage
|
||||
err := r.db.Where("product_id = ?", productID).Order("sort ASC").Find(&images).Error
|
||||
return images, err
|
||||
}
|
||||
|
||||
// CreateProductImage 创建产品图片
|
||||
func (r *ProductRepository) CreateProductImage(image *model.ProductImage) error {
|
||||
return r.db.Create(image).Error
|
||||
}
|
||||
|
||||
// DeleteProductImage 删除产品图片
|
||||
func (r *ProductRepository) DeleteProductImage(id uint) error {
|
||||
return r.db.Delete(&model.ProductImage{}, id).Error
|
||||
}
|
||||
|
||||
// GetProductSpecs 获取产品规格
|
||||
func (r *ProductRepository) GetProductSpecs(productID uint) ([]model.ProductSpec, error) {
|
||||
var specs []model.ProductSpec
|
||||
err := r.db.Where("product_id = ?", productID).Order("sort ASC").Find(&specs).Error
|
||||
return specs, err
|
||||
}
|
||||
|
||||
// CreateProductSpec 创建产品规格
|
||||
func (r *ProductRepository) CreateProductSpec(spec *model.ProductSpec) error {
|
||||
return r.db.Create(spec).Error
|
||||
}
|
||||
|
||||
// UpdateProductSpec 更新产品规格
|
||||
func (r *ProductRepository) UpdateProductSpec(id uint, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.ProductSpec{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteProductSpec 删除产品规格
|
||||
func (r *ProductRepository) DeleteProductSpec(id uint) error {
|
||||
return r.db.Delete(&model.ProductSpec{}, id).Error
|
||||
}
|
||||
|
||||
// GetHotProducts 获取热门产品
|
||||
func (r *ProductRepository) GetHotProducts(limit int) ([]model.Product, error) {
|
||||
var products []model.Product
|
||||
err := r.db.Preload("Category").
|
||||
Where("status = ? AND is_hot = ?", 1, 1).
|
||||
Order("sales DESC, created_at DESC").Limit(limit).Find(&products).Error
|
||||
return products, err
|
||||
}
|
||||
|
||||
// GetRecommendProducts 获取推荐产品
|
||||
func (r *ProductRepository) GetRecommendProducts(limit int) ([]model.Product, error) {
|
||||
var products []model.Product
|
||||
err := r.db.Preload("Category").
|
||||
Where("status = ? AND is_recommend = ?", 1, 1).
|
||||
Order("sort DESC, created_at DESC").Limit(limit).Find(&products).Error
|
||||
return products, err
|
||||
}
|
||||
|
||||
// GetProductSKUs 获取产品SKU列表
|
||||
func (r *ProductRepository) GetProductSKUs(productID uint) ([]model.ProductSKU, error) {
|
||||
var skus []model.ProductSKU
|
||||
err := r.db.Where("product_id = ? AND status = ?", productID, 1).
|
||||
Order("weight ASC").Find(&skus).Error
|
||||
return skus, err
|
||||
}
|
||||
|
||||
// GetSKUByID 根据SKU ID获取SKU详情
|
||||
func (r *ProductRepository) GetSKUByID(skuID uint) (*model.ProductSKU, error) {
|
||||
var sku model.ProductSKU
|
||||
err := r.db.Where("id = ? AND status = ?", skuID, 1).First(&sku).Error
|
||||
return &sku, err
|
||||
}
|
||||
|
||||
// GetProductTags 获取产品标签列表
|
||||
func (r *ProductRepository) GetProductTags() ([]model.ProductTag, error) {
|
||||
var tags []model.ProductTag
|
||||
err := r.db.Where("status = ?", 1).Order("sort ASC").Find(&tags).Error
|
||||
return tags, err
|
||||
}
|
||||
|
||||
// GetStores 获取店铺列表
|
||||
func (r *ProductRepository) GetStores() ([]model.Store, error) {
|
||||
var stores []model.Store
|
||||
err := r.db.Where("status = ?", 1).Order("sort ASC").Find(&stores).Error
|
||||
return stores, err
|
||||
}
|
||||
|
||||
// GetStoreByID 根据ID获取店铺信息
|
||||
func (r *ProductRepository) GetStoreByID(id uint) (*model.Store, error) {
|
||||
var store model.Store
|
||||
err := r.db.Where("id = ? AND status = ?", id, 1).First(&store).Error
|
||||
return &store, err
|
||||
}
|
||||
|
||||
// GetProductStatistics 获取产品统计
|
||||
func (r *ProductRepository) GetProductStatistics() (map[string]interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// 总商品数
|
||||
var totalProducts int64
|
||||
if err := r.db.Model(&model.Product{}).Count(&totalProducts).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result["total_products"] = totalProducts
|
||||
|
||||
// 上架商品数
|
||||
var onlineProducts int64
|
||||
if err := r.db.Model(&model.Product{}).Where("status = ?", 1).Count(&onlineProducts).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result["online_products"] = onlineProducts
|
||||
|
||||
// 下架商品数
|
||||
var offlineProducts int64
|
||||
if err := r.db.Model(&model.Product{}).Where("status = ?", 0).Count(&offlineProducts).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result["offline_products"] = offlineProducts
|
||||
|
||||
// 库存不足商品数(库存小于10)
|
||||
var lowStockProducts int64
|
||||
if err := r.db.Model(&model.Product{}).Where("stock < ?", 10).Count(&lowStockProducts).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result["low_stock_products"] = lowStockProducts
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetCategorySalesStatistics 获取分类销售统计
|
||||
func (r *ProductRepository) GetCategorySalesStatistics(startDate, endDate string, limit int) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
c.id as category_id,
|
||||
c.name as category_name,
|
||||
COALESCE(SUM(oi.quantity), 0) as sales_count,
|
||||
COALESCE(SUM(oi.total_price), 0) as sales_amount
|
||||
FROM ai_categories c
|
||||
LEFT JOIN ai_products p ON c.id = p.category_id
|
||||
LEFT JOIN order_items oi ON p.id = oi.product_id
|
||||
LEFT JOIN ai_orders o ON oi.order_id = o.id
|
||||
WHERE c.status = 1
|
||||
AND (o.created_at IS NULL OR (o.created_at >= ? AND o.created_at <= ?))
|
||||
AND (o.status IS NULL OR o.status IN (2, 3, 4, 5, 6))
|
||||
GROUP BY c.id, c.name
|
||||
ORDER BY sales_amount DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
err := r.db.Raw(query, startDate+" 00:00:00", endDate+" 23:59:59", limit).Scan(&results).Error
|
||||
return results, err
|
||||
}
|
||||
|
||||
// BatchUpdateStatus 批量更新商品状态
|
||||
func (r *ProductRepository) BatchUpdateStatus(ids []uint, status int) error {
|
||||
return r.db.Model(&model.Product{}).Where("id IN ?", ids).Update("status", status).Error
|
||||
}
|
||||
|
||||
// BatchUpdatePrice 批量更新商品价格
|
||||
func (r *ProductRepository) BatchUpdatePrice(updates []map[string]interface{}) error {
|
||||
tx := r.db.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
for _, update := range updates {
|
||||
if id, ok := update["id"]; ok {
|
||||
delete(update, "id")
|
||||
if err := tx.Model(&model.Product{}).Where("id = ?", id).Updates(update).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// CountProductsByCategory 统计指定分类下的商品数量(包括子分类)
|
||||
func (r *ProductRepository) CountProductsByCategory(categoryID uint) (int64, error) {
|
||||
var count int64
|
||||
|
||||
// 获取所有子分类ID
|
||||
var categoryIDs []uint
|
||||
categoryIDs = append(categoryIDs, categoryID)
|
||||
|
||||
// 查找所有子分类
|
||||
var childCategories []model.Category
|
||||
err := r.db.Where("parent_id = ?", categoryID).Find(&childCategories).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for _, child := range childCategories {
|
||||
categoryIDs = append(categoryIDs, child.ID)
|
||||
}
|
||||
|
||||
// 统计这些分类下的商品数量(只统计未删除的商品)
|
||||
err = r.db.Model(&model.Product{}).
|
||||
Where("category_id IN (?)", categoryIDs).
|
||||
Count(&count).Error
|
||||
|
||||
return count, err
|
||||
}
|
||||
|
||||
// BatchDelete 批量删除商品(软删除)
|
||||
func (r *ProductRepository) BatchDelete(ids []uint) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用GORM软删除,只删除商品主记录
|
||||
// 关联记录保留,通过商品的deleted_at字段来判断商品是否被删除
|
||||
return r.db.Delete(&model.Product{}, ids).Error
|
||||
}
|
||||
|
||||
// CreateSKU 创建商品SKU
|
||||
func (r *ProductRepository) CreateSKU(sku *model.ProductSKU) error {
|
||||
return r.db.Create(sku).Error
|
||||
}
|
||||
|
||||
// UpdateSKU 更新商品SKU
|
||||
func (r *ProductRepository) UpdateSKU(id uint, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.ProductSKU{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteSKU 删除商品SKU
|
||||
func (r *ProductRepository) DeleteSKU(id uint) error {
|
||||
// 检查SKU是否被订单引用
|
||||
var count int64
|
||||
err := r.db.Table("order_items").Where("sk_uid = ?", id).Count(&count).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
// 如果被订单引用,使用软删除(设置status=0)
|
||||
return r.db.Model(&model.ProductSKU{}).Where("id = ?", id).Update("status", 0).Error
|
||||
} else {
|
||||
// 如果没有被引用,可以进行硬删除
|
||||
return r.db.Delete(&model.ProductSKU{}, id).Error
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateProductImageSort 更新商品图片排序
|
||||
func (r *ProductRepository) UpdateProductImageSort(id uint, sort int) error {
|
||||
return r.db.Model(&model.ProductImage{}).Where("id = ?", id).Update("sort", sort).Error
|
||||
}
|
||||
|
||||
// CreateProductTag 创建商品标签
|
||||
func (r *ProductRepository) CreateProductTag(tag *model.ProductTag) error {
|
||||
return r.db.Create(tag).Error
|
||||
}
|
||||
|
||||
// UpdateProductTag 更新商品标签
|
||||
func (r *ProductRepository) UpdateProductTag(id uint, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.ProductTag{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteProductTag 删除商品标签
|
||||
func (r *ProductRepository) DeleteProductTag(id uint) error {
|
||||
return r.db.Delete(&model.ProductTag{}, id).Error
|
||||
}
|
||||
|
||||
// AssignTagsToProduct 为商品分配标签
|
||||
func (r *ProductRepository) AssignTagsToProduct(productID uint, tagIDs []uint) error {
|
||||
// 先清除现有关联
|
||||
if err := r.db.Exec("DELETE FROM ai_product_tag_relations WHERE product_id = ?", productID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 添加新关联
|
||||
if len(tagIDs) > 0 {
|
||||
var relations []map[string]interface{}
|
||||
for _, tagID := range tagIDs {
|
||||
relations = append(relations, map[string]interface{}{
|
||||
"product_id": productID,
|
||||
"tag_id": tagID,
|
||||
})
|
||||
}
|
||||
return r.db.Table("ai_product_tag_relations").Create(relations).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLowStockProducts 获取低库存商品
|
||||
func (r *ProductRepository) GetLowStockProducts(threshold int) ([]model.Product, error) {
|
||||
var products []model.Product
|
||||
err := r.db.Where("stock <= ? AND status = ?", threshold, 1).
|
||||
Preload("Category").Find(&products).Error
|
||||
return products, err
|
||||
}
|
||||
|
||||
// GetInventoryStatistics 获取库存统计
|
||||
func (r *ProductRepository) GetInventoryStatistics() (map[string]interface{}, error) {
|
||||
var result map[string]interface{}
|
||||
|
||||
// 总商品数
|
||||
var totalProducts int64
|
||||
r.db.Model(&model.Product{}).Where("status = ?", 1).Count(&totalProducts)
|
||||
|
||||
// 低库存商品数(库存小于等于10)
|
||||
var lowStockProducts int64
|
||||
r.db.Model(&model.Product{}).Where("stock <= ? AND status = ?", 10, 1).Count(&lowStockProducts)
|
||||
|
||||
// 缺货商品数
|
||||
var outOfStockProducts int64
|
||||
r.db.Model(&model.Product{}).Where("stock = ? AND status = ?", 0, 1).Count(&outOfStockProducts)
|
||||
|
||||
// 总库存值
|
||||
var totalStockValue float64
|
||||
r.db.Model(&model.Product{}).Where("status = ?", 1).
|
||||
Select("SUM(stock * price)").Scan(&totalStockValue)
|
||||
|
||||
result = map[string]interface{}{
|
||||
"total_products": totalProducts,
|
||||
"low_stock_products": lowStockProducts,
|
||||
"out_of_stock_products": outOfStockProducts,
|
||||
"total_stock_value": totalStockValue,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetProductsForExport 获取用于导出的商品数据
|
||||
func (r *ProductRepository) GetProductsForExport(conditions map[string]interface{}) ([]model.Product, error) {
|
||||
var products []model.Product
|
||||
query := r.db.Model(&model.Product{}).Preload("Category")
|
||||
|
||||
// 添加查询条件
|
||||
for key, value := range conditions {
|
||||
switch key {
|
||||
case "category_id":
|
||||
// 导出也支持子分类
|
||||
var catID uint
|
||||
switch v := value.(type) {
|
||||
case uint:
|
||||
catID = v
|
||||
case int:
|
||||
catID = uint(v)
|
||||
case float64:
|
||||
catID = uint(v)
|
||||
}
|
||||
if catID > 0 {
|
||||
categoryIDs, err := r.getCategoryIDsIncludingChildren(catID)
|
||||
if err == nil && len(categoryIDs) > 0 {
|
||||
query = query.Where("category_id IN (?)", categoryIDs)
|
||||
} else {
|
||||
query = query.Where("category_id = ?", catID)
|
||||
}
|
||||
}
|
||||
case "status":
|
||||
query = query.Where("status = ?", value)
|
||||
case "keyword":
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+value.(string)+"%", "%"+value.(string)+"%")
|
||||
}
|
||||
}
|
||||
|
||||
err := query.Find(&products).Error
|
||||
return products, err
|
||||
}
|
||||
|
||||
// getCategoryIDsIncludingChildren 获取给定分类ID及其子分类ID集合(仅支持两级分类)
|
||||
func (r *ProductRepository) getCategoryIDsIncludingChildren(categoryID uint) ([]uint, error) {
|
||||
// 将自身分类加入集合
|
||||
ids := []uint{categoryID}
|
||||
|
||||
// 查找直接子分类(系统设计为最多二级)
|
||||
var childCategories []model.Category
|
||||
if err := r.db.Where("parent_id = ? AND status = ?", categoryID, 1).Find(&childCategories).Error; err != nil {
|
||||
return ids, err
|
||||
}
|
||||
for _, child := range childCategories {
|
||||
ids = append(ids, child.ID)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// DeductSKUStock 扣减SKU库存
|
||||
func (r *ProductRepository) DeductSKUStock(skuID uint, quantity int) error {
|
||||
// 检查SKU库存是否充足并扣减
|
||||
result := r.db.Model(&model.ProductSKU{}).
|
||||
Where("id = ? AND stock >= ?", skuID, quantity).
|
||||
Update("stock", gorm.Expr("stock - ?", quantity))
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound // SKU库存不足或SKU不存在
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreSKUStock 恢复SKU库存
|
||||
func (r *ProductRepository) RestoreSKUStock(skuID uint, quantity int) error {
|
||||
return r.db.Model(&model.ProductSKU{}).Where("id = ?", skuID).
|
||||
Update("stock", gorm.Expr("stock + ?", quantity)).Error
|
||||
}
|
||||
|
||||
// SyncProductStockFromSKUs 根据SKU库存同步商品总库存
|
||||
func (r *ProductRepository) SyncProductStockFromSKUs(productID uint) error {
|
||||
// 计算所有SKU的库存总和
|
||||
var totalStock int
|
||||
err := r.db.Model(&model.ProductSKU{}).
|
||||
Where("product_id = ? AND status = ?", productID, 1).
|
||||
Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新商品总库存
|
||||
return r.db.Model(&model.Product{}).Where("id = ?", productID).
|
||||
Update("stock", totalStock).Error
|
||||
}
|
||||
|
||||
// DeductStockWithSKU 扣减商品和SKU库存(支持SKU商品的库存扣减)
|
||||
func (r *ProductRepository) DeductStockWithSKU(productID uint, skuID *uint, quantity int) error {
|
||||
tx := r.db.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
if skuID != nil && *skuID > 0 {
|
||||
// 有SKU的商品,先扣减SKU库存
|
||||
result := tx.Model(&model.ProductSKU{}).
|
||||
Where("id = ? AND stock >= ?", *skuID, quantity).
|
||||
Update("stock", gorm.Expr("stock - ?", quantity))
|
||||
|
||||
if result.Error != nil {
|
||||
tx.Rollback()
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
tx.Rollback()
|
||||
return gorm.ErrRecordNotFound // SKU库存不足
|
||||
}
|
||||
|
||||
// 同步更新商品总库存
|
||||
var totalStock int
|
||||
err := tx.Model(&model.ProductSKU{}).
|
||||
Where("product_id = ? AND status = ?", productID, 1).
|
||||
Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Model(&model.Product{}).Where("id = ?", productID).
|
||||
Update("stock", totalStock).Error
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// 没有SKU的商品,直接扣减商品库存
|
||||
result := tx.Model(&model.Product{}).
|
||||
Where("id = ? AND stock >= ?", productID, quantity).
|
||||
Update("stock", gorm.Expr("stock - ?", quantity))
|
||||
|
||||
if result.Error != nil {
|
||||
tx.Rollback()
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
tx.Rollback()
|
||||
return gorm.ErrRecordNotFound // 商品库存不足
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// RestoreStockWithSKU 恢复商品和SKU库存(支持SKU商品的库存恢复)
|
||||
func (r *ProductRepository) RestoreStockWithSKU(productID uint, skuID *uint, quantity int) error {
|
||||
tx := r.db.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
if skuID != nil && *skuID > 0 {
|
||||
// 有SKU的商品,先恢复SKU库存
|
||||
err := tx.Model(&model.ProductSKU{}).Where("id = ?", *skuID).
|
||||
Update("stock", gorm.Expr("stock + ?", quantity)).Error
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// 同步更新商品总库存
|
||||
var totalStock int
|
||||
err = tx.Model(&model.ProductSKU{}).
|
||||
Where("product_id = ? AND status = ?", productID, 1).
|
||||
Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Model(&model.Product{}).Where("id = ?", productID).
|
||||
Update("stock", totalStock).Error
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// 没有SKU的商品,直接恢复商品库存
|
||||
err := tx.Model(&model.Product{}).Where("id = ?", productID).
|
||||
Update("stock", gorm.Expr("stock + ?", quantity)).Error
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
406
server/internal/repository/refund.go
Normal file
406
server/internal/repository/refund.go
Normal file
@@ -0,0 +1,406 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type RefundRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRefundRepository(db *gorm.DB) *RefundRepository {
|
||||
return &RefundRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建退款记录(别名方法)
|
||||
func (r *RefundRepository) Create(refund *model.Refund) error {
|
||||
return r.db.Create(refund).Error
|
||||
}
|
||||
|
||||
// CreateRefund 创建退款记录
|
||||
func (r *RefundRepository) CreateRefund(refund *model.Refund) error {
|
||||
return r.db.Create(refund).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取退款记录(别名方法)
|
||||
func (r *RefundRepository) GetByID(id uint) (*model.Refund, error) {
|
||||
var refund model.Refund
|
||||
err := r.db.Preload("Order").Preload("User").Preload("RefundItems").
|
||||
First(&refund, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &refund, nil
|
||||
}
|
||||
|
||||
// GetRefundByID 根据ID获取退款记录
|
||||
func (r *RefundRepository) GetRefundByID(id uint) (*model.Refund, error) {
|
||||
var refund model.Refund
|
||||
err := r.db.Preload("Order").Preload("User").Preload("RefundItems").
|
||||
First(&refund, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &refund, nil
|
||||
}
|
||||
|
||||
// GetRefundByRefundNo 根据退款单号获取退款记录
|
||||
func (r *RefundRepository) GetRefundByRefundNo(refundNo string) (*model.Refund, error) {
|
||||
var refund model.Refund
|
||||
err := r.db.Preload("Order").Preload("User").Preload("RefundItems").
|
||||
Where("refund_no = ?", refundNo).First(&refund).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &refund, nil
|
||||
}
|
||||
|
||||
// GetByWechatOutRefundNo 根据微信退款单号获取退款记录(别名方法)
|
||||
func (r *RefundRepository) GetByWechatOutRefundNo(wechatOutRefundNo string) (*model.Refund, error) {
|
||||
var refund model.Refund
|
||||
err := r.db.Preload("Order").Preload("User").Preload("RefundItems").
|
||||
Where("wechat_out_refund_no = ?", wechatOutRefundNo).First(&refund).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &refund, nil
|
||||
}
|
||||
|
||||
// GetRefundByWechatOutRefundNo 根据微信退款单号获取退款记录
|
||||
func (r *RefundRepository) GetRefundByWechatOutRefundNo(wechatOutRefundNo string) (*model.Refund, error) {
|
||||
var refund model.Refund
|
||||
err := r.db.Preload("Order").Preload("User").Preload("RefundItems").
|
||||
Where("wechat_out_refund_no = ?", wechatOutRefundNo).First(&refund).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &refund, nil
|
||||
}
|
||||
|
||||
// GetTotalRefundedByOrderID 根据订单ID获取已退款总金额
|
||||
func (r *RefundRepository) GetTotalRefundedByOrderID(orderID uint) (float64, error) {
|
||||
var totalRefunded float64
|
||||
err := r.db.Model(&model.Refund{}).
|
||||
Where("order_id = ? AND status IN (?)", orderID, []int{model.RefundStatusSuccess}).
|
||||
Select("COALESCE(SUM(actual_refund_amount), 0)").
|
||||
Scan(&totalRefunded).Error
|
||||
return totalRefunded, err
|
||||
}
|
||||
|
||||
// GetByOrderID 根据订单ID获取退款记录列表(别名方法)
|
||||
func (r *RefundRepository) GetByOrderID(orderID uint) ([]model.Refund, error) {
|
||||
var refunds []model.Refund
|
||||
err := r.db.Preload("RefundItems").
|
||||
Where("order_id = ?", orderID).
|
||||
Order("created_at DESC").
|
||||
Find(&refunds).Error
|
||||
return refunds, err
|
||||
}
|
||||
|
||||
// GetRefundsByOrderID 根据订单ID获取退款记录列表
|
||||
func (r *RefundRepository) GetRefundsByOrderID(orderID uint) ([]model.Refund, error) {
|
||||
var refunds []model.Refund
|
||||
err := r.db.Preload("RefundItems").
|
||||
Where("order_id = ?", orderID).
|
||||
Order("created_at DESC").
|
||||
Find(&refunds).Error
|
||||
return refunds, err
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取退款记录列表(别名方法)
|
||||
func (r *RefundRepository) GetByUserID(userID uint, page, pageSize int) ([]model.Refund, int64, error) {
|
||||
var refunds []model.Refund
|
||||
var total int64
|
||||
|
||||
// 计算总数
|
||||
err := r.db.Model(&model.Refund{}).Where("user_id = ?", userID).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err = r.db.Preload("Order").Preload("RefundItems").
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Offset(offset).Limit(pageSize).
|
||||
Find(&refunds).Error
|
||||
|
||||
return refunds, total, err
|
||||
}
|
||||
|
||||
// GetRefundCountByOrderID 获取订单的退款数量
|
||||
func (r *RefundRepository) GetRefundCountByOrderID(orderID uint) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Refund{}).Where("order_id = ?", orderID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CreateLog 创建退款日志
|
||||
func (r *RefundRepository) CreateLog(log *model.RefundLog) error {
|
||||
return r.db.Create(log).Error
|
||||
}
|
||||
|
||||
// GetRefundsByUserID 根据用户ID获取退款记录列表
|
||||
func (r *RefundRepository) GetRefundsByUserID(userID uint, page, pageSize int) ([]model.Refund, int64, error) {
|
||||
var refunds []model.Refund
|
||||
var total int64
|
||||
|
||||
// 计算总数
|
||||
err := r.db.Model(&model.Refund{}).Where("user_id = ?", userID).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err = r.db.Preload("Order").Preload("RefundItems").
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Offset(offset).Limit(pageSize).
|
||||
Find(&refunds).Error
|
||||
|
||||
return refunds, total, err
|
||||
}
|
||||
|
||||
// GetPendingRefunds 获取待审核的退款记录
|
||||
func (r *RefundRepository) GetPendingRefunds(page, pageSize int) ([]model.Refund, int64, error) {
|
||||
var refunds []model.Refund
|
||||
var total int64
|
||||
|
||||
// 计算总数
|
||||
err := r.db.Model(&model.Refund{}).Where("status = ?", model.RefundStatusPending).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err = r.db.Preload("Order").Preload("User").Preload("RefundItems").
|
||||
Where("status = ?", model.RefundStatusPending).
|
||||
Order("apply_time ASC").
|
||||
Offset(offset).Limit(pageSize).
|
||||
Find(&refunds).Error
|
||||
|
||||
return refunds, total, err
|
||||
}
|
||||
|
||||
// GetAllRefunds 获取所有退款记录(管理员)
|
||||
func (r *RefundRepository) GetAllRefunds(page, pageSize int, conditions map[string]interface{}) ([]model.Refund, int64, error) {
|
||||
var refunds []model.Refund
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Refund{})
|
||||
|
||||
// 添加查询条件
|
||||
for key, value := range conditions {
|
||||
switch key {
|
||||
case "status":
|
||||
query = query.Where("status = ?", value)
|
||||
case "user_id":
|
||||
query = query.Where("user_id = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 计算总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Preload("Order").Preload("User").Preload("RefundItems").
|
||||
Order("apply_time DESC").
|
||||
Offset(offset).Limit(pageSize).
|
||||
Find(&refunds).Error
|
||||
|
||||
return refunds, total, err
|
||||
}
|
||||
|
||||
// UpdateStatus 更新退款状态
|
||||
func (r *RefundRepository) UpdateStatus(refundID uint, status int) error {
|
||||
return r.db.Model(&model.Refund{}).Where("id = ?", refundID).Update("status", status).Error
|
||||
}
|
||||
|
||||
// UpdateByID 根据ID更新退款记录
|
||||
func (r *RefundRepository) UpdateByID(refundID uint, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.Refund{}).Where("id = ?", refundID).Updates(updates).Error
|
||||
}
|
||||
|
||||
// UpdateRefund 更新退款记录
|
||||
func (r *RefundRepository) UpdateRefund(refundID uint, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.Refund{}).Where("id = ?", refundID).Updates(updates).Error
|
||||
}
|
||||
|
||||
// UpdateRefundByRefundNo 根据退款单号更新退款记录
|
||||
func (r *RefundRepository) UpdateRefundByRefundNo(refundNo string, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.Refund{}).Where("refund_no = ?", refundNo).Updates(updates).Error
|
||||
}
|
||||
|
||||
// UpdateRefundByWechatOutRefundNo 根据微信退款单号更新退款记录
|
||||
func (r *RefundRepository) UpdateRefundByWechatOutRefundNo(wechatOutRefundNo string, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.Refund{}).Where("wechat_out_refund_no = ?", wechatOutRefundNo).Updates(updates).Error
|
||||
}
|
||||
|
||||
// CreateRefundItem 创建退款项目
|
||||
func (r *RefundRepository) CreateRefundItem(refundItem *model.RefundItem) error {
|
||||
return r.db.Create(refundItem).Error
|
||||
}
|
||||
|
||||
// CreateRefundItems 批量创建退款项目
|
||||
func (r *RefundRepository) CreateRefundItems(refundItems []model.RefundItem) error {
|
||||
return r.db.CreateInBatches(refundItems, 100).Error
|
||||
}
|
||||
|
||||
// GetRefundItemsByRefundID 根据退款ID获取退款项目列表
|
||||
func (r *RefundRepository) GetRefundItemsByRefundID(refundID uint) ([]model.RefundItem, error) {
|
||||
var refundItems []model.RefundItem
|
||||
err := r.db.Preload("Product").Preload("SKU").
|
||||
Where("refund_id = ?", refundID).
|
||||
Find(&refundItems).Error
|
||||
return refundItems, err
|
||||
}
|
||||
|
||||
// CreateRefundLog 创建退款日志
|
||||
func (r *RefundRepository) CreateRefundLog(refundLog *model.RefundLog) error {
|
||||
return r.db.Create(refundLog).Error
|
||||
}
|
||||
|
||||
// GetRefundLogsByRefundID 根据退款ID获取操作日志
|
||||
func (r *RefundRepository) GetRefundLogsByRefundID(refundID uint) ([]model.RefundLog, error) {
|
||||
var refundLogs []model.RefundLog
|
||||
err := r.db.Preload("Operator").
|
||||
Where("refund_id = ?", refundID).
|
||||
Order("created_at ASC").
|
||||
Find(&refundLogs).Error
|
||||
return refundLogs, err
|
||||
}
|
||||
|
||||
// GetRefundStatistics 获取退款统计数据
|
||||
func (r *RefundRepository) GetRefundStatistics(startTime, endTime time.Time) (map[string]interface{}, error) {
|
||||
var result struct {
|
||||
TotalCount int64 `json:"total_count"`
|
||||
TotalAmount float64 `json:"total_amount"`
|
||||
PendingCount int64 `json:"pending_count"`
|
||||
ApprovedCount int64 `json:"approved_count"`
|
||||
RejectedCount int64 `json:"rejected_count"`
|
||||
ProcessingCount int64 `json:"processing_count"`
|
||||
SuccessCount int64 `json:"success_count"`
|
||||
FailedCount int64 `json:"failed_count"`
|
||||
SuccessAmount float64 `json:"success_amount"`
|
||||
}
|
||||
|
||||
// 总退款申请数和金额
|
||||
err := r.db.Model(&model.Refund{}).
|
||||
Where("apply_time BETWEEN ? AND ?", startTime, endTime).
|
||||
Select("COUNT(*) as total_count, COALESCE(SUM(refund_amount), 0) as total_amount").
|
||||
Scan(&result).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 各状态统计
|
||||
statusCounts := []struct {
|
||||
Status int `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
}{}
|
||||
|
||||
err = r.db.Model(&model.Refund{}).
|
||||
Where("apply_time BETWEEN ? AND ?", startTime, endTime).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
Scan(&statusCounts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 分配到对应字段
|
||||
for _, sc := range statusCounts {
|
||||
switch sc.Status {
|
||||
case model.RefundStatusPending:
|
||||
result.PendingCount = sc.Count
|
||||
case model.RefundStatusApproved:
|
||||
result.ApprovedCount = sc.Count
|
||||
case model.RefundStatusRejected:
|
||||
result.RejectedCount = sc.Count
|
||||
case model.RefundStatusProcessing:
|
||||
result.ProcessingCount = sc.Count
|
||||
case model.RefundStatusSuccess:
|
||||
result.SuccessCount = sc.Count
|
||||
case model.RefundStatusFailed:
|
||||
result.FailedCount = sc.Count
|
||||
}
|
||||
}
|
||||
|
||||
// 成功退款金额
|
||||
err = r.db.Model(&model.Refund{}).
|
||||
Where("apply_time BETWEEN ? AND ? AND status = ?", startTime, endTime, model.RefundStatusSuccess).
|
||||
Select("COALESCE(SUM(actual_refund_amount), 0) as success_amount").
|
||||
Scan(&result).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_count": result.TotalCount,
|
||||
"total_amount": result.TotalAmount,
|
||||
"pending_count": result.PendingCount,
|
||||
"approved_count": result.ApprovedCount,
|
||||
"rejected_count": result.RejectedCount,
|
||||
"processing_count": result.ProcessingCount,
|
||||
"success_count": result.SuccessCount,
|
||||
"failed_count": result.FailedCount,
|
||||
"success_amount": result.SuccessAmount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetRefundsByStatus 根据状态获取退款记录列表
|
||||
func (r *RefundRepository) GetRefundsByStatus(status int) ([]model.Refund, error) {
|
||||
var refunds []model.Refund
|
||||
err := r.db.Preload("Order").Preload("RefundItems").
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&refunds).Error
|
||||
return refunds, err
|
||||
}
|
||||
|
||||
// GetRefundTrends 获取退款趋势数据
|
||||
func (r *RefundRepository) GetRefundTrends(days int) ([]map[string]interface{}, error) {
|
||||
var trends []struct {
|
||||
Date string `json:"date"`
|
||||
Count int64 `json:"count"`
|
||||
Amount float64 `json:"amount"`
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
DATE(apply_time) as date,
|
||||
COUNT(*) as count,
|
||||
COALESCE(SUM(refund_amount), 0) as amount
|
||||
FROM ai_refunds
|
||||
WHERE apply_time >= DATE_SUB(CURDATE(), INTERVAL %d DAY)
|
||||
GROUP BY DATE(apply_time)
|
||||
ORDER BY date ASC
|
||||
`, days)
|
||||
|
||||
err := r.db.Raw(query).Scan(&trends).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, len(trends))
|
||||
for i, trend := range trends {
|
||||
result[i] = map[string]interface{}{
|
||||
"date": trend.Date,
|
||||
"count": trend.Count,
|
||||
"amount": trend.Amount,
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
151
server/internal/repository/user.go
Normal file
151
server/internal/repository/user.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UserRepository 用户仓储
|
||||
type UserRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserRepository 创建用户仓储
|
||||
func NewUserRepository(db *gorm.DB) *UserRepository {
|
||||
return &UserRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建用户
|
||||
func (r *UserRepository) Create(user *model.User) error {
|
||||
return r.db.Create(user).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户
|
||||
func (r *UserRepository) GetByID(id uint) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.First(&user, id).Error
|
||||
return &user, err
|
||||
}
|
||||
|
||||
// GetByOpenID 根据OpenID获取用户
|
||||
func (r *UserRepository) GetByOpenID(openID string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.Where("open_id = ?", openID).First(&user).Error
|
||||
return &user, err
|
||||
}
|
||||
|
||||
// Update 更新用户
|
||||
func (r *UserRepository) Update(id uint, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.User{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// Delete 删除用户(软删除)
|
||||
func (r *UserRepository) Delete(id uint) error {
|
||||
return r.db.Delete(&model.User{}, id).Error
|
||||
}
|
||||
|
||||
// BatchDelete 批量删除用户(软删除)
|
||||
func (r *UserRepository) BatchDelete(ids []uint) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
return r.db.Delete(&model.User{}, ids).Error
|
||||
}
|
||||
|
||||
// GetList 获取用户列表
|
||||
func (r *UserRepository) GetList(offset, limit int, conditions map[string]interface{}) ([]model.User, int64, error) {
|
||||
var users []model.User
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.User{})
|
||||
|
||||
// 添加查询条件
|
||||
for key, value := range conditions {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
err := query.Offset(offset).Limit(limit).Find(&users).Error
|
||||
return users, total, err
|
||||
}
|
||||
|
||||
// GetAddresses 获取用户地址列表
|
||||
func (r *UserRepository) GetAddresses(userID uint) ([]model.UserAddress, error) {
|
||||
var addresses []model.UserAddress
|
||||
err := r.db.Where("user_id = ?", userID).Order("is_default DESC, created_at DESC").Find(&addresses).Error
|
||||
return addresses, err
|
||||
}
|
||||
|
||||
// CreateAddress 创建用户地址
|
||||
func (r *UserRepository) CreateAddress(address *model.UserAddress) error {
|
||||
return r.db.Create(address).Error
|
||||
}
|
||||
|
||||
// GetAddressByID 根据ID获取地址
|
||||
func (r *UserRepository) GetAddressByID(id uint) (*model.UserAddress, error) {
|
||||
var address model.UserAddress
|
||||
err := r.db.First(&address, id).Error
|
||||
return &address, err
|
||||
}
|
||||
|
||||
// UpdateAddress 更新用户地址
|
||||
func (r *UserRepository) UpdateAddress(id uint, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.UserAddress{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteAddress 删除用户地址
|
||||
func (r *UserRepository) DeleteAddress(id uint) error {
|
||||
return r.db.Delete(&model.UserAddress{}, id).Error
|
||||
}
|
||||
|
||||
// ClearDefaultAddress 清除用户的默认地址
|
||||
func (r *UserRepository) ClearDefaultAddress(userID uint) error {
|
||||
return r.db.Model(&model.UserAddress{}).Where("user_id = ?", userID).Update("is_default", 0).Error
|
||||
}
|
||||
|
||||
// GetDefaultAddress 获取用户默认地址
|
||||
func (r *UserRepository) GetDefaultAddress(userID uint) (*model.UserAddress, error) {
|
||||
var address model.UserAddress
|
||||
err := r.db.Where("user_id = ? AND is_default = 1", userID).First(&address).Error
|
||||
return &address, err
|
||||
}
|
||||
|
||||
// GetFavorites 获取用户收藏列表
|
||||
func (r *UserRepository) GetFavorites(userID uint, offset, limit int) ([]model.UserFavorite, int64, error) {
|
||||
var favorites []model.UserFavorite
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.UserFavorite{}).Where("user_id = ?", userID)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表,预加载商品信息
|
||||
err := query.Preload("Product").Offset(offset).Limit(limit).Order("created_at DESC").Find(&favorites).Error
|
||||
return favorites, total, err
|
||||
}
|
||||
|
||||
// CreateFavorite 创建收藏
|
||||
func (r *UserRepository) CreateFavorite(favorite *model.UserFavorite) error {
|
||||
return r.db.Create(favorite).Error
|
||||
}
|
||||
|
||||
// DeleteFavorite 删除收藏
|
||||
func (r *UserRepository) DeleteFavorite(userID, productID uint) error {
|
||||
return r.db.Where("user_id = ? AND product_id = ?", userID, productID).Delete(&model.UserFavorite{}).Error
|
||||
}
|
||||
|
||||
// IsFavorite 检查是否已收藏
|
||||
func (r *UserRepository) IsFavorite(userID, productID uint) bool {
|
||||
var count int64
|
||||
r.db.Model(&model.UserFavorite{}).Where("user_id = ? AND product_id = ?", userID, productID).Count(&count)
|
||||
return count > 0
|
||||
}
|
||||
Reference in New Issue
Block a user