Files
ai_dianshang/server/internal/repository/product.go

905 lines
27 KiB
Go
Raw Normal View History

2025-11-17 14:11:46 +08:00
package repository
import (
"dianshang/internal/model"
"gorm.io/gorm"
)
// ProductRepository 产品仓储
type ProductRepository struct {
db *gorm.DB
}
// NewProductRepository 创建产品仓储
func NewProductRepository(db *gorm.DB) *ProductRepository {
return &ProductRepository{db: db}
}
// GetDB 获取数据库连接
func (r *ProductRepository) GetDB() *gorm.DB {
return r.db
}
// GetList 获取产品列表
func (r *ProductRepository) GetList(offset, limit int, conditions map[string]interface{}) ([]model.Product, int64, error) {
var products []model.Product
var total int64
query := r.db.Model(&model.Product{})
// 默认只查询上架商品,除非明确指定状态
if statusValue, exists := conditions["status"]; exists {
if statusValue == "all" {
// 管理系统可以传递 "all" 来获取所有状态的商品
} else {
query = query.Where("status = ?", statusValue)
}
} else {
// 默认只查询上架商品(兼容原有逻辑)
query = query.Where("status = ?", 1)
}
// 添加查询条件
var sortField string
var sortType string
for key, value := range conditions {
switch key {
case "category_id":
// 支持包含子分类的筛选
var catID uint
switch v := value.(type) {
case uint:
catID = v
case int:
catID = uint(v)
case float64:
catID = uint(v)
}
if catID > 0 {
categoryIDs, err := r.getCategoryIDsIncludingChildren(catID)
if err == nil && len(categoryIDs) > 0 {
query = query.Where("category_id IN (?)", categoryIDs)
} else {
// 兜底:如果获取子分类失败,退化为当前分类
query = query.Where("category_id = ?", catID)
}
}
case "keyword":
query = query.Where("name LIKE ? OR description LIKE ?", "%"+value.(string)+"%", "%"+value.(string)+"%")
case "min_price":
query = query.Where("price >= ?", value)
case "max_price":
query = query.Where("price <= ?", value)
case "is_hot":
if value.(string) == "true" {
query = query.Where("is_hot = ?", true)
} else if value.(string) == "false" {
query = query.Where("is_hot = ?", false)
}
case "is_new":
if value.(string) == "true" {
query = query.Where("is_new = ?", true)
} else if value.(string) == "false" {
query = query.Where("is_new = ?", false)
}
case "is_recommend":
if value.(string) == "true" {
query = query.Where("is_recommend = ?", true)
} else if value.(string) == "false" {
query = query.Where("is_recommend = ?", false)
}
case "sort":
sortField = value.(string)
case "sort_type":
sortType = value.(string)
case "status":
// 状态条件已在上面处理,这里跳过
continue
}
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 构建排序条件
orderBy := "sort DESC, created_at DESC" // 默认排序
if sortField != "" {
switch sortField {
case "price":
if sortType == "asc" {
orderBy = "price ASC"
} else {
orderBy = "price DESC"
}
case "sales":
if sortType == "asc" {
orderBy = "sales ASC"
} else {
orderBy = "sales DESC"
}
case "created_at":
if sortType == "asc" {
orderBy = "created_at ASC"
} else {
orderBy = "created_at DESC"
}
}
}
// 获取列表,预加载分类
err := query.Preload("Category").
Offset(offset).Limit(limit).Order(orderBy).Find(&products).Error
return products, total, err
}
// GetByID 根据ID获取产品详情
func (r *ProductRepository) GetByID(id uint) (*model.Product, error) {
var product model.Product
err := r.db.Preload("Category").Preload("Specs").Preload("SKUs", "status = ?", 1).
Where("id = ?", id).First(&product).Error
return &product, err
}
// Create 创建产品
func (r *ProductRepository) Create(product *model.Product) error {
return r.db.Create(product).Error
}
// Update 更新产品
func (r *ProductRepository) Update(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.Product{}).Where("id = ?", id).Updates(updates).Error
}
// Delete 删除产品(软删除)
func (r *ProductRepository) Delete(id uint) error {
// 使用GORM软删除只删除商品主记录
// 关联记录保留通过商品的deleted_at字段来判断商品是否被删除
return r.db.Delete(&model.Product{}, id).Error
}
// UpdateStock 更新库存
func (r *ProductRepository) UpdateStock(id uint, quantity int) error {
return r.db.Model(&model.Product{}).Where("id = ?", id).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
}
// UpdateSales 更新销量
func (r *ProductRepository) UpdateSales(id uint, quantity int) error {
return r.db.Model(&model.Product{}).Where("id = ?", id).
Update("sales", gorm.Expr("sales + ?", quantity)).Error
}
// DeductStock 扣减库存(支付成功时使用)
func (r *ProductRepository) DeductStock(id uint, quantity int) error {
// 检查库存是否充足并扣减
result := r.db.Model(&model.Product{}).
Where("id = ? AND stock >= ?", id, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 库存不足或产品不存在
}
return nil
}
// ReserveStock 预扣库存(创建订单时使用)
func (r *ProductRepository) ReserveStock(id uint, quantity int) error {
// 检查库存是否充足并预扣
result := r.db.Model(&model.Product{}).
Where("id = ? AND stock >= ?", id, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 库存不足或产品不存在
}
return nil
}
// RestoreStock 恢复库存(取消订单或支付失败时使用)
func (r *ProductRepository) RestoreStock(id uint, quantity int) error {
return r.db.Model(&model.Product{}).Where("id = ?", id).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
}
// GetCategories 获取分类列表
func (r *ProductRepository) GetCategories() ([]model.Category, error) {
var allCategories []model.Category
err := r.db.Where("status = ?", 1).Order("level ASC, sort DESC, created_at ASC").Find(&allCategories).Error
if err != nil {
return nil, err
}
// 构建层级关系
categoryMap := make(map[uint]*model.Category)
var rootCategories []model.Category
// 先将所有分类放入map
for i := range allCategories {
categoryMap[allCategories[i].ID] = &allCategories[i]
}
// 构建父子关系
for i := range allCategories {
category := &allCategories[i]
if category.ParentID == nil {
// 一级分类
rootCategories = append(rootCategories, *category)
} else {
// 二级分类添加到父分类的Children中
if parent, exists := categoryMap[*category.ParentID]; exists {
parent.Children = append(parent.Children, *category)
}
}
}
// 更新根分类的Children并设置HasChildren字段
for i := range rootCategories {
if parent, exists := categoryMap[rootCategories[i].ID]; exists {
rootCategories[i].Children = parent.Children
rootCategories[i].HasChildren = len(parent.Children) > 0
}
// 为子分类设置HasChildren字段
for j := range rootCategories[i].Children {
rootCategories[i].Children[j].HasChildren = false // 二级分类没有子分类
}
}
return rootCategories, nil
}
// GetCategoryByID 根据ID获取分类
func (r *ProductRepository) GetCategoryByID(id uint) (*model.Category, error) {
var category model.Category
err := r.db.Where("id = ? AND status = ?", id, 1).First(&category).Error
return &category, err
}
// CreateCategory 创建分类
func (r *ProductRepository) CreateCategory(category *model.Category) error {
return r.db.Create(category).Error
}
// UpdateCategory 更新分类
func (r *ProductRepository) UpdateCategory(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.Category{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteCategory 删除分类
func (r *ProductRepository) DeleteCategory(id uint) error {
return r.db.Delete(&model.Category{}, id).Error
}
// GetReviews 获取产品评价列表
func (r *ProductRepository) GetReviews(productID uint, offset, limit int) ([]model.ProductReview, int64, error) {
var reviews []model.ProductReview
var total int64
query := r.db.Model(&model.ProductReview{}).Where("product_id = ?", productID)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
err := query.Offset(offset).Limit(limit).
Order("created_at DESC").Find(&reviews).Error
return reviews, total, err
}
// CreateReview 创建评价
func (r *ProductRepository) CreateReview(review *model.ProductReview) error {
return r.db.Create(review).Error
}
// GetReviewByOrderID 根据订单ID获取评价
func (r *ProductRepository) GetReviewByOrderID(userID uint, orderID uint) (*model.ProductReview, error) {
var review model.ProductReview
err := r.db.Where("user_id = ? AND order_id = ?", userID, orderID).First(&review).Error
return &review, err
}
// GetReviewCount 获取产品评价统计
func (r *ProductRepository) GetReviewCount(productID uint) (map[string]interface{}, error) {
var total int64
var goodCount int64 // 好评数 (4-5星)
var middleCount int64 // 中评数 (3星)
var badCount int64 // 差评数 (1-2星)
var hasImageCount int64 // 有图评价数
// 获取总评价数
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ?", productID).Count(&total).Error; err != nil {
return nil, err
}
// 获取好评数 (4-5星)
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating >= ?", productID, 4).
Count(&goodCount).Error; err != nil {
return nil, err
}
// 获取中评数 (3星)
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating = ?", productID, 3).
Count(&middleCount).Error; err != nil {
return nil, err
}
// 获取差评数 (1-2星)
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating <= ?", productID, 2).
Count(&badCount).Error; err != nil {
return nil, err
}
// 获取有图评价数
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND images IS NOT NULL AND images != ''", productID).
Count(&hasImageCount).Error; err != nil {
return nil, err
}
// 计算好评率
var goodRate float64 = 100.0
if total > 0 {
goodRate = float64(goodCount) / float64(total) * 100
}
result := map[string]interface{}{
"total": total,
"good_rate": goodRate,
"good_count": goodCount,
"middle_count": middleCount,
"bad_count": badCount,
"has_image_count": hasImageCount,
}
return result, nil
}
// GetProductImages 获取产品图片
func (r *ProductRepository) GetProductImages(productID uint) ([]model.ProductImage, error) {
var images []model.ProductImage
err := r.db.Where("product_id = ?", productID).Order("sort ASC").Find(&images).Error
return images, err
}
// CreateProductImage 创建产品图片
func (r *ProductRepository) CreateProductImage(image *model.ProductImage) error {
return r.db.Create(image).Error
}
// DeleteProductImage 删除产品图片
func (r *ProductRepository) DeleteProductImage(id uint) error {
return r.db.Delete(&model.ProductImage{}, id).Error
}
// GetProductSpecs 获取产品规格
func (r *ProductRepository) GetProductSpecs(productID uint) ([]model.ProductSpec, error) {
var specs []model.ProductSpec
err := r.db.Where("product_id = ?", productID).Order("sort ASC").Find(&specs).Error
return specs, err
}
// CreateProductSpec 创建产品规格
func (r *ProductRepository) CreateProductSpec(spec *model.ProductSpec) error {
return r.db.Create(spec).Error
}
// UpdateProductSpec 更新产品规格
func (r *ProductRepository) UpdateProductSpec(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.ProductSpec{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteProductSpec 删除产品规格
func (r *ProductRepository) DeleteProductSpec(id uint) error {
return r.db.Delete(&model.ProductSpec{}, id).Error
}
// GetHotProducts 获取热门产品
func (r *ProductRepository) GetHotProducts(limit int) ([]model.Product, error) {
var products []model.Product
err := r.db.Preload("Category").
Where("status = ? AND is_hot = ?", 1, 1).
Order("sales DESC, created_at DESC").Limit(limit).Find(&products).Error
return products, err
}
// GetRecommendProducts 获取推荐产品
func (r *ProductRepository) GetRecommendProducts(limit int) ([]model.Product, error) {
var products []model.Product
err := r.db.Preload("Category").
Where("status = ? AND is_recommend = ?", 1, 1).
Order("sort DESC, created_at DESC").Limit(limit).Find(&products).Error
return products, err
}
// GetProductSKUs 获取产品SKU列表
func (r *ProductRepository) GetProductSKUs(productID uint) ([]model.ProductSKU, error) {
var skus []model.ProductSKU
err := r.db.Where("product_id = ? AND status = ?", productID, 1).
Order("weight ASC").Find(&skus).Error
return skus, err
}
// GetSKUByID 根据SKU ID获取SKU详情
func (r *ProductRepository) GetSKUByID(skuID uint) (*model.ProductSKU, error) {
var sku model.ProductSKU
err := r.db.Where("id = ? AND status = ?", skuID, 1).First(&sku).Error
return &sku, err
}
// GetProductTags 获取产品标签列表
func (r *ProductRepository) GetProductTags() ([]model.ProductTag, error) {
var tags []model.ProductTag
err := r.db.Where("status = ?", 1).Order("sort ASC").Find(&tags).Error
return tags, err
}
// GetStores 获取店铺列表
func (r *ProductRepository) GetStores() ([]model.Store, error) {
var stores []model.Store
err := r.db.Where("status = ?", 1).Order("sort ASC").Find(&stores).Error
return stores, err
}
// GetStoreByID 根据ID获取店铺信息
func (r *ProductRepository) GetStoreByID(id uint) (*model.Store, error) {
var store model.Store
err := r.db.Where("id = ? AND status = ?", id, 1).First(&store).Error
return &store, err
}
// GetProductStatistics 获取产品统计
func (r *ProductRepository) GetProductStatistics() (map[string]interface{}, error) {
result := make(map[string]interface{})
// 总商品数
var totalProducts int64
if err := r.db.Model(&model.Product{}).Count(&totalProducts).Error; err != nil {
return nil, err
}
result["total_products"] = totalProducts
// 上架商品数
var onlineProducts int64
if err := r.db.Model(&model.Product{}).Where("status = ?", 1).Count(&onlineProducts).Error; err != nil {
return nil, err
}
result["online_products"] = onlineProducts
// 下架商品数
var offlineProducts int64
if err := r.db.Model(&model.Product{}).Where("status = ?", 0).Count(&offlineProducts).Error; err != nil {
return nil, err
}
result["offline_products"] = offlineProducts
// 库存不足商品数库存小于10
var lowStockProducts int64
if err := r.db.Model(&model.Product{}).Where("stock < ?", 10).Count(&lowStockProducts).Error; err != nil {
return nil, err
}
result["low_stock_products"] = lowStockProducts
return result, nil
}
// GetCategorySalesStatistics 获取分类销售统计
func (r *ProductRepository) GetCategorySalesStatistics(startDate, endDate string, limit int) ([]map[string]interface{}, error) {
var results []map[string]interface{}
query := `
SELECT
c.id as category_id,
c.name as category_name,
COALESCE(SUM(oi.quantity), 0) as sales_count,
COALESCE(SUM(oi.total_price), 0) as sales_amount
FROM ai_categories c
LEFT JOIN ai_products p ON c.id = p.category_id
LEFT JOIN order_items oi ON p.id = oi.product_id
LEFT JOIN ai_orders o ON oi.order_id = o.id
WHERE c.status = 1
AND (o.created_at IS NULL OR (o.created_at >= ? AND o.created_at <= ?))
AND (o.status IS NULL OR o.status IN (2, 3, 4, 5, 6))
GROUP BY c.id, c.name
ORDER BY sales_amount DESC
LIMIT ?
`
err := r.db.Raw(query, startDate+" 00:00:00", endDate+" 23:59:59", limit).Scan(&results).Error
return results, err
}
// BatchUpdateStatus 批量更新商品状态
func (r *ProductRepository) BatchUpdateStatus(ids []uint, status int) error {
return r.db.Model(&model.Product{}).Where("id IN ?", ids).Update("status", status).Error
}
// BatchUpdatePrice 批量更新商品价格
func (r *ProductRepository) BatchUpdatePrice(updates []map[string]interface{}) error {
tx := r.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
for _, update := range updates {
if id, ok := update["id"]; ok {
delete(update, "id")
if err := tx.Model(&model.Product{}).Where("id = ?", id).Updates(update).Error; err != nil {
tx.Rollback()
return err
}
}
}
return tx.Commit().Error
}
// CountProductsByCategory 统计指定分类下的商品数量(包括子分类)
func (r *ProductRepository) CountProductsByCategory(categoryID uint) (int64, error) {
var count int64
// 获取所有子分类ID
var categoryIDs []uint
categoryIDs = append(categoryIDs, categoryID)
// 查找所有子分类
var childCategories []model.Category
err := r.db.Where("parent_id = ?", categoryID).Find(&childCategories).Error
if err != nil {
return 0, err
}
for _, child := range childCategories {
categoryIDs = append(categoryIDs, child.ID)
}
// 统计这些分类下的商品数量(只统计未删除的商品)
err = r.db.Model(&model.Product{}).
Where("category_id IN (?)", categoryIDs).
Count(&count).Error
return count, err
}
// BatchDelete 批量删除商品(软删除)
func (r *ProductRepository) BatchDelete(ids []uint) error {
if len(ids) == 0 {
return nil
}
// 使用GORM软删除只删除商品主记录
// 关联记录保留通过商品的deleted_at字段来判断商品是否被删除
return r.db.Delete(&model.Product{}, ids).Error
}
// CreateSKU 创建商品SKU
func (r *ProductRepository) CreateSKU(sku *model.ProductSKU) error {
return r.db.Create(sku).Error
}
// UpdateSKU 更新商品SKU
func (r *ProductRepository) UpdateSKU(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.ProductSKU{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteSKU 删除商品SKU
func (r *ProductRepository) DeleteSKU(id uint) error {
// 检查SKU是否被订单引用
var count int64
err := r.db.Table("order_items").Where("sk_uid = ?", id).Count(&count).Error
if err != nil {
return err
}
if count > 0 {
// 如果被订单引用使用软删除设置status=0
return r.db.Model(&model.ProductSKU{}).Where("id = ?", id).Update("status", 0).Error
} else {
// 如果没有被引用,可以进行硬删除
return r.db.Delete(&model.ProductSKU{}, id).Error
}
}
// UpdateProductImageSort 更新商品图片排序
func (r *ProductRepository) UpdateProductImageSort(id uint, sort int) error {
return r.db.Model(&model.ProductImage{}).Where("id = ?", id).Update("sort", sort).Error
}
// CreateProductTag 创建商品标签
func (r *ProductRepository) CreateProductTag(tag *model.ProductTag) error {
return r.db.Create(tag).Error
}
// UpdateProductTag 更新商品标签
func (r *ProductRepository) UpdateProductTag(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.ProductTag{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteProductTag 删除商品标签
func (r *ProductRepository) DeleteProductTag(id uint) error {
return r.db.Delete(&model.ProductTag{}, id).Error
}
// AssignTagsToProduct 为商品分配标签
func (r *ProductRepository) AssignTagsToProduct(productID uint, tagIDs []uint) error {
// 先清除现有关联
if err := r.db.Exec("DELETE FROM ai_product_tag_relations WHERE product_id = ?", productID).Error; err != nil {
return err
}
// 添加新关联
if len(tagIDs) > 0 {
var relations []map[string]interface{}
for _, tagID := range tagIDs {
relations = append(relations, map[string]interface{}{
"product_id": productID,
"tag_id": tagID,
})
}
return r.db.Table("ai_product_tag_relations").Create(relations).Error
}
return nil
}
// GetLowStockProducts 获取低库存商品
func (r *ProductRepository) GetLowStockProducts(threshold int) ([]model.Product, error) {
var products []model.Product
err := r.db.Where("stock <= ? AND status = ?", threshold, 1).
Preload("Category").Find(&products).Error
return products, err
}
// GetInventoryStatistics 获取库存统计
func (r *ProductRepository) GetInventoryStatistics() (map[string]interface{}, error) {
var result map[string]interface{}
// 总商品数
var totalProducts int64
r.db.Model(&model.Product{}).Where("status = ?", 1).Count(&totalProducts)
// 低库存商品数库存小于等于10
var lowStockProducts int64
r.db.Model(&model.Product{}).Where("stock <= ? AND status = ?", 10, 1).Count(&lowStockProducts)
// 缺货商品数
var outOfStockProducts int64
r.db.Model(&model.Product{}).Where("stock = ? AND status = ?", 0, 1).Count(&outOfStockProducts)
// 总库存值
var totalStockValue float64
r.db.Model(&model.Product{}).Where("status = ?", 1).
Select("SUM(stock * price)").Scan(&totalStockValue)
result = map[string]interface{}{
"total_products": totalProducts,
"low_stock_products": lowStockProducts,
"out_of_stock_products": outOfStockProducts,
"total_stock_value": totalStockValue,
}
return result, nil
}
// GetProductsForExport 获取用于导出的商品数据
func (r *ProductRepository) GetProductsForExport(conditions map[string]interface{}) ([]model.Product, error) {
var products []model.Product
query := r.db.Model(&model.Product{}).Preload("Category")
// 添加查询条件
for key, value := range conditions {
switch key {
case "category_id":
// 导出也支持子分类
var catID uint
switch v := value.(type) {
case uint:
catID = v
case int:
catID = uint(v)
case float64:
catID = uint(v)
}
if catID > 0 {
categoryIDs, err := r.getCategoryIDsIncludingChildren(catID)
if err == nil && len(categoryIDs) > 0 {
query = query.Where("category_id IN (?)", categoryIDs)
} else {
query = query.Where("category_id = ?", catID)
}
}
case "status":
query = query.Where("status = ?", value)
case "keyword":
query = query.Where("name LIKE ? OR description LIKE ?", "%"+value.(string)+"%", "%"+value.(string)+"%")
}
}
err := query.Find(&products).Error
return products, err
}
// getCategoryIDsIncludingChildren 获取给定分类ID及其子分类ID集合仅支持两级分类
func (r *ProductRepository) getCategoryIDsIncludingChildren(categoryID uint) ([]uint, error) {
// 将自身分类加入集合
ids := []uint{categoryID}
// 查找直接子分类(系统设计为最多二级)
var childCategories []model.Category
if err := r.db.Where("parent_id = ? AND status = ?", categoryID, 1).Find(&childCategories).Error; err != nil {
return ids, err
}
for _, child := range childCategories {
ids = append(ids, child.ID)
}
return ids, nil
}
// DeductSKUStock 扣减SKU库存
func (r *ProductRepository) DeductSKUStock(skuID uint, quantity int) error {
// 检查SKU库存是否充足并扣减
result := r.db.Model(&model.ProductSKU{}).
Where("id = ? AND stock >= ?", skuID, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // SKU库存不足或SKU不存在
}
return nil
}
// RestoreSKUStock 恢复SKU库存
func (r *ProductRepository) RestoreSKUStock(skuID uint, quantity int) error {
return r.db.Model(&model.ProductSKU{}).Where("id = ?", skuID).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
}
// SyncProductStockFromSKUs 根据SKU库存同步商品总库存
func (r *ProductRepository) SyncProductStockFromSKUs(productID uint) error {
// 计算所有SKU的库存总和
var totalStock int
err := r.db.Model(&model.ProductSKU{}).
Where("product_id = ? AND status = ?", productID, 1).
Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error
if err != nil {
return err
}
// 更新商品总库存
return r.db.Model(&model.Product{}).Where("id = ?", productID).
Update("stock", totalStock).Error
}
// DeductStockWithSKU 扣减商品和SKU库存支持SKU商品的库存扣减
func (r *ProductRepository) DeductStockWithSKU(productID uint, skuID *uint, quantity int) error {
tx := r.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
if skuID != nil && *skuID > 0 {
// 有SKU的商品先扣减SKU库存
result := tx.Model(&model.ProductSKU{}).
Where("id = ? AND stock >= ?", *skuID, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
tx.Rollback()
return result.Error
}
if result.RowsAffected == 0 {
tx.Rollback()
return gorm.ErrRecordNotFound // SKU库存不足
}
// 同步更新商品总库存
var totalStock int
err := tx.Model(&model.ProductSKU{}).
Where("product_id = ? AND status = ?", productID, 1).
Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error
if err != nil {
tx.Rollback()
return err
}
err = tx.Model(&model.Product{}).Where("id = ?", productID).
Update("stock", totalStock).Error
if err != nil {
tx.Rollback()
return err
}
} else {
// 没有SKU的商品直接扣减商品库存
result := tx.Model(&model.Product{}).
Where("id = ? AND stock >= ?", productID, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
tx.Rollback()
return result.Error
}
if result.RowsAffected == 0 {
tx.Rollback()
return gorm.ErrRecordNotFound // 商品库存不足
}
}
return tx.Commit().Error
}
// RestoreStockWithSKU 恢复商品和SKU库存支持SKU商品的库存恢复
func (r *ProductRepository) RestoreStockWithSKU(productID uint, skuID *uint, quantity int) error {
tx := r.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
if skuID != nil && *skuID > 0 {
// 有SKU的商品先恢复SKU库存
err := tx.Model(&model.ProductSKU{}).Where("id = ?", *skuID).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
if err != nil {
tx.Rollback()
return err
}
// 同步更新商品总库存
var totalStock int
err = tx.Model(&model.ProductSKU{}).
Where("product_id = ? AND status = ?", productID, 1).
Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error
if err != nil {
tx.Rollback()
return err
}
err = tx.Model(&model.Product{}).Where("id = ?", productID).
Update("stock", totalStock).Error
if err != nil {
tx.Rollback()
return err
}
} else {
// 没有SKU的商品直接恢复商品库存
err := tx.Model(&model.Product{}).Where("id = ?", productID).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
if err != nil {
tx.Rollback()
return err
}
}
return tx.Commit().Error
}