package repository import ( "dianshang/internal/model" "gorm.io/gorm" ) // ProductRepository 产品仓储 type ProductRepository struct { db *gorm.DB } // NewProductRepository 创建产品仓储 func NewProductRepository(db *gorm.DB) *ProductRepository { return &ProductRepository{db: db} } // GetDB 获取数据库连接 func (r *ProductRepository) GetDB() *gorm.DB { return r.db } // GetList 获取产品列表 func (r *ProductRepository) GetList(offset, limit int, conditions map[string]interface{}) ([]model.Product, int64, error) { var products []model.Product var total int64 query := r.db.Model(&model.Product{}) // 默认只查询上架商品,除非明确指定状态 if statusValue, exists := conditions["status"]; exists { if statusValue == "all" { // 管理系统可以传递 "all" 来获取所有状态的商品 } else { query = query.Where("status = ?", statusValue) } } else { // 默认只查询上架商品(兼容原有逻辑) query = query.Where("status = ?", 1) } // 添加查询条件 var sortField string var sortType string for key, value := range conditions { switch key { case "category_id": // 支持包含子分类的筛选 var catID uint switch v := value.(type) { case uint: catID = v case int: catID = uint(v) case float64: catID = uint(v) } if catID > 0 { categoryIDs, err := r.getCategoryIDsIncludingChildren(catID) if err == nil && len(categoryIDs) > 0 { query = query.Where("category_id IN (?)", categoryIDs) } else { // 兜底:如果获取子分类失败,退化为当前分类 query = query.Where("category_id = ?", catID) } } case "keyword": query = query.Where("name LIKE ? OR description LIKE ?", "%"+value.(string)+"%", "%"+value.(string)+"%") case "min_price": query = query.Where("price >= ?", value) case "max_price": query = query.Where("price <= ?", value) case "is_hot": if value.(string) == "true" { query = query.Where("is_hot = ?", true) } else if value.(string) == "false" { query = query.Where("is_hot = ?", false) } case "is_new": if value.(string) == "true" { query = query.Where("is_new = ?", true) } else if value.(string) == "false" { query = query.Where("is_new = ?", false) } case "is_recommend": if value.(string) == "true" { query = query.Where("is_recommend = ?", true) } else if value.(string) == "false" { query = query.Where("is_recommend = ?", false) } case "sort": sortField = value.(string) case "sort_type": sortType = value.(string) case "status": // 状态条件已在上面处理,这里跳过 continue } } // 获取总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 构建排序条件 orderBy := "sort DESC, created_at DESC" // 默认排序 if sortField != "" { switch sortField { case "price": if sortType == "asc" { orderBy = "price ASC" } else { orderBy = "price DESC" } case "sales": if sortType == "asc" { orderBy = "sales ASC" } else { orderBy = "sales DESC" } case "created_at": if sortType == "asc" { orderBy = "created_at ASC" } else { orderBy = "created_at DESC" } } } // 获取列表,预加载分类 err := query.Preload("Category"). Offset(offset).Limit(limit).Order(orderBy).Find(&products).Error return products, total, err } // GetByID 根据ID获取产品详情 func (r *ProductRepository) GetByID(id uint) (*model.Product, error) { var product model.Product err := r.db.Preload("Category").Preload("Specs").Preload("SKUs", "status = ?", 1). Where("id = ?", id).First(&product).Error return &product, err } // Create 创建产品 func (r *ProductRepository) Create(product *model.Product) error { return r.db.Create(product).Error } // Update 更新产品 func (r *ProductRepository) Update(id uint, updates map[string]interface{}) error { return r.db.Model(&model.Product{}).Where("id = ?", id).Updates(updates).Error } // Delete 删除产品(软删除) func (r *ProductRepository) Delete(id uint) error { // 使用GORM软删除,只删除商品主记录 // 关联记录保留,通过商品的deleted_at字段来判断商品是否被删除 return r.db.Delete(&model.Product{}, id).Error } // UpdateStock 更新库存 func (r *ProductRepository) UpdateStock(id uint, quantity int) error { return r.db.Model(&model.Product{}).Where("id = ?", id). Update("stock", gorm.Expr("stock + ?", quantity)).Error } // UpdateSales 更新销量 func (r *ProductRepository) UpdateSales(id uint, quantity int) error { return r.db.Model(&model.Product{}).Where("id = ?", id). Update("sales", gorm.Expr("sales + ?", quantity)).Error } // DeductStock 扣减库存(支付成功时使用) func (r *ProductRepository) DeductStock(id uint, quantity int) error { // 检查库存是否充足并扣减 result := r.db.Model(&model.Product{}). Where("id = ? AND stock >= ?", id, quantity). Update("stock", gorm.Expr("stock - ?", quantity)) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return gorm.ErrRecordNotFound // 库存不足或产品不存在 } return nil } // ReserveStock 预扣库存(创建订单时使用) func (r *ProductRepository) ReserveStock(id uint, quantity int) error { // 检查库存是否充足并预扣 result := r.db.Model(&model.Product{}). Where("id = ? AND stock >= ?", id, quantity). Update("stock", gorm.Expr("stock - ?", quantity)) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return gorm.ErrRecordNotFound // 库存不足或产品不存在 } return nil } // RestoreStock 恢复库存(取消订单或支付失败时使用) func (r *ProductRepository) RestoreStock(id uint, quantity int) error { return r.db.Model(&model.Product{}).Where("id = ?", id). Update("stock", gorm.Expr("stock + ?", quantity)).Error } // GetCategories 获取分类列表 func (r *ProductRepository) GetCategories() ([]model.Category, error) { var allCategories []model.Category err := r.db.Where("status = ?", 1).Order("level ASC, sort DESC, created_at ASC").Find(&allCategories).Error if err != nil { return nil, err } // 构建层级关系 categoryMap := make(map[uint]*model.Category) var rootCategories []model.Category // 先将所有分类放入map for i := range allCategories { categoryMap[allCategories[i].ID] = &allCategories[i] } // 构建父子关系 for i := range allCategories { category := &allCategories[i] if category.ParentID == nil { // 一级分类 rootCategories = append(rootCategories, *category) } else { // 二级分类,添加到父分类的Children中 if parent, exists := categoryMap[*category.ParentID]; exists { parent.Children = append(parent.Children, *category) } } } // 更新根分类的Children并设置HasChildren字段 for i := range rootCategories { if parent, exists := categoryMap[rootCategories[i].ID]; exists { rootCategories[i].Children = parent.Children rootCategories[i].HasChildren = len(parent.Children) > 0 } // 为子分类设置HasChildren字段 for j := range rootCategories[i].Children { rootCategories[i].Children[j].HasChildren = false // 二级分类没有子分类 } } return rootCategories, nil } // GetCategoryByID 根据ID获取分类 func (r *ProductRepository) GetCategoryByID(id uint) (*model.Category, error) { var category model.Category err := r.db.Where("id = ? AND status = ?", id, 1).First(&category).Error return &category, err } // CreateCategory 创建分类 func (r *ProductRepository) CreateCategory(category *model.Category) error { return r.db.Create(category).Error } // UpdateCategory 更新分类 func (r *ProductRepository) UpdateCategory(id uint, updates map[string]interface{}) error { return r.db.Model(&model.Category{}).Where("id = ?", id).Updates(updates).Error } // DeleteCategory 删除分类 func (r *ProductRepository) DeleteCategory(id uint) error { return r.db.Delete(&model.Category{}, id).Error } // GetReviews 获取产品评价列表 func (r *ProductRepository) GetReviews(productID uint, offset, limit int) ([]model.ProductReview, int64, error) { var reviews []model.ProductReview var total int64 query := r.db.Model(&model.ProductReview{}).Where("product_id = ?", productID) // 获取总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 获取列表 err := query.Offset(offset).Limit(limit). Order("created_at DESC").Find(&reviews).Error return reviews, total, err } // CreateReview 创建评价 func (r *ProductRepository) CreateReview(review *model.ProductReview) error { return r.db.Create(review).Error } // GetReviewByOrderID 根据订单ID获取评价 func (r *ProductRepository) GetReviewByOrderID(userID uint, orderID uint) (*model.ProductReview, error) { var review model.ProductReview err := r.db.Where("user_id = ? AND order_id = ?", userID, orderID).First(&review).Error return &review, err } // GetReviewCount 获取产品评价统计 func (r *ProductRepository) GetReviewCount(productID uint) (map[string]interface{}, error) { var total int64 var goodCount int64 // 好评数 (4-5星) var middleCount int64 // 中评数 (3星) var badCount int64 // 差评数 (1-2星) var hasImageCount int64 // 有图评价数 // 获取总评价数 if err := r.db.Model(&model.ProductReview{}).Where("product_id = ?", productID).Count(&total).Error; err != nil { return nil, err } // 获取好评数 (4-5星) if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating >= ?", productID, 4). Count(&goodCount).Error; err != nil { return nil, err } // 获取中评数 (3星) if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating = ?", productID, 3). Count(&middleCount).Error; err != nil { return nil, err } // 获取差评数 (1-2星) if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating <= ?", productID, 2). Count(&badCount).Error; err != nil { return nil, err } // 获取有图评价数 if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND images IS NOT NULL AND images != ''", productID). Count(&hasImageCount).Error; err != nil { return nil, err } // 计算好评率 var goodRate float64 = 100.0 if total > 0 { goodRate = float64(goodCount) / float64(total) * 100 } result := map[string]interface{}{ "total": total, "good_rate": goodRate, "good_count": goodCount, "middle_count": middleCount, "bad_count": badCount, "has_image_count": hasImageCount, } return result, nil } // GetProductImages 获取产品图片 func (r *ProductRepository) GetProductImages(productID uint) ([]model.ProductImage, error) { var images []model.ProductImage err := r.db.Where("product_id = ?", productID).Order("sort ASC").Find(&images).Error return images, err } // CreateProductImage 创建产品图片 func (r *ProductRepository) CreateProductImage(image *model.ProductImage) error { return r.db.Create(image).Error } // DeleteProductImage 删除产品图片 func (r *ProductRepository) DeleteProductImage(id uint) error { return r.db.Delete(&model.ProductImage{}, id).Error } // GetProductSpecs 获取产品规格 func (r *ProductRepository) GetProductSpecs(productID uint) ([]model.ProductSpec, error) { var specs []model.ProductSpec err := r.db.Where("product_id = ?", productID).Order("sort ASC").Find(&specs).Error return specs, err } // CreateProductSpec 创建产品规格 func (r *ProductRepository) CreateProductSpec(spec *model.ProductSpec) error { return r.db.Create(spec).Error } // UpdateProductSpec 更新产品规格 func (r *ProductRepository) UpdateProductSpec(id uint, updates map[string]interface{}) error { return r.db.Model(&model.ProductSpec{}).Where("id = ?", id).Updates(updates).Error } // DeleteProductSpec 删除产品规格 func (r *ProductRepository) DeleteProductSpec(id uint) error { return r.db.Delete(&model.ProductSpec{}, id).Error } // GetHotProducts 获取热门产品 func (r *ProductRepository) GetHotProducts(limit int) ([]model.Product, error) { var products []model.Product err := r.db.Preload("Category"). Where("status = ? AND is_hot = ?", 1, 1). Order("sales DESC, created_at DESC").Limit(limit).Find(&products).Error return products, err } // GetRecommendProducts 获取推荐产品 func (r *ProductRepository) GetRecommendProducts(limit int) ([]model.Product, error) { var products []model.Product err := r.db.Preload("Category"). Where("status = ? AND is_recommend = ?", 1, 1). Order("sort DESC, created_at DESC").Limit(limit).Find(&products).Error return products, err } // GetProductSKUs 获取产品SKU列表 func (r *ProductRepository) GetProductSKUs(productID uint) ([]model.ProductSKU, error) { var skus []model.ProductSKU err := r.db.Where("product_id = ? AND status = ?", productID, 1). Order("weight ASC").Find(&skus).Error return skus, err } // GetSKUByID 根据SKU ID获取SKU详情 func (r *ProductRepository) GetSKUByID(skuID uint) (*model.ProductSKU, error) { var sku model.ProductSKU err := r.db.Where("id = ? AND status = ?", skuID, 1).First(&sku).Error return &sku, err } // GetProductTags 获取产品标签列表 func (r *ProductRepository) GetProductTags() ([]model.ProductTag, error) { var tags []model.ProductTag err := r.db.Where("status = ?", 1).Order("sort ASC").Find(&tags).Error return tags, err } // GetStores 获取店铺列表 func (r *ProductRepository) GetStores() ([]model.Store, error) { var stores []model.Store err := r.db.Where("status = ?", 1).Order("sort ASC").Find(&stores).Error return stores, err } // GetStoreByID 根据ID获取店铺信息 func (r *ProductRepository) GetStoreByID(id uint) (*model.Store, error) { var store model.Store err := r.db.Where("id = ? AND status = ?", id, 1).First(&store).Error return &store, err } // GetProductStatistics 获取产品统计 func (r *ProductRepository) GetProductStatistics() (map[string]interface{}, error) { result := make(map[string]interface{}) // 总商品数 var totalProducts int64 if err := r.db.Model(&model.Product{}).Count(&totalProducts).Error; err != nil { return nil, err } result["total_products"] = totalProducts // 上架商品数 var onlineProducts int64 if err := r.db.Model(&model.Product{}).Where("status = ?", 1).Count(&onlineProducts).Error; err != nil { return nil, err } result["online_products"] = onlineProducts // 下架商品数 var offlineProducts int64 if err := r.db.Model(&model.Product{}).Where("status = ?", 0).Count(&offlineProducts).Error; err != nil { return nil, err } result["offline_products"] = offlineProducts // 库存不足商品数(库存小于10) var lowStockProducts int64 if err := r.db.Model(&model.Product{}).Where("stock < ?", 10).Count(&lowStockProducts).Error; err != nil { return nil, err } result["low_stock_products"] = lowStockProducts return result, nil } // GetCategorySalesStatistics 获取分类销售统计 func (r *ProductRepository) GetCategorySalesStatistics(startDate, endDate string, limit int) ([]map[string]interface{}, error) { var results []map[string]interface{} query := ` SELECT c.id as category_id, c.name as category_name, COALESCE(SUM(oi.quantity), 0) as sales_count, COALESCE(SUM(oi.total_price), 0) as sales_amount FROM ai_categories c LEFT JOIN ai_products p ON c.id = p.category_id LEFT JOIN order_items oi ON p.id = oi.product_id LEFT JOIN ai_orders o ON oi.order_id = o.id WHERE c.status = 1 AND (o.created_at IS NULL OR (o.created_at >= ? AND o.created_at <= ?)) AND (o.status IS NULL OR o.status IN (2, 3, 4, 5, 6)) GROUP BY c.id, c.name ORDER BY sales_amount DESC LIMIT ? ` err := r.db.Raw(query, startDate+" 00:00:00", endDate+" 23:59:59", limit).Scan(&results).Error return results, err } // BatchUpdateStatus 批量更新商品状态 func (r *ProductRepository) BatchUpdateStatus(ids []uint, status int) error { return r.db.Model(&model.Product{}).Where("id IN ?", ids).Update("status", status).Error } // BatchUpdatePrice 批量更新商品价格 func (r *ProductRepository) BatchUpdatePrice(updates []map[string]interface{}) error { tx := r.db.Begin() defer func() { if r := recover(); r != nil { tx.Rollback() } }() for _, update := range updates { if id, ok := update["id"]; ok { delete(update, "id") if err := tx.Model(&model.Product{}).Where("id = ?", id).Updates(update).Error; err != nil { tx.Rollback() return err } } } return tx.Commit().Error } // CountProductsByCategory 统计指定分类下的商品数量(包括子分类) func (r *ProductRepository) CountProductsByCategory(categoryID uint) (int64, error) { var count int64 // 获取所有子分类ID var categoryIDs []uint categoryIDs = append(categoryIDs, categoryID) // 查找所有子分类 var childCategories []model.Category err := r.db.Where("parent_id = ?", categoryID).Find(&childCategories).Error if err != nil { return 0, err } for _, child := range childCategories { categoryIDs = append(categoryIDs, child.ID) } // 统计这些分类下的商品数量(只统计未删除的商品) err = r.db.Model(&model.Product{}). Where("category_id IN (?)", categoryIDs). Count(&count).Error return count, err } // BatchDelete 批量删除商品(软删除) func (r *ProductRepository) BatchDelete(ids []uint) error { if len(ids) == 0 { return nil } // 使用GORM软删除,只删除商品主记录 // 关联记录保留,通过商品的deleted_at字段来判断商品是否被删除 return r.db.Delete(&model.Product{}, ids).Error } // CreateSKU 创建商品SKU func (r *ProductRepository) CreateSKU(sku *model.ProductSKU) error { return r.db.Create(sku).Error } // UpdateSKU 更新商品SKU func (r *ProductRepository) UpdateSKU(id uint, updates map[string]interface{}) error { return r.db.Model(&model.ProductSKU{}).Where("id = ?", id).Updates(updates).Error } // DeleteSKU 删除商品SKU func (r *ProductRepository) DeleteSKU(id uint) error { // 检查SKU是否被订单引用 var count int64 err := r.db.Table("order_items").Where("sk_uid = ?", id).Count(&count).Error if err != nil { return err } if count > 0 { // 如果被订单引用,使用软删除(设置status=0) return r.db.Model(&model.ProductSKU{}).Where("id = ?", id).Update("status", 0).Error } else { // 如果没有被引用,可以进行硬删除 return r.db.Delete(&model.ProductSKU{}, id).Error } } // UpdateProductImageSort 更新商品图片排序 func (r *ProductRepository) UpdateProductImageSort(id uint, sort int) error { return r.db.Model(&model.ProductImage{}).Where("id = ?", id).Update("sort", sort).Error } // CreateProductTag 创建商品标签 func (r *ProductRepository) CreateProductTag(tag *model.ProductTag) error { return r.db.Create(tag).Error } // UpdateProductTag 更新商品标签 func (r *ProductRepository) UpdateProductTag(id uint, updates map[string]interface{}) error { return r.db.Model(&model.ProductTag{}).Where("id = ?", id).Updates(updates).Error } // DeleteProductTag 删除商品标签 func (r *ProductRepository) DeleteProductTag(id uint) error { return r.db.Delete(&model.ProductTag{}, id).Error } // AssignTagsToProduct 为商品分配标签 func (r *ProductRepository) AssignTagsToProduct(productID uint, tagIDs []uint) error { // 先清除现有关联 if err := r.db.Exec("DELETE FROM ai_product_tag_relations WHERE product_id = ?", productID).Error; err != nil { return err } // 添加新关联 if len(tagIDs) > 0 { var relations []map[string]interface{} for _, tagID := range tagIDs { relations = append(relations, map[string]interface{}{ "product_id": productID, "tag_id": tagID, }) } return r.db.Table("ai_product_tag_relations").Create(relations).Error } return nil } // GetLowStockProducts 获取低库存商品 func (r *ProductRepository) GetLowStockProducts(threshold int) ([]model.Product, error) { var products []model.Product err := r.db.Where("stock <= ? AND status = ?", threshold, 1). Preload("Category").Find(&products).Error return products, err } // GetInventoryStatistics 获取库存统计 func (r *ProductRepository) GetInventoryStatistics() (map[string]interface{}, error) { var result map[string]interface{} // 总商品数 var totalProducts int64 r.db.Model(&model.Product{}).Where("status = ?", 1).Count(&totalProducts) // 低库存商品数(库存小于等于10) var lowStockProducts int64 r.db.Model(&model.Product{}).Where("stock <= ? AND status = ?", 10, 1).Count(&lowStockProducts) // 缺货商品数 var outOfStockProducts int64 r.db.Model(&model.Product{}).Where("stock = ? AND status = ?", 0, 1).Count(&outOfStockProducts) // 总库存值 var totalStockValue float64 r.db.Model(&model.Product{}).Where("status = ?", 1). Select("SUM(stock * price)").Scan(&totalStockValue) result = map[string]interface{}{ "total_products": totalProducts, "low_stock_products": lowStockProducts, "out_of_stock_products": outOfStockProducts, "total_stock_value": totalStockValue, } return result, nil } // GetProductsForExport 获取用于导出的商品数据 func (r *ProductRepository) GetProductsForExport(conditions map[string]interface{}) ([]model.Product, error) { var products []model.Product query := r.db.Model(&model.Product{}).Preload("Category") // 添加查询条件 for key, value := range conditions { switch key { case "category_id": // 导出也支持子分类 var catID uint switch v := value.(type) { case uint: catID = v case int: catID = uint(v) case float64: catID = uint(v) } if catID > 0 { categoryIDs, err := r.getCategoryIDsIncludingChildren(catID) if err == nil && len(categoryIDs) > 0 { query = query.Where("category_id IN (?)", categoryIDs) } else { query = query.Where("category_id = ?", catID) } } case "status": query = query.Where("status = ?", value) case "keyword": query = query.Where("name LIKE ? OR description LIKE ?", "%"+value.(string)+"%", "%"+value.(string)+"%") } } err := query.Find(&products).Error return products, err } // getCategoryIDsIncludingChildren 获取给定分类ID及其子分类ID集合(仅支持两级分类) func (r *ProductRepository) getCategoryIDsIncludingChildren(categoryID uint) ([]uint, error) { // 将自身分类加入集合 ids := []uint{categoryID} // 查找直接子分类(系统设计为最多二级) var childCategories []model.Category if err := r.db.Where("parent_id = ? AND status = ?", categoryID, 1).Find(&childCategories).Error; err != nil { return ids, err } for _, child := range childCategories { ids = append(ids, child.ID) } return ids, nil } // DeductSKUStock 扣减SKU库存 func (r *ProductRepository) DeductSKUStock(skuID uint, quantity int) error { // 检查SKU库存是否充足并扣减 result := r.db.Model(&model.ProductSKU{}). Where("id = ? AND stock >= ?", skuID, quantity). Update("stock", gorm.Expr("stock - ?", quantity)) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return gorm.ErrRecordNotFound // SKU库存不足或SKU不存在 } return nil } // RestoreSKUStock 恢复SKU库存 func (r *ProductRepository) RestoreSKUStock(skuID uint, quantity int) error { return r.db.Model(&model.ProductSKU{}).Where("id = ?", skuID). Update("stock", gorm.Expr("stock + ?", quantity)).Error } // SyncProductStockFromSKUs 根据SKU库存同步商品总库存 func (r *ProductRepository) SyncProductStockFromSKUs(productID uint) error { // 计算所有SKU的库存总和 var totalStock int err := r.db.Model(&model.ProductSKU{}). Where("product_id = ? AND status = ?", productID, 1). Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error if err != nil { return err } // 更新商品总库存 return r.db.Model(&model.Product{}).Where("id = ?", productID). Update("stock", totalStock).Error } // DeductStockWithSKU 扣减商品和SKU库存(支持SKU商品的库存扣减) func (r *ProductRepository) DeductStockWithSKU(productID uint, skuID *uint, quantity int) error { tx := r.db.Begin() defer func() { if r := recover(); r != nil { tx.Rollback() } }() if skuID != nil && *skuID > 0 { // 有SKU的商品,先扣减SKU库存 result := tx.Model(&model.ProductSKU{}). Where("id = ? AND stock >= ?", *skuID, quantity). Update("stock", gorm.Expr("stock - ?", quantity)) if result.Error != nil { tx.Rollback() return result.Error } if result.RowsAffected == 0 { tx.Rollback() return gorm.ErrRecordNotFound // SKU库存不足 } // 同步更新商品总库存 var totalStock int err := tx.Model(&model.ProductSKU{}). Where("product_id = ? AND status = ?", productID, 1). Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error if err != nil { tx.Rollback() return err } err = tx.Model(&model.Product{}).Where("id = ?", productID). Update("stock", totalStock).Error if err != nil { tx.Rollback() return err } } else { // 没有SKU的商品,直接扣减商品库存 result := tx.Model(&model.Product{}). Where("id = ? AND stock >= ?", productID, quantity). Update("stock", gorm.Expr("stock - ?", quantity)) if result.Error != nil { tx.Rollback() return result.Error } if result.RowsAffected == 0 { tx.Rollback() return gorm.ErrRecordNotFound // 商品库存不足 } } return tx.Commit().Error } // RestoreStockWithSKU 恢复商品和SKU库存(支持SKU商品的库存恢复) func (r *ProductRepository) RestoreStockWithSKU(productID uint, skuID *uint, quantity int) error { tx := r.db.Begin() defer func() { if r := recover(); r != nil { tx.Rollback() } }() if skuID != nil && *skuID > 0 { // 有SKU的商品,先恢复SKU库存 err := tx.Model(&model.ProductSKU{}).Where("id = ?", *skuID). Update("stock", gorm.Expr("stock + ?", quantity)).Error if err != nil { tx.Rollback() return err } // 同步更新商品总库存 var totalStock int err = tx.Model(&model.ProductSKU{}). Where("product_id = ? AND status = ?", productID, 1). Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error if err != nil { tx.Rollback() return err } err = tx.Model(&model.Product{}).Where("id = ?", productID). Update("stock", totalStock).Error if err != nil { tx.Rollback() return err } } else { // 没有SKU的商品,直接恢复商品库存 err := tx.Model(&model.Product{}).Where("id = ?", productID). Update("stock", gorm.Expr("stock + ?", quantity)).Error if err != nil { tx.Rollback() return err } } return tx.Commit().Error }