Files
2025-11-17 13:32:54 +08:00

905 lines
27 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"dianshang/internal/model"
"gorm.io/gorm"
)
// ProductRepository 产品仓储
type ProductRepository struct {
db *gorm.DB
}
// NewProductRepository 创建产品仓储
func NewProductRepository(db *gorm.DB) *ProductRepository {
return &ProductRepository{db: db}
}
// GetDB 获取数据库连接
func (r *ProductRepository) GetDB() *gorm.DB {
return r.db
}
// GetList 获取产品列表
func (r *ProductRepository) GetList(offset, limit int, conditions map[string]interface{}) ([]model.Product, int64, error) {
var products []model.Product
var total int64
query := r.db.Model(&model.Product{})
// 默认只查询上架商品,除非明确指定状态
if statusValue, exists := conditions["status"]; exists {
if statusValue == "all" {
// 管理系统可以传递 "all" 来获取所有状态的商品
} else {
query = query.Where("status = ?", statusValue)
}
} else {
// 默认只查询上架商品(兼容原有逻辑)
query = query.Where("status = ?", 1)
}
// 添加查询条件
var sortField string
var sortType string
for key, value := range conditions {
switch key {
case "category_id":
// 支持包含子分类的筛选
var catID uint
switch v := value.(type) {
case uint:
catID = v
case int:
catID = uint(v)
case float64:
catID = uint(v)
}
if catID > 0 {
categoryIDs, err := r.getCategoryIDsIncludingChildren(catID)
if err == nil && len(categoryIDs) > 0 {
query = query.Where("category_id IN (?)", categoryIDs)
} else {
// 兜底:如果获取子分类失败,退化为当前分类
query = query.Where("category_id = ?", catID)
}
}
case "keyword":
query = query.Where("name LIKE ? OR description LIKE ?", "%"+value.(string)+"%", "%"+value.(string)+"%")
case "min_price":
query = query.Where("price >= ?", value)
case "max_price":
query = query.Where("price <= ?", value)
case "is_hot":
if value.(string) == "true" {
query = query.Where("is_hot = ?", true)
} else if value.(string) == "false" {
query = query.Where("is_hot = ?", false)
}
case "is_new":
if value.(string) == "true" {
query = query.Where("is_new = ?", true)
} else if value.(string) == "false" {
query = query.Where("is_new = ?", false)
}
case "is_recommend":
if value.(string) == "true" {
query = query.Where("is_recommend = ?", true)
} else if value.(string) == "false" {
query = query.Where("is_recommend = ?", false)
}
case "sort":
sortField = value.(string)
case "sort_type":
sortType = value.(string)
case "status":
// 状态条件已在上面处理,这里跳过
continue
}
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 构建排序条件
orderBy := "sort DESC, created_at DESC" // 默认排序
if sortField != "" {
switch sortField {
case "price":
if sortType == "asc" {
orderBy = "price ASC"
} else {
orderBy = "price DESC"
}
case "sales":
if sortType == "asc" {
orderBy = "sales ASC"
} else {
orderBy = "sales DESC"
}
case "created_at":
if sortType == "asc" {
orderBy = "created_at ASC"
} else {
orderBy = "created_at DESC"
}
}
}
// 获取列表,预加载分类
err := query.Preload("Category").
Offset(offset).Limit(limit).Order(orderBy).Find(&products).Error
return products, total, err
}
// GetByID 根据ID获取产品详情
func (r *ProductRepository) GetByID(id uint) (*model.Product, error) {
var product model.Product
err := r.db.Preload("Category").Preload("Specs").Preload("SKUs", "status = ?", 1).
Where("id = ?", id).First(&product).Error
return &product, err
}
// Create 创建产品
func (r *ProductRepository) Create(product *model.Product) error {
return r.db.Create(product).Error
}
// Update 更新产品
func (r *ProductRepository) Update(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.Product{}).Where("id = ?", id).Updates(updates).Error
}
// Delete 删除产品(软删除)
func (r *ProductRepository) Delete(id uint) error {
// 使用GORM软删除只删除商品主记录
// 关联记录保留通过商品的deleted_at字段来判断商品是否被删除
return r.db.Delete(&model.Product{}, id).Error
}
// UpdateStock 更新库存
func (r *ProductRepository) UpdateStock(id uint, quantity int) error {
return r.db.Model(&model.Product{}).Where("id = ?", id).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
}
// UpdateSales 更新销量
func (r *ProductRepository) UpdateSales(id uint, quantity int) error {
return r.db.Model(&model.Product{}).Where("id = ?", id).
Update("sales", gorm.Expr("sales + ?", quantity)).Error
}
// DeductStock 扣减库存(支付成功时使用)
func (r *ProductRepository) DeductStock(id uint, quantity int) error {
// 检查库存是否充足并扣减
result := r.db.Model(&model.Product{}).
Where("id = ? AND stock >= ?", id, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 库存不足或产品不存在
}
return nil
}
// ReserveStock 预扣库存(创建订单时使用)
func (r *ProductRepository) ReserveStock(id uint, quantity int) error {
// 检查库存是否充足并预扣
result := r.db.Model(&model.Product{}).
Where("id = ? AND stock >= ?", id, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 库存不足或产品不存在
}
return nil
}
// RestoreStock 恢复库存(取消订单或支付失败时使用)
func (r *ProductRepository) RestoreStock(id uint, quantity int) error {
return r.db.Model(&model.Product{}).Where("id = ?", id).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
}
// GetCategories 获取分类列表
func (r *ProductRepository) GetCategories() ([]model.Category, error) {
var allCategories []model.Category
err := r.db.Where("status = ?", 1).Order("level ASC, sort DESC, created_at ASC").Find(&allCategories).Error
if err != nil {
return nil, err
}
// 构建层级关系
categoryMap := make(map[uint]*model.Category)
var rootCategories []model.Category
// 先将所有分类放入map
for i := range allCategories {
categoryMap[allCategories[i].ID] = &allCategories[i]
}
// 构建父子关系
for i := range allCategories {
category := &allCategories[i]
if category.ParentID == nil {
// 一级分类
rootCategories = append(rootCategories, *category)
} else {
// 二级分类添加到父分类的Children中
if parent, exists := categoryMap[*category.ParentID]; exists {
parent.Children = append(parent.Children, *category)
}
}
}
// 更新根分类的Children并设置HasChildren字段
for i := range rootCategories {
if parent, exists := categoryMap[rootCategories[i].ID]; exists {
rootCategories[i].Children = parent.Children
rootCategories[i].HasChildren = len(parent.Children) > 0
}
// 为子分类设置HasChildren字段
for j := range rootCategories[i].Children {
rootCategories[i].Children[j].HasChildren = false // 二级分类没有子分类
}
}
return rootCategories, nil
}
// GetCategoryByID 根据ID获取分类
func (r *ProductRepository) GetCategoryByID(id uint) (*model.Category, error) {
var category model.Category
err := r.db.Where("id = ? AND status = ?", id, 1).First(&category).Error
return &category, err
}
// CreateCategory 创建分类
func (r *ProductRepository) CreateCategory(category *model.Category) error {
return r.db.Create(category).Error
}
// UpdateCategory 更新分类
func (r *ProductRepository) UpdateCategory(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.Category{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteCategory 删除分类
func (r *ProductRepository) DeleteCategory(id uint) error {
return r.db.Delete(&model.Category{}, id).Error
}
// GetReviews 获取产品评价列表
func (r *ProductRepository) GetReviews(productID uint, offset, limit int) ([]model.ProductReview, int64, error) {
var reviews []model.ProductReview
var total int64
query := r.db.Model(&model.ProductReview{}).Where("product_id = ?", productID)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
err := query.Offset(offset).Limit(limit).
Order("created_at DESC").Find(&reviews).Error
return reviews, total, err
}
// CreateReview 创建评价
func (r *ProductRepository) CreateReview(review *model.ProductReview) error {
return r.db.Create(review).Error
}
// GetReviewByOrderID 根据订单ID获取评价
func (r *ProductRepository) GetReviewByOrderID(userID uint, orderID uint) (*model.ProductReview, error) {
var review model.ProductReview
err := r.db.Where("user_id = ? AND order_id = ?", userID, orderID).First(&review).Error
return &review, err
}
// GetReviewCount 获取产品评价统计
func (r *ProductRepository) GetReviewCount(productID uint) (map[string]interface{}, error) {
var total int64
var goodCount int64 // 好评数 (4-5星)
var middleCount int64 // 中评数 (3星)
var badCount int64 // 差评数 (1-2星)
var hasImageCount int64 // 有图评价数
// 获取总评价数
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ?", productID).Count(&total).Error; err != nil {
return nil, err
}
// 获取好评数 (4-5星)
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating >= ?", productID, 4).
Count(&goodCount).Error; err != nil {
return nil, err
}
// 获取中评数 (3星)
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating = ?", productID, 3).
Count(&middleCount).Error; err != nil {
return nil, err
}
// 获取差评数 (1-2星)
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND rating <= ?", productID, 2).
Count(&badCount).Error; err != nil {
return nil, err
}
// 获取有图评价数
if err := r.db.Model(&model.ProductReview{}).Where("product_id = ? AND images IS NOT NULL AND images != ''", productID).
Count(&hasImageCount).Error; err != nil {
return nil, err
}
// 计算好评率
var goodRate float64 = 100.0
if total > 0 {
goodRate = float64(goodCount) / float64(total) * 100
}
result := map[string]interface{}{
"total": total,
"good_rate": goodRate,
"good_count": goodCount,
"middle_count": middleCount,
"bad_count": badCount,
"has_image_count": hasImageCount,
}
return result, nil
}
// GetProductImages 获取产品图片
func (r *ProductRepository) GetProductImages(productID uint) ([]model.ProductImage, error) {
var images []model.ProductImage
err := r.db.Where("product_id = ?", productID).Order("sort ASC").Find(&images).Error
return images, err
}
// CreateProductImage 创建产品图片
func (r *ProductRepository) CreateProductImage(image *model.ProductImage) error {
return r.db.Create(image).Error
}
// DeleteProductImage 删除产品图片
func (r *ProductRepository) DeleteProductImage(id uint) error {
return r.db.Delete(&model.ProductImage{}, id).Error
}
// GetProductSpecs 获取产品规格
func (r *ProductRepository) GetProductSpecs(productID uint) ([]model.ProductSpec, error) {
var specs []model.ProductSpec
err := r.db.Where("product_id = ?", productID).Order("sort ASC").Find(&specs).Error
return specs, err
}
// CreateProductSpec 创建产品规格
func (r *ProductRepository) CreateProductSpec(spec *model.ProductSpec) error {
return r.db.Create(spec).Error
}
// UpdateProductSpec 更新产品规格
func (r *ProductRepository) UpdateProductSpec(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.ProductSpec{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteProductSpec 删除产品规格
func (r *ProductRepository) DeleteProductSpec(id uint) error {
return r.db.Delete(&model.ProductSpec{}, id).Error
}
// GetHotProducts 获取热门产品
func (r *ProductRepository) GetHotProducts(limit int) ([]model.Product, error) {
var products []model.Product
err := r.db.Preload("Category").
Where("status = ? AND is_hot = ?", 1, 1).
Order("sales DESC, created_at DESC").Limit(limit).Find(&products).Error
return products, err
}
// GetRecommendProducts 获取推荐产品
func (r *ProductRepository) GetRecommendProducts(limit int) ([]model.Product, error) {
var products []model.Product
err := r.db.Preload("Category").
Where("status = ? AND is_recommend = ?", 1, 1).
Order("sort DESC, created_at DESC").Limit(limit).Find(&products).Error
return products, err
}
// GetProductSKUs 获取产品SKU列表
func (r *ProductRepository) GetProductSKUs(productID uint) ([]model.ProductSKU, error) {
var skus []model.ProductSKU
err := r.db.Where("product_id = ? AND status = ?", productID, 1).
Order("weight ASC").Find(&skus).Error
return skus, err
}
// GetSKUByID 根据SKU ID获取SKU详情
func (r *ProductRepository) GetSKUByID(skuID uint) (*model.ProductSKU, error) {
var sku model.ProductSKU
err := r.db.Where("id = ? AND status = ?", skuID, 1).First(&sku).Error
return &sku, err
}
// GetProductTags 获取产品标签列表
func (r *ProductRepository) GetProductTags() ([]model.ProductTag, error) {
var tags []model.ProductTag
err := r.db.Where("status = ?", 1).Order("sort ASC").Find(&tags).Error
return tags, err
}
// GetStores 获取店铺列表
func (r *ProductRepository) GetStores() ([]model.Store, error) {
var stores []model.Store
err := r.db.Where("status = ?", 1).Order("sort ASC").Find(&stores).Error
return stores, err
}
// GetStoreByID 根据ID获取店铺信息
func (r *ProductRepository) GetStoreByID(id uint) (*model.Store, error) {
var store model.Store
err := r.db.Where("id = ? AND status = ?", id, 1).First(&store).Error
return &store, err
}
// GetProductStatistics 获取产品统计
func (r *ProductRepository) GetProductStatistics() (map[string]interface{}, error) {
result := make(map[string]interface{})
// 总商品数
var totalProducts int64
if err := r.db.Model(&model.Product{}).Count(&totalProducts).Error; err != nil {
return nil, err
}
result["total_products"] = totalProducts
// 上架商品数
var onlineProducts int64
if err := r.db.Model(&model.Product{}).Where("status = ?", 1).Count(&onlineProducts).Error; err != nil {
return nil, err
}
result["online_products"] = onlineProducts
// 下架商品数
var offlineProducts int64
if err := r.db.Model(&model.Product{}).Where("status = ?", 0).Count(&offlineProducts).Error; err != nil {
return nil, err
}
result["offline_products"] = offlineProducts
// 库存不足商品数库存小于10
var lowStockProducts int64
if err := r.db.Model(&model.Product{}).Where("stock < ?", 10).Count(&lowStockProducts).Error; err != nil {
return nil, err
}
result["low_stock_products"] = lowStockProducts
return result, nil
}
// GetCategorySalesStatistics 获取分类销售统计
func (r *ProductRepository) GetCategorySalesStatistics(startDate, endDate string, limit int) ([]map[string]interface{}, error) {
var results []map[string]interface{}
query := `
SELECT
c.id as category_id,
c.name as category_name,
COALESCE(SUM(oi.quantity), 0) as sales_count,
COALESCE(SUM(oi.total_price), 0) as sales_amount
FROM ai_categories c
LEFT JOIN ai_products p ON c.id = p.category_id
LEFT JOIN order_items oi ON p.id = oi.product_id
LEFT JOIN ai_orders o ON oi.order_id = o.id
WHERE c.status = 1
AND (o.created_at IS NULL OR (o.created_at >= ? AND o.created_at <= ?))
AND (o.status IS NULL OR o.status IN (2, 3, 4, 5, 6))
GROUP BY c.id, c.name
ORDER BY sales_amount DESC
LIMIT ?
`
err := r.db.Raw(query, startDate+" 00:00:00", endDate+" 23:59:59", limit).Scan(&results).Error
return results, err
}
// BatchUpdateStatus 批量更新商品状态
func (r *ProductRepository) BatchUpdateStatus(ids []uint, status int) error {
return r.db.Model(&model.Product{}).Where("id IN ?", ids).Update("status", status).Error
}
// BatchUpdatePrice 批量更新商品价格
func (r *ProductRepository) BatchUpdatePrice(updates []map[string]interface{}) error {
tx := r.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
for _, update := range updates {
if id, ok := update["id"]; ok {
delete(update, "id")
if err := tx.Model(&model.Product{}).Where("id = ?", id).Updates(update).Error; err != nil {
tx.Rollback()
return err
}
}
}
return tx.Commit().Error
}
// CountProductsByCategory 统计指定分类下的商品数量(包括子分类)
func (r *ProductRepository) CountProductsByCategory(categoryID uint) (int64, error) {
var count int64
// 获取所有子分类ID
var categoryIDs []uint
categoryIDs = append(categoryIDs, categoryID)
// 查找所有子分类
var childCategories []model.Category
err := r.db.Where("parent_id = ?", categoryID).Find(&childCategories).Error
if err != nil {
return 0, err
}
for _, child := range childCategories {
categoryIDs = append(categoryIDs, child.ID)
}
// 统计这些分类下的商品数量(只统计未删除的商品)
err = r.db.Model(&model.Product{}).
Where("category_id IN (?)", categoryIDs).
Count(&count).Error
return count, err
}
// BatchDelete 批量删除商品(软删除)
func (r *ProductRepository) BatchDelete(ids []uint) error {
if len(ids) == 0 {
return nil
}
// 使用GORM软删除只删除商品主记录
// 关联记录保留通过商品的deleted_at字段来判断商品是否被删除
return r.db.Delete(&model.Product{}, ids).Error
}
// CreateSKU 创建商品SKU
func (r *ProductRepository) CreateSKU(sku *model.ProductSKU) error {
return r.db.Create(sku).Error
}
// UpdateSKU 更新商品SKU
func (r *ProductRepository) UpdateSKU(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.ProductSKU{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteSKU 删除商品SKU
func (r *ProductRepository) DeleteSKU(id uint) error {
// 检查SKU是否被订单引用
var count int64
err := r.db.Table("order_items").Where("sk_uid = ?", id).Count(&count).Error
if err != nil {
return err
}
if count > 0 {
// 如果被订单引用使用软删除设置status=0
return r.db.Model(&model.ProductSKU{}).Where("id = ?", id).Update("status", 0).Error
} else {
// 如果没有被引用,可以进行硬删除
return r.db.Delete(&model.ProductSKU{}, id).Error
}
}
// UpdateProductImageSort 更新商品图片排序
func (r *ProductRepository) UpdateProductImageSort(id uint, sort int) error {
return r.db.Model(&model.ProductImage{}).Where("id = ?", id).Update("sort", sort).Error
}
// CreateProductTag 创建商品标签
func (r *ProductRepository) CreateProductTag(tag *model.ProductTag) error {
return r.db.Create(tag).Error
}
// UpdateProductTag 更新商品标签
func (r *ProductRepository) UpdateProductTag(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.ProductTag{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteProductTag 删除商品标签
func (r *ProductRepository) DeleteProductTag(id uint) error {
return r.db.Delete(&model.ProductTag{}, id).Error
}
// AssignTagsToProduct 为商品分配标签
func (r *ProductRepository) AssignTagsToProduct(productID uint, tagIDs []uint) error {
// 先清除现有关联
if err := r.db.Exec("DELETE FROM ai_product_tag_relations WHERE product_id = ?", productID).Error; err != nil {
return err
}
// 添加新关联
if len(tagIDs) > 0 {
var relations []map[string]interface{}
for _, tagID := range tagIDs {
relations = append(relations, map[string]interface{}{
"product_id": productID,
"tag_id": tagID,
})
}
return r.db.Table("ai_product_tag_relations").Create(relations).Error
}
return nil
}
// GetLowStockProducts 获取低库存商品
func (r *ProductRepository) GetLowStockProducts(threshold int) ([]model.Product, error) {
var products []model.Product
err := r.db.Where("stock <= ? AND status = ?", threshold, 1).
Preload("Category").Find(&products).Error
return products, err
}
// GetInventoryStatistics 获取库存统计
func (r *ProductRepository) GetInventoryStatistics() (map[string]interface{}, error) {
var result map[string]interface{}
// 总商品数
var totalProducts int64
r.db.Model(&model.Product{}).Where("status = ?", 1).Count(&totalProducts)
// 低库存商品数库存小于等于10
var lowStockProducts int64
r.db.Model(&model.Product{}).Where("stock <= ? AND status = ?", 10, 1).Count(&lowStockProducts)
// 缺货商品数
var outOfStockProducts int64
r.db.Model(&model.Product{}).Where("stock = ? AND status = ?", 0, 1).Count(&outOfStockProducts)
// 总库存值
var totalStockValue float64
r.db.Model(&model.Product{}).Where("status = ?", 1).
Select("SUM(stock * price)").Scan(&totalStockValue)
result = map[string]interface{}{
"total_products": totalProducts,
"low_stock_products": lowStockProducts,
"out_of_stock_products": outOfStockProducts,
"total_stock_value": totalStockValue,
}
return result, nil
}
// GetProductsForExport 获取用于导出的商品数据
func (r *ProductRepository) GetProductsForExport(conditions map[string]interface{}) ([]model.Product, error) {
var products []model.Product
query := r.db.Model(&model.Product{}).Preload("Category")
// 添加查询条件
for key, value := range conditions {
switch key {
case "category_id":
// 导出也支持子分类
var catID uint
switch v := value.(type) {
case uint:
catID = v
case int:
catID = uint(v)
case float64:
catID = uint(v)
}
if catID > 0 {
categoryIDs, err := r.getCategoryIDsIncludingChildren(catID)
if err == nil && len(categoryIDs) > 0 {
query = query.Where("category_id IN (?)", categoryIDs)
} else {
query = query.Where("category_id = ?", catID)
}
}
case "status":
query = query.Where("status = ?", value)
case "keyword":
query = query.Where("name LIKE ? OR description LIKE ?", "%"+value.(string)+"%", "%"+value.(string)+"%")
}
}
err := query.Find(&products).Error
return products, err
}
// getCategoryIDsIncludingChildren 获取给定分类ID及其子分类ID集合仅支持两级分类
func (r *ProductRepository) getCategoryIDsIncludingChildren(categoryID uint) ([]uint, error) {
// 将自身分类加入集合
ids := []uint{categoryID}
// 查找直接子分类(系统设计为最多二级)
var childCategories []model.Category
if err := r.db.Where("parent_id = ? AND status = ?", categoryID, 1).Find(&childCategories).Error; err != nil {
return ids, err
}
for _, child := range childCategories {
ids = append(ids, child.ID)
}
return ids, nil
}
// DeductSKUStock 扣减SKU库存
func (r *ProductRepository) DeductSKUStock(skuID uint, quantity int) error {
// 检查SKU库存是否充足并扣减
result := r.db.Model(&model.ProductSKU{}).
Where("id = ? AND stock >= ?", skuID, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // SKU库存不足或SKU不存在
}
return nil
}
// RestoreSKUStock 恢复SKU库存
func (r *ProductRepository) RestoreSKUStock(skuID uint, quantity int) error {
return r.db.Model(&model.ProductSKU{}).Where("id = ?", skuID).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
}
// SyncProductStockFromSKUs 根据SKU库存同步商品总库存
func (r *ProductRepository) SyncProductStockFromSKUs(productID uint) error {
// 计算所有SKU的库存总和
var totalStock int
err := r.db.Model(&model.ProductSKU{}).
Where("product_id = ? AND status = ?", productID, 1).
Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error
if err != nil {
return err
}
// 更新商品总库存
return r.db.Model(&model.Product{}).Where("id = ?", productID).
Update("stock", totalStock).Error
}
// DeductStockWithSKU 扣减商品和SKU库存支持SKU商品的库存扣减
func (r *ProductRepository) DeductStockWithSKU(productID uint, skuID *uint, quantity int) error {
tx := r.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
if skuID != nil && *skuID > 0 {
// 有SKU的商品先扣减SKU库存
result := tx.Model(&model.ProductSKU{}).
Where("id = ? AND stock >= ?", *skuID, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
tx.Rollback()
return result.Error
}
if result.RowsAffected == 0 {
tx.Rollback()
return gorm.ErrRecordNotFound // SKU库存不足
}
// 同步更新商品总库存
var totalStock int
err := tx.Model(&model.ProductSKU{}).
Where("product_id = ? AND status = ?", productID, 1).
Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error
if err != nil {
tx.Rollback()
return err
}
err = tx.Model(&model.Product{}).Where("id = ?", productID).
Update("stock", totalStock).Error
if err != nil {
tx.Rollback()
return err
}
} else {
// 没有SKU的商品直接扣减商品库存
result := tx.Model(&model.Product{}).
Where("id = ? AND stock >= ?", productID, quantity).
Update("stock", gorm.Expr("stock - ?", quantity))
if result.Error != nil {
tx.Rollback()
return result.Error
}
if result.RowsAffected == 0 {
tx.Rollback()
return gorm.ErrRecordNotFound // 商品库存不足
}
}
return tx.Commit().Error
}
// RestoreStockWithSKU 恢复商品和SKU库存支持SKU商品的库存恢复
func (r *ProductRepository) RestoreStockWithSKU(productID uint, skuID *uint, quantity int) error {
tx := r.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
if skuID != nil && *skuID > 0 {
// 有SKU的商品先恢复SKU库存
err := tx.Model(&model.ProductSKU{}).Where("id = ?", *skuID).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
if err != nil {
tx.Rollback()
return err
}
// 同步更新商品总库存
var totalStock int
err = tx.Model(&model.ProductSKU{}).
Where("product_id = ? AND status = ?", productID, 1).
Select("COALESCE(SUM(stock), 0)").Scan(&totalStock).Error
if err != nil {
tx.Rollback()
return err
}
err = tx.Model(&model.Product{}).Where("id = ?", productID).
Update("stock", totalStock).Error
if err != nil {
tx.Rollback()
return err
}
} else {
// 没有SKU的商品直接恢复商品库存
err := tx.Model(&model.Product{}).Where("id = ?", productID).
Update("stock", gorm.Expr("stock + ?", quantity)).Error
if err != nil {
tx.Rollback()
return err
}
}
return tx.Commit().Error
}