This commit is contained in:
sjk
2025-11-17 14:11:46 +08:00
commit ad4a600af9
1659 changed files with 171560 additions and 0 deletions

View File

@@ -0,0 +1,226 @@
package service
import (
"dianshang/internal/model"
"dianshang/internal/repository"
"dianshang/pkg/jwt"
"errors"
"time"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// AdminService 管理员服务
type AdminService struct {
adminRepo *repository.AdminRepository
db *gorm.DB
}
// NewAdminService 创建管理员服务
func NewAdminService(db *gorm.DB) *AdminService {
return &AdminService{
adminRepo: repository.NewAdminRepository(db),
db: db,
}
}
// LoginRequest 管理员登录请求
type LoginRequest struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
// CreateAdminRequest 创建管理员请求
type CreateAdminRequest struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required,min=6"`
Nickname string `json:"nickname"`
Email string `json:"email"`
Phone string `json:"phone"`
RoleID uint `json:"role_id" binding:"required"`
}
// UpdateAdminRequest 更新管理员请求
type UpdateAdminRequest struct {
Nickname string `json:"nickname"`
Email string `json:"email"`
Phone string `json:"phone"`
RoleID uint `json:"role_id"`
Status *uint8 `json:"status"`
}
// AdminLoginResponse 管理员登录响应
type AdminLoginResponse struct {
Token string `json:"token"`
AdminUser *model.AdminUser `json:"admin_user"`
}
// Login 管理员登录
func (s *AdminService) Login(req *LoginRequest) (*AdminLoginResponse, error) {
// 查找管理员
admin, err := s.adminRepo.GetByUsername(req.Username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("用户名或密码错误")
}
return nil, err
}
// 检查管理员状态
if admin.Status == 0 {
return nil, errors.New("账户已被禁用")
}
// 验证密码
if err := bcrypt.CompareHashAndPassword([]byte(admin.Password), []byte(req.Password)); err != nil {
return nil, errors.New("用户名或密码错误")
}
// 生成JWT token
tokenExpiry := 8 * 3600 // 8小时有效期
token, err := jwt.GenerateToken(admin.ID, "admin", tokenExpiry)
if err != nil {
return nil, errors.New("生成token失败")
}
// 更新最后登录时间
now := time.Now()
admin.LastLogin = &now
s.adminRepo.Update(admin.ID, map[string]interface{}{
"last_login": now,
})
// 加载角色信息
admin, _ = s.adminRepo.GetByIDWithRole(admin.ID)
return &AdminLoginResponse{
Token: token,
AdminUser: admin,
}, nil
}
// CreateAdmin 创建管理员
func (s *AdminService) CreateAdmin(req *CreateAdminRequest) (*model.AdminUser, error) {
// 检查用户名是否已存在
if _, err := s.adminRepo.GetByUsername(req.Username); err == nil {
return nil, errors.New("用户名已存在")
}
// 加密密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
return nil, errors.New("密码加密失败")
}
// 创建管理员
admin := &model.AdminUser{
Username: req.Username,
Password: string(hashedPassword),
Nickname: req.Nickname,
Email: req.Email,
Phone: req.Phone,
RoleID: req.RoleID,
Status: 1, // 默认启用
}
if err := s.adminRepo.Create(admin); err != nil {
return nil, err
}
// 返回时不包含密码
admin.Password = ""
return admin, nil
}
// GetAdminList 获取管理员列表
func (s *AdminService) GetAdminList(page, pageSize int, keyword string) ([]model.AdminUser, int64, error) {
return s.adminRepo.GetList(page, pageSize, keyword)
}
// GetAdminByID 根据ID获取管理员
func (s *AdminService) GetAdminByID(id uint) (*model.AdminUser, error) {
return s.adminRepo.GetByIDWithRole(id)
}
// UpdateAdmin 更新管理员
func (s *AdminService) UpdateAdmin(id uint, req *UpdateAdminRequest) error {
updates := make(map[string]interface{})
if req.Nickname != "" {
updates["nickname"] = req.Nickname
}
if req.Email != "" {
updates["email"] = req.Email
}
if req.Phone != "" {
updates["phone"] = req.Phone
}
if req.RoleID != 0 {
updates["role_id"] = req.RoleID
}
if req.Status != nil {
updates["status"] = *req.Status
}
return s.adminRepo.Update(id, updates)
}
// DeleteAdmin 删除管理员
func (s *AdminService) DeleteAdmin(id uint) error {
return s.adminRepo.Delete(id)
}
// ChangePassword 修改密码
func (s *AdminService) ChangePassword(id uint, oldPassword, newPassword string) error {
// 获取管理员信息
admin, err := s.adminRepo.GetByID(id)
if err != nil {
return err
}
// 验证旧密码
if err := bcrypt.CompareHashAndPassword([]byte(admin.Password), []byte(oldPassword)); err != nil {
return errors.New("原密码错误")
}
// 加密新密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return errors.New("密码加密失败")
}
// 更新密码
return s.adminRepo.Update(id, map[string]interface{}{
"password": string(hashedPassword),
})
}
// GetProfile 获取管理员个人信息
func (s *AdminService) GetProfile(id uint) (*model.AdminUser, error) {
admin, err := s.adminRepo.GetByIDWithRole(id)
if err != nil {
return nil, err
}
// 不返回密码
admin.Password = ""
return admin, nil
}
// UpdateProfile 更新管理员个人信息
func (s *AdminService) UpdateProfile(id uint, nickname, email, phone string) error {
updates := make(map[string]interface{})
if nickname != "" {
updates["nickname"] = nickname
}
if email != "" {
updates["email"] = email
}
if phone != "" {
updates["phone"] = phone
}
return s.adminRepo.Update(id, updates)
}

View File

@@ -0,0 +1,125 @@
package service
import (
"dianshang/internal/model"
"fmt"
"time"
"gorm.io/gorm"
)
// AfterSaleService 售后服务
type AfterSaleService struct {
db *gorm.DB
}
// NewAfterSaleService 创建售后服务
func NewAfterSaleService(db *gorm.DB) *AfterSaleService {
return &AfterSaleService{
db: db,
}
}
// GetUserAfterSales 获取用户售后列表
func (s *AfterSaleService) GetUserAfterSales(userID uint, page, pageSize int, status int) ([]model.AfterSale, int64, error) {
var afterSales []model.AfterSale
var total int64
query := s.db.Model(&model.AfterSale{}).Where("user_id = ?", userID)
// 如果指定了状态,添加状态过滤
if status > 0 {
query = query.Where("status = ?", status)
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 分页查询,预加载关联数据
offset := (page - 1) * pageSize
if err := query.Preload("Order").Preload("OrderItem").Preload("User").
Order("created_at DESC").
Offset(offset).Limit(pageSize).
Find(&afterSales).Error; err != nil {
return nil, 0, err
}
return afterSales, total, nil
}
// GetAfterSaleDetail 获取售后详情
func (s *AfterSaleService) GetAfterSaleDetail(userID, afterSaleID uint) (*model.AfterSale, error) {
var afterSale model.AfterSale
if err := s.db.Where("id = ? AND user_id = ?", afterSaleID, userID).
Preload("Order").Preload("OrderItem").Preload("User").
First(&afterSale).Error; err != nil {
return nil, err
}
return &afterSale, nil
}
// CreateAfterSale 创建售后申请
func (s *AfterSaleService) CreateAfterSale(userID uint, req *CreateAfterSaleRequest) (*model.AfterSale, error) {
// 验证订单是否存在且属于该用户
var order model.Order
if err := s.db.Where("id = ? AND user_id = ?", req.OrderID, userID).First(&order).Error; err != nil {
return nil, fmt.Errorf("订单不存在或无权限")
}
// 验证订单项是否存在
var orderItem model.OrderItem
if err := s.db.Where("id = ? AND order_id = ?", req.OrderItemID, req.OrderID).First(&orderItem).Error; err != nil {
return nil, fmt.Errorf("订单项不存在")
}
// 创建售后记录
afterSale := &model.AfterSale{
OrderID: req.OrderID,
OrderItemID: req.OrderItemID,
UserID: userID,
Type: req.Type,
Reason: req.Reason,
Description: req.Description,
Images: req.Images,
Status: 1, // 1待审核
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.db.Create(afterSale).Error; err != nil {
return nil, err
}
// 预加载关联数据
if err := s.db.Preload("Order").Preload("OrderItem").Preload("User").
First(afterSale, afterSale.ID).Error; err != nil {
return nil, err
}
return afterSale, nil
}
// CreateAfterSaleRequest 创建售后申请请求
type CreateAfterSaleRequest struct {
OrderID uint `json:"order_id" binding:"required"`
OrderItemID uint `json:"order_item_id" binding:"required"`
Type int `json:"type" binding:"required"` // 1退货2换货3维修
Reason string `json:"reason" binding:"required"`
Description string `json:"description"`
Images model.JSONSlice `json:"images"`
}
// UpdateAfterSaleStatus 更新售后状态
func (s *AfterSaleService) UpdateAfterSaleStatus(afterSaleID uint, status int, adminRemark string) error {
return s.db.Model(&model.AfterSale{}).
Where("id = ?", afterSaleID).
Updates(map[string]interface{}{
"status": status,
"admin_remark": adminRemark,
"updated_at": time.Now(),
}).Error
}

View File

@@ -0,0 +1,319 @@
package service
import (
"dianshang/internal/model"
"dianshang/internal/repository"
"errors"
"fmt"
"time"
)
// BannerService 轮播图服务
type BannerService struct {
bannerRepo *repository.BannerRepository
}
// NewBannerService 创建轮播图服务
func NewBannerService(bannerRepo *repository.BannerRepository) *BannerService {
return &BannerService{
bannerRepo: bannerRepo,
}
}
// GetActiveBanners 获取有效的轮播图
func (s *BannerService) GetActiveBanners() ([]model.Banner, error) {
return s.bannerRepo.GetActiveBannersWithTimeRange()
}
// GetBannerList 获取轮播图列表(分页)
func (s *BannerService) GetBannerList(page, pageSize int, status *int) ([]model.Banner, int64, error) {
if page <= 0 {
page = 1
}
if pageSize <= 0 || pageSize > 100 {
pageSize = 10
}
return s.bannerRepo.GetBannerList(page, pageSize, status)
}
// GetBannerByID 根据ID获取轮播图
func (s *BannerService) GetBannerByID(id uint) (*model.Banner, error) {
if id == 0 {
return nil, errors.New("轮播图ID不能为空")
}
return s.bannerRepo.GetBannerByID(id)
}
// CreateBanner 创建轮播图
func (s *BannerService) CreateBanner(banner *model.Banner) error {
// 验证必填字段
if err := s.validateBanner(banner); err != nil {
return err
}
// 如果没有设置排序值,自动设置为最大值+10
if banner.Sort == 0 {
maxSort, err := s.bannerRepo.GetMaxSort()
if err != nil {
return fmt.Errorf("获取最大排序值失败: %v", err)
}
banner.Sort = maxSort + 10
}
// 设置默认状态
if banner.Status == 0 {
banner.Status = 1
}
return s.bannerRepo.CreateBanner(banner)
}
// UpdateBanner 更新轮播图
func (s *BannerService) UpdateBanner(id uint, banner *model.Banner) error {
if id == 0 {
return errors.New("轮播图ID不能为空")
}
// 检查轮播图是否存在
exists, err := s.bannerRepo.CheckBannerExists(id)
if err != nil {
return fmt.Errorf("检查轮播图是否存在失败: %v", err)
}
if !exists {
return errors.New("轮播图不存在")
}
// 验证必填字段
if err := s.validateBanner(banner); err != nil {
return err
}
banner.ID = id
return s.bannerRepo.UpdateBanner(banner)
}
// DeleteBanner 删除轮播图
func (s *BannerService) DeleteBanner(id uint) error {
if id == 0 {
return errors.New("轮播图ID不能为空")
}
// 检查轮播图是否存在
exists, err := s.bannerRepo.CheckBannerExists(id)
if err != nil {
return fmt.Errorf("检查轮播图是否存在失败: %v", err)
}
if !exists {
return errors.New("轮播图不存在")
}
return s.bannerRepo.DeleteBanner(id)
}
// BatchDeleteBanners 批量删除轮播图
func (s *BannerService) BatchDeleteBanners(ids []uint) error {
if len(ids) == 0 {
return errors.New("轮播图ID列表不能为空")
}
return s.bannerRepo.BatchDeleteBanners(ids)
}
// UpdateBannerStatus 更新轮播图状态
func (s *BannerService) UpdateBannerStatus(id uint, status int) error {
if id == 0 {
return errors.New("轮播图ID不能为空")
}
if status < 0 || status > 1 {
return errors.New("状态值无效只能是0或1")
}
// 检查轮播图是否存在
exists, err := s.bannerRepo.CheckBannerExists(id)
if err != nil {
return fmt.Errorf("检查轮播图是否存在失败: %v", err)
}
if !exists {
return errors.New("轮播图不存在")
}
return s.bannerRepo.UpdateBannerStatus(id, status)
}
// BatchUpdateBannerStatus 批量更新轮播图状态
func (s *BannerService) BatchUpdateBannerStatus(ids []uint, status int) error {
if len(ids) == 0 {
return errors.New("轮播图ID列表不能为空")
}
if status < 0 || status > 1 {
return errors.New("状态值无效只能是0或1")
}
return s.bannerRepo.BatchUpdateBannerStatus(ids, status)
}
// UpdateBannerSort 更新轮播图排序
func (s *BannerService) UpdateBannerSort(id uint, sort int) error {
if id == 0 {
return errors.New("轮播图ID不能为空")
}
if sort < 0 {
return errors.New("排序值不能为负数")
}
// 检查轮播图是否存在
exists, err := s.bannerRepo.CheckBannerExists(id)
if err != nil {
return fmt.Errorf("检查轮播图是否存在失败: %v", err)
}
if !exists {
return errors.New("轮播图不存在")
}
return s.bannerRepo.UpdateBannerSort(id, sort)
}
// BatchUpdateBannerSort 批量更新轮播图排序
func (s *BannerService) BatchUpdateBannerSort(sortData []map[string]interface{}) error {
if len(sortData) == 0 {
return errors.New("排序数据不能为空")
}
// 验证排序数据
for _, data := range sortData {
id, ok := data["id"]
if !ok {
return errors.New("排序数据中缺少ID字段")
}
sort, ok := data["sort"]
if !ok {
return errors.New("排序数据中缺少sort字段")
}
// 类型检查
if _, ok := id.(uint); !ok {
if idFloat, ok := id.(float64); ok {
data["id"] = uint(idFloat)
} else {
return errors.New("ID字段类型错误")
}
}
if _, ok := sort.(int); !ok {
if sortFloat, ok := sort.(float64); ok {
data["sort"] = int(sortFloat)
} else {
return errors.New("sort字段类型错误")
}
}
}
return s.bannerRepo.BatchUpdateBannerSort(sortData)
}
// GetBannersByDateRange 根据日期范围获取轮播图
func (s *BannerService) GetBannersByDateRange(startDate, endDate time.Time) ([]model.Banner, error) {
if startDate.After(endDate) {
return nil, errors.New("开始日期不能晚于结束日期")
}
return s.bannerRepo.GetBannersByDateRange(startDate, endDate)
}
// GetBannersByStatus 根据状态获取轮播图
func (s *BannerService) GetBannersByStatus(status int) ([]model.Banner, error) {
if status < 0 || status > 1 {
return nil, errors.New("状态值无效只能是0或1")
}
return s.bannerRepo.GetBannersByStatus(status)
}
// GetBannerStatistics 获取轮播图统计信息
func (s *BannerService) GetBannerStatistics() (map[string]interface{}, error) {
// 获取总数
total, err := s.bannerRepo.GetBannerCount()
if err != nil {
return nil, fmt.Errorf("获取轮播图总数失败: %v", err)
}
// 获取启用数量
activeCount, err := s.bannerRepo.GetBannerCountByStatus(1)
if err != nil {
return nil, fmt.Errorf("获取启用轮播图数量失败: %v", err)
}
// 获取禁用数量
inactiveCount, err := s.bannerRepo.GetBannerCountByStatus(0)
if err != nil {
return nil, fmt.Errorf("获取禁用轮播图数量失败: %v", err)
}
// 获取过期轮播图
expiredBanners, err := s.bannerRepo.GetExpiredBanners()
if err != nil {
return nil, fmt.Errorf("获取过期轮播图失败: %v", err)
}
return map[string]interface{}{
"total": total,
"active": activeCount,
"inactive": inactiveCount,
"expired": len(expiredBanners),
"expired_list": expiredBanners,
}, nil
}
// CleanExpiredBanners 清理过期轮播图
func (s *BannerService) CleanExpiredBanners() error {
expiredBanners, err := s.bannerRepo.GetExpiredBanners()
if err != nil {
return fmt.Errorf("获取过期轮播图失败: %v", err)
}
if len(expiredBanners) == 0 {
return nil
}
// 将过期轮播图状态设置为禁用
var ids []uint
for _, banner := range expiredBanners {
ids = append(ids, banner.ID)
}
return s.bannerRepo.BatchUpdateBannerStatus(ids, 0)
}
// validateBanner 验证轮播图数据
func (s *BannerService) validateBanner(banner *model.Banner) error {
if banner.Title == "" {
return errors.New("轮播图标题不能为空")
}
if banner.Image == "" {
return errors.New("轮播图图片不能为空")
}
if banner.LinkType < 1 || banner.LinkType > 4 {
return errors.New("链接类型无效只能是1-4")
}
if banner.Sort < 0 {
return errors.New("排序值不能为负数")
}
// 验证时间范围
if banner.StartTime != nil && banner.EndTime != nil {
if banner.StartTime.After(*banner.EndTime) {
return errors.New("开始时间不能晚于结束时间")
}
}
return nil
}

View File

@@ -0,0 +1,661 @@
package service
import (
"dianshang/internal/model"
"dianshang/internal/repository"
"errors"
"fmt"
"time"
"gorm.io/gorm"
)
// CartService 购物车服务
type CartService struct {
orderRepo *repository.OrderRepository
productRepo *repository.ProductRepository
userRepo *repository.UserRepository
}
// NewCartService 创建购物车服务
func NewCartService(orderRepo *repository.OrderRepository, productRepo *repository.ProductRepository, userRepo *repository.UserRepository) *CartService {
return &CartService{
orderRepo: orderRepo,
productRepo: productRepo,
userRepo: userRepo,
}
}
// GetCart 获取购物车
func (s *CartService) GetCart(userID uint) ([]model.Cart, error) {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return nil, errors.New("用户不存在")
}
return s.orderRepo.GetCart(userID)
}
// AddToCart 添加到购物车
func (s *CartService) AddToCart(userID, productID uint, skuID uint, quantity int) error {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return errors.New("用户不存在")
}
// 检查产品是否存在
product, err := s.productRepo.GetByID(productID)
if err != nil {
return errors.New("产品不存在")
}
// 检查产品状态
if product.Status != 1 {
return errors.New("产品已下架")
}
// 检查库存
if product.Stock < quantity {
return errors.New("库存不足")
}
// 检查购物车中是否已存在该商品包括SKU
existingCart, err := s.orderRepo.GetCartItemBySKU(userID, productID, skuID)
if err == nil && existingCart != nil {
// 已存在,更新数量
newQuantity := existingCart.Quantity + quantity
if product.Stock < newQuantity {
return errors.New("库存不足")
}
return s.orderRepo.UpdateCartItem(existingCart.ID, newQuantity)
}
// 如果不是记录不存在的错误,返回错误
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
// 不存在,添加新项
var skuPtr *uint
if skuID != 0 {
skuPtr = &skuID
}
cart := &model.Cart{
UserID: userID,
ProductID: productID,
SKUID: skuPtr,
Quantity: quantity,
Selected: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
return s.orderRepo.AddToCart(cart)
}
// UpdateCartItem 更新购物车项
func (s *CartService) UpdateCartItem(userID, productID uint, quantity int) error {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return errors.New("用户不存在")
}
// 获取该用户该产品的所有购物车项
cartItems, err := s.orderRepo.GetCart(userID)
if err != nil {
return errors.New("获取购物车失败")
}
// 查找匹配的购物车项
var targetCartItem *model.Cart
for _, item := range cartItems {
if item.ProductID == productID {
targetCartItem = &item
break
}
}
if targetCartItem == nil {
return errors.New("购物车项不存在")
}
if quantity == 0 {
// 数量为0删除该项
return s.orderRepo.RemoveFromCart(userID, productID)
}
// 检查产品库存
product, err := s.productRepo.GetByID(productID)
if err != nil {
return errors.New("产品不存在")
}
if product.Stock < quantity {
return errors.New("库存不足")
}
return s.orderRepo.UpdateCartItem(targetCartItem.ID, quantity)
}
// UpdateCartItemBySKU 基于SKU更新购物车项
func (s *CartService) UpdateCartItemBySKU(userID, productID, skuID uint, quantity int) error {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return errors.New("用户不存在")
}
// 使用SKU查找购物车项
cartItem, err := s.orderRepo.GetCartItemBySKU(userID, productID, skuID)
if err != nil {
return errors.New("购物车项不存在")
}
if quantity == 0 {
// 数量为0删除该项
return s.orderRepo.RemoveFromCartBySKU(userID, productID, skuID)
}
// 检查产品库存
product, err := s.productRepo.GetByID(productID)
if err != nil {
return errors.New("产品不存在")
}
if product.Stock < quantity {
return errors.New("库存不足")
}
return s.orderRepo.UpdateCartItem(cartItem.ID, quantity)
}
// RemoveFromCart 从购物车移除
func (s *CartService) RemoveFromCart(userID, productID uint) error {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return errors.New("用户不存在")
}
// 获取该用户该产品的所有购物车项
cartItems, err := s.orderRepo.GetCart(userID)
if err != nil {
return errors.New("获取购物车失败")
}
// 查找匹配的购物车项
var found bool
for _, item := range cartItems {
if item.ProductID == productID {
found = true
break
}
}
if !found {
return errors.New("购物车项不存在")
}
return s.orderRepo.RemoveFromCart(userID, productID)
}
// RemoveFromCartBySKU 基于SKU从购物车移除
func (s *CartService) RemoveFromCartBySKU(userID, productID, skuID uint) error {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return errors.New("用户不存在")
}
// 使用SKU查找购物车项
_, err = s.orderRepo.GetCartItemBySKU(userID, productID, skuID)
if err != nil {
return errors.New("购物车项不存在")
}
return s.orderRepo.RemoveFromCartBySKU(userID, productID, skuID)
}
// ClearCart 清空购物车
func (s *CartService) ClearCart(userID uint) error {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return errors.New("用户不存在")
}
return s.orderRepo.ClearCart(userID)
}
// GetCartCount 获取购物车商品数量
func (s *CartService) GetCartCount(userID uint) (int, error) {
cart, err := s.GetCart(userID)
if err != nil {
return 0, err
}
var count int
for _, item := range cart {
count += int(item.Quantity)
}
return count, nil
}
// GetCartTotal 获取购物车总金额
func (s *CartService) GetCartTotal(userID uint) (float64, error) {
cart, err := s.GetCart(userID)
if err != nil {
return 0, err
}
var total float64
for _, item := range cart {
if item.Product.ID != 0 {
// 将价格从分转换为元
total += (float64(item.Product.Price) / 100) * float64(item.Quantity)
}
}
return total, nil
}
// SelectCartItem 选择/取消选择购物车项
func (s *CartService) SelectCartItem(userID, cartID uint, selected bool) error {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return errors.New("用户不存在")
}
return s.orderRepo.SelectCartItem(userID, cartID, selected)
}
// SelectAllCartItems 全选/取消全选购物车
func (s *CartService) SelectAllCartItems(userID uint, selected bool) error {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return errors.New("用户不存在")
}
return s.orderRepo.SelectAllCartItems(userID, selected)
}
// BatchAddToCart 批量添加到购物车
func (s *CartService) BatchAddToCart(userID uint, items []struct {
ProductID uint `json:"product_id"`
SKUID uint `json:"sku_id"`
Quantity int `json:"quantity"`
}) error {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return errors.New("用户不存在")
}
// 验证所有商品
for _, item := range items {
product, err := s.productRepo.GetByID(item.ProductID)
if err != nil {
return errors.New("商品不存在: " + err.Error())
}
if product.Status != 1 {
return errors.New("商品已下架")
}
if product.Stock < item.Quantity {
return errors.New("商品库存不足")
}
}
// 批量添加
for _, item := range items {
err := s.AddToCart(userID, item.ProductID, item.SKUID, item.Quantity)
if err != nil {
return err
}
}
return nil
}
// BatchRemoveFromCart 批量从购物车移除
func (s *CartService) BatchRemoveFromCart(userID uint, cartIDs []uint) error {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return errors.New("用户不存在")
}
return s.orderRepo.BatchRemoveFromCart(userID, cartIDs)
}
// BatchUpdateCartItems 批量更新购物车项
func (s *CartService) BatchUpdateCartItems(userID uint, updates []struct {
CartID uint `json:"cart_id"`
Quantity int `json:"quantity"`
}) error {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return errors.New("用户不存在")
}
// 获取购物车项并验证
for _, update := range updates {
cartItem, err := s.orderRepo.GetCartItem(userID, update.CartID)
if err != nil {
return errors.New("购物车项不存在")
}
// 检查库存
product, err := s.productRepo.GetByID(cartItem.ProductID)
if err != nil {
return errors.New("商品不存在")
}
if product.Stock < update.Quantity {
return errors.New("商品库存不足")
}
// 更新数量
if update.Quantity == 0 {
err = s.orderRepo.RemoveCartItem(update.CartID)
} else {
err = s.orderRepo.UpdateCartItem(update.CartID, update.Quantity)
}
if err != nil {
return err
}
}
return nil
}
// GetCartWithDetails 获取购物车详细信息(包含商品详情)
func (s *CartService) GetCartWithDetails(userID uint) (map[string]interface{}, error) {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return nil, errors.New("用户不存在")
}
cart, err := s.orderRepo.GetCart(userID)
if err != nil {
return nil, err
}
var validItems []model.Cart
var invalidItems []model.Cart
var totalAmount float64
var totalQuantity int
var selectedAmount float64
var selectedQuantity int
for _, item := range cart {
// 检查商品是否有效
product, err := s.productRepo.GetByID(item.ProductID)
if err != nil || product.Status != 1 {
invalidItems = append(invalidItems, item)
continue
}
// 检查库存
if product.Stock < item.Quantity {
item.Product = *product
invalidItems = append(invalidItems, item)
continue
}
// 计算价格
item.Product = *product
itemPrice := float64(product.Price) / 100 * float64(item.Quantity)
totalAmount += itemPrice
totalQuantity += item.Quantity
if item.Selected {
selectedAmount += itemPrice
selectedQuantity += item.Quantity
}
validItems = append(validItems, item)
}
return map[string]interface{}{
"valid_items": validItems,
"invalid_items": invalidItems,
"total_amount": totalAmount,
"total_quantity": totalQuantity,
"selected_amount": selectedAmount,
"selected_quantity": selectedQuantity,
"valid_count": len(validItems),
"invalid_count": len(invalidItems),
}, nil
}
// ValidateCartItems 验证购物车商品有效性
func (s *CartService) ValidateCartItems(userID uint) (map[string]interface{}, error) {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return nil, errors.New("用户不存在")
}
cart, err := s.orderRepo.GetCart(userID)
if err != nil {
return nil, err
}
var validItems []uint
var invalidItems []struct {
CartID uint `json:"cart_id"`
Reason string `json:"reason"`
}
for _, item := range cart {
// 检查商品是否存在
product, err := s.productRepo.GetByID(item.ProductID)
if err != nil {
invalidItems = append(invalidItems, struct {
CartID uint `json:"cart_id"`
Reason string `json:"reason"`
}{
CartID: item.ID,
Reason: "商品不存在",
})
continue
}
// 检查商品状态
if product.Status != 1 {
invalidItems = append(invalidItems, struct {
CartID uint `json:"cart_id"`
Reason string `json:"reason"`
}{
CartID: item.ID,
Reason: "商品已下架",
})
continue
}
// 检查库存
if product.Stock < item.Quantity {
invalidItems = append(invalidItems, struct {
CartID uint `json:"cart_id"`
Reason string `json:"reason"`
}{
CartID: item.ID,
Reason: "库存不足",
})
continue
}
validItems = append(validItems, item.ID)
}
return map[string]interface{}{
"valid_items": validItems,
"invalid_items": invalidItems,
"valid_count": len(validItems),
"invalid_count": len(invalidItems),
}, nil
}
// CleanInvalidCartItems 清理无效的购物车项
func (s *CartService) CleanInvalidCartItems(userID uint) error {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return errors.New("用户不存在")
}
validation, err := s.ValidateCartItems(userID)
if err != nil {
return err
}
invalidItems := validation["invalid_items"].([]struct {
CartID uint `json:"cart_id"`
Reason string `json:"reason"`
})
if len(invalidItems) == 0 {
return nil
}
var cartIDs []uint
for _, item := range invalidItems {
cartIDs = append(cartIDs, item.CartID)
}
return s.orderRepo.BatchRemoveFromCart(userID, cartIDs)
}
// GetCartSummary 获取购物车摘要信息
func (s *CartService) GetCartSummary(userID uint) (map[string]interface{}, error) {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return nil, errors.New("用户不存在")
}
details, err := s.GetCartWithDetails(userID)
if err != nil {
return nil, err
}
return map[string]interface{}{
"total_items": details["valid_count"].(int) + details["invalid_count"].(int),
"valid_items": details["valid_count"],
"invalid_items": details["invalid_count"],
"total_amount": details["total_amount"],
"selected_amount": details["selected_amount"],
"total_quantity": details["total_quantity"],
"selected_quantity": details["selected_quantity"],
}, nil
}
// MergeCart 合并购物车(用于登录后合并游客购物车)
func (s *CartService) MergeCart(userID uint, guestCartItems []struct {
ProductID uint `json:"product_id"`
SKUID uint `json:"sku_id"`
Quantity int `json:"quantity"`
}) error {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return errors.New("用户不存在")
}
// 获取用户现有购物车
existingCart, err := s.orderRepo.GetCart(userID)
if err != nil {
return err
}
// 创建现有购物车的映射
existingMap := make(map[string]*model.Cart)
for i, item := range existingCart {
key := fmt.Sprintf("%d_%d", item.ProductID, item.SKUID)
if item.SKUID == nil {
key = fmt.Sprintf("%d_0", item.ProductID)
}
existingMap[key] = &existingCart[i]
}
// 合并游客购物车项
for _, guestItem := range guestCartItems {
key := fmt.Sprintf("%d_%d", guestItem.ProductID, guestItem.SKUID)
if guestItem.SKUID == 0 {
key = fmt.Sprintf("%d_0", guestItem.ProductID)
}
if existingItem, exists := existingMap[key]; exists {
// 已存在,更新数量
newQuantity := existingItem.Quantity + guestItem.Quantity
err = s.orderRepo.UpdateCartItem(existingItem.ID, newQuantity)
if err != nil {
return err
}
} else {
// 不存在,添加新项
err = s.AddToCart(userID, guestItem.ProductID, guestItem.SKUID, guestItem.Quantity)
if err != nil {
return err
}
}
}
return nil
}
// GetSelectedCartItems 获取选中的购物车项
func (s *CartService) GetSelectedCartItems(userID uint) ([]model.Cart, error) {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return nil, errors.New("用户不存在")
}
return s.orderRepo.GetSelectedCartItems(userID)
}
// CalculateCartDiscount 计算购物车优惠(预留接口)
func (s *CartService) CalculateCartDiscount(userID uint, couponID uint) (map[string]interface{}, error) {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return nil, errors.New("用户不存在")
}
selectedItems, err := s.GetSelectedCartItems(userID)
if err != nil {
return nil, err
}
var originalAmount float64
for _, item := range selectedItems {
product, err := s.productRepo.GetByID(item.ProductID)
if err != nil {
continue
}
originalAmount += float64(product.Price) / 100 * float64(item.Quantity)
}
// 这里可以添加优惠券计算逻辑
discountAmount := 0.0
finalAmount := originalAmount - discountAmount
return map[string]interface{}{
"original_amount": originalAmount,
"discount_amount": discountAmount,
"final_amount": finalAmount,
"coupon_id": couponID,
}, nil
}

View File

@@ -0,0 +1,202 @@
package service
import (
"encoding/json"
"fmt"
"dianshang/internal/model"
"dianshang/internal/repository"
)
type CommentService struct {
commentRepo *repository.CommentRepository
orderRepo *repository.OrderRepository
productRepo *repository.ProductRepository
}
func NewCommentService(commentRepo *repository.CommentRepository, orderRepo *repository.OrderRepository, productRepo *repository.ProductRepository) *CommentService {
return &CommentService{
commentRepo: commentRepo,
orderRepo: orderRepo,
productRepo: productRepo,
}
}
// CreateCommentRequest 创建评论请求
type CreateCommentRequest struct {
OrderItemID uint `json:"order_item_id" binding:"required"`
Rating int `json:"rating" binding:"required,min=1,max=5"`
Content string `json:"content"`
Images []string `json:"images"`
IsAnonymous bool `json:"is_anonymous"`
}
// CreateComment 创建评论
func (s *CommentService) CreateComment(userID uint, req *CreateCommentRequest) (*model.Comment, error) {
// 1. 验证订单项是否存在且属于该用户
orderItem, err := s.orderRepo.GetOrderItemByID(req.OrderItemID)
if err != nil {
return nil, fmt.Errorf("订单项不存在")
}
// 获取订单信息验证用户权限
order, err := s.orderRepo.GetByID(orderItem.OrderID)
if err != nil {
return nil, fmt.Errorf("订单不存在")
}
if order.UserID != userID {
return nil, fmt.Errorf("无权限评论此商品")
}
// 2. 验证订单状态(只有已完成的订单才能评论)
if order.Status != model.OrderStatusCompleted {
return nil, fmt.Errorf("订单未完成,无法评论")
}
// 3. 检查是否已经评论过
if orderItem.IsCommented {
return nil, fmt.Errorf("该商品已经评论过了")
}
// 4. 处理图片数据
var imagesJSON string
if len(req.Images) > 0 {
imagesBytes, _ := json.Marshal(req.Images)
imagesJSON = string(imagesBytes)
}
// 5. 创建评论
comment := &model.Comment{
UserID: userID,
ProductID: orderItem.ProductID,
OrderID: orderItem.OrderID,
OrderItemID: req.OrderItemID,
Rating: req.Rating,
Content: req.Content,
Images: imagesJSON,
IsAnonymous: req.IsAnonymous,
Status: 1,
}
if err := s.commentRepo.Create(comment); err != nil {
return nil, fmt.Errorf("创建评论失败: %v", err)
}
// 6. 更新订单项评论状态
orderItem.IsCommented = true
if err := s.orderRepo.SaveOrderItem(orderItem); err != nil {
return nil, fmt.Errorf("更新订单项状态失败: %v", err)
}
// 7. 更新商品评论统计
if err := s.commentRepo.UpdateProductStats(orderItem.ProductID); err != nil {
return nil, fmt.Errorf("更新商品统计失败: %v", err)
}
return comment, nil
}
// GetProductComments 获取商品评论列表
func (s *CommentService) GetProductComments(productID uint, page, pageSize int, rating int) ([]model.Comment, int64, error) {
offset := (page - 1) * pageSize
return s.commentRepo.GetByProductID(productID, offset, pageSize, rating)
}
// GetUserComments 获取用户评论列表
func (s *CommentService) GetUserComments(userID uint, page, pageSize int) ([]model.Comment, int64, error) {
offset := (page - 1) * pageSize
return s.commentRepo.GetByUserID(userID, offset, pageSize)
}
// GetCommentStats 获取商品评论统计
func (s *CommentService) GetCommentStats(productID uint) (*model.CommentStats, error) {
return s.commentRepo.GetStats(productID)
}
// GetCommentByID 获取评论详情
func (s *CommentService) GetCommentByID(id uint) (*model.Comment, error) {
return s.commentRepo.GetByID(id)
}
// CreateReplyRequest 创建回复请求
type CreateReplyRequest struct {
CommentID uint `json:"comment_id" binding:"required"`
Content string `json:"content" binding:"required"`
}
// CreateReply 创建评论回复
func (s *CommentService) CreateReply(userID uint, req *CreateReplyRequest, isAdmin bool) (*model.CommentReply, error) {
// 验证评论是否存在
comment, err := s.commentRepo.GetByID(req.CommentID)
if err != nil {
return nil, fmt.Errorf("评论不存在")
}
if comment.Status != 1 {
return nil, fmt.Errorf("评论状态异常,无法回复")
}
// 创建回复
reply := &model.CommentReply{
CommentID: req.CommentID,
UserID: userID,
Content: req.Content,
IsAdmin: isAdmin,
Status: 1,
}
if err := s.commentRepo.CreateReply(reply); err != nil {
return nil, fmt.Errorf("创建回复失败: %v", err)
}
return reply, nil
}
// LikeComment 点赞评论
func (s *CommentService) LikeComment(commentID, userID uint) error {
return s.commentRepo.LikeComment(commentID, userID)
}
// UnlikeComment 取消点赞评论
func (s *CommentService) UnlikeComment(commentID, userID uint) error {
return s.commentRepo.UnlikeComment(commentID, userID)
}
// GetCommentList 获取评论列表(管理端)
func (s *CommentService) GetCommentList(page, pageSize int, conditions map[string]interface{}) ([]model.Comment, int64, error) {
offset := (page - 1) * pageSize
return s.commentRepo.GetList(offset, pageSize, conditions)
}
// UpdateCommentStatus 更新评论状态(管理端)
func (s *CommentService) UpdateCommentStatus(id uint, status int) error {
comment, err := s.commentRepo.GetByID(id)
if err != nil {
return fmt.Errorf("评论不存在")
}
comment.Status = status
if err := s.commentRepo.Update(comment); err != nil {
return fmt.Errorf("更新评论状态失败: %v", err)
}
// 如果是隐藏或删除评论,需要更新商品统计
if status != 1 {
if err := s.commentRepo.UpdateProductStats(comment.ProductID); err != nil {
return fmt.Errorf("更新商品统计失败: %v", err)
}
}
return nil
}
// DeleteComment 删除评论(管理端)
func (s *CommentService) DeleteComment(id uint) error {
return s.UpdateCommentStatus(id, 3)
}
// GetUncommentedOrderItems 获取用户未评论的订单项
func (s *CommentService) GetUncommentedOrderItems(userID uint) ([]model.OrderItem, error) {
// 获取用户已完成但未评论的订单项
return s.orderRepo.GetUncommentedOrderItems(userID)
}

View File

@@ -0,0 +1,226 @@
package service
import (
"dianshang/internal/model"
"dianshang/internal/repository"
"errors"
"fmt"
"time"
)
// CouponService 优惠券服务
type CouponService struct {
couponRepo *repository.CouponRepository
}
// NewCouponService 创建优惠券服务
func NewCouponService(couponRepo *repository.CouponRepository) *CouponService {
return &CouponService{
couponRepo: couponRepo,
}
}
// GetAvailableCoupons 获取可用优惠券列表
func (s *CouponService) GetAvailableCoupons() ([]model.Coupon, error) {
return s.couponRepo.GetAvailableCoupons()
}
// GetAvailableCouponsWithUserStatus 获取可用优惠券列表(包含用户已领取状态)
func (s *CouponService) GetAvailableCouponsWithUserStatus(userID uint) ([]map[string]interface{}, error) {
// 获取所有可用优惠券
coupons, err := s.couponRepo.GetAvailableCoupons()
if err != nil {
return nil, err
}
var result []map[string]interface{}
// 如果用户已登录,检查每个优惠券的领取状态
for _, coupon := range coupons {
couponData := map[string]interface{}{
"id": coupon.ID,
"name": coupon.Name,
"type": coupon.Type,
"value": coupon.Value,
"min_amount": coupon.MinAmount,
"description": coupon.Description,
"start_time": coupon.StartTime,
"end_time": coupon.EndTime,
"total_count": coupon.TotalCount,
"used_count": coupon.UsedCount,
"is_received": false, // 默认未领取
}
// 如果用户已登录,检查是否已领取
if userID > 0 {
exists, err := s.couponRepo.CheckUserCouponExists(userID, coupon.ID)
if err == nil {
couponData["is_received"] = exists
}
}
result = append(result, couponData)
}
return result, nil
}
// GetUserCoupons 获取用户优惠券
func (s *CouponService) GetUserCoupons(userID uint, status int) ([]model.UserCoupon, error) {
return s.couponRepo.GetUserCoupons(userID, status)
}
// ReceiveCoupon 领取优惠券
func (s *CouponService) ReceiveCoupon(userID, couponID uint) error {
// 检查优惠券是否存在且有效
coupon, err := s.couponRepo.GetByID(couponID)
if err != nil {
return errors.New("优惠券不存在")
}
// 检查是否在有效期内
now := time.Now()
if now.Before(coupon.StartTime) || now.After(coupon.EndTime) {
return errors.New("优惠券不在有效期内")
}
// 检查是否还有库存
if coupon.TotalCount > 0 && coupon.UsedCount >= coupon.TotalCount {
return errors.New("优惠券已被领完")
}
// 检查用户是否已经领取过
exists, err := s.couponRepo.CheckUserCouponExists(userID, couponID)
if err != nil {
return err
}
if exists {
return errors.New("您已经领取过该优惠券")
}
// 创建用户优惠券记录
userCoupon := &model.UserCoupon{
UserID: userID,
CouponID: couponID,
Status: 0, // 未使用
}
return s.couponRepo.CreateUserCoupon(userCoupon)
}
// UseCoupon 使用优惠券
func (s *CouponService) UseCoupon(userID, userCouponID, orderID uint) error {
// 获取用户优惠券
userCoupon, err := s.couponRepo.GetUserCouponByID(userCouponID)
if err != nil {
return errors.New("优惠券不存在")
}
// 检查是否属于该用户
if userCoupon.UserID != userID {
return errors.New("无权使用该优惠券")
}
// 检查是否已使用
if userCoupon.Status != 0 {
return errors.New("优惠券已使用或已过期")
}
// 检查优惠券是否在有效期内
now := time.Now()
if now.Before(userCoupon.Coupon.StartTime) || now.After(userCoupon.Coupon.EndTime) {
return errors.New("优惠券不在有效期内")
}
// 更新优惠券状态为已使用
return s.couponRepo.UseCoupon(userCouponID, orderID)
}
// ValidateCoupon 验证优惠券是否可用
func (s *CouponService) ValidateCoupon(userID, userCouponID uint, orderAmount float64) (*model.UserCoupon, float64, error) {
// 获取用户优惠券
userCoupon, err := s.couponRepo.GetUserCouponByID(userCouponID)
if err != nil {
return nil, 0, errors.New("优惠券不存在")
}
// 检查是否属于该用户
if userCoupon.UserID != userID {
return nil, 0, errors.New("无权使用该优惠券")
}
// 检查是否已使用
if userCoupon.Status != 0 {
return nil, 0, errors.New("优惠券已使用或已过期")
}
// 检查优惠券是否在有效期内
now := time.Now()
if now.Before(userCoupon.Coupon.StartTime) || now.After(userCoupon.Coupon.EndTime) {
return nil, 0, errors.New("优惠券不在有效期内")
}
// 检查最低消费金额
minAmount := float64(userCoupon.Coupon.MinAmount) / 100 // 分转元
if orderAmount < minAmount {
return nil, 0, errors.New(fmt.Sprintf("订单金额不满足优惠券使用条件,最低需要%.2f元", minAmount))
}
// 计算优惠金额
var discountAmount float64
switch userCoupon.Coupon.Type {
case 1: // 满减券
discountAmount = float64(userCoupon.Coupon.Value) / 100 // 分转元
case 2: // 折扣券
discountRate := float64(userCoupon.Coupon.Value) / 100 // 85 -> 0.85
discountAmount = orderAmount * (1 - discountRate)
case 3: // 免邮券
discountAmount = 0 // 免邮券的优惠金额在运费中体现
default:
return nil, 0, errors.New("不支持的优惠券类型")
}
// 确保优惠金额不超过订单金额
if discountAmount > orderAmount {
discountAmount = orderAmount
}
return userCoupon, discountAmount, nil
}
// GetAvailableCouponsForOrder 获取订单可用的优惠券
func (s *CouponService) GetAvailableCouponsForOrder(userID uint, orderAmount float64) ([]model.UserCoupon, error) {
// 获取用户未使用的优惠券
userCoupons, err := s.couponRepo.GetUserCoupons(userID, 1) // 1表示未使用(API状态值)
if err != nil {
return nil, err
}
var availableCoupons []model.UserCoupon
now := time.Now()
for _, userCoupon := range userCoupons {
// 严格检查优惠券状态:必须是未使用状态(0)且没有关联订单
if userCoupon.Status != 0 || userCoupon.OrderID != nil {
continue
}
// 检查是否在有效期内
if now.Before(userCoupon.Coupon.StartTime) || now.After(userCoupon.Coupon.EndTime) {
continue
}
// 检查优惠券模板是否可用
if userCoupon.Coupon.Status != 1 {
continue
}
// 检查最低消费金额
minAmount := float64(userCoupon.Coupon.MinAmount) / 100 // 分转元
if orderAmount >= minAmount {
availableCoupons = append(availableCoupons, userCoupon)
}
}
return availableCoupons, nil
}

View File

@@ -0,0 +1,405 @@
package service
import (
"dianshang/internal/model"
"time"
"gorm.io/gorm"
)
// LogService 日志服务
type LogService struct {
db *gorm.DB
}
// NewLogService 创建日志服务
func NewLogService(db *gorm.DB) *LogService {
return &LogService{
db: db,
}
}
// CreateLoginLog 创建登录日志
func (s *LogService) CreateLoginLog(userID uint, ip, userAgent string, status int, remark string) error {
log := &model.UserLoginLog{
UserID: userID,
LoginIP: ip,
UserAgent: userAgent,
LoginTime: time.Now(),
Status: status,
Remark: remark,
}
return s.db.Create(log).Error
}
// GetUserLoginLogs 获取用户登录日志
func (s *LogService) GetUserLoginLogs(userID uint, page, pageSize int) ([]model.UserLoginLog, map[string]interface{}, error) {
var logs []model.UserLoginLog
var total int64
query := s.db.Model(&model.UserLoginLog{}).Where("user_id = ?", userID)
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, nil, err
}
// 分页查询
offset := (page - 1) * pageSize
err = query.Offset(offset).Limit(pageSize).Order("login_time DESC").Find(&logs).Error
if err != nil {
return nil, nil, err
}
// 构建分页信息
pagination := map[string]interface{}{
"total": total,
"page": page,
"page_size": pageSize,
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
}
return logs, pagination, nil
}
// GetLoginLogList 获取登录日志列表(管理后台)
func (s *LogService) GetLoginLogList(page, pageSize int, conditions map[string]interface{}) ([]model.UserLoginLog, map[string]interface{}, error) {
var logs []model.UserLoginLog
var total int64
query := s.db.Model(&model.UserLoginLog{}).Preload("User")
// 应用查询条件
if userID, ok := conditions["user_id"]; ok && userID != "" {
query = query.Where("user_id = ?", userID)
}
if ip, ok := conditions["ip"]; ok && ip != "" {
query = query.Where("login_ip LIKE ?", "%"+ip.(string)+"%")
}
if status, ok := conditions["status"]; ok && status != "" {
query = query.Where("status = ?", status)
}
if startDate, ok := conditions["start_date"]; ok && startDate != "" {
query = query.Where("login_time >= ?", startDate)
}
if endDate, ok := conditions["end_date"]; ok && endDate != "" {
query = query.Where("login_time <= ?", endDate)
}
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, nil, err
}
// 分页查询
offset := (page - 1) * pageSize
err = query.Offset(offset).Limit(pageSize).Order("login_time DESC").Find(&logs).Error
if err != nil {
return nil, nil, err
}
// 构建分页信息
pagination := map[string]interface{}{
"total": total,
"page": page,
"page_size": pageSize,
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
}
return logs, pagination, nil
}
// CreateOperationLog 创建操作日志
func (s *LogService) CreateOperationLog(userID uint, module, action, description, ip, userAgent, requestData string) error {
log := &model.UserOperationLog{
UserID: userID,
Module: module,
Action: action,
Description: description,
IP: ip,
UserAgent: userAgent,
RequestData: requestData,
CreatedAt: time.Now(),
}
return s.db.Create(log).Error
}
// GetUserOperationLogs 获取用户操作日志
func (s *LogService) GetUserOperationLogs(userID uint, page, pageSize int) ([]model.UserOperationLog, map[string]interface{}, error) {
var logs []model.UserOperationLog
var total int64
query := s.db.Model(&model.UserOperationLog{}).Where("user_id = ?", userID)
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, nil, err
}
// 分页查询
offset := (page - 1) * pageSize
err = query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&logs).Error
if err != nil {
return nil, nil, err
}
// 构建分页信息
pagination := map[string]interface{}{
"total": total,
"page": page,
"page_size": pageSize,
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
}
return logs, pagination, nil
}
// GetOperationLogList 获取操作日志列表(管理后台)
func (s *LogService) GetOperationLogList(page, pageSize int, conditions map[string]interface{}) ([]model.UserOperationLog, map[string]interface{}, error) {
var logs []model.UserOperationLog
var total int64
query := s.db.Model(&model.UserOperationLog{}).Preload("User")
// 应用查询条件
if userID, ok := conditions["user_id"]; ok && userID != "" {
query = query.Where("user_id = ?", userID)
}
if module, ok := conditions["module"]; ok && module != "" {
query = query.Where("module = ?", module)
}
if action, ok := conditions["action"]; ok && action != "" {
query = query.Where("action = ?", action)
}
if ip, ok := conditions["ip"]; ok && ip != "" {
query = query.Where("ip LIKE ?", "%"+ip.(string)+"%")
}
if startDate, ok := conditions["start_date"]; ok && startDate != "" {
query = query.Where("created_at >= ?", startDate)
}
if endDate, ok := conditions["end_date"]; ok && endDate != "" {
query = query.Where("created_at <= ?", endDate)
}
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, nil, err
}
// 分页查询
offset := (page - 1) * pageSize
err = query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&logs).Error
if err != nil {
return nil, nil, err
}
// 构建分页信息
pagination := map[string]interface{}{
"total": total,
"page": page,
"page_size": pageSize,
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
}
return logs, pagination, nil
}
// GetLoginStatistics 获取登录统计
func (s *LogService) GetLoginStatistics(startDate, endDate string) (map[string]interface{}, error) {
result := make(map[string]interface{})
// 总登录次数
var totalLogins int64
query := s.db.Model(&model.UserLoginLog{})
if startDate != "" && endDate != "" {
query = query.Where("DATE(login_time) BETWEEN ? AND ?", startDate, endDate)
}
query.Count(&totalLogins)
result["total_logins"] = totalLogins
// 成功登录次数
var successLogins int64
query = s.db.Model(&model.UserLoginLog{}).Where("status = 1")
if startDate != "" && endDate != "" {
query = query.Where("DATE(login_time) BETWEEN ? AND ?", startDate, endDate)
}
query.Count(&successLogins)
result["success_logins"] = successLogins
// 失败登录次数
var failedLogins int64
query = s.db.Model(&model.UserLoginLog{}).Where("status = 0")
if startDate != "" && endDate != "" {
query = query.Where("DATE(login_time) BETWEEN ? AND ?", startDate, endDate)
}
query.Count(&failedLogins)
result["failed_logins"] = failedLogins
// 独立用户数
var uniqueUsers int64
query = s.db.Model(&model.UserLoginLog{}).Distinct("user_id")
if startDate != "" && endDate != "" {
query = query.Where("DATE(login_time) BETWEEN ? AND ?", startDate, endDate)
}
query.Count(&uniqueUsers)
result["unique_users"] = uniqueUsers
return result, nil
}
// GetOperationStatistics 获取操作统计
func (s *LogService) GetOperationStatistics(startDate, endDate string) (map[string]interface{}, error) {
result := make(map[string]interface{})
// 总操作次数
var totalOperations int64
query := s.db.Model(&model.UserOperationLog{})
if startDate != "" && endDate != "" {
query = query.Where("DATE(created_at) BETWEEN ? AND ?", startDate, endDate)
}
query.Count(&totalOperations)
result["total_operations"] = totalOperations
// 按模块统计
var moduleStats []struct {
Module string `json:"module"`
Count int64 `json:"count"`
}
query = s.db.Model(&model.UserOperationLog{}).Select("module, COUNT(*) as count").Group("module")
if startDate != "" && endDate != "" {
query = query.Where("DATE(created_at) BETWEEN ? AND ?", startDate, endDate)
}
query.Scan(&moduleStats)
result["module_stats"] = moduleStats
// 按操作统计
var actionStats []struct {
Action string `json:"action"`
Count int64 `json:"count"`
}
query = s.db.Model(&model.UserOperationLog{}).Select("action, COUNT(*) as count").Group("action")
if startDate != "" && endDate != "" {
query = query.Where("DATE(created_at) BETWEEN ? AND ?", startDate, endDate)
}
query.Scan(&actionStats)
result["action_stats"] = actionStats
return result, nil
}
// CleanOldLogs 清理旧日志
func (s *LogService) CleanOldLogs(days int) error {
cutoffDate := time.Now().AddDate(0, 0, -days)
// 清理登录日志
if err := s.db.Where("login_time < ?", cutoffDate).Delete(&model.UserLoginLog{}).Error; err != nil {
return err
}
// 清理操作日志
if err := s.db.Where("created_at < ?", cutoffDate).Delete(&model.UserOperationLog{}).Error; err != nil {
return err
}
return nil
}
// GetLoginLogByID 根据ID获取登录日志
func (s *LogService) GetLoginLogByID(id uint) (*model.UserLoginLog, error) {
var log model.UserLoginLog
err := s.db.Preload("User").First(&log, id).Error
if err != nil {
return nil, err
}
return &log, nil
}
// GetTodayLoginCount 获取今日登录次数
func (s *LogService) GetTodayLoginCount() (int64, error) {
var count int64
today := time.Now().Format("2006-01-02")
err := s.db.Model(&model.UserLoginLog{}).Where("DATE(login_time) = ?", today).Count(&count).Error
return count, err
}
// GetTodayOperationCount 获取今日操作次数
func (s *LogService) GetTodayOperationCount() (int64, error) {
var count int64
today := time.Now().Format("2006-01-02")
err := s.db.Model(&model.UserOperationLog{}).Where("DATE(created_at) = ?", today).Count(&count).Error
return count, err
}
// GetOnlineUserCount 获取在线用户数最近30分钟有登录记录的用户
func (s *LogService) GetOnlineUserCount() (int64, error) {
var count int64
thirtyMinutesAgo := time.Now().Add(-30 * time.Minute)
err := s.db.Model(&model.UserLoginLog{}).
Where("login_time >= ? AND status = 1", thirtyMinutesAgo).
Distinct("user_id").
Count(&count).Error
return count, err
}
// GetLoginTrend 获取登录趋势
func (s *LogService) GetLoginTrend(days int) ([]map[string]interface{}, error) {
var results []map[string]interface{}
for i := days - 1; i >= 0; i-- {
date := time.Now().AddDate(0, 0, -i).Format("2006-01-02")
var count int64
s.db.Model(&model.UserLoginLog{}).
Where("DATE(login_time) = ?", date).
Count(&count)
results = append(results, map[string]interface{}{
"date": date,
"count": count,
})
}
return results, nil
}
// GetOperationTrend 获取操作趋势
func (s *LogService) GetOperationTrend(days int) ([]map[string]interface{}, error) {
var results []map[string]interface{}
for i := days - 1; i >= 0; i-- {
date := time.Now().AddDate(0, 0, -i).Format("2006-01-02")
var count int64
s.db.Model(&model.UserOperationLog{}).
Where("DATE(created_at) = ?", date).
Count(&count)
results = append(results, map[string]interface{}{
"date": date,
"count": count,
})
}
return results, nil
}
// GetOperationLogByID 根据ID获取操作日志
func (s *LogService) GetOperationLogByID(id uint) (*model.UserOperationLog, error) {
var log model.UserOperationLog
err := s.db.Preload("User").First(&log, id).Error
if err != nil {
return nil, err
}
return &log, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,350 @@
package service
import (
"dianshang/internal/model"
"dianshang/internal/repository"
"errors"
"fmt"
"time"
"gorm.io/gorm"
)
// PointsService 积分服务
type PointsService struct {
pointsRepo *repository.PointsRepository
db *gorm.DB
}
// NewPointsService 创建积分服务
func NewPointsService(pointsRepo *repository.PointsRepository, db *gorm.DB) *PointsService {
return &PointsService{
pointsRepo: pointsRepo,
db: db,
}
}
// GetUserPoints 获取用户积分
func (s *PointsService) GetUserPoints(userID uint) (int, error) {
return s.pointsRepo.GetUserPoints(userID)
}
// AddPoints 增加用户积分
func (s *PointsService) AddPoints(userID uint, amount int, description string, orderID *uint, orderNo, productName string) error {
if amount <= 0 {
return errors.New("积分数量必须大于0")
}
// 开启事务
tx := s.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 获取当前积分
currentPoints, err := s.pointsRepo.GetUserPoints(userID)
if err != nil {
tx.Rollback()
return err
}
// 更新用户积分
newPoints := currentPoints + amount
err = s.pointsRepo.UpdateUserPoints(userID, newPoints)
if err != nil {
tx.Rollback()
return err
}
// 创建积分历史记录
history := &model.PointsHistory{
UserID: userID,
Type: 1, // 获得
Points: amount,
Description: description,
OrderID: orderID,
OrderNo: orderNo,
ProductName: productName,
}
err = s.pointsRepo.CreatePointsHistory(history)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
// DeductPoints 扣减用户积分
func (s *PointsService) DeductPoints(userID uint, amount int, description string, orderID *uint, orderNo, productName string) error {
if amount <= 0 {
return errors.New("积分数量必须大于0")
}
// 开启事务
tx := s.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 获取当前积分
currentPoints, err := s.pointsRepo.GetUserPoints(userID)
if err != nil {
tx.Rollback()
return err
}
// 检查积分是否足够
if currentPoints < amount {
tx.Rollback()
return errors.New("积分不足")
}
// 更新用户积分
newPoints := currentPoints - amount
err = s.pointsRepo.UpdateUserPoints(userID, newPoints)
if err != nil {
tx.Rollback()
return err
}
// 创建积分历史记录
history := &model.PointsHistory{
UserID: userID,
Type: 2, // 消费
Points: amount,
Description: description,
OrderID: orderID,
OrderNo: orderNo,
ProductName: productName,
}
err = s.pointsRepo.CreatePointsHistory(history)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
// GetPointsHistory 获取积分历史记录
func (s *PointsService) GetPointsHistory(userID uint, page, pageSize int) ([]model.PointsHistory, map[string]interface{}, error) {
if page <= 0 {
page = 1
}
if pageSize <= 0 || pageSize > 100 {
pageSize = 20
}
histories, total, err := s.pointsRepo.GetPointsHistory(userID, page, pageSize)
if err != nil {
return nil, nil, err
}
// 构建分页信息
totalPages := (int(total) + pageSize - 1) / pageSize
pagination := map[string]interface{}{
"total": total,
"page": page,
"page_size": pageSize,
"total_pages": totalPages,
}
return histories, pagination, nil
}
// GetPointsRules 获取积分规则列表
func (s *PointsService) GetPointsRules() ([]model.PointsRule, error) {
return s.pointsRepo.GetPointsRules()
}
// GetPointsExchangeList 获取积分兑换商品列表
func (s *PointsService) GetPointsExchangeList() ([]model.PointsExchange, error) {
return s.pointsRepo.GetPointsExchangeList()
}
// ExchangePoints 积分兑换
func (s *PointsService) ExchangePoints(userID, exchangeID uint) error {
// 开启事务
tx := s.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 获取兑换商品信息
exchange, err := s.pointsRepo.GetPointsExchangeByID(exchangeID)
if err != nil {
tx.Rollback()
return errors.New("兑换商品不存在")
}
// 检查库存
if exchange.Stock > 0 && exchange.ExchangeCount >= exchange.Stock {
tx.Rollback()
return errors.New("商品库存不足")
}
// 获取用户当前积分
currentPoints, err := s.pointsRepo.GetUserPoints(userID)
if err != nil {
tx.Rollback()
return err
}
// 检查积分是否足够
if currentPoints < exchange.Points {
tx.Rollback()
return errors.New("积分不足")
}
// 扣减积分
description := fmt.Sprintf("兑换商品:%s", exchange.Name)
err = s.DeductPoints(userID, exchange.Points, description, nil, "", exchange.Name)
if err != nil {
tx.Rollback()
return err
}
// 创建兑换记录
record := &model.PointsExchangeRecord{
UserID: userID,
PointsExchangeID: exchangeID,
Points: exchange.Points,
Status: 1, // 已兑换
}
err = s.pointsRepo.CreatePointsExchangeRecord(record)
if err != nil {
tx.Rollback()
return err
}
// 更新兑换次数
err = s.pointsRepo.UpdatePointsExchangeCount(exchangeID)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
// GetUserExchangeRecords 获取用户兑换记录
func (s *PointsService) GetUserExchangeRecords(userID uint, page, pageSize int) ([]model.PointsExchangeRecord, map[string]interface{}, error) {
if page <= 0 {
page = 1
}
if pageSize <= 0 || pageSize > 100 {
pageSize = 20
}
records, total, err := s.pointsRepo.GetUserExchangeRecords(userID, page, pageSize)
if err != nil {
return nil, nil, err
}
// 构建分页信息
totalPages := (int(total) + pageSize - 1) / pageSize
pagination := map[string]interface{}{
"total": total,
"page": page,
"page_size": pageSize,
"total_pages": totalPages,
}
return records, pagination, nil
}
// GetPointsOverview 获取积分概览
func (s *PointsService) GetPointsOverview(userID uint) (map[string]interface{}, error) {
// 获取用户当前积分
currentPoints, err := s.pointsRepo.GetUserPoints(userID)
if err != nil {
return nil, err
}
// 获取积分历史统计
var totalEarned, totalSpent, thisMonthEarned, thisMonthSpent int64
// 统计总获得积分
s.db.Model(&model.PointsHistory{}).
Where("user_id = ? AND type = ?", userID, 1).
Select("COALESCE(SUM(points), 0)").
Scan(&totalEarned)
// 统计总消费积分
s.db.Model(&model.PointsHistory{}).
Where("user_id = ? AND type = ?", userID, 2).
Select("COALESCE(SUM(points), 0)").
Scan(&totalSpent)
// 获取本月的开始时间和结束时间
now := time.Now()
firstDayOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
firstDayOfNextMonth := firstDayOfMonth.AddDate(0, 1, 0)
// 统计本月获得积分
s.db.Model(&model.PointsHistory{}).
Where("user_id = ? AND type = ? AND created_at >= ? AND created_at < ?",
userID, 1, firstDayOfMonth, firstDayOfNextMonth).
Select("COALESCE(SUM(points), 0)").
Scan(&thisMonthEarned)
// 统计本月消费积分
s.db.Model(&model.PointsHistory{}).
Where("user_id = ? AND type = ? AND created_at >= ? AND created_at < ?",
userID, 2, firstDayOfMonth, firstDayOfNextMonth).
Select("COALESCE(SUM(points), 0)").
Scan(&thisMonthSpent)
overview := map[string]interface{}{
"total_points": currentPoints,
"available_points": currentPoints,
"frozen_points": 0,
"total_earned": totalEarned,
"total_spent": totalSpent,
"this_month_earned": thisMonthEarned,
"this_month_spent": thisMonthSpent,
}
return overview, nil
}
// CheckAndGiveDailyLoginPoints 检查并给予每日首次登录积分
func (s *PointsService) CheckAndGiveDailyLoginPoints(userID uint) (bool, error) {
// 获取今天的开始时间00:00:00
now := time.Now()
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
tomorrow := today.Add(24 * time.Hour)
// 检查今天是否已经有登录积分记录
var count int64
err := s.db.Model(&model.PointsHistory{}).
Where("user_id = ? AND description = ? AND created_at >= ? AND created_at < ?",
userID, "每日首次登录", today, tomorrow).
Count(&count).Error
if err != nil {
return false, fmt.Errorf("检查每日登录记录失败: %v", err)
}
// 如果今天已经有登录积分记录,则不再给予
if count > 0 {
return false, nil
}
// 给予每日登录积分1积分
err = s.AddPoints(userID, 1, "每日首次登录", nil, "", "")
if err != nil {
return false, fmt.Errorf("给予每日登录积分失败: %v", err)
}
return true, nil
}

View File

@@ -0,0 +1,869 @@
package service
import (
"dianshang/internal/model"
"dianshang/internal/repository"
"dianshang/pkg/utils"
"errors"
"fmt"
"strconv"
"time"
)
// ProductService 产品服务
type ProductService struct {
productRepo *repository.ProductRepository
userRepo *repository.UserRepository
}
// NewProductService 创建产品服务
func NewProductService(productRepo *repository.ProductRepository, userRepo *repository.UserRepository) *ProductService {
return &ProductService{
productRepo: productRepo,
userRepo: userRepo,
}
}
// GetProductList 获取产品列表(前端用户)
func (s *ProductService) GetProductList(page, pageSize int, categoryID uint, keyword string, minPrice, maxPrice float64, sort, sortType string) ([]model.Product, *utils.Pagination, error) {
if page <= 0 {
page = 1
}
if pageSize <= 0 || pageSize > 100 {
pageSize = 20
}
offset := (page - 1) * pageSize
conditions := make(map[string]interface{})
if categoryID > 0 {
conditions["category_id"] = categoryID
}
if keyword != "" {
conditions["keyword"] = keyword
}
if minPrice > 0 {
conditions["min_price"] = minPrice
}
if maxPrice > 0 {
conditions["max_price"] = maxPrice
}
if sort != "" {
conditions["sort"] = sort
}
if sortType != "" {
conditions["sort_type"] = sortType
}
products, total, err := s.productRepo.GetList(offset, pageSize, conditions)
if err != nil {
return nil, nil, err
}
pagination := utils.NewPagination(page, pageSize)
pagination.Total = total
return products, pagination, nil
}
// GetProductListForAdmin 获取产品列表(管理系统)
func (s *ProductService) GetProductListForAdmin(page, pageSize int, categoryID uint, keyword string, minPrice, maxPrice float64, sort, sortType, status, isHot, isNew, isRecommend string) ([]model.Product, *utils.Pagination, error) {
if page <= 0 {
page = 1
}
if pageSize <= 0 || pageSize > 100 {
pageSize = 20
}
offset := (page - 1) * pageSize
conditions := make(map[string]interface{})
if categoryID > 0 {
conditions["category_id"] = categoryID
}
if keyword != "" {
conditions["keyword"] = keyword
}
if minPrice > 0 {
conditions["min_price"] = minPrice
}
if maxPrice > 0 {
conditions["max_price"] = maxPrice
}
if sort != "" {
conditions["sort"] = sort
}
if sortType != "" {
conditions["sort_type"] = sortType
}
// 添加状态条件,支持获取所有状态的商品
if status != "" {
conditions["status"] = status
}
// 添加热门、新品、推荐筛选条件
if isHot != "" {
conditions["is_hot"] = isHot
}
if isNew != "" {
conditions["is_new"] = isNew
}
if isRecommend != "" {
conditions["is_recommend"] = isRecommend
}
products, total, err := s.productRepo.GetList(offset, pageSize, conditions)
if err != nil {
return nil, nil, err
}
pagination := utils.NewPagination(page, pageSize)
pagination.Total = total
return products, pagination, nil
}
// GetProductDetail 获取产品详情
func (s *ProductService) GetProductDetail(id uint) (*model.Product, error) {
return s.productRepo.GetByID(id)
}
// CreateProduct 创建产品
func (s *ProductService) CreateProduct(product *model.Product) error {
// 验证分类是否存在
if product.CategoryID > 0 {
_, err := s.productRepo.GetCategoryByID(product.CategoryID)
if err != nil {
return errors.New("分类不存在")
}
}
return s.productRepo.Create(product)
}
// UpdateProduct 更新产品
func (s *ProductService) UpdateProduct(id uint, updates map[string]interface{}) error {
// 检查产品是否存在
_, err := s.productRepo.GetByID(id)
if err != nil {
return errors.New("产品不存在")
}
// 如果更新分类,验证分类是否存在
if categoryID, ok := updates["category_id"]; ok {
var catID uint
switch v := categoryID.(type) {
case uint:
catID = v
case float64:
catID = uint(v)
case int:
catID = uint(v)
default:
return errors.New("分类ID格式错误")
}
if catID > 0 {
_, err := s.productRepo.GetCategoryByID(catID)
if err != nil {
return errors.New("分类不存在")
}
}
}
// 处理 detail_images 字段 - 确保正确转换为 JSONSlice 类型
if detailImages, ok := updates["detail_images"]; ok {
switch v := detailImages.(type) {
case []interface{}:
// 将 []interface{} 转换为 []string
var stringSlice []string
for _, item := range v {
if str, ok := item.(string); ok {
stringSlice = append(stringSlice, str)
}
}
updates["detail_images"] = model.JSONSlice(stringSlice)
case []string:
updates["detail_images"] = model.JSONSlice(v)
}
}
// 处理 images 字段 - 确保正确转换为 JSONSlice 类型
if images, ok := updates["images"]; ok {
switch v := images.(type) {
case []interface{}:
// 将 []interface{} 转换为 []string
var stringSlice []string
for _, item := range v {
if str, ok := item.(string); ok {
stringSlice = append(stringSlice, str)
}
}
updates["images"] = model.JSONSlice(stringSlice)
case []string:
updates["images"] = model.JSONSlice(v)
}
}
// 处理SKU数据
var skusData []interface{}
if skus, ok := updates["skus"]; ok {
skusData, _ = skus.([]interface{})
// 从updates中移除skus避免直接更新到Product表
delete(updates, "skus")
}
// 更新商品基本信息
if err := s.productRepo.Update(id, updates); err != nil {
return err
}
// 处理SKU数据
if len(skusData) > 0 {
if err := s.handleProductSKUs(id, skusData); err != nil {
return err
}
}
return nil
}
// handleProductSKUs 处理商品SKU数据
func (s *ProductService) handleProductSKUs(productID uint, skusData []interface{}) error {
// 获取当前商品的所有现有SKU
existingSKUs, err := s.productRepo.GetProductSKUs(productID)
if err != nil {
return err
}
// 收集前端发送的SKU ID列表
var submittedSKUIDs []uint
// 处理前端发送的SKU数据
for _, skuData := range skusData {
skuMap, ok := skuData.(map[string]interface{})
if !ok {
continue
}
// 创建SKU对象
sku := &model.ProductSKU{
ProductID: productID,
}
// 处理SKU字段
if skuCode, ok := skuMap["sku_code"].(string); ok {
sku.SKUCode = skuCode
}
if price, ok := skuMap["price"].(float64); ok {
sku.Price = price
}
if stock, ok := skuMap["stock"]; ok {
switch v := stock.(type) {
case float64:
sku.Stock = int(v)
case int:
sku.Stock = v
case string:
if stockInt, err := strconv.Atoi(v); err == nil {
sku.Stock = stockInt
}
}
}
// 处理spec_values
if specValues, ok := skuMap["spec_values"]; ok {
if specMap, ok := specValues.(map[string]interface{}); ok {
sku.SpecValues = model.JSONMap(specMap)
}
}
// 处理image字段
if image, ok := skuMap["image"].(string); ok {
sku.Image = image
}
// 检查是否是更新还是创建
var isUpdate bool
var skuIDValue uint
if skuID, ok := skuMap["id"]; ok && skuID != nil {
switch v := skuID.(type) {
case float64:
if v > 0 {
isUpdate = true
skuIDValue = uint(v)
submittedSKUIDs = append(submittedSKUIDs, skuIDValue)
}
case int:
if v > 0 {
isUpdate = true
skuIDValue = uint(v)
submittedSKUIDs = append(submittedSKUIDs, skuIDValue)
}
}
}
if isUpdate {
// 更新现有SKU
updates := make(map[string]interface{})
if sku.SKUCode != "" {
updates["sku_code"] = sku.SKUCode
}
updates["price"] = sku.Price
updates["stock"] = sku.Stock
// 直接传递JSONMap类型让GORM处理序列化
updates["spec_values"] = sku.SpecValues
// 添加image字段的更新
if sku.Image != "" {
updates["image"] = sku.Image
}
if err := s.productRepo.UpdateSKU(skuIDValue, updates); err != nil {
return err
}
} else {
// 创建新SKU - 确保不设置ID字段
sku.ID = 0 // 明确设置为0让数据库自动生成
if sku.SKUCode == "" {
// 生成默认SKU代码
sku.SKUCode = fmt.Sprintf("SKU-%d-%d", productID, time.Now().Unix())
}
if err := s.productRepo.CreateSKU(sku); err != nil {
return err
}
}
}
// 删除不在前端提交列表中的现有SKU
for _, existingSKU := range existingSKUs {
shouldDelete := true
for _, submittedID := range submittedSKUIDs {
if existingSKU.ID == submittedID {
shouldDelete = false
break
}
}
if shouldDelete {
if err := s.productRepo.DeleteSKU(existingSKU.ID); err != nil {
return fmt.Errorf("删除SKU失败 - SKU ID: %d, 错误: %v", existingSKU.ID, err)
}
}
}
// 处理完所有SKU后同步商品库存
if err := s.productRepo.SyncProductStockFromSKUs(productID); err != nil {
// 记录错误但不阻止操作
fmt.Printf("同步商品库存失败 - 商品ID: %d, 错误: %v\n", productID, err)
}
return nil
}
// DeleteProduct 删除产品
func (s *ProductService) DeleteProduct(id uint) error {
// 检查产品是否存在
_, err := s.productRepo.GetByID(id)
if err != nil {
return errors.New("产品不存在")
}
return s.productRepo.Delete(id)
}
// GetCategories 获取分类列表
func (s *ProductService) GetCategories() ([]model.Category, error) {
return s.productRepo.GetCategories()
}
// CreateCategory 创建分类
func (s *ProductService) CreateCategory(category *model.Category) error {
return s.productRepo.CreateCategory(category)
}
// UpdateCategory 更新分类
func (s *ProductService) UpdateCategory(id uint, updates map[string]interface{}) error {
// 检查分类是否存在
_, err := s.productRepo.GetCategoryByID(id)
if err != nil {
return errors.New("分类不存在")
}
return s.productRepo.UpdateCategory(id, updates)
}
// DeleteCategory 删除分类
func (s *ProductService) DeleteCategory(id uint) error {
// 检查分类是否存在
_, err := s.productRepo.GetCategoryByID(id)
if err != nil {
return errors.New("分类不存在")
}
// 检查分类下是否有商品
productCount, err := s.productRepo.CountProductsByCategory(id)
if err != nil {
return errors.New("检查分类商品数量失败")
}
if productCount > 0 {
return errors.New("该分类下还有商品,无法删除")
}
// 检查是否有子分类
var childCategories []model.Category
err = s.productRepo.GetDB().Where("parent_id = ?", id).Find(&childCategories).Error
if err != nil {
return errors.New("检查子分类失败")
}
if len(childCategories) > 0 {
return errors.New("该分类下还有子分类,请先删除子分类")
}
return s.productRepo.DeleteCategory(id)
}
// GetProductReviews 获取产品评价列表
func (s *ProductService) GetProductReviews(productID uint, page, pageSize int) ([]model.ProductReview, *utils.Pagination, error) {
if page <= 0 {
page = 1
}
if pageSize <= 0 || pageSize > 100 {
pageSize = 20
}
offset := (page - 1) * pageSize
reviews, total, err := s.productRepo.GetReviews(productID, offset, pageSize)
if err != nil {
return nil, nil, err
}
pagination := utils.NewPagination(page, pageSize)
pagination.Total = total
return reviews, pagination, nil
}
// CreateReview 创建评价
func (s *ProductService) CreateReview(userID uint, review *model.ProductReview) error {
// 检查用户是否存在
_, err := s.userRepo.GetByID(userID)
if err != nil {
return errors.New("用户不存在")
}
// 检查产品是否存在
_, err = s.productRepo.GetByID(review.ProductID)
if err != nil {
return errors.New("产品不存在")
}
// 检查是否已经评价过
if review.OrderID != nil {
existingReview, _ := s.productRepo.GetReviewByOrderID(userID, *review.OrderID)
if existingReview != nil {
return errors.New("已经评价过该商品")
}
}
review.UserID = userID
return s.productRepo.CreateReview(review)
}
// GetHotProducts 获取热门产品
func (s *ProductService) GetHotProducts(limit int) ([]model.Product, error) {
if limit <= 0 || limit > 50 {
limit = 10
}
return s.productRepo.GetHotProducts(limit)
}
// GetRecommendProducts 获取推荐产品
func (s *ProductService) GetRecommendProducts(limit int) ([]model.Product, error) {
if limit <= 0 || limit > 50 {
limit = 10
}
return s.productRepo.GetRecommendProducts(limit)
}
// SearchProducts 搜索产品(支持价格与排序)
func (s *ProductService) SearchProducts(keyword string, page, pageSize int, minPrice, maxPrice float64, sort, sortType string) ([]model.Product, *utils.Pagination, error) {
if keyword == "" {
return []model.Product{}, utils.NewPagination(page, pageSize), nil
}
return s.GetProductList(page, pageSize, 0, keyword, minPrice, maxPrice, sort, sortType)
}
// UpdateStock 更新库存
func (s *ProductService) UpdateStock(id uint, quantity int) error {
// 检查产品是否存在
product, err := s.productRepo.GetByID(id)
if err != nil {
return errors.New("产品不存在")
}
// 检查库存是否足够(减库存时)
if quantity < 0 && product.Stock < -quantity {
return errors.New("库存不足")
}
return s.productRepo.UpdateStock(id, quantity)
}
// GetProductSKUs 获取产品SKU列表
func (s *ProductService) GetProductSKUs(productID uint) ([]model.ProductSKU, error) {
return s.productRepo.GetProductSKUs(productID)
}
// GetSKUByID 根据SKU ID获取SKU详情
func (s *ProductService) GetSKUByID(skuID uint) (*model.ProductSKU, error) {
return s.productRepo.GetSKUByID(skuID)
}
// GetProductTags 获取产品标签列表
func (s *ProductService) GetProductTags() ([]model.ProductTag, error) {
return s.productRepo.GetProductTags()
}
// GetStores 获取店铺列表
func (s *ProductService) GetStores() ([]model.Store, error) {
return s.productRepo.GetStores()
}
// GetStoreByID 根据ID获取店铺信息
func (s *ProductService) GetStoreByID(id uint) (*model.Store, error) {
return s.productRepo.GetStoreByID(id)
}
// GetProductReviewCount 获取产品评价统计
func (s *ProductService) GetProductReviewCount(productID uint) (map[string]interface{}, error) {
// 检查产品是否存在
_, err := s.productRepo.GetByID(productID)
if err != nil {
return nil, errors.New("产品不存在")
}
return s.productRepo.GetReviewCount(productID)
}
// GetProductStatistics 获取产品统计
func (s *ProductService) GetProductStatistics() (map[string]interface{}, error) {
// 使用ProductRepository的GetProductStatistics方法
return s.productRepo.GetProductStatistics()
}
// GetProductSalesRanking 获取产品销售排行
func (s *ProductService) GetProductSalesRanking(startDate, endDate, limit string) ([]map[string]interface{}, error) {
// 简化实现,返回基础排行数据
var results []map[string]interface{}
// 这里应该根据订单数据统计产品销量,暂时返回模拟数据
products, _, err := s.productRepo.GetList(0, 10, map[string]interface{}{"status": 1})
if err != nil {
return nil, err
}
for i, product := range products {
if i >= 10 { // 限制返回数量
break
}
results = append(results, map[string]interface{}{
"product_id": product.ID,
"product_name": product.Name,
"sales_count": 100 - i*5, // 模拟销量数据
"sales_amount": float64(1000 - i*50),
})
}
return results, nil
}
// GetCategorySalesRanking 获取分类销售排行
func (s *ProductService) GetCategorySalesRanking(startDate, endDate, limit string) ([]map[string]interface{}, error) {
// 解析limit参数
limitInt := 10 // 默认值
if limit != "" {
if parsedLimit, err := strconv.Atoi(limit); err == nil && parsedLimit > 0 {
limitInt = parsedLimit
}
}
// 如果没有提供日期范围使用最近30天
if startDate == "" || endDate == "" {
now := time.Now()
endDate = now.Format("2006-01-02")
startDate = now.AddDate(0, 0, -30).Format("2006-01-02")
}
// 使用真实的数据库查询
return s.productRepo.GetCategorySalesStatistics(startDate, endDate, limitInt)
}
// BatchUpdateProductStatus 批量更新商品状态
func (s *ProductService) BatchUpdateProductStatus(ids []uint, status int) error {
if len(ids) == 0 {
return errors.New("商品ID列表不能为空")
}
return s.productRepo.BatchUpdateStatus(ids, status)
}
// BatchUpdateProductPrice 批量更新商品价格
func (s *ProductService) BatchUpdateProductPrice(updates []map[string]interface{}) error {
if len(updates) == 0 {
return errors.New("更新数据不能为空")
}
return s.productRepo.BatchUpdatePrice(updates)
}
// BatchDeleteProducts 批量删除商品
func (s *ProductService) BatchDeleteProducts(ids []uint) error {
if len(ids) == 0 {
return errors.New("商品ID列表不能为空")
}
return s.productRepo.BatchDelete(ids)
}
// CreateProductSKU 创建商品SKU
func (s *ProductService) CreateProductSKU(sku *model.ProductSKU) error {
// 验证商品是否存在
_, err := s.productRepo.GetByID(sku.ProductID)
if err != nil {
return errors.New("商品不存在")
}
return s.productRepo.CreateSKU(sku)
}
// UpdateProductSKU 更新商品SKU
func (s *ProductService) UpdateProductSKU(id uint, updates map[string]interface{}) error {
// 检查SKU是否存在
_, err := s.productRepo.GetSKUByID(id)
if err != nil {
return errors.New("SKU不存在")
}
return s.productRepo.UpdateSKU(id, updates)
}
// DeleteProductSKU 删除商品SKU
func (s *ProductService) DeleteProductSKU(id uint) error {
// 检查SKU是否存在包括已软删除的
var sku model.ProductSKU
err := s.productRepo.GetDB().Where("id = ?", id).First(&sku).Error
if err != nil {
return errors.New("SKU不存在")
}
// 如果SKU已经被软删除直接返回成功
if sku.Status == 0 {
fmt.Printf("SKU ID %d 已经被软删除,无需重复操作\n", id)
return nil
}
// 检查SKU是否被订单引用
var count int64
err = s.productRepo.GetDB().Table("order_items").Where("sk_uid = ?", id).Count(&count).Error
if err != nil {
return fmt.Errorf("检查SKU引用关系失败: %v", err)
}
if count > 0 {
// 如果被订单引用,执行软删除
err = s.productRepo.DeleteSKU(id)
if err != nil {
return fmt.Errorf("删除SKU失败: %v", err)
}
// 软删除成功,记录日志但不返回错误
fmt.Printf("SKU ID %d 已被 %d 个订单引用,已执行软删除(设置为不可用状态)\n", id, count)
return nil
}
// 如果没有被引用,执行硬删除
err = s.productRepo.DeleteSKU(id)
if err != nil {
return fmt.Errorf("删除SKU失败: %v", err)
}
// 同步更新商品库存
if err := s.productRepo.SyncProductStockFromSKUs(sku.ProductID); err != nil {
// 记录错误但不阻止删除操作
fmt.Printf("同步商品库存失败 - 商品ID: %d, 错误: %v\n", sku.ProductID, err)
}
return nil
}
// GetProductImages 获取商品图片列表
func (s *ProductService) GetProductImages(productID uint) ([]model.ProductImage, error) {
return s.productRepo.GetProductImages(productID)
}
// CreateProductImage 创建商品图片
func (s *ProductService) CreateProductImage(image *model.ProductImage) error {
// 验证商品是否存在
_, err := s.productRepo.GetByID(image.ProductID)
if err != nil {
return errors.New("商品不存在")
}
return s.productRepo.CreateProductImage(image)
}
// UpdateProductImageSort 更新商品图片排序
func (s *ProductService) UpdateProductImageSort(id uint, sort int) error {
return s.productRepo.UpdateProductImageSort(id, sort)
}
// DeleteProductImage 删除商品图片
func (s *ProductService) DeleteProductImage(id uint) error {
return s.productRepo.DeleteProductImage(id)
}
// CreateProductSpec 创建商品规格
func (s *ProductService) CreateProductSpec(spec *model.ProductSpec) error {
// 验证商品是否存在
_, err := s.productRepo.GetByID(spec.ProductID)
if err != nil {
return errors.New("商品不存在")
}
return s.productRepo.CreateProductSpec(spec)
}
// UpdateProductSpec 更新商品规格
func (s *ProductService) UpdateProductSpec(id uint, updates map[string]interface{}) error {
return s.productRepo.UpdateProductSpec(id, updates)
}
// DeleteProductSpec 删除商品规格
func (s *ProductService) DeleteProductSpec(id uint) error {
return s.productRepo.DeleteProductSpec(id)
}
// GetProductSpecs 获取商品规格列表
func (s *ProductService) GetProductSpecs(productID uint) ([]model.ProductSpec, error) {
return s.productRepo.GetProductSpecs(productID)
}
// CreateProductTag 创建商品标签
func (s *ProductService) CreateProductTag(tag *model.ProductTag) error {
return s.productRepo.CreateProductTag(tag)
}
// UpdateProductTag 更新商品标签
func (s *ProductService) UpdateProductTag(id uint, updates map[string]interface{}) error {
return s.productRepo.UpdateProductTag(id, updates)
}
// DeleteProductTag 删除商品标签
func (s *ProductService) DeleteProductTag(id uint) error {
return s.productRepo.DeleteProductTag(id)
}
// AssignTagsToProduct 为商品分配标签
func (s *ProductService) AssignTagsToProduct(productID uint, tagIDs []uint) error {
// 验证商品是否存在
_, err := s.productRepo.GetByID(productID)
if err != nil {
return errors.New("商品不存在")
}
return s.productRepo.AssignTagsToProduct(productID, tagIDs)
}
// GetLowStockProducts 获取低库存商品
func (s *ProductService) GetLowStockProducts(threshold int) ([]model.Product, error) {
if threshold <= 0 {
threshold = 10 // 默认阈值
}
return s.productRepo.GetLowStockProducts(threshold)
}
// GetInventoryStatistics 获取库存统计
func (s *ProductService) GetInventoryStatistics() (map[string]interface{}, error) {
return s.productRepo.GetInventoryStatistics()
}
// ExportProducts 导出商品数据
func (s *ProductService) ExportProducts(conditions map[string]interface{}) ([]model.Product, error) {
return s.productRepo.GetProductsForExport(conditions)
}
// ImportProducts 导入商品数据
func (s *ProductService) ImportProducts(products []model.Product) (map[string]interface{}, error) {
successCount := 0
failCount := 0
var errors []string
for _, product := range products {
// 验证商品数据
if product.Name == "" {
errors = append(errors, "商品名称不能为空")
failCount++
continue
}
if product.Price <= 0 {
errors = append(errors, "商品价格必须大于0")
failCount++
continue
}
// 创建商品
err := s.productRepo.Create(&product)
if err != nil {
errors = append(errors, err.Error())
failCount++
} else {
successCount++
}
}
return map[string]interface{}{
"success_count": successCount,
"fail_count": failCount,
"errors": errors,
}, nil
}
// SyncProductStock 同步商品库存根据SKU库存计算
func (s *ProductService) SyncProductStock(productID uint) error {
return s.productRepo.SyncProductStockFromSKUs(productID)
}
// SyncAllProductsStock 同步所有商品库存
func (s *ProductService) SyncAllProductsStock() error {
// 获取所有有SKU的商品
products, _, err := s.productRepo.GetList(0, 0, map[string]interface{}{})
if err != nil {
return err
}
var syncErrors []string
for _, product := range products {
// 检查商品是否有SKU
skus, err := s.productRepo.GetProductSKUs(product.ID)
if err != nil {
continue
}
if len(skus) > 0 {
// 如果有SKU同步库存
err = s.productRepo.SyncProductStockFromSKUs(product.ID)
if err != nil {
syncErrors = append(syncErrors, fmt.Sprintf("商品ID %d 同步失败: %v", product.ID, err))
}
}
}
if len(syncErrors) > 0 {
return fmt.Errorf("部分商品同步失败: %v", syncErrors)
}
return nil
}

View File

@@ -0,0 +1,829 @@
package service
import (
"context"
"dianshang/internal/model"
"dianshang/internal/repository"
"dianshang/pkg/logger"
"dianshang/pkg/utils"
"fmt"
"time"
)
type RefundService struct {
refundRepo *repository.RefundRepository
orderRepo *repository.OrderRepository
wechatPaySvc *WeChatPayService
}
func NewRefundService(refundRepo *repository.RefundRepository, orderRepo *repository.OrderRepository, wechatPaySvc *WeChatPayService) *RefundService {
return &RefundService{
refundRepo: refundRepo,
orderRepo: orderRepo,
wechatPaySvc: wechatPaySvc,
}
}
// CreateRefund 创建退款申请
func (s *RefundService) CreateRefund(ctx context.Context, req *CreateRefundRequest) (*model.Refund, error) {
logger.Info("开始创建退款申请",
"orderID", req.OrderID,
"refundAmount", req.RefundAmount,
"refundReason", req.RefundReason,
"userID", req.UserID)
// 1. 验证订单
order, err := s.orderRepo.GetByID(req.OrderID)
if err != nil {
logger.Error("查询订单失败", "error", err, "orderID", req.OrderID)
return nil, fmt.Errorf("订单不存在")
}
// 2. 验证订单状态
if order.Status != model.OrderStatusPaid {
return nil, fmt.Errorf("订单状态不允许退款,当前状态: %s", order.GetStatusText())
}
// 3. 验证用户权限
if order.UserID != req.UserID {
return nil, fmt.Errorf("无权限操作此订单")
}
// 4. 验证退款金额
if req.RefundAmount <= 0 {
return nil, fmt.Errorf("退款金额必须大于0")
}
// 将前端传递的元金额转换为分(数据库统一使用分存储)
refundAmountInCents := req.RefundAmount * 100
// 计算已退款金额
totalRefunded, err := s.refundRepo.GetTotalRefundedByOrderID(req.OrderID)
if err != nil {
logger.Error("查询已退款金额失败", "error", err, "orderID", req.OrderID)
return nil, fmt.Errorf("查询退款信息失败")
}
// 检查退款金额是否超过可退款金额(订单金额也需要转换为分进行比较)
orderAmountInCents := order.TotalAmount * 100
availableRefund := orderAmountInCents - totalRefunded
if refundAmountInCents > availableRefund {
return nil, fmt.Errorf("退款金额超过可退款金额,可退款: %.2f", availableRefund/100.0)
}
// 5. 生成退款记录
refund := &model.Refund{
RefundNo: utils.GenerateRefundNo(),
WechatOutRefundNo: utils.GenerateWechatOutRefundNo(),
OrderID: req.OrderID,
OrderNo: order.OrderNo, // 设置订单号
UserID: req.UserID,
RefundAmount: refundAmountInCents, // 存储为分
ActualRefundAmount: refundAmountInCents, // 设置实际退款金额,初始等于申请退款金额(分)
RefundReason: req.RefundReason,
RefundType: req.RefundType,
Status: model.RefundStatusPending,
WechatRefundStatus: "",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 6. 保存退款记录
err = s.refundRepo.Create(refund)
if err != nil {
logger.Error("创建退款记录失败", "error", err)
return nil, fmt.Errorf("创建退款申请失败")
}
// 7. 更新订单状态为退款中
if order.Status != model.OrderStatusReturning {
orderUpdates := map[string]interface{}{
"status": model.OrderStatusReturning,
"refund_time": time.Now(),
"updated_at": time.Now(),
}
err = s.orderRepo.UpdateByID(order.ID, orderUpdates)
if err != nil {
logger.Error("更新订单状态为退款中失败", "error", err, "orderID", order.ID)
// 不返回错误,因为退款记录已经创建成功
} else {
logger.Info("订单状态已更新为退款中", "orderID", order.ID, "orderNo", order.OrderNo)
}
}
// 8. 创建退款日志
statusTo := model.RefundStatusPending
userID := req.UserID
err = s.createRefundLog(refund.ID, "create", nil, &statusTo, "用户申请退款", &userID)
if err != nil {
logger.Warn("创建退款日志失败", "error", err, "refundID", refund.ID)
}
logger.Info("退款申请创建成功",
"refundID", refund.ID,
"refundNo", refund.RefundNo,
"orderID", req.OrderID,
"refundAmountYuan", req.RefundAmount,
"refundAmountCents", refundAmountInCents)
return refund, nil
}
// ProcessRefund 处理退款(管理员审核通过后调用)
func (s *RefundService) ProcessRefund(ctx context.Context, refundID uint, adminID uint, adminRemark string) error {
logger.Info("开始处理退款", "refundID", refundID, "adminID", adminID)
// 1. 查询退款记录
refund, err := s.refundRepo.GetByID(refundID)
if err != nil {
logger.Error("查询退款记录失败", "error", err, "refundID", refundID)
return fmt.Errorf("退款记录不存在")
}
// 2. 验证退款状态
if refund.Status != model.RefundStatusPending {
return fmt.Errorf("退款状态不允许处理,当前状态: %s", refund.GetStatusText())
}
// 3. 查询订单信息
order, err := s.orderRepo.GetByID(refund.OrderID)
if err != nil {
logger.Error("查询订单失败", "error", err, "orderID", refund.OrderID)
return fmt.Errorf("订单不存在")
}
// 4. 更新退款状态为处理中
err = s.refundRepo.UpdateByID(refundID, map[string]interface{}{
"status": model.RefundStatusProcessing,
"admin_remark": adminRemark,
"audit_time": time.Now(),
})
if err != nil {
logger.Error("更新退款状态失败", "error", err, "refundID", refundID)
return fmt.Errorf("更新退款状态失败")
}
// 5. 创建退款日志
statusFrom := model.RefundStatusPending
statusTo := model.RefundStatusProcessing
err = s.createRefundLog(refundID, "approve", &statusFrom, &statusTo, fmt.Sprintf("管理员审核通过: %s", adminRemark), &adminID)
if err != nil {
logger.Warn("创建退款日志失败", "error", err, "refundID", refundID)
}
// 6. 调用微信退款API
wechatResp, err := s.wechatPaySvc.CreateRefund(ctx, refund, order)
if err != nil {
logger.Error("调用微信退款API失败", "error", err, "refundID", refundID)
// 更新退款状态为失败
s.refundRepo.UpdateByID(refundID, map[string]interface{}{
"status": model.RefundStatusFailed,
"admin_remark": fmt.Sprintf("微信退款失败: %v", err),
})
statusFrom := model.RefundStatusProcessing
statusTo := model.RefundStatusFailed
s.createRefundLog(refundID, "fail", &statusFrom, &statusTo, fmt.Sprintf("微信退款失败: %v", err), &adminID)
return fmt.Errorf("微信退款失败: %v", err)
}
// 7. 更新退款记录的微信信息
updates := map[string]interface{}{
"wechat_refund_id": wechatResp.Data["refund_id"],
"wechat_refund_status": wechatResp.Data["status"],
"updated_at": time.Now(),
}
// 如果微信返回了用户收款账户信息
if userAccount, ok := wechatResp.Data["user_received_account"].(string); ok && userAccount != "" {
updates["wechat_user_received_account"] = userAccount
}
// 如果微信返回了退款账户信息
if refundAccount, ok := wechatResp.Data["funds_account"].(string); ok && refundAccount != "" {
updates["wechat_refund_account"] = refundAccount
}
// 如果微信退款立即成功
if status, ok := wechatResp.Data["status"].(string); ok && status == "SUCCESS" {
updates["status"] = model.RefundStatusSuccess
if successTime, ok := wechatResp.Data["success_time"].(string); ok && successTime != "" {
if parsedTime, err := time.Parse("2006-01-02T15:04:05+08:00", successTime); err == nil {
updates["wechat_success_time"] = parsedTime
}
}
}
err = s.refundRepo.UpdateByID(refundID, updates)
if err != nil {
logger.Error("更新退款微信信息失败", "error", err, "refundID", refundID)
}
// 8. 如果退款成功,更新订单退款信息
if status, ok := wechatResp.Data["status"].(string); ok && status == "SUCCESS" {
err = s.updateOrderRefundInfo(order, refund)
if err != nil {
logger.Error("更新订单退款信息失败", "error", err, "orderID", order.ID)
}
// 创建成功日志
statusFrom := model.RefundStatusProcessing
statusTo := model.RefundStatusSuccess
s.createRefundLog(refundID, "success", &statusFrom, &statusTo, "微信退款成功", &adminID)
} else {
// 创建处理中日志
statusFrom := model.RefundStatusProcessing
statusTo := model.RefundStatusProcessing
s.createRefundLog(refundID, "processing", &statusFrom, &statusTo, "微信退款处理中", &adminID)
}
logger.Info("退款处理完成",
"refundID", refundID,
"wechatRefundID", wechatResp.Data["refund_id"],
"status", wechatResp.Data["status"])
return nil
}
// RejectRefund 拒绝退款申请
func (s *RefundService) RejectRefund(ctx context.Context, refundID uint, adminID uint, rejectReason string) error {
logger.Info("拒绝退款申请", "refundID", refundID, "adminID", adminID, "reason", rejectReason)
// 1. 查询退款记录
refund, err := s.refundRepo.GetByID(refundID)
if err != nil {
logger.Error("查询退款记录失败", "error", err, "refundID", refundID)
return fmt.Errorf("退款记录不存在")
}
// 2. 验证退款状态
if refund.Status != model.RefundStatusPending {
return fmt.Errorf("退款状态不允许拒绝,当前状态: %s", refund.GetStatusText())
}
// 3. 更新退款状态为已拒绝
err = s.refundRepo.UpdateByID(refundID, map[string]interface{}{
"status": model.RefundStatusRejected,
"reject_reason": rejectReason,
"reject_time": time.Now(),
})
if err != nil {
logger.Error("更新退款状态失败", "error", err, "refundID", refundID)
return fmt.Errorf("更新退款状态失败")
}
// 4. 创建退款日志
statusFrom := model.RefundStatusPending
statusTo := model.RefundStatusRejected
err = s.createRefundLog(refundID, "reject", &statusFrom, &statusTo, fmt.Sprintf("管理员拒绝: %s", rejectReason), &adminID)
if err != nil {
logger.Warn("创建退款日志失败", "error", err, "refundID", refundID)
}
logger.Info("退款申请已拒绝", "refundID", refundID)
return nil
}
// HandleWeChatRefundNotify 处理微信退款回调通知(解析和解密)
func (s *RefundService) HandleWeChatRefundNotify(ctx context.Context, body []byte, headers map[string]string) (*WeChatRefundNotify, error) {
logger.Info("开始处理微信退款回调通知")
if s.wechatPaySvc == nil {
return nil, fmt.Errorf("微信支付服务未初始化")
}
// 调用微信支付服务解析和解密回调数据
notify, err := s.wechatPaySvc.HandleRefundNotify(ctx, body, headers)
if err != nil {
logger.Error("解析退款回调数据失败", "error", err)
return nil, err
}
logger.Info("成功解析退款回调数据", "eventType", notify.EventType)
return notify, nil
}
// HandleRefundCallback 处理微信退款回调
func (s *RefundService) HandleRefundCallback(ctx context.Context, notify *WeChatRefundNotify) error {
logger.Info("处理微信退款回调", "eventType", notify.EventType)
if notify.DecryptedData == nil {
return fmt.Errorf("回调数据中缺少解密数据")
}
outRefundNo := notify.DecryptedData.OutRefundNo
if outRefundNo == "" {
return fmt.Errorf("回调数据中缺少退款单号")
}
// 1. 查询退款记录
refund, err := s.refundRepo.GetByWechatOutRefundNo(outRefundNo)
if err != nil {
logger.Error("根据微信退款单号查询退款记录失败", "error", err, "outRefundNo", outRefundNo)
return fmt.Errorf("退款记录不存在")
}
// 2. 根据事件类型处理不同的退款状态
var newStatus int
var logRemark string
switch notify.EventType {
case "REFUND.SUCCESS":
// 退款成功
if refund.Status == model.RefundStatusSuccess {
logger.Info("退款已经是成功状态,跳过处理", "refundID", refund.ID)
return nil
}
newStatus = model.RefundStatusSuccess
logRemark = "微信退款回调:退款成功"
case "REFUND.ABNORMAL":
// 退款异常
newStatus = model.RefundStatusFailed
logRemark = "微信退款回调:退款异常"
case "REFUND.CLOSED":
// 退款关闭
newStatus = model.RefundStatusFailed
logRemark = "微信退款回调:退款关闭"
default:
logger.Warn("未知的退款回调事件类型", "eventType", notify.EventType)
return nil
}
// 3. 更新退款状态和微信信息
updates := map[string]interface{}{
"status": newStatus,
"wechat_refund_id": notify.DecryptedData.RefundId,
"wechat_refund_status": notify.DecryptedData.RefundStatus,
"updated_at": time.Now(),
}
// 只有成功时才更新收款账户和成功时间
if notify.EventType == "REFUND.SUCCESS" {
updates["wechat_user_received_account"] = notify.DecryptedData.UserReceivedAccount
// 解析成功时间
if notify.DecryptedData.SuccessTime != "" {
if successTime, err := time.Parse("2006-01-02T15:04:05+08:00", notify.DecryptedData.SuccessTime); err == nil {
updates["wechat_success_time"] = successTime
}
}
}
err = s.refundRepo.UpdateByID(refund.ID, updates)
if err != nil {
logger.Error("更新退款状态失败", "error", err, "refundID", refund.ID)
return fmt.Errorf("更新退款状态失败")
}
// 4. 只有退款成功时才更新订单退款信息
if newStatus == model.RefundStatusSuccess {
order, err := s.orderRepo.GetByID(refund.OrderID)
if err != nil {
logger.Error("查询订单失败", "error", err, "orderID", refund.OrderID)
} else {
err = s.updateOrderRefundInfo(order, refund)
if err != nil {
logger.Error("更新订单退款信息失败", "error", err, "orderID", order.ID)
}
}
}
// 5. 创建退款日志
statusFrom := refund.Status
statusTo := newStatus
var operatorID *uint = nil
err = s.createRefundLog(refund.ID, "callback", &statusFrom, &statusTo, logRemark, operatorID)
if err != nil {
logger.Warn("创建退款日志失败", "error", err, "refundID", refund.ID)
}
logger.Info("微信退款回调处理完成", "refundID", refund.ID, "outRefundNo", outRefundNo, "newStatus", newStatus)
return nil
}
// GetRefundsByOrderID 获取订单的退款记录
func (s *RefundService) GetRefundsByOrderID(ctx context.Context, orderID uint, userID uint) ([]*model.Refund, error) {
// 验证用户权限
order, err := s.orderRepo.GetByID(orderID)
if err != nil {
return nil, fmt.Errorf("订单不存在")
}
if order.UserID != userID {
return nil, fmt.Errorf("无权限查看此订单的退款信息")
}
refunds, err := s.refundRepo.GetByOrderID(orderID)
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*model.Refund, len(refunds))
for i := range refunds {
result[i] = &refunds[i]
}
return result, nil
}
// SyncRefundAndOrderStatus 同步退款状态和订单状态
// 这个方法用于修复退款状态已成功但订单状态未更新的问题
func (s *RefundService) SyncRefundAndOrderStatus(ctx context.Context) error {
logger.Info("开始同步退款状态和订单状态")
// 1. 查询所有状态为成功的退款记录
refunds, err := s.refundRepo.GetRefundsByStatus(model.RefundStatusSuccess)
if err != nil {
logger.Error("查询成功退款记录失败", "error", err)
return fmt.Errorf("查询成功退款记录失败: %v", err)
}
// 2. 遍历每个退款记录,检查对应的订单状态
for _, refund := range refunds {
// 获取订单信息
order, err := s.orderRepo.GetByID(refund.OrderID)
if err != nil {
logger.Error("获取订单信息失败", "error", err, "orderID", refund.OrderID)
continue
}
// 如果订单状态不是已退款,则需要更新
if order.Status != model.OrderStatusRefunded {
// 计算订单总退款金额
totalRefunded, err := s.refundRepo.GetTotalRefundedByOrderID(order.ID)
if err != nil {
logger.Error("计算订单总退款金额失败", "error", err, "orderID", order.ID)
continue
}
// 如果总退款金额大于等于订单金额,则更新订单状态为已退款
if totalRefunded >= order.TotalAmount {
updates := map[string]interface{}{
"status": model.OrderStatusRefunded,
"refunded_at": time.Now(),
"updated_at": time.Now(),
}
err = s.orderRepo.UpdateByID(order.ID, updates)
if err != nil {
logger.Error("更新订单状态为已退款失败", "error", err, "orderID", order.ID)
continue
}
logger.Info("订单状态已更新为已退款",
"orderID", order.ID,
"orderNo", order.OrderNo,
"totalAmount", order.TotalAmount,
"totalRefunded", totalRefunded)
} else if order.Status != model.OrderStatusReturning {
// 如果是部分退款且订单状态不是退款中,则更新为退款中
updates := map[string]interface{}{
"status": model.OrderStatusReturning,
"updated_at": time.Now(),
}
err = s.orderRepo.UpdateByID(order.ID, updates)
if err != nil {
logger.Error("更新订单状态为退款中失败", "error", err, "orderID", order.ID)
continue
}
logger.Info("订单状态已更新为退款中",
"orderID", order.ID,
"orderNo", order.OrderNo,
"totalAmount", order.TotalAmount,
"totalRefunded", totalRefunded)
}
}
}
logger.Info("同步退款状态和订单状态完成")
return nil
}
// GetRefundsByUserID 获取用户的退款记录
func (s *RefundService) GetRefundsByUserID(ctx context.Context, userID uint, page, pageSize int) ([]*model.Refund, int64, error) {
refunds, total, err := s.refundRepo.GetByUserID(userID, page, pageSize)
if err != nil {
return nil, 0, err
}
// 转换为指针切片
result := make([]*model.Refund, len(refunds))
for i := range refunds {
result[i] = &refunds[i]
// 检查退款状态是否为成功,但订单状态不是已退款
if refunds[i].Status == model.RefundStatusSuccess && refunds[i].Order.Status != model.OrderStatusRefunded {
// 计算订单总退款金额
totalRefunded, err := s.refundRepo.GetTotalRefundedByOrderID(refunds[i].OrderID)
if err != nil {
logger.Error("计算订单总退款金额失败", "error", err, "orderID", refunds[i].OrderID)
continue
}
// 如果总退款金额大于等于订单金额,则更新订单状态为已退款
if totalRefunded >= refunds[i].Order.TotalAmount {
updates := map[string]interface{}{
"status": model.OrderStatusRefunded,
"refunded_at": time.Now(),
"updated_at": time.Now(),
}
err = s.orderRepo.UpdateByID(refunds[i].OrderID, updates)
if err != nil {
logger.Error("更新订单状态为已退款失败", "error", err, "orderID", refunds[i].OrderID)
continue
}
// 更新当前退款记录中的订单状态
result[i].Order.Status = model.OrderStatusRefunded
logger.Info("订单状态已更新为已退款",
"orderID", refunds[i].OrderID,
"orderNo", refunds[i].OrderNo,
"totalAmount", refunds[i].Order.TotalAmount,
"totalRefunded", totalRefunded)
} else if refunds[i].Order.Status != model.OrderStatusReturning {
// 如果是部分退款且订单状态不是退款中,则更新为退款中
updates := map[string]interface{}{
"status": model.OrderStatusReturning,
"updated_at": time.Now(),
}
err = s.orderRepo.UpdateByID(refunds[i].OrderID, updates)
if err != nil {
logger.Error("更新订单状态为退款中失败", "error", err, "orderID", refunds[i].OrderID)
continue
}
// 更新当前退款记录中的订单状态
result[i].Order.Status = model.OrderStatusReturning
logger.Info("订单状态已更新为退款中",
"orderID", refunds[i].OrderID,
"orderNo", refunds[i].OrderNo,
"totalAmount", refunds[i].Order.TotalAmount,
"totalRefunded", totalRefunded)
}
}
}
return result, total, nil
}
// GetRefundByID 获取退款详情
func (s *RefundService) GetRefundByID(ctx context.Context, refundID uint, userID uint) (*model.Refund, error) {
refund, err := s.refundRepo.GetByID(refundID)
if err != nil {
return nil, fmt.Errorf("退款记录不存在")
}
// 验证用户权限
if refund.UserID != userID {
return nil, fmt.Errorf("无权限查看此退款记录")
}
return refund, nil
}
// QueryRefundStatus 查询微信退款状态
func (s *RefundService) QueryRefundStatus(ctx context.Context, refundID uint) error {
logger.Info("查询微信退款状态", "refundID", refundID)
// 1. 查询退款记录
refund, err := s.refundRepo.GetByID(refundID)
if err != nil {
logger.Error("查询退款记录失败", "error", err, "refundID", refundID)
return fmt.Errorf("退款记录不存在")
}
if refund.WechatOutRefundNo == "" {
return fmt.Errorf("退款记录没有微信退款单号")
}
// 2. 调用微信查询退款API
wechatRefund, err := s.wechatPaySvc.QueryRefund(ctx, refund.WechatOutRefundNo)
if err != nil {
logger.Error("查询微信退款状态失败", "error", err, "outRefundNo", refund.WechatOutRefundNo)
return fmt.Errorf("查询微信退款状态失败: %v", err)
}
// 3. 更新退款记录
updates := map[string]interface{}{
"wechat_refund_status": wechatRefund.WechatRefundStatus,
"wechat_user_received_account": wechatRefund.WechatUserReceivedAccount,
"wechat_refund_account": wechatRefund.WechatRefundAccount,
"updated_at": time.Now(),
}
// 如果微信退款成功,更新本地状态
if wechatRefund.WechatRefundStatus == "SUCCESS" {
// 无论当前退款状态如何,只要微信退款成功,就更新为成功状态
updates["status"] = model.RefundStatusSuccess
if wechatRefund.WechatSuccessTime != nil {
updates["wechat_success_time"] = *wechatRefund.WechatSuccessTime
}
// 更新订单退款信息
order, err := s.orderRepo.GetByID(refund.OrderID)
if err == nil {
s.updateOrderRefundInfo(order, refund)
}
// 只有当状态发生变化时才创建日志
if refund.Status != model.RefundStatusSuccess {
statusFrom := refund.Status
statusTo := model.RefundStatusSuccess
var operatorID *uint = nil
s.createRefundLog(refund.ID, "query_success", &statusFrom, &statusTo, "查询确认微信退款成功", operatorID)
}
}
err = s.refundRepo.UpdateByID(refund.ID, updates)
if err != nil {
logger.Error("更新退款状态失败", "error", err, "refundID", refundID)
return fmt.Errorf("更新退款状态失败")
}
logger.Info("退款状态查询完成", "refundID", refundID, "status", wechatRefund.WechatRefundStatus)
return nil
}
// updateOrderRefundInfo 更新订单退款信息
func (s *RefundService) updateOrderRefundInfo(order *model.Order, refund *model.Refund) error {
// 计算订单总退款金额
totalRefunded, err := s.refundRepo.GetTotalRefundedByOrderID(order.ID)
if err != nil {
return err
}
// 计算退款次数
refundCount, err := s.refundRepo.GetRefundCountByOrderID(order.ID)
if err != nil {
return err
}
updates := map[string]interface{}{
"total_refund_amount": totalRefunded,
"refund_count": refundCount,
"updated_at": time.Now(),
}
// 如果全额退款,更新订单状态为已退款
if totalRefunded >= order.TotalAmount {
updates["status"] = model.OrderStatusRefunded
updates["refunded_at"] = time.Now()
logger.Info("更新订单状态为已退款",
"orderID", order.ID,
"totalAmount", order.TotalAmount,
"totalRefunded", totalRefunded,
"refundID", refund.ID)
} else if order.Status == model.OrderStatusReturning {
// 如果是部分退款且当前状态是退款中,保持退款中状态
// 这样可以区分部分退款和全额退款的订单
updates["status"] = model.OrderStatusReturning
logger.Info("保持订单状态为退款中",
"orderID", order.ID,
"totalAmount", order.TotalAmount,
"totalRefunded", totalRefunded,
"refundID", refund.ID)
}
err = s.orderRepo.UpdateByID(order.ID, updates)
if err != nil {
logger.Error("更新订单退款信息失败", "error", err, "orderID", order.ID)
return err
}
return nil
}
// createRefundLog 创建退款日志
func (s *RefundService) createRefundLog(refundID uint, action string, statusFrom, statusTo *int, remark string, operatorID *uint) error {
log := &model.RefundLog{
RefundID: refundID,
Action: action,
StatusFrom: statusFrom,
StatusTo: statusTo,
OperatorType: "admin",
OperatorID: operatorID,
Remark: remark,
CreatedAt: time.Now(),
}
return s.refundRepo.CreateLog(log)
}
// GetPendingRefunds 获取待处理的退款申请(管理员)
func (s *RefundService) GetPendingRefunds(ctx context.Context, page, pageSize int) ([]*model.Refund, int64, error) {
logger.Info("获取待处理退款申请", "page", page, "pageSize", pageSize)
refunds, total, err := s.refundRepo.GetPendingRefunds(page, pageSize)
if err != nil {
logger.Error("获取待处理退款申请失败", "error", err)
return nil, 0, err
}
// 转换为指针切片
result := make([]*model.Refund, len(refunds))
for i := range refunds {
result[i] = &refunds[i]
}
return result, total, nil
}
// GetAllRefunds 获取所有退款记录(管理员)
func (s *RefundService) GetAllRefunds(ctx context.Context, page, pageSize int, status, userID string) ([]*model.Refund, int64, error) {
logger.Info("获取所有退款记录", "page", page, "pageSize", pageSize, "status", status, "userID", userID)
// 构建查询条件
conditions := make(map[string]interface{})
if status != "" {
conditions["status"] = status
}
if userID != "" {
conditions["user_id"] = userID
}
refunds, total, err := s.refundRepo.GetAllRefunds(page, pageSize, conditions)
if err != nil {
logger.Error("获取所有退款记录失败", "error", err)
return nil, 0, err
}
// 转换为指针切片
result := make([]*model.Refund, len(refunds))
for i := range refunds {
result[i] = &refunds[i]
}
return result, total, nil
}
// GetRefundLogs 获取退款日志(管理员)
func (s *RefundService) GetRefundLogs(ctx context.Context, refundID uint) ([]model.RefundLog, error) {
logger.Info("获取退款日志", "refundID", refundID)
logs, err := s.refundRepo.GetRefundLogsByRefundID(refundID)
if err != nil {
logger.Error("获取退款日志失败", "error", err, "refundID", refundID)
return nil, err
}
return logs, nil
}
// GetRefundDetailForAdmin 获取退款详情(管理员专用)
func (s *RefundService) GetRefundDetailForAdmin(ctx context.Context, refundID uint) (*model.Refund, error) {
logger.Info("管理员获取退款详情", "refundID", refundID)
refund, err := s.refundRepo.GetByID(refundID)
if err != nil {
logger.Error("获取退款详情失败", "error", err, "refundID", refundID)
return nil, fmt.Errorf("退款记录不存在")
}
return refund, nil
}
// GetRefundStatistics 获取退款统计数据
func (s *RefundService) GetRefundStatistics(ctx context.Context, startTime, endTime time.Time) (map[string]interface{}, error) {
logger.Info("获取退款统计数据", "startTime", startTime, "endTime", endTime)
stats, err := s.refundRepo.GetRefundStatistics(startTime, endTime)
if err != nil {
logger.Error("获取退款统计数据失败", "error", err)
return nil, err
}
// 转换数据格式以匹配前端期望的格式
result := map[string]interface{}{
"total_refunds": stats["total_count"],
"pending_refunds": stats["pending_count"],
"processing_refunds": stats["processing_count"],
"total_amount": stats["total_amount"],
"success_count": stats["success_count"],
"success_amount": stats["success_amount"],
"approved_count": stats["approved_count"],
"rejected_count": stats["rejected_count"],
"failed_count": stats["failed_count"],
}
return result, nil
}
// 请求结构体
type CreateRefundRequest struct {
OrderID uint `json:"order_id" binding:"required"`
UserID uint `json:"user_id"` // 由后端从JWT token中获取不需要前端提供
RefundAmount float64 `json:"refund_amount" binding:"required,gt=0"`
RefundReason string `json:"refund_reason" binding:"required,max=500"`
RefundType int `json:"refund_type" binding:"required,oneof=1 2"` // 1:仅退款 2:退货退款
}

View File

@@ -0,0 +1,307 @@
package service
import (
"dianshang/internal/model"
"errors"
"time"
"gorm.io/gorm"
)
// RoleService 角色服务
type RoleService struct {
db *gorm.DB
}
// NewRoleService 创建角色服务
func NewRoleService(db *gorm.DB) *RoleService {
return &RoleService{
db: db,
}
}
// CreateRole 创建角色
func (s *RoleService) CreateRole(role *model.Role) error {
return s.db.Create(role).Error
}
// GetRoleByID 根据ID获取角色
func (s *RoleService) GetRoleByID(id uint) (*model.Role, error) {
var role model.Role
err := s.db.Preload("Permissions").Where("id = ?", id).First(&role).Error
if err != nil {
return nil, err
}
return &role, nil
}
// GetRoleByName 根据名称获取角色
func (s *RoleService) GetRoleByName(name string) (*model.Role, error) {
var role model.Role
err := s.db.Preload("Permissions").Where("name = ?", name).First(&role).Error
if err != nil {
return nil, err
}
return &role, nil
}
// GetRoleList 获取角色列表
func (s *RoleService) GetRoleList(page, pageSize int, conditions map[string]interface{}) ([]model.Role, map[string]interface{}, error) {
var roles []model.Role
var total int64
query := s.db.Model(&model.Role{})
// 应用查询条件
if keyword, ok := conditions["keyword"]; ok && keyword != "" {
query = query.Where("name LIKE ? OR display_name LIKE ?",
"%"+keyword.(string)+"%", "%"+keyword.(string)+"%")
}
if status, ok := conditions["status"]; ok && status != "" {
query = query.Where("status = ?", status)
}
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, nil, err
}
// 分页查询
offset := (page - 1) * pageSize
err = query.Preload("Permissions").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&roles).Error
if err != nil {
return nil, nil, err
}
// 构建分页信息
pagination := map[string]interface{}{
"total": total,
"page": page,
"page_size": pageSize,
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
}
return roles, pagination, nil
}
// UpdateRole 更新角色
func (s *RoleService) UpdateRole(id uint, updates map[string]interface{}) error {
return s.db.Model(&model.Role{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteRole 删除角色
func (s *RoleService) DeleteRole(id uint) error {
// 检查是否有用户使用该角色
var count int64
s.db.Model(&model.UserRole{}).Where("role_id = ?", id).Count(&count)
if count > 0 {
return errors.New("该角色正在被用户使用,无法删除")
}
// 删除角色权限关联
s.db.Where("role_id = ?", id).Delete(&model.RolePermission{})
// 删除角色
return s.db.Delete(&model.Role{}, id).Error
}
// AssignPermissionsToRole 为角色分配权限
func (s *RoleService) AssignPermissionsToRole(roleID uint, permissionIDs []uint) error {
// 删除原有权限
s.db.Where("role_id = ?", roleID).Delete(&model.RolePermission{})
// 添加新权限
for _, permissionID := range permissionIDs {
rolePermission := &model.RolePermission{
RoleID: roleID,
PermissionID: permissionID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.db.Create(rolePermission).Error; err != nil {
return err
}
}
return nil
}
// AssignRolesToUser 为用户分配角色
func (s *RoleService) AssignRolesToUser(userID uint, roleIDs []uint) error {
// 删除原有角色
s.db.Where("user_id = ?", userID).Delete(&model.UserRole{})
// 添加新角色
for _, roleID := range roleIDs {
userRole := &model.UserRole{
UserID: userID,
RoleID: roleID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.db.Create(userRole).Error; err != nil {
return err
}
}
return nil
}
// GetUserRoles 获取用户角色
func (s *RoleService) GetUserRoles(userID uint) ([]model.Role, error) {
var roles []model.Role
err := s.db.Table("ai_roles").
Joins("JOIN ai_user_roles ON ai_roles.id = ai_user_roles.role_id").
Where("ai_user_roles.user_id = ?", userID).
Find(&roles).Error
return roles, err
}
// GetUserPermissions 获取用户权限
func (s *RoleService) GetUserPermissions(userID uint) ([]model.Permission, error) {
var permissions []model.Permission
err := s.db.Table("ai_permissions").
Joins("JOIN ai_role_permissions ON ai_permissions.id = ai_role_permissions.permission_id").
Joins("JOIN ai_user_roles ON ai_role_permissions.role_id = ai_user_roles.role_id").
Where("ai_user_roles.user_id = ?", userID).
Distinct().
Find(&permissions).Error
return permissions, err
}
// CheckUserPermission 检查用户权限
func (s *RoleService) CheckUserPermission(userID uint, module, action string) (bool, error) {
var count int64
err := s.db.Table("ai_permissions").
Joins("JOIN ai_role_permissions ON ai_permissions.id = ai_role_permissions.permission_id").
Joins("JOIN ai_user_roles ON ai_role_permissions.role_id = ai_user_roles.role_id").
Where("ai_user_roles.user_id = ? AND ai_permissions.module = ? AND ai_permissions.action = ? AND ai_permissions.status = 1",
userID, module, action).
Count(&count).Error
return count > 0, err
}
// CreatePermission 创建权限
func (s *RoleService) CreatePermission(permission *model.Permission) error {
return s.db.Create(permission).Error
}
// GetPermissionList 获取权限列表
func (s *RoleService) GetPermissionList(page, pageSize int, conditions map[string]interface{}) ([]model.Permission, map[string]interface{}, error) {
var permissions []model.Permission
var total int64
query := s.db.Model(&model.Permission{})
// 应用查询条件
if keyword, ok := conditions["keyword"]; ok && keyword != "" {
query = query.Where("name LIKE ? OR display_name LIKE ?",
"%"+keyword.(string)+"%", "%"+keyword.(string)+"%")
}
if module, ok := conditions["module"]; ok && module != "" {
query = query.Where("module = ?", module)
}
if status, ok := conditions["status"]; ok && status != "" {
query = query.Where("status = ?", status)
}
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, nil, err
}
// 分页查询
offset := (page - 1) * pageSize
err = query.Offset(offset).Limit(pageSize).Order("module ASC, action ASC").Find(&permissions).Error
if err != nil {
return nil, nil, err
}
// 构建分页信息
pagination := map[string]interface{}{
"total": total,
"page": page,
"page_size": pageSize,
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
}
return permissions, pagination, nil
}
// UpdatePermission 更新权限
func (s *RoleService) UpdatePermission(id uint, updates map[string]interface{}) error {
return s.db.Model(&model.Permission{}).Where("id = ?", id).Updates(updates).Error
}
// DeletePermission 删除权限
func (s *RoleService) DeletePermission(id uint) error {
// 检查是否有角色使用该权限
var count int64
s.db.Model(&model.RolePermission{}).Where("permission_id = ?", id).Count(&count)
if count > 0 {
return errors.New("该权限正在被角色使用,无法删除")
}
return s.db.Delete(&model.Permission{}, id).Error
}
// InitDefaultRolesAndPermissions 初始化默认角色和权限
func (s *RoleService) InitDefaultRolesAndPermissions() error {
// 创建默认权限
permissions := []model.Permission{
// 用户管理权限
{Name: "user.create", DisplayName: "创建用户", Module: "user", Action: "create", Status: 1},
{Name: "user.read", DisplayName: "查看用户", Module: "user", Action: "read", Status: 1},
{Name: "user.update", DisplayName: "更新用户", Module: "user", Action: "update", Status: 1},
{Name: "user.delete", DisplayName: "删除用户", Module: "user", Action: "delete", Status: 1},
// 商品管理权限
{Name: "product.create", DisplayName: "创建商品", Module: "product", Action: "create", Status: 1},
{Name: "product.read", DisplayName: "查看商品", Module: "product", Action: "read", Status: 1},
{Name: "product.update", DisplayName: "更新商品", Module: "product", Action: "update", Status: 1},
{Name: "product.delete", DisplayName: "删除商品", Module: "product", Action: "delete", Status: 1},
// 订单管理权限
{Name: "order.create", DisplayName: "创建订单", Module: "order", Action: "create", Status: 1},
{Name: "order.read", DisplayName: "查看订单", Module: "order", Action: "read", Status: 1},
{Name: "order.update", DisplayName: "更新订单", Module: "order", Action: "update", Status: 1},
{Name: "order.delete", DisplayName: "删除订单", Module: "order", Action: "delete", Status: 1},
// 系统管理权限
{Name: "system.config", DisplayName: "系统配置", Module: "system", Action: "config", Status: 1},
{Name: "system.log", DisplayName: "系统日志", Module: "system", Action: "log", Status: 1},
}
for _, permission := range permissions {
var existingPermission model.Permission
if err := s.db.Where("name = ?", permission.Name).First(&existingPermission).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
s.db.Create(&permission)
}
}
}
// 创建默认角色
roles := []model.Role{
{Name: "admin", DisplayName: "超级管理员", Description: "拥有所有权限", Status: 1},
{Name: "manager", DisplayName: "管理员", Description: "拥有大部分管理权限", Status: 1},
{Name: "user", DisplayName: "普通用户", Description: "基础用户权限", Status: 1},
}
for _, role := range roles {
var existingRole model.Role
if err := s.db.Where("name = ?", role.Name).First(&existingRole).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
s.db.Create(&role)
}
}
}
return nil
}

View File

@@ -0,0 +1,946 @@
package service
import (
"dianshang/internal/model"
"dianshang/internal/repository"
"dianshang/pkg/jwt"
"errors"
"fmt"
"time"
"gorm.io/gorm"
)
// UserService 用户服务
type UserService struct {
userRepo *repository.UserRepository
db *gorm.DB
}
// NewUserService 创建用户服务
func NewUserService(db *gorm.DB) *UserService {
return &UserService{
userRepo: repository.NewUserRepository(db),
db: db,
}
}
// WeChatLogin 微信登录
func (s *UserService) WeChatLogin(code string) (*model.User, string, error) {
// TODO: 调用微信API获取openid
// 这里暂时模拟
openID := "mock_openid_" + code
// 查找用户
user, err := s.userRepo.GetByOpenID(openID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// 用户不存在,创建新用户
user = &model.User{
OpenID: openID,
Nickname: "微信用户",
Status: 1,
}
if err := s.userRepo.Create(user); err != nil {
return nil, "", err
}
} else {
return nil, "", err
}
}
// 检查用户状态
if user.Status == 0 {
return nil, "", errors.New("用户已被禁用")
}
// 生成JWT token
token, err := jwt.GenerateToken(user.ID, "user", 7200)
if err != nil {
return nil, "", err
}
return user, token, nil
}
// CreateUser 创建用户
func (s *UserService) CreateUser(user *model.User) error {
// 检查用户是否已存在
existingUser, err := s.userRepo.GetByOpenID(user.OpenID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if existingUser != nil {
return errors.New("用户已存在")
}
return s.userRepo.Create(user)
}
// GetUserByID 根据ID获取用户
func (s *UserService) GetUserByID(id uint) (*model.User, error) {
return s.userRepo.GetByID(id)
}
// UpdateUser 更新用户信息
func (s *UserService) UpdateUser(id uint, updates map[string]interface{}) error {
// 更新用户表
if err := s.userRepo.Update(id, updates); err != nil {
return err
}
// 同步更新微信用户信息表
wechatUpdates := make(map[string]interface{})
if nickname, ok := updates["nickname"]; ok {
wechatUpdates["nick_name"] = nickname
}
if avatar, ok := updates["avatar"]; ok {
wechatUpdates["avatar_url"] = avatar
}
if gender, ok := updates["gender"]; ok {
wechatUpdates["gender"] = gender
}
// 如果有需要更新的微信信息字段
if len(wechatUpdates) > 0 {
wechatUpdates["updated_at"] = time.Now()
// 更新微信用户信息表
if err := s.db.Model(&model.User{}).Where("id = ?", id).Updates(wechatUpdates).Error; err != nil {
// 记录错误但不影响主要更新流程
fmt.Printf("更新微信用户信息失败: %v\n", err)
}
}
return nil
}
// GetUserAddresses 获取用户地址列表
func (s *UserService) GetUserAddresses(userID uint) ([]model.UserAddress, error) {
return s.userRepo.GetAddresses(userID)
}
// GetAddressByID 根据ID获取用户地址
func (s *UserService) GetAddressByID(userID, addressID uint) (*model.UserAddress, error) {
address, err := s.userRepo.GetAddressByID(addressID)
if err != nil {
return nil, err
}
// 检查地址是否属于该用户
if address.UserID != userID {
return nil, errors.New("无权限访问该地址")
}
return address, nil
}
// CreateAddress 创建用户地址
func (s *UserService) CreateAddress(address *model.UserAddress) error {
// 如果设置为默认地址,先取消其他默认地址
if address.IsDefault {
if err := s.userRepo.ClearDefaultAddress(address.UserID); err != nil {
return err
}
}
return s.userRepo.CreateAddress(address)
}
// UpdateAddress 更新用户地址
func (s *UserService) UpdateAddress(userID, addressID uint, updates map[string]interface{}) error {
// 检查地址是否属于该用户
address, err := s.userRepo.GetAddressByID(addressID)
if err != nil {
return err
}
if address.UserID != userID {
return errors.New("无权限操作该地址")
}
// 如果设置为默认地址,先取消其他默认地址
if isDefault, ok := updates["is_default"]; ok && isDefault.(uint8) == 1 {
if err := s.userRepo.ClearDefaultAddress(userID); err != nil {
return err
}
}
return s.userRepo.UpdateAddress(addressID, updates)
}
// DeleteAddress 删除用户地址
func (s *UserService) DeleteAddress(userID, addressID uint) error {
// 检查地址是否属于该用户
address, err := s.userRepo.GetAddressByID(addressID)
if err != nil {
return err
}
if address.UserID != userID {
return errors.New("无权限操作该地址")
}
return s.userRepo.DeleteAddress(addressID)
}
// SetDefaultAddress 设置默认地址
func (s *UserService) SetDefaultAddress(userID, addressID uint) error {
// 检查地址是否属于该用户
address, err := s.userRepo.GetAddressByID(addressID)
if err != nil {
return err
}
if address.UserID != userID {
return errors.New("无权限操作该地址")
}
// 先取消其他默认地址
if err := s.userRepo.ClearDefaultAddress(userID); err != nil {
return err
}
// 设置为默认地址
return s.userRepo.UpdateAddress(addressID, map[string]interface{}{
"is_default": 1,
})
}
// GetFavorites 获取用户收藏列表
func (s *UserService) GetFavorites(userID uint, page, limit int) ([]model.UserFavorite, int64, error) {
offset := (page - 1) * limit
return s.userRepo.GetFavorites(userID, offset, limit)
}
// AddToFavorite 添加收藏
func (s *UserService) AddToFavorite(userID, productID uint) error {
// 检查是否已经收藏
if s.userRepo.IsFavorite(userID, productID) {
return errors.New("商品已在收藏列表中")
}
favorite := &model.UserFavorite{
UserID: userID,
ProductID: productID,
}
return s.userRepo.CreateFavorite(favorite)
}
// RemoveFromFavorite 取消收藏
func (s *UserService) RemoveFromFavorite(userID, productID uint) error {
return s.userRepo.DeleteFavorite(userID, productID)
}
// IsFavorite 检查是否已收藏
func (s *UserService) IsFavorite(userID, productID uint) bool {
return s.userRepo.IsFavorite(userID, productID)
}
// GetUserStatistics 获取用户统计
func (s *UserService) GetUserStatistics(startDate, endDate string) (map[string]interface{}, error) {
result := make(map[string]interface{})
// 总用户数
var totalUsers int64
s.db.Model(&model.User{}).Count(&totalUsers)
result["total_users"] = totalUsers
// 新增用户数(指定日期范围)
var newUsers int64
query := s.db.Model(&model.User{})
if startDate != "" && endDate != "" {
query = query.Where("DATE(created_at) BETWEEN ? AND ?", startDate, endDate)
}
query.Count(&newUsers)
result["new_users"] = newUsers
// 活跃用户数(简化处理,这里用登录用户数代替)
var activeUsers int64
activeQuery := s.db.Model(&model.User{}).Where("status = ?", 1)
if startDate != "" && endDate != "" {
activeQuery = activeQuery.Where("DATE(updated_at) BETWEEN ? AND ?", startDate, endDate)
}
activeQuery.Count(&activeUsers)
result["active_users"] = activeUsers
return result, nil
}
// GetDailyUserStatistics 获取每日用户统计
func (s *UserService) GetDailyUserStatistics(startDate, endDate string) ([]map[string]interface{}, error) {
// 简化实现,返回基础统计数据
var results []map[string]interface{}
// 解析日期
start, err := time.Parse("2006-01-02", startDate)
if err != nil {
return nil, err
}
end, err := time.Parse("2006-01-02", endDate)
if err != nil {
return nil, err
}
// 遍历日期范围
for d := start; !d.After(end); d = d.AddDate(0, 0, 1) {
dateStr := d.Format("2006-01-02")
var newUsers int64
s.db.Model(&model.User{}).
Where("DATE(created_at) = ?", dateStr).
Count(&newUsers)
results = append(results, map[string]interface{}{
"date": dateStr,
"new_users": newUsers,
})
}
return results, nil
}
// GetUserListForAdmin 获取用户列表(管理后台)
func (s *UserService) GetUserListForAdmin(page, pageSize int, conditions map[string]interface{}) ([]model.User, map[string]interface{}, error) {
var users []model.User
var total int64
query := s.db.Model(&model.User{})
// 应用查询条件
if keyword, ok := conditions["keyword"]; ok && keyword != "" {
query = query.Where("nickname LIKE ? OR email LIKE ? OR phone LIKE ?",
"%"+keyword.(string)+"%", "%"+keyword.(string)+"%", "%"+keyword.(string)+"%")
}
if status, ok := conditions["status"]; ok && status != "" {
query = query.Where("status = ?", status)
}
if startDate, ok := conditions["start_date"]; ok && startDate != "" {
query = query.Where("created_at >= ?", startDate)
}
if endDate, ok := conditions["end_date"]; ok && endDate != "" {
query = query.Where("created_at <= ?", endDate)
}
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, nil, err
}
// 分页查询
offset := (page - 1) * pageSize
err = query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&users).Error
if err != nil {
return nil, nil, err
}
// 构建分页信息
pagination := map[string]interface{}{
"total": total,
"page": page,
"page_size": pageSize,
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
}
return users, pagination, nil
}
// GetUserDetailForAdmin 获取用户详情(管理后台)
func (s *UserService) GetUserDetailForAdmin(userID uint) (map[string]interface{}, error) {
var user model.User
err := s.db.Where("id = ?", userID).First(&user).Error
if err != nil {
return nil, err
}
// 构建返回数据
result := map[string]interface{}{
"id": user.ID,
"openid": user.OpenID,
"unionid": user.UnionID,
"nickname": user.Nickname,
"avatar": user.Avatar,
"gender": user.Gender,
"phone": user.Phone,
"email": user.Email,
"birthday": user.Birthday,
"points": user.Points,
"level": user.Level,
"status": user.Status,
"created_at": user.CreatedAt,
"updated_at": user.UpdatedAt,
}
return result, nil
}
// UpdateUserStatusByAdmin 管理员更新用户状态
func (s *UserService) UpdateUserStatusByAdmin(userID uint, status uint8, remark string, adminID uint) error {
return s.db.Model(&model.User{}).Where("id = ?", userID).Update("status", status).Error
}
// UpdateUserProfile 更新用户资料
func (s *UserService) UpdateUserProfile(userID uint, updates map[string]interface{}) error {
// 验证更新字段
allowedFields := map[string]bool{
"nickname": true,
"avatar": true,
"gender": true,
"phone": true,
"email": true,
"birthday": true,
}
filteredUpdates := make(map[string]interface{})
for key, value := range updates {
if allowedFields[key] {
filteredUpdates[key] = value
}
}
if len(filteredUpdates) == 0 {
return fmt.Errorf("没有有效的更新字段")
}
filteredUpdates["updated_at"] = time.Now()
return s.db.Model(&model.User{}).Where("id = ?", userID).Updates(filteredUpdates).Error
}
// UpdateUserProfileByAdmin 管理员更新用户资料
func (s *UserService) UpdateUserProfileByAdmin(userID uint, updates map[string]interface{}, adminID uint) error {
// 验证更新字段
allowedFields := map[string]bool{
"nickname": true,
"avatar": true,
"gender": true,
"phone": true,
"email": true,
"birthday": true,
}
filteredUpdates := make(map[string]interface{})
for key, value := range updates {
if allowedFields[key] {
filteredUpdates[key] = value
}
}
if len(filteredUpdates) == 0 {
return fmt.Errorf("没有有效的更新字段")
}
filteredUpdates["updated_at"] = time.Now()
// TODO: 记录操作日志
// 可以在这里添加操作日志记录,记录管理员修改用户资料的操作
return s.db.Model(&model.User{}).Where("id = ?", userID).Updates(filteredUpdates).Error
}
// ResetUserPassword 重置用户密码(管理员操作)
func (s *UserService) ResetUserPassword(userID uint, newPassword string, adminID uint) error {
// 这里可以添加密码加密逻辑
// 由于是微信小程序,通常不需要密码,这里预留接口
updates := map[string]interface{}{
"updated_at": time.Now(),
}
return s.db.Model(&model.User{}).Where("id = ?", userID).Updates(updates).Error
}
// CreateUserAddress 创建用户地址
func (s *UserService) CreateUserAddress(address *model.UserAddress) error {
// 如果设置为默认地址,先取消其他默认地址
if address.IsDefault {
s.db.Model(&model.UserAddress{}).Where("user_id = ?", address.UserID).Update("is_default", false)
}
return s.db.Create(address).Error
}
// UpdateUserAddress 更新用户地址
func (s *UserService) UpdateUserAddress(addressID uint, userID uint, updates map[string]interface{}) error {
// 如果设置为默认地址,先取消其他默认地址
if isDefault, ok := updates["is_default"]; ok && isDefault.(bool) {
s.db.Model(&model.UserAddress{}).Where("user_id = ?", userID).Update("is_default", false)
}
updates["updated_at"] = time.Now()
return s.db.Model(&model.UserAddress{}).Where("id = ? AND user_id = ?", addressID, userID).Updates(updates).Error
}
// DeleteUserAddress 删除用户地址
func (s *UserService) DeleteUserAddress(addressID uint, userID uint) error {
return s.db.Where("id = ? AND user_id = ?", addressID, userID).Delete(&model.UserAddress{}).Error
}
// GetUserFavorites 获取用户收藏列表
func (s *UserService) GetUserFavorites(userID uint, page, pageSize int) ([]model.UserFavorite, map[string]interface{}, error) {
var favorites []model.UserFavorite
var total int64
query := s.db.Model(&model.UserFavorite{}).Where("user_id = ?", userID).Preload("Product")
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, nil, err
}
// 分页查询
offset := (page - 1) * pageSize
err = query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&favorites).Error
if err != nil {
return nil, nil, err
}
// 构建分页信息
pagination := map[string]interface{}{
"total": total,
"page": page,
"page_size": pageSize,
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
}
return favorites, pagination, nil
}
// AddUserFavorite 添加用户收藏
func (s *UserService) AddUserFavorite(userID, productID uint) error {
// 检查是否已收藏
var count int64
s.db.Model(&model.UserFavorite{}).Where("user_id = ? AND product_id = ?", userID, productID).Count(&count)
if count > 0 {
return fmt.Errorf("商品已收藏")
}
favorite := &model.UserFavorite{
UserID: userID,
ProductID: productID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
return s.db.Create(favorite).Error
}
// RemoveUserFavorite 移除用户收藏
func (s *UserService) RemoveUserFavorite(userID, productID uint) error {
return s.db.Where("user_id = ? AND product_id = ?", userID, productID).Delete(&model.UserFavorite{}).Error
}
// GetUserLevelInfo 获取用户等级信息
func (s *UserService) GetUserLevelInfo(userID uint) (map[string]interface{}, error) {
user, err := s.GetUserByID(userID)
if err != nil {
return nil, err
}
// 计算用户等级相关信息
levelInfo := map[string]interface{}{
"current_level": user.Level,
"current_points": user.Points,
"level_name": s.getLevelName(user.Level),
"next_level": user.Level + 1,
"next_level_name": s.getLevelName(user.Level + 1),
"points_to_next": s.getPointsToNextLevel(user.Level, user.Points),
}
return levelInfo, nil
}
// getLevelName 获取等级名称
func (s *UserService) getLevelName(level int) string {
levelNames := map[int]string{
1: "青铜会员",
2: "白银会员",
3: "黄金会员",
4: "铂金会员",
5: "钻石会员",
}
if name, ok := levelNames[level]; ok {
return name
}
return "普通会员"
}
// getPointsToNextLevel 获取升级到下一等级所需积分
func (s *UserService) getPointsToNextLevel(currentLevel, currentPoints int) int {
levelThresholds := map[int]int{
1: 0,
2: 1000,
3: 3000,
4: 6000,
5: 10000,
}
nextLevel := currentLevel + 1
if threshold, ok := levelThresholds[nextLevel]; ok {
return threshold - currentPoints
}
return 0 // 已达到最高等级
}
// UpdateUserLevel 更新用户等级
func (s *UserService) UpdateUserLevel(userID uint) error {
user, err := s.GetUserByID(userID)
if err != nil {
return err
}
newLevel := s.calculateUserLevel(user.Points)
if newLevel != user.Level {
return s.db.Model(&model.User{}).Where("id = ?", userID).Update("level", newLevel).Error
}
return nil
}
// UpdateUserLevelByAdmin 管理员手动设置用户等级
func (s *UserService) UpdateUserLevelByAdmin(userID uint, level uint8, remark string, adminID uint) error {
// 验证等级范围
if level < 1 || level > 5 {
return errors.New("用户等级必须在1-5之间")
}
// 更新用户等级
err := s.db.Model(&model.User{}).Where("id = ?", userID).Update("level", level).Error
if err != nil {
return err
}
// TODO: 记录操作日志
// 可以在这里添加操作日志记录,记录管理员修改用户等级的操作
return nil
}
// calculateUserLevel 根据积分计算用户等级
func (s *UserService) calculateUserLevel(points int) int {
if points >= 10000 {
return 5
} else if points >= 6000 {
return 4
} else if points >= 3000 {
return 3
} else if points >= 1000 {
return 2
}
return 1
}
// ResetUserPasswordByAdmin 管理员重置用户密码
func (s *UserService) ResetUserPasswordByAdmin(userID uint, newPassword string, adminID uint) error {
// 这里应该对密码进行加密处理,暂时简化实现
// 在实际项目中,用户可能通过微信登录,不需要密码
// 这里只是为了满足接口需求
return s.db.Model(&model.User{}).Where("id = ?", userID).Update("updated_at", time.Now()).Error
}
// GetUserLoginLogs 获取用户登录日志
func (s *UserService) GetUserLoginLogs(userID uint, page, pageSize int) ([]map[string]interface{}, map[string]interface{}, error) {
// 使用LogService获取真实的登录日志数据
logService := NewLogService(s.db)
logs, pagination, err := logService.GetUserLoginLogs(userID, page, pageSize)
if err != nil {
return nil, nil, err
}
// 转换为前端需要的格式
var result []map[string]interface{}
for _, log := range logs {
result = append(result, map[string]interface{}{
"id": log.ID,
"user_id": log.UserID,
"login_time": log.LoginTime.Format("2006-01-02 15:04:05"),
"ip_address": log.LoginIP,
"device": log.UserAgent,
"location": "未知", // 可以后续添加IP地址解析功能
"status": log.Status,
"remark": log.Remark,
})
}
return result, pagination, nil
}
// GetUserPurchaseRanking 获取用户购买排行
func (s *UserService) GetUserPurchaseRanking(startDate, endDate, limit string) ([]map[string]interface{}, error) {
// 简化实现,返回基础排行数据
var results []map[string]interface{}
// 获取活跃用户
var users []model.User
err := s.db.Model(&model.User{}).
Where("status = ?", 1).
Limit(10).
Find(&users).Error
if err != nil {
return nil, err
}
for i, user := range users {
results = append(results, map[string]interface{}{
"user_id": user.ID,
"nickname": user.Nickname,
"purchase_count": 50 - i*3, // 模拟购买次数
"purchase_amount": float64(5000 - i*200), // 模拟购买金额
})
}
return results, nil
}
// GetUserGrowthTrend 获取用户增长趋势
func (s *UserService) GetUserGrowthTrend(days int) ([]map[string]interface{}, error) {
var results []map[string]interface{}
for i := days - 1; i >= 0; i-- {
date := time.Now().AddDate(0, 0, -i).Format("2006-01-02")
// 新增用户数
var newUsers int64
s.db.Model(&model.User{}).
Where("DATE(created_at) = ?", date).
Count(&newUsers)
// 累计用户数
var totalUsers int64
s.db.Model(&model.User{}).
Where("DATE(created_at) <= ?", date).
Count(&totalUsers)
// 活跃用户数(当天有更新记录的用户)
var activeUsers int64
s.db.Model(&model.User{}).
Where("DATE(updated_at) = ? AND status = ?", date, 1).
Count(&activeUsers)
results = append(results, map[string]interface{}{
"date": date,
"new_users": newUsers,
"total_users": totalUsers,
"active_users": activeUsers,
})
}
return results, nil
}
// GetUserActivityAnalysis 获取用户活跃度分析
func (s *UserService) GetUserActivityAnalysis(startDate, endDate string) (map[string]interface{}, error) {
result := make(map[string]interface{})
// 总用户数
var totalUsers int64
s.db.Model(&model.User{}).Count(&totalUsers)
// 活跃用户数(指定时间范围内有更新的用户)
var activeUsers int64
query := s.db.Model(&model.User{}).Where("status = ?", 1)
if startDate != "" && endDate != "" {
query = query.Where("DATE(updated_at) BETWEEN ? AND ?", startDate, endDate)
}
query.Count(&activeUsers)
// 新增用户数
var newUsers int64
newQuery := s.db.Model(&model.User{})
if startDate != "" && endDate != "" {
newQuery = newQuery.Where("DATE(created_at) BETWEEN ? AND ?", startDate, endDate)
}
newQuery.Count(&newUsers)
// 沉默用户数30天内无活动的用户
thirtyDaysAgo := time.Now().AddDate(0, 0, -30).Format("2006-01-02")
var silentUsers int64
s.db.Model(&model.User{}).
Where("status = ? AND DATE(updated_at) < ?", 1, thirtyDaysAgo).
Count(&silentUsers)
// 计算活跃率
var activityRate float64
if totalUsers > 0 {
activityRate = float64(activeUsers) / float64(totalUsers) * 100
}
result["total_users"] = totalUsers
result["active_users"] = activeUsers
result["new_users"] = newUsers
result["silent_users"] = silentUsers
result["activity_rate"] = activityRate
return result, nil
}
// GetUserRetentionRate 获取用户留存率
func (s *UserService) GetUserRetentionRate(days int) ([]map[string]interface{}, error) {
var results []map[string]interface{}
for i := days - 1; i >= 0; i-- {
date := time.Now().AddDate(0, 0, -i).Format("2006-01-02")
nextDate := time.Now().AddDate(0, 0, -i+1).Format("2006-01-02")
// 当天新增用户数
var newUsers int64
s.db.Model(&model.User{}).
Where("DATE(created_at) = ?", date).
Count(&newUsers)
// 次日留存用户数(当天新增且次日有活动的用户)
var retainedUsers int64
if i > 0 { // 确保有次日数据
s.db.Model(&model.User{}).
Where("DATE(created_at) = ? AND DATE(updated_at) = ?", date, nextDate).
Count(&retainedUsers)
}
// 计算留存率
var retentionRate float64
if newUsers > 0 {
retentionRate = float64(retainedUsers) / float64(newUsers) * 100
}
results = append(results, map[string]interface{}{
"date": date,
"new_users": newUsers,
"retained_users": retainedUsers,
"retention_rate": retentionRate,
})
}
return results, nil
}
// GetUserLevelDistribution 获取用户等级分布
func (s *UserService) GetUserLevelDistribution() ([]map[string]interface{}, error) {
var results []map[string]interface{}
// 统计各等级用户数量
for level := 1; level <= 5; level++ {
var count int64
s.db.Model(&model.User{}).
Where("level = ? AND status = ?", level, 1).
Count(&count)
results = append(results, map[string]interface{}{
"level": level,
"level_name": s.getLevelName(level),
"user_count": count,
})
}
return results, nil
}
// GetUserGeographicDistribution 获取用户地域分布
func (s *UserService) GetUserGeographicDistribution() ([]map[string]interface{}, error) {
// 简化实现,返回模拟的地域分布数据
// 在实际项目中可以根据用户地址或IP地址统计
regions := []map[string]interface{}{
{"region": "北京", "user_count": 1200},
{"region": "上海", "user_count": 980},
{"region": "广州", "user_count": 750},
{"region": "深圳", "user_count": 680},
{"region": "杭州", "user_count": 520},
{"region": "成都", "user_count": 450},
{"region": "武汉", "user_count": 380},
{"region": "西安", "user_count": 320},
{"region": "南京", "user_count": 280},
{"region": "其他", "user_count": 1430},
}
return regions, nil
}
// GetUserAgeDistribution 获取用户年龄分布
func (s *UserService) GetUserAgeDistribution() ([]map[string]interface{}, error) {
// 简化实现,返回模拟的年龄分布数据
// 在实际项目中,可以根据用户生日计算年龄分布
ageGroups := []map[string]interface{}{
{"age_group": "18-25", "user_count": 1500},
{"age_group": "26-30", "user_count": 2200},
{"age_group": "31-35", "user_count": 1800},
{"age_group": "36-40", "user_count": 1200},
{"age_group": "41-50", "user_count": 800},
{"age_group": "50+", "user_count": 500},
}
return ageGroups, nil
}
// GetUserEngagementMetrics 获取用户参与度指标
func (s *UserService) GetUserEngagementMetrics(startDate, endDate string) (map[string]interface{}, error) {
result := make(map[string]interface{})
// 平均会话时长(模拟数据)
result["avg_session_duration"] = 25.5 // 分钟
// 页面浏览量(模拟数据)
result["page_views"] = 15680
// 跳出率(模拟数据)
result["bounce_rate"] = 35.2 // 百分比
// 用户互动次数(收藏、评价等)
var favoriteCount int64
query := s.db.Model(&model.UserFavorite{})
if startDate != "" && endDate != "" {
query = query.Where("DATE(created_at) BETWEEN ? AND ?", startDate, endDate)
}
query.Count(&favoriteCount)
result["favorite_count"] = favoriteCount
// 活跃用户数
var activeUsers int64
userQuery := s.db.Model(&model.User{}).Where("status = ?", 1)
if startDate != "" && endDate != "" {
userQuery = userQuery.Where("DATE(updated_at) BETWEEN ? AND ?", startDate, endDate)
}
userQuery.Count(&activeUsers)
result["active_users"] = activeUsers
return result, nil
}
// DeleteUser 删除用户
func (s *UserService) DeleteUser(userID uint) error {
// 软删除用户
return s.db.Delete(&model.User{}, userID).Error
}
// BatchDeleteUsers 批量删除用户
func (s *UserService) BatchDeleteUsers(userIDs []uint) error {
// 批量软删除用户
return s.db.Delete(&model.User{}, userIDs).Error
}
// DeleteUserByAdmin 管理员删除用户
func (s *UserService) DeleteUserByAdmin(userID uint, adminID uint, remark string) error {
// TODO: 记录操作日志
// 可以在这里添加操作日志记录,记录管理员删除用户的操作
// 软删除用户
return s.db.Delete(&model.User{}, userID).Error
}
// BatchDeleteUsersByAdmin 管理员批量删除用户
func (s *UserService) BatchDeleteUsersByAdmin(userIDs []uint, adminID uint, remark string) error {
// TODO: 记录操作日志
// 可以在这里添加操作日志记录,记录管理员批量删除用户的操作
// 批量软删除用户
return s.db.Delete(&model.User{}, userIDs).Error
}

View File

@@ -0,0 +1,417 @@
package service
import (
"dianshang/internal/model"
"dianshang/internal/repository"
"dianshang/pkg/jwt"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"time"
"gorm.io/gorm"
)
// WeChatService 微信服务
type WeChatService struct {
userRepo *repository.UserRepository
pointsService *PointsService
db *gorm.DB
appID string
appSecret string
}
// NewWeChatService 创建微信服务实例
func NewWeChatService(db *gorm.DB, pointsService *PointsService, appID, appSecret string) *WeChatService {
// 初始化随机数种子
rand.Seed(time.Now().UnixNano())
return &WeChatService{
userRepo: repository.NewUserRepository(db),
pointsService: pointsService,
db: db,
appID: appID,
appSecret: appSecret,
}
}
// WeChatLoginResponse 微信登录响应
type WeChatLoginResponse struct {
OpenID string `json:"openid"`
SessionKey string `json:"session_key"`
UnionID string `json:"unionid"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
// WeChatUserInfo 微信用户信息
type WeChatUserInfo struct {
OpenID string `json:"openId"`
NickName string `json:"nickName"`
Gender int `json:"gender"`
City string `json:"city"`
Province string `json:"province"`
Country string `json:"country"`
AvatarURL string `json:"avatarUrl"`
Language string `json:"language"`
}
// Login 微信登录
func (s *WeChatService) Login(code string, ip string, userAgent string) (*model.User, string, error) {
// 验证输入参数
if code == "" {
return nil, "", errors.New("微信登录code不能为空")
}
fmt.Printf("开始微信登录流程: code=%s\n", code)
// 1. 调用微信API获取openid和session_key
wechatResp, err := s.getWeChatSession(code)
if err != nil {
s.logUserLogin(0, "wechat", false, fmt.Sprintf("获取微信会话失败: %v", err), ip, userAgent)
return nil, "", fmt.Errorf("获取微信会话失败: %v", err)
}
if wechatResp.ErrCode != 0 {
errorMsg := fmt.Sprintf("微信API返回错误: code=%d, msg=%s", wechatResp.ErrCode, wechatResp.ErrMsg)
s.logUserLogin(0, "wechat", false, errorMsg, ip, userAgent)
return nil, "", fmt.Errorf("微信登录失败: %s", wechatResp.ErrMsg)
}
fmt.Printf("成功获取微信会话: OpenID=%s\n", wechatResp.OpenID)
// 2. 查找或创建用户
user, err := s.findOrCreateUser(wechatResp)
if err != nil {
s.logUserLogin(0, "wechat", false, fmt.Sprintf("用户处理失败: %v", err), ip, userAgent)
return nil, "", fmt.Errorf("用户处理失败: %v", err)
}
// 3. 保存微信会话信息
if err := s.saveWeChatSession(user.ID, wechatResp); err != nil {
s.logUserLogin(user.ID, "wechat", false, fmt.Sprintf("保存会话失败: %v", err), ip, userAgent)
return nil, "", fmt.Errorf("保存会话失败: %v", err)
}
// 4. 生成自定义登录态JWT token
// 按照微信官方建议,生成自定义登录态用于维护用户登录状态
tokenExpiry := 7 * 24 * 3600 // 7天有效期与session_key保持一致
token, err := jwt.GenerateToken(user.ID, "user", tokenExpiry)
if err != nil {
s.logUserLogin(user.ID, "wechat", false, fmt.Sprintf("生成token失败: %v", err), ip, userAgent)
return nil, "", fmt.Errorf("生成自定义登录态失败: %v", err)
}
// 5. 检查并给予每日首次登录积分
if s.pointsService != nil {
awarded, err := s.pointsService.CheckAndGiveDailyLoginPoints(user.ID)
if err != nil {
fmt.Printf("每日登录积分处理失败: %v\n", err)
} else if awarded {
fmt.Printf("用户 %d 获得每日首次登录积分\n", user.ID)
}
}
// 6. 记录登录日志
s.logUserLogin(user.ID, "wechat", true, "", ip, userAgent)
fmt.Printf("微信登录成功: UserID=%d, OpenID=%s, Token生成完成\n", user.ID, user.OpenID)
return user, token, nil
}
// LoginWithUserInfo 微信登录并更新用户信息
func (s *WeChatService) LoginWithUserInfo(code string, userInfo WeChatUserInfo, ip string, userAgent string) (*model.User, string, error) {
// 1. 先进行基本登录
user, token, err := s.Login(code, ip, userAgent)
if err != nil {
return nil, "", err
}
// 2. 更新用户信息
if err := s.updateUserInfo(user.ID, userInfo); err != nil {
return nil, "", fmt.Errorf("更新用户信息失败: %v", err)
}
// 3. 重新获取用户信息
updatedUser, err := s.userRepo.GetByID(user.ID)
if err != nil {
return nil, "", fmt.Errorf("获取用户信息失败: %v", err)
}
return updatedUser, token, nil
}
// getWeChatSession 获取微信会话按照官方文档标准实现code2Session
func (s *WeChatService) getWeChatSession(code string) (*WeChatLoginResponse, error) {
// 验证code格式
if code == "" {
return nil, errors.New("登录凭证code不能为空")
}
if len(code) < 10 {
return nil, errors.New("登录凭证code格式异常")
}
// 开发模式如果AppSecret是占位符或为空返回模拟数据
// 注意当配置了真实的AppSecret时会调用微信官方API
if s.appSecret == "your-wechat-app-secret" || s.appSecret == "your_wechat_appsecret" || s.appSecret == "" {
// 在开发模式下使用固定的OpenID来模拟同一个微信用户
// 这样可以避免每次登录都创建新用户的问题
return &WeChatLoginResponse{
OpenID: "dev_openid_fixed_user_001", // 使用固定的OpenID
SessionKey: "dev_session_key_" + time.Now().Format("20060102150405"),
UnionID: "dev_unionid_fixed_user_001", // 使用固定的UnionID
ErrCode: 0,
ErrMsg: "",
}, nil
}
// 按照微信官方文档调用auth.code2Session接口
url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
s.appID, s.appSecret, code)
// 创建HTTP客户端设置超时
client := &http.Client{
Timeout: 10 * time.Second,
}
resp, err := client.Get(url)
if err != nil {
return nil, fmt.Errorf("调用微信API失败: %v", err)
}
defer resp.Body.Close()
// 检查HTTP状态码
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("微信API返回异常状态码: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取微信API响应失败: %v", err)
}
var wechatResp WeChatLoginResponse
if err := json.Unmarshal(body, &wechatResp); err != nil {
return nil, fmt.Errorf("解析微信API响应失败: %v", err)
}
// 检查微信API返回的错误
if wechatResp.ErrCode != 0 {
return nil, fmt.Errorf("微信API错误 [%d]: %s", wechatResp.ErrCode, wechatResp.ErrMsg)
}
// 验证必要字段
if wechatResp.OpenID == "" {
return nil, errors.New("微信API未返回OpenID")
}
if wechatResp.SessionKey == "" {
return nil, errors.New("微信API未返回SessionKey")
}
return &wechatResp, nil
}
// generateRandomUsername 生成随机用户名,格式为"用户xxxxxxxx"(包含字母和数字)
func (s *WeChatService) generateRandomUsername() string {
// 定义字符集:数字和小写字母
charset := "0123456789abcdefghijklmnopqrstuvwxyz"
// 生成8位随机字符串
randomSuffix := make([]byte, 8)
for i := range randomSuffix {
randomSuffix[i] = charset[rand.Intn(len(charset))]
}
return fmt.Sprintf("用户%s", string(randomSuffix))
}
// findOrCreateUser 查找或创建用户
func (s *WeChatService) findOrCreateUser(wechatResp *WeChatLoginResponse) (*model.User, error) {
// 验证必要参数
if wechatResp.OpenID == "" {
return nil, errors.New("微信OpenID不能为空")
}
// 先尝试通过openid查找用户
user, err := s.userRepo.GetByOpenID(wechatResp.OpenID)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("查询用户失败: %v", err)
}
} else {
// 用户已存在,检查状态
if user.Status == 0 {
return nil, errors.New("用户已被禁用,请联系客服")
}
fmt.Printf("找到已存在用户: ID=%d, OpenID=%s, Nickname=%s\n", user.ID, user.OpenID, user.Nickname)
return user, nil
}
// 用户不存在,创建新用户
fmt.Printf("用户不存在,开始创建新用户: OpenID=%s\n", wechatResp.OpenID)
// 生成随机用户名,格式为"用户xxxxxxxx"
randomUsername := s.generateRandomUsername()
user = &model.User{
OpenID: wechatResp.OpenID,
UnionID: wechatResp.UnionID,
Nickname: randomUsername,
Avatar: "", // 默认头像为空,后续可通过授权获取
Status: 1, // 1表示正常状态
Level: 1, // 初始等级为1
Gender: 0, // 0表示未知性别
}
if err := s.userRepo.Create(user); err != nil {
return nil, fmt.Errorf("创建用户失败: %v", err)
}
fmt.Printf("成功创建新用户: ID=%d, OpenID=%s, Nickname=%s\n", user.ID, user.OpenID, user.Nickname)
return user, nil
}
// saveWeChatSession 保存微信会话信息安全存储session_key
func (s *WeChatService) saveWeChatSession(userID uint, wechatResp *WeChatLoginResponse) error {
// session_key是敏感信息需要安全存储
// 在生产环境中建议对session_key进行加密存储
// 计算session_key过期时间微信session_key有效期通常为7天
sessionExpiry := time.Now().Add(7 * 24 * time.Hour)
// 简单示例:保存到用户表的额外字段中
// 在生产环境中建议使用专门的会话表或Redis等缓存存储
updates := map[string]interface{}{
"open_id": wechatResp.OpenID,
"wechat_session_key": wechatResp.SessionKey, // 生产环境中应加密存储
"union_id": wechatResp.UnionID,
"session_expiry": sessionExpiry,
"updated_at": time.Now(),
}
if err := s.db.Model(&model.User{}).Where("id = ?", userID).Updates(updates).Error; err != nil {
return fmt.Errorf("保存微信会话信息失败: %v", err)
}
// 记录会话创建日志
fmt.Printf("用户 %d 的微信会话已保存OpenID: %s, 过期时间: %s\n",
userID, wechatResp.OpenID, sessionExpiry.Format("2006-01-02 15:04:05"))
return nil
}
// updateUserInfo 更新用户信息
func (s *WeChatService) updateUserInfo(userID uint, userInfo WeChatUserInfo) error {
updates := map[string]interface{}{
"nickname": userInfo.NickName,
"avatar": userInfo.AvatarURL,
"gender": userInfo.Gender,
}
if err := s.userRepo.Update(userID, updates); err != nil {
return err
}
// 获取用户的openid从ai_users表中获取
user, err := s.userRepo.GetByID(userID)
if err != nil {
return fmt.Errorf("获取用户信息失败: %v", err)
}
// 保存详细的微信用户信息
wechatUserInfo := struct {
ID uint `gorm:"primaryKey"`
UserID uint `gorm:"column:user_id;not null;unique"`
OpenID string `gorm:"column:openid;not null;unique"`
Nickname string `gorm:"column:nickname"`
AvatarURL string `gorm:"column:avatar_url"`
Gender int `gorm:"column:gender"`
Country string `gorm:"column:country"`
Province string `gorm:"column:province"`
City string `gorm:"column:city"`
Language string `gorm:"column:language"`
CreatedAt time.Time `gorm:"column:created_at"`
UpdatedAt time.Time `gorm:"column:updated_at"`
}{
UserID: userID,
OpenID: user.OpenID, // 使用从数据库获取的openid
Nickname: userInfo.NickName,
AvatarURL: userInfo.AvatarURL,
Gender: userInfo.Gender,
Country: userInfo.Country,
Province: userInfo.Province,
City: userInfo.City,
Language: userInfo.Language,
}
return s.db.Table("ai_wechat_user_info").Save(&wechatUserInfo).Error
}
// logUserLogin 记录用户登录日志
func (s *WeChatService) logUserLogin(userID uint, loginType string, success bool, errorMsg string, ip string, userAgent string) {
status := 1
if !success {
status = 0
}
// 使用LogService创建登录日志
logService := NewLogService(s.db)
remark := loginType
if errorMsg != "" {
remark = fmt.Sprintf("%s: %s", loginType, errorMsg)
}
err := logService.CreateLoginLog(userID, ip, userAgent, status, remark)
if err != nil {
fmt.Printf("创建登录日志失败: %v\n", err)
}
}
// GetUserSession 获取用户会话信息
func (s *WeChatService) GetUserSession(userID uint) (map[string]interface{}, error) {
var user model.User
err := s.db.Where("id = ?", userID).First(&user).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("用户不存在")
}
return nil, fmt.Errorf("查询用户失败: %v", err)
}
// 检查session是否过期
if user.SessionExpiry != nil && user.SessionExpiry.Before(time.Now()) {
return nil, errors.New("会话已过期")
}
return map[string]interface{}{
"session_key": user.WeChatSessionKey,
"openid": user.OpenID,
"unionid": user.UnionID,
"expires_at": user.SessionExpiry,
}, nil
}
// ValidateSessionKey 验证session_key有效性
func (s *WeChatService) ValidateSessionKey(userID uint) (bool, error) {
session, err := s.GetUserSession(userID)
if err != nil {
return false, err
}
// 检查session_key是否存在
sessionKey, ok := session["session_key"].(string)
if !ok || sessionKey == "" {
return false, errors.New("session_key不存在")
}
// 检查过期时间
expiresAt, ok := session["expires_at"].(*time.Time)
if ok && expiresAt != nil && expiresAt.Before(time.Now()) {
return false, errors.New("session_key已过期")
}
return true, nil
}

View File

@@ -0,0 +1,911 @@
package service
import (
"context"
"crypto/rsa"
"dianshang/internal/config"
"dianshang/internal/model"
"dianshang/internal/repository"
"dianshang/pkg/logger"
"dianshang/pkg/utils"
"encoding/json"
"fmt"
"log"
"strconv"
"time"
"github.com/wechatpay-apiv3/wechatpay-go/core"
"github.com/wechatpay-apiv3/wechatpay-go/core/option"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi"
"github.com/wechatpay-apiv3/wechatpay-go/services/refunddomestic"
wechatutils "github.com/wechatpay-apiv3/wechatpay-go/utils"
)
type WeChatPayService struct {
config *config.WeChatPayConfig
client *core.Client
jsapiSvc *jsapi.JsapiApiService
refundSvc *refunddomestic.RefundsApiService
privateKey *rsa.PrivateKey
orderRepo *repository.OrderRepository
refundSvcRef *RefundService
}
func NewWeChatPayService(cfg *config.WeChatPayConfig, orderRepo *repository.OrderRepository, refundService *RefundService) (*WeChatPayService, error) {
// 检查是否为沙盒环境
if cfg.Environment == "sandbox" {
logger.Info("微信支付配置为沙盒模式,将使用模拟支付")
return &WeChatPayService{
config: cfg,
orderRepo: orderRepo,
refundSvcRef: refundService,
}, nil
}
// 生产环境:加载商户私钥
privateKey, err := wechatutils.LoadPrivateKeyWithPath(cfg.KeyPath)
if err != nil {
logger.Warn("加载商户私钥失败,将使用模拟模式", "error", err)
// 在开发环境下允许没有私钥,使用模拟模式
return &WeChatPayService{
config: cfg,
orderRepo: orderRepo,
refundSvcRef: refundService,
}, nil
}
ctx := context.Background()
// 使用商户私钥等初始化 client并使它具有自动定时获取微信支付平台证书的能力
opts := []core.ClientOption{
option.WithWechatPayAutoAuthCipher(cfg.MchID, cfg.SerialNo, privateKey, cfg.APIv3Key),
}
client, err := core.NewClient(ctx, opts...)
if err != nil {
logger.Warn("初始化微信支付客户端失败,将使用模拟模式", "error", err)
// 在开发环境下允许客户端初始化失败,使用模拟模式
return &WeChatPayService{
config: cfg,
orderRepo: orderRepo,
refundSvcRef: refundService,
}, nil
}
// 创建JSAPI服务
jsapiSvc := &jsapi.JsapiApiService{Client: client}
// 创建退款服务
refundSvc := &refunddomestic.RefundsApiService{Client: client}
logger.Info("微信支付客户端初始化成功",
"mchId", cfg.MchID,
"serialNo", cfg.SerialNo,
"environment", cfg.Environment)
return &WeChatPayService{
config: cfg,
client: client,
jsapiSvc: jsapiSvc,
refundSvc: refundSvc,
privateKey: privateKey,
orderRepo: orderRepo,
refundSvcRef: refundService,
}, nil
}
// CreateOrder 创建支付订单
func (s *WeChatPayService) CreateOrder(ctx context.Context, order *model.Order, openID string) (*WeChatPayResponse, error) {
// 生成唯一的微信支付订单号
wechatOutTradeNo := utils.GenerateWechatOutTradeNo()
logger.Info("开始创建微信支付订单",
"orderNo", order.OrderNo,
"wechatOutTradeNo", wechatOutTradeNo,
"openID", openID,
"totalAmount", order.TotalAmount,
"hasClient", s.client != nil)
// 更新订单的微信支付订单号
err := s.orderRepo.UpdateByOrderNo(order.OrderNo, map[string]interface{}{
"wechat_out_trade_no": wechatOutTradeNo,
"updated_at": time.Now(),
})
if err != nil {
logger.Error("更新订单微信支付订单号失败", "error", err, "orderNo", order.OrderNo)
return nil, fmt.Errorf("更新订单失败: %v", err)
}
// 如果没有客户端(开发环境),使用模拟数据
if s.client == nil {
logger.Warn("开发环境下使用模拟支付数据")
return s.createMockPayment(order, openID)
}
// 构建预支付请求,使用唯一的微信支付订单号
req := jsapi.PrepayRequest{
Appid: core.String(s.config.AppID),
Mchid: core.String(s.config.MchID),
Description: core.String(fmt.Sprintf("订单号: %s", order.OrderNo)),
OutTradeNo: core.String(wechatOutTradeNo), // 使用唯一的微信支付订单号
NotifyUrl: core.String(s.config.NotifyURL),
Amount: &jsapi.Amount{
Total: core.Int64(int64(order.TotalAmount)), // 金额已经是分为单位,无需转换
Currency: core.String("CNY"),
},
Payer: &jsapi.Payer{
Openid: core.String(openID),
},
}
// 使用PrepayWithRequestPayment方法直接获取调起支付的参数
resp, result, err := s.jsapiSvc.PrepayWithRequestPayment(ctx, req)
if err != nil {
log.Printf("call PrepayWithRequestPayment err:%s", err)
logger.Error("创建支付订单失败", "error", err, "orderNo", order.OrderNo)
return nil, fmt.Errorf("创建支付订单失败: %v", err)
}
if result.Response.StatusCode != 200 {
log.Printf("PrepayWithRequestPayment status=%d", result.Response.StatusCode)
return nil, fmt.Errorf("预支付请求失败,状态码: %d", result.Response.StatusCode)
}
log.Printf("PrepayWithRequestPayment success, prepay_id=%s", *resp.PrepayId)
logger.Info("微信支付API响应",
"prepayId", *resp.PrepayId,
"orderNo", order.OrderNo)
// 直接使用SDK返回的支付参数
payParams := &MiniProgramPayParams{
AppID: *resp.Appid,
TimeStamp: *resp.TimeStamp,
NonceStr: *resp.NonceStr,
Package: *resp.Package,
SignType: *resp.SignType,
PaySign: *resp.PaySign,
}
return &WeChatPayResponse{
Code: 0,
Message: "success",
Data: map[string]interface{}{
"payInfo": payParams,
},
}, nil
}
// createMockPayment 创建模拟支付数据(沙盒环境使用)
func (s *WeChatPayService) createMockPayment(order *model.Order, openID string) (*WeChatPayResponse, error) {
mockPrepayID := fmt.Sprintf("wx%d%s", time.Now().Unix(), generateNonceStr()[:8])
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
nonceStr := generateNonceStr()
// 生成更真实的模拟签名
mockSign := fmt.Sprintf("sandbox_%s_%s", nonceStr[:16], timestamp)
payParams := &MiniProgramPayParams{
AppID: s.config.AppID,
TimeStamp: timestamp,
NonceStr: nonceStr,
Package: fmt.Sprintf("prepay_id=%s", mockPrepayID),
SignType: "RSA",
PaySign: mockSign,
}
logger.Info("生成沙盒支付参数",
"environment", s.config.Environment,
"prepayId", mockPrepayID,
"orderNo", order.OrderNo,
"openID", openID,
"totalAmount", order.TotalAmount,
"description", fmt.Sprintf("订单号: %s", order.OrderNo))
return &WeChatPayResponse{
Code: 0,
Message: "沙盒支付创建成功",
Data: map[string]interface{}{
"payInfo": payParams,
"sandbox": true,
"tips": "这是沙盒环境的模拟支付,可以直接调用成功",
},
}, nil
}
// QueryOrder 查询订单
func (s *WeChatPayService) QueryOrder(ctx context.Context, orderNo string) (*model.Order, error) {
logger.Info("开始查询订单",
"orderNo", orderNo,
"hasClient", s.client != nil,
"environment", s.config.Environment)
// 如果没有客户端(沙盒环境或开发环境),返回模拟数据
if s.client == nil {
if s.config.Environment == "sandbox" {
logger.Info("沙盒环境下返回模拟查询结果")
} else {
logger.Warn("开发环境下返回模拟查询结果")
}
// 模拟不同的支付状态,让测试更真实
var status int
if time.Now().Unix()%3 == 0 {
status = 1 // 未付款
} else {
status = 2 // 已付款
}
return &model.Order{
OrderNo: orderNo,
TotalAmount: 100.0, // 模拟金额
Status: status,
}, nil
}
// 首先从数据库获取订单信息
order, err := s.orderRepo.GetByOrderNo(orderNo)
if err != nil {
logger.Error("从数据库获取订单失败", "error", err, "orderNo", orderNo)
return nil, fmt.Errorf("订单不存在: %v", err)
}
// 如果没有微信支付订单号,说明还没有创建过微信支付订单
if order.WechatOutTradeNo == "" {
logger.Warn("订单尚未创建微信支付订单", "orderNo", orderNo)
return order, nil
}
// 使用微信支付订单号查询微信支付状态
req := jsapi.QueryOrderByOutTradeNoRequest{
OutTradeNo: core.String(order.WechatOutTradeNo),
Mchid: core.String(s.config.MchID),
}
resp, result, err := s.jsapiSvc.QueryOrderByOutTradeNo(ctx, req)
if err != nil {
log.Printf("call QueryOrderByOutTradeNo err:%s", err)
logger.Error("查询微信支付订单失败", "error", err, "wechatOutTradeNo", order.WechatOutTradeNo)
return nil, fmt.Errorf("查询微信支付订单失败: %v", err)
}
if result.Response.StatusCode != 200 {
log.Printf("QueryOrderByOutTradeNo status=%d", result.Response.StatusCode)
return nil, fmt.Errorf("查询微信支付订单失败,状态码: %d", result.Response.StatusCode)
}
log.Printf("QueryOrderByOutTradeNo success, resp=%+v", resp)
logger.Info("查询微信支付订单成功",
"orderNo", orderNo,
"wechatOutTradeNo", order.WechatOutTradeNo,
"tradeState", *resp.TradeState)
// 更新订单的微信交易号和支付状态
wechatStatus := convertWeChatPayStatus(*resp.TradeState)
updates := map[string]interface{}{
"updated_at": time.Now(),
}
// 如果有微信交易号,保存到数据库
if resp.TransactionId != nil {
updates["wechat_transaction_id"] = *resp.TransactionId
}
// 如果微信支付状态是已支付,更新订单状态
if wechatStatus == 2 && order.Status == 1 {
updates["status"] = 2
updates["pay_status"] = 1
updates["paid_at"] = time.Now()
}
// 更新订单信息
if len(updates) > 1 { // 除了updated_at还有其他字段需要更新
err = s.orderRepo.UpdateByOrderNo(orderNo, updates)
if err != nil {
logger.Error("更新订单微信支付信息失败", "error", err, "orderNo", orderNo)
}
}
// 更新订单对象的状态和金额信息
order.TotalAmount = float64(*resp.Amount.Total) // 保持分为单位,与系统内部一致
order.Status = wechatStatus
if resp.TransactionId != nil {
order.WechatTransactionID = *resp.TransactionId
}
return order, nil
}
// HandleNotify 处理支付回调
func (s *WeChatPayService) HandleNotify(ctx context.Context, body []byte, headers map[string]string) (*WeChatPayNotify, error) {
// 解析回调数据
var notify WeChatPayNotify
if err := json.Unmarshal(body, &notify); err != nil {
return nil, fmt.Errorf("解析回调数据失败: %v", err)
}
logger.Info("收到微信支付回调",
"eventType", notify.EventType,
"id", notify.ID,
"algorithm", notify.Resource.Algorithm)
// 解密resource中的数据
if notify.Resource.Ciphertext != "" {
// 使用AEAD_AES_256_GCM算法解密
decryptedData, err := s.decryptNotifyResource(
notify.Resource.Ciphertext,
notify.Resource.Nonce,
notify.Resource.AssociatedData,
)
if err != nil {
logger.Error("解密回调数据失败", "error", err)
return nil, fmt.Errorf("解密回调数据失败: %v", err)
}
// 解析解密后的JSON数据
var paymentData WeChatPayNotifyData
if err := json.Unmarshal(decryptedData, &paymentData); err != nil {
logger.Error("解析解密数据失败", "error", err, "data", string(decryptedData))
return nil, fmt.Errorf("解析解密数据失败: %v", err)
}
notify.DecryptedData = &paymentData
logger.Info("成功解密回调数据",
"outTradeNo", paymentData.OutTradeNo,
"transactionID", paymentData.TransactionID,
"tradeState", paymentData.TradeState)
}
return &notify, nil
}
// decryptNotifyResource 解密回调通知中的resource数据
func (s *WeChatPayService) decryptNotifyResource(ciphertext, nonce, associatedData string) ([]byte, error) {
// 使用wechatpay-go SDK提供的解密工具
plaintext, err := wechatutils.DecryptAES256GCM(s.config.APIv3Key, associatedData, nonce, ciphertext)
if err != nil {
return nil, fmt.Errorf("AES解密失败: %v", err)
}
return []byte(plaintext), nil
}
// ProcessPaymentSuccess 处理支付成功回调
func (s *WeChatPayService) ProcessPaymentSuccess(ctx context.Context, notify *WeChatPayNotify) error {
if notify.EventType != "TRANSACTION.SUCCESS" {
return fmt.Errorf("不是支付成功回调: %s", notify.EventType)
}
logger.Info("开始处理支付成功回调", "eventType", notify.EventType)
// 解析回调数据中的订单信息
var orderNo string
var transactionID string
// 如果有解密数据,从中获取订单号
if notify.DecryptedData != nil {
orderNo = notify.DecryptedData.OutTradeNo
transactionID = notify.DecryptedData.TransactionID
logger.Info("从解密数据中获取订单信息", "orderNo", orderNo, "transactionID", transactionID)
} else {
// 开发环境下可能需要从Resource字段中解析
// 或者从其他地方获取订单号
logger.Warn("回调数据中没有解密数据尝试从Resource字段获取")
// 在开发环境下我们可以尝试解析Resource中的数据
if notify.Resource.Ciphertext != "" {
// 这里可以添加解密逻辑,但在开发环境下我们先跳过
logger.Info("Resource中有加密数据但开发环境暂不解密")
}
// 如果无法获取订单号,我们可以从最近的订单中查找
// 这是一个临时的开发环境解决方案
logger.Warn("无法从回调数据中获取订单号,这可能是开发环境的模拟回调")
return fmt.Errorf("无法从回调数据中获取订单号")
}
if orderNo == "" {
return fmt.Errorf("回调数据中缺少订单号")
}
logger.Info("处理支付成功回调", "orderNo", orderNo)
// 查询订单
order, err := s.orderRepo.GetOrderByWechatOutTradeNo(orderNo)
if err != nil {
logger.Error("根据微信订单号查询订单失败", "error", err, "wechatOutTradeNo", orderNo)
return fmt.Errorf("订单不存在: %v", err)
}
// 检查订单状态,避免重复处理
if order.Status >= 2 {
logger.Info("订单已经是已支付状态,跳过处理", "orderNo", order.OrderNo, "status", order.Status)
return nil
}
logger.Info("开始更新订单状态", "orderNo", order.OrderNo, "currentStatus", order.Status)
// 更新订单状态为已支付
updates := map[string]interface{}{
"status": 2, // 已支付
"pay_status": 1, // 已支付
"pay_time": time.Now(),
"updated_at": time.Now(),
}
// 如果有微信交易号,也保存
if transactionID != "" {
updates["wechat_transaction_id"] = transactionID
logger.Info("保存微信交易号", "transactionID", transactionID)
}
err = s.orderRepo.UpdateByOrderNo(order.OrderNo, updates)
if err != nil {
logger.Error("更新订单支付状态失败", "error", err, "orderNo", order.OrderNo)
return fmt.Errorf("更新订单状态失败: %v", err)
}
logger.Info("订单支付状态更新成功", "orderNo", order.OrderNo, "newStatus", 2)
return nil
}
// ProcessPaymentSuccessByOrderNo 根据订单号手动处理支付成功(用于测试)
func (s *WeChatPayService) ProcessPaymentSuccessByOrderNo(ctx context.Context, orderNo string) error {
logger.Info("手动处理支付成功", "orderNo", orderNo)
// 查询订单
order, err := s.orderRepo.GetByOrderNo(orderNo)
if err != nil {
logger.Error("查询订单失败", "error", err, "orderNo", orderNo)
return fmt.Errorf("订单不存在: %v", err)
}
// 检查订单状态,避免重复处理
if order.Status >= 2 {
logger.Info("订单已经是已支付状态,跳过处理", "orderNo", order.OrderNo, "status", order.Status)
return nil
}
logger.Info("开始更新订单状态", "orderNo", order.OrderNo, "currentStatus", order.Status)
// 更新订单状态为已支付
updates := map[string]interface{}{
"status": 2, // 已支付
"pay_status": 1, // 已支付
"pay_time": time.Now(),
"updated_at": time.Now(),
}
err = s.orderRepo.UpdateByOrderNo(order.OrderNo, updates)
if err != nil {
logger.Error("更新订单支付状态失败", "error", err, "orderNo", order.OrderNo)
return fmt.Errorf("更新订单状态失败: %v", err)
}
logger.Info("订单支付状态更新成功", "orderNo", order.OrderNo, "newStatus", 2)
return nil
}
// CreateRefund 创建微信退款
func (s *WeChatPayService) CreateRefund(ctx context.Context, refundRecord *model.Refund, order *model.Order) (*WeChatRefundResponse, error) {
logger.Info("开始创建微信退款",
"refundNo", refundRecord.RefundNo,
"orderNo", order.OrderNo,
"refundAmount", refundRecord.RefundAmount,
"hasClient", s.client != nil)
// 如果没有客户端(开发环境),使用模拟数据
if s.client == nil {
logger.Warn("开发环境下使用模拟退款数据")
return s.createMockRefund(refundRecord, order)
}
// 构建退款请求
req := refunddomestic.CreateRequest{
OutTradeNo: core.String(order.WechatOutTradeNo),
OutRefundNo: core.String(refundRecord.WechatOutRefundNo),
Reason: core.String(refundRecord.RefundReason),
FundsAccount: (*refunddomestic.ReqFundsAccount)(core.String("AVAILABLE")), // 可用余额退款
Amount: &refunddomestic.AmountReq{
Refund: core.Int64(int64(refundRecord.RefundAmount)),
Total: core.Int64(int64(order.TotalAmount)),
Currency: core.String("CNY"),
},
}
// 只有当RefundNotifyURL不为空时才设置NotifyUrl
if s.config.RefundNotifyURL != "" {
req.NotifyUrl = core.String(s.config.RefundNotifyURL)
}
// 如果有微信交易号,优先使用
if order.WechatTransactionID != "" {
req.TransactionId = core.String(order.WechatTransactionID)
req.OutTradeNo = nil // 使用微信交易号时,不需要商户订单号
}
// 调用微信退款API
resp, result, err := s.refundSvc.Create(ctx, req)
if err != nil {
log.Printf("call CreateRefund err:%s", err)
logger.Error("创建微信退款失败", "error", err, "refundNo", refundRecord.RefundNo)
return nil, fmt.Errorf("创建微信退款失败: %v", err)
}
if result.Response.StatusCode != 200 {
log.Printf("CreateRefund status=%d", result.Response.StatusCode)
return nil, fmt.Errorf("微信退款请求失败,状态码: %d", result.Response.StatusCode)
}
log.Printf("CreateRefund success, refund_id=%s", *resp.RefundId)
logger.Info("微信退款API响应",
"refundId", *resp.RefundId,
"refundNo", refundRecord.RefundNo,
"status", *resp.Status)
return &WeChatRefundResponse{
Code: 0,
Message: "success",
Data: map[string]interface{}{
"refund_id": *resp.RefundId,
"out_refund_no": *resp.OutRefundNo,
"transaction_id": getStringValue(resp.TransactionId),
"out_trade_no": getStringValue(resp.OutTradeNo),
"channel": getChannelValue(resp.Channel),
"user_received_account": getStringValue(resp.UserReceivedAccount),
"success_time": getTimeValue(resp.SuccessTime),
"create_time": getTimeValue(resp.CreateTime),
"status": getStatusValue(*resp.Status),
"funds_account": getFundsAccountValue(resp.FundsAccount),
"amount": map[string]interface{}{
"total": *resp.Amount.Total,
"refund": *resp.Amount.Refund,
"payer_total": getInt64Value(resp.Amount.PayerTotal),
"payer_refund": getInt64Value(resp.Amount.PayerRefund),
"settlement_refund": getInt64Value(resp.Amount.SettlementRefund),
"settlement_total": getInt64Value(resp.Amount.SettlementTotal),
"discount_refund": getInt64Value(resp.Amount.DiscountRefund),
"currency": *resp.Amount.Currency,
},
},
}, nil
}
// createMockRefund 创建模拟退款数据(开发环境使用)
func (s *WeChatPayService) createMockRefund(refundRecord *model.Refund, order *model.Order) (*WeChatRefundResponse, error) {
logger.Info("创建模拟退款数据", "refundNo", refundRecord.RefundNo)
// 生成模拟的微信退款ID
mockRefundID := fmt.Sprintf("mock_refund_%d", time.Now().Unix())
return &WeChatRefundResponse{
Code: 0,
Message: "success",
Data: map[string]interface{}{
"refund_id": mockRefundID,
"out_refund_no": refundRecord.WechatOutRefundNo,
"transaction_id": order.WechatTransactionID,
"out_trade_no": order.WechatOutTradeNo,
"channel": "ORIGINAL",
"user_received_account": "招商银行信用卡0403",
"success_time": time.Now().Format("2006-01-02T15:04:05+08:00"),
"create_time": time.Now().Format("2006-01-02T15:04:05+08:00"),
"status": "SUCCESS",
"funds_account": "AVAILABLE",
"amount": map[string]interface{}{
"total": int64(order.TotalAmount),
"refund": int64(refundRecord.RefundAmount),
"payer_total": int64(order.TotalAmount),
"payer_refund": int64(refundRecord.RefundAmount),
"settlement_refund": int64(refundRecord.RefundAmount),
"settlement_total": int64(order.TotalAmount),
"discount_refund": int64(0),
"currency": "CNY",
},
},
}, nil
}
// QueryRefund 查询微信退款状态
func (s *WeChatPayService) QueryRefund(ctx context.Context, outRefundNo string) (*model.Refund, error) {
logger.Info("查询微信退款状态", "outRefundNo", outRefundNo)
// 如果没有客户端(开发环境),返回模拟数据
if s.client == nil {
logger.Warn("开发环境下使用模拟退款查询")
return s.queryMockRefund(outRefundNo)
}
// 构建查询请求
req := refunddomestic.QueryByOutRefundNoRequest{
OutRefundNo: core.String(outRefundNo),
}
// 调用微信查询退款API
resp, result, err := s.refundSvc.QueryByOutRefundNo(ctx, req)
if err != nil {
log.Printf("call QueryRefund err:%s", err)
logger.Error("查询微信退款失败", "error", err, "outRefundNo", outRefundNo)
return nil, fmt.Errorf("查询微信退款失败: %v", err)
}
if result.Response.StatusCode != 200 {
log.Printf("QueryRefund status=%d", result.Response.StatusCode)
return nil, fmt.Errorf("查询微信退款失败,状态码: %d", result.Response.StatusCode)
}
log.Printf("QueryRefund success, resp=%+v", resp)
logger.Info("查询微信退款成功",
"outRefundNo", outRefundNo,
"refundId", *resp.RefundId,
"status", *resp.Status)
// 构建返回的退款记录(这里只是示例,实际应该从数据库获取完整记录)
refundRecord := &model.Refund{
WechatRefundID: *resp.RefundId,
WechatOutRefundNo: *resp.OutRefundNo,
WechatRefundStatus: getStatusValue(*resp.Status),
WechatUserReceivedAccount: getStringValue(resp.UserReceivedAccount),
WechatRefundAccount: getFundsAccountValue(resp.FundsAccount),
}
// 如果退款成功,设置成功时间
if getStatusValue(*resp.Status) == "SUCCESS" && resp.SuccessTime != nil {
successTime, err := time.Parse("2006-01-02T15:04:05+08:00", getTimeValue(resp.SuccessTime))
if err == nil {
refundRecord.WechatSuccessTime = &successTime
}
}
return refundRecord, nil
}
// queryMockRefund 查询模拟退款数据(开发环境使用)
func (s *WeChatPayService) queryMockRefund(outRefundNo string) (*model.Refund, error) {
logger.Info("查询模拟退款数据", "outRefundNo", outRefundNo)
now := time.Now()
return &model.Refund{
WechatRefundID: fmt.Sprintf("mock_refund_%d", now.Unix()),
WechatOutRefundNo: outRefundNo,
WechatRefundStatus: "SUCCESS",
WechatUserReceivedAccount: "招商银行信用卡0403",
WechatRefundAccount: "AVAILABLE",
WechatSuccessTime: &now,
}, nil
}
// HandleRefundNotify 处理微信退款回调
func (s *WeChatPayService) HandleRefundNotify(ctx context.Context, body []byte, headers map[string]string) (*WeChatRefundNotify, error) {
// 解析回调数据
var notify WeChatRefundNotify
if err := json.Unmarshal(body, &notify); err != nil {
return nil, fmt.Errorf("解析退款回调数据失败: %v", err)
}
logger.Info("收到微信退款回调",
"eventType", notify.EventType,
"id", notify.ID,
"algorithm", notify.Resource.Algorithm)
// 解密resource中的数据
if notify.Resource.Ciphertext != "" {
// 使用AEAD_AES_256_GCM算法解密
decryptedData, err := s.decryptNotifyResource(
notify.Resource.Ciphertext,
notify.Resource.Nonce,
notify.Resource.AssociatedData,
)
if err != nil {
logger.Error("解密退款回调数据失败", "error", err)
return nil, fmt.Errorf("解密退款回调数据失败: %v", err)
}
// 解析解密后的JSON数据
var refundData WeChatRefundNotifyData
if err := json.Unmarshal(decryptedData, &refundData); err != nil {
logger.Error("解析解密数据失败", "error", err, "data", string(decryptedData))
return nil, fmt.Errorf("解析解密数据失败: %v", err)
}
notify.DecryptedData = &refundData
logger.Info("成功解密退款回调数据",
"outRefundNo", refundData.OutRefundNo,
"refundId", refundData.RefundId,
"refundStatus", refundData.RefundStatus)
}
return &notify, nil
}
// ProcessRefundSuccess 处理退款成功回调
func (s *WeChatPayService) ProcessRefundSuccess(ctx context.Context, notify *WeChatRefundNotify) error {
if notify.EventType != "REFUND.SUCCESS" {
return fmt.Errorf("不是退款成功回调: %s", notify.EventType)
}
logger.Info("开始处理退款成功回调", "eventType", notify.EventType)
// 解析回调数据中的退款信息
var outRefundNo string
var refundID string
// 如果有解密数据,从中获取退款单号
if notify.DecryptedData != nil {
outRefundNo = notify.DecryptedData.OutRefundNo
refundID = notify.DecryptedData.RefundId
logger.Info("从解密数据中获取退款信息", "outRefundNo", outRefundNo, "refundId", refundID)
} else {
logger.Warn("退款回调数据中没有解密数据")
return fmt.Errorf("无法从回调数据中获取退款单号")
}
if outRefundNo == "" {
return fmt.Errorf("回调数据中缺少退款单号")
}
logger.Info("处理退款成功回调", "outRefundNo", outRefundNo)
// 这里应该调用退款服务来更新退款状态
// 由于这是在微信支付服务中,我们只记录日志,实际更新由退款服务处理
logger.Info("退款成功回调处理完成", "outRefundNo", outRefundNo, "refundId", refundID)
return nil
}
// 辅助函数
func getStringValue(ptr *string) string {
if ptr == nil {
return ""
}
return *ptr
}
func getInt64Value(ptr *int64) int64 {
if ptr == nil {
return 0
}
return *ptr
}
func getChannelValue(ptr *refunddomestic.Channel) string {
if ptr == nil {
return ""
}
return string(*ptr)
}
func getFundsAccountValue(ptr *refunddomestic.FundsAccount) string {
if ptr == nil {
return ""
}
return string(*ptr)
}
func getTimeValue(ptr *time.Time) string {
if ptr == nil {
return ""
}
return ptr.Format("2006-01-02T15:04:05+08:00")
}
func getStatusValue(status refunddomestic.Status) string {
return string(status)
}
// generateNonceStr 生成随机字符串用于微信支付
func generateNonceStr() string {
return utils.GenerateRandomString(32)
}
// convertWeChatPayStatus 将微信支付状态转换为订单状态
func convertWeChatPayStatus(wechatStatus string) int {
switch wechatStatus {
case "SUCCESS":
return model.OrderStatusPaid
case "REFUND":
return model.OrderStatusRefunded
case "NOTPAY":
return model.OrderStatusPending
case "CLOSED":
return model.OrderStatusCancelled
case "REVOKED":
return model.OrderStatusCancelled
case "USERPAYING":
return model.OrderStatusPending
case "PAYERROR":
return model.OrderStatusCancelled
default:
return model.OrderStatusPending
}
}
// 微信支付相关数据结构
type MiniProgramPayParams struct {
AppID string `json:"appId"`
TimeStamp string `json:"timeStamp"`
NonceStr string `json:"nonceStr"`
Package string `json:"package"`
SignType string `json:"signType"`
PaySign string `json:"paySign"`
}
type WeChatPayResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data map[string]interface{} `json:"data"`
}
type WeChatPayNotify struct {
ID string `json:"id"`
CreateTime string `json:"create_time"`
ResourceType string `json:"resource_type"`
EventType string `json:"event_type"`
Summary string `json:"summary"`
Resource struct {
OriginalType string `json:"original_type"`
Algorithm string `json:"algorithm"`
Ciphertext string `json:"ciphertext"`
AssociatedData string `json:"associated_data"`
Nonce string `json:"nonce"`
} `json:"resource"`
DecryptedData *WeChatPayNotifyData `json:"decrypted_data,omitempty"`
}
type WeChatPayNotifyData struct {
MchID string `json:"mchid"`
AppID string `json:"appid"`
OutTradeNo string `json:"out_trade_no"`
TransactionID string `json:"transaction_id"`
TradeType string `json:"trade_type"`
TradeState string `json:"trade_state"`
BankType string `json:"bank_type"`
SuccessTime string `json:"success_time"`
Payer struct {
OpenID string `json:"openid"`
} `json:"payer"`
Amount struct {
Total int `json:"total"`
PayerTotal int `json:"payer_total"`
Currency string `json:"currency"`
PayerCurrency string `json:"payer_currency"`
} `json:"amount"`
}
type WeChatRefundResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data map[string]interface{} `json:"data"`
}
type WeChatRefundNotify struct {
ID string `json:"id"`
CreateTime string `json:"create_time"`
ResourceType string `json:"resource_type"`
EventType string `json:"event_type"`
Summary string `json:"summary"`
Resource struct {
OriginalType string `json:"original_type"`
Algorithm string `json:"algorithm"`
Ciphertext string `json:"ciphertext"`
AssociatedData string `json:"associated_data"`
Nonce string `json:"nonce"`
} `json:"resource"`
DecryptedData *WeChatRefundNotifyData `json:"decrypted_data,omitempty"`
}
type WeChatRefundNotifyData struct {
MchID string `json:"mchid"`
OutTradeNo string `json:"out_trade_no"`
TransactionID string `json:"transaction_id"`
OutRefundNo string `json:"out_refund_no"`
RefundId string `json:"refund_id"`
RefundStatus string `json:"refund_status"`
SuccessTime string `json:"success_time"`
UserReceivedAccount string `json:"user_received_account"`
Amount struct {
Total int `json:"total"`
Refund int `json:"refund"`
PayerTotal int `json:"payer_total"`
PayerRefund int `json:"payer_refund"`
} `json:"amount"`
}