init
This commit is contained in:
226
server/internal/service/admin.go
Normal file
226
server/internal/service/admin.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"dianshang/internal/repository"
|
||||
"dianshang/pkg/jwt"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AdminService 管理员服务
|
||||
type AdminService struct {
|
||||
adminRepo *repository.AdminRepository
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewAdminService 创建管理员服务
|
||||
func NewAdminService(db *gorm.DB) *AdminService {
|
||||
return &AdminService{
|
||||
adminRepo: repository.NewAdminRepository(db),
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// LoginRequest 管理员登录请求
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
// CreateAdminRequest 创建管理员请求
|
||||
type CreateAdminRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Nickname string `json:"nickname"`
|
||||
Email string `json:"email"`
|
||||
Phone string `json:"phone"`
|
||||
RoleID uint `json:"role_id" binding:"required"`
|
||||
}
|
||||
|
||||
// UpdateAdminRequest 更新管理员请求
|
||||
type UpdateAdminRequest struct {
|
||||
Nickname string `json:"nickname"`
|
||||
Email string `json:"email"`
|
||||
Phone string `json:"phone"`
|
||||
RoleID uint `json:"role_id"`
|
||||
Status *uint8 `json:"status"`
|
||||
}
|
||||
|
||||
// AdminLoginResponse 管理员登录响应
|
||||
type AdminLoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
AdminUser *model.AdminUser `json:"admin_user"`
|
||||
}
|
||||
|
||||
// Login 管理员登录
|
||||
func (s *AdminService) Login(req *LoginRequest) (*AdminLoginResponse, error) {
|
||||
// 查找管理员
|
||||
admin, err := s.adminRepo.GetByUsername(req.Username)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("用户名或密码错误")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查管理员状态
|
||||
if admin.Status == 0 {
|
||||
return nil, errors.New("账户已被禁用")
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(admin.Password), []byte(req.Password)); err != nil {
|
||||
return nil, errors.New("用户名或密码错误")
|
||||
}
|
||||
|
||||
// 生成JWT token
|
||||
tokenExpiry := 8 * 3600 // 8小时有效期
|
||||
token, err := jwt.GenerateToken(admin.ID, "admin", tokenExpiry)
|
||||
if err != nil {
|
||||
return nil, errors.New("生成token失败")
|
||||
}
|
||||
|
||||
// 更新最后登录时间
|
||||
now := time.Now()
|
||||
admin.LastLogin = &now
|
||||
s.adminRepo.Update(admin.ID, map[string]interface{}{
|
||||
"last_login": now,
|
||||
})
|
||||
|
||||
// 加载角色信息
|
||||
admin, _ = s.adminRepo.GetByIDWithRole(admin.ID)
|
||||
|
||||
return &AdminLoginResponse{
|
||||
Token: token,
|
||||
AdminUser: admin,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateAdmin 创建管理员
|
||||
func (s *AdminService) CreateAdmin(req *CreateAdminRequest) (*model.AdminUser, error) {
|
||||
// 检查用户名是否已存在
|
||||
if _, err := s.adminRepo.GetByUsername(req.Username); err == nil {
|
||||
return nil, errors.New("用户名已存在")
|
||||
}
|
||||
|
||||
// 加密密码
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
// 创建管理员
|
||||
admin := &model.AdminUser{
|
||||
Username: req.Username,
|
||||
Password: string(hashedPassword),
|
||||
Nickname: req.Nickname,
|
||||
Email: req.Email,
|
||||
Phone: req.Phone,
|
||||
RoleID: req.RoleID,
|
||||
Status: 1, // 默认启用
|
||||
}
|
||||
|
||||
if err := s.adminRepo.Create(admin); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 返回时不包含密码
|
||||
admin.Password = ""
|
||||
return admin, nil
|
||||
}
|
||||
|
||||
// GetAdminList 获取管理员列表
|
||||
func (s *AdminService) GetAdminList(page, pageSize int, keyword string) ([]model.AdminUser, int64, error) {
|
||||
return s.adminRepo.GetList(page, pageSize, keyword)
|
||||
}
|
||||
|
||||
// GetAdminByID 根据ID获取管理员
|
||||
func (s *AdminService) GetAdminByID(id uint) (*model.AdminUser, error) {
|
||||
return s.adminRepo.GetByIDWithRole(id)
|
||||
}
|
||||
|
||||
// UpdateAdmin 更新管理员
|
||||
func (s *AdminService) UpdateAdmin(id uint, req *UpdateAdminRequest) error {
|
||||
updates := make(map[string]interface{})
|
||||
|
||||
if req.Nickname != "" {
|
||||
updates["nickname"] = req.Nickname
|
||||
}
|
||||
if req.Email != "" {
|
||||
updates["email"] = req.Email
|
||||
}
|
||||
if req.Phone != "" {
|
||||
updates["phone"] = req.Phone
|
||||
}
|
||||
if req.RoleID != 0 {
|
||||
updates["role_id"] = req.RoleID
|
||||
}
|
||||
if req.Status != nil {
|
||||
updates["status"] = *req.Status
|
||||
}
|
||||
|
||||
return s.adminRepo.Update(id, updates)
|
||||
}
|
||||
|
||||
// DeleteAdmin 删除管理员
|
||||
func (s *AdminService) DeleteAdmin(id uint) error {
|
||||
return s.adminRepo.Delete(id)
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (s *AdminService) ChangePassword(id uint, oldPassword, newPassword string) error {
|
||||
// 获取管理员信息
|
||||
admin, err := s.adminRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证旧密码
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(admin.Password), []byte(oldPassword)); err != nil {
|
||||
return errors.New("原密码错误")
|
||||
}
|
||||
|
||||
// 加密新密码
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
return s.adminRepo.Update(id, map[string]interface{}{
|
||||
"password": string(hashedPassword),
|
||||
})
|
||||
}
|
||||
|
||||
// GetProfile 获取管理员个人信息
|
||||
func (s *AdminService) GetProfile(id uint) (*model.AdminUser, error) {
|
||||
admin, err := s.adminRepo.GetByIDWithRole(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 不返回密码
|
||||
admin.Password = ""
|
||||
return admin, nil
|
||||
}
|
||||
|
||||
// UpdateProfile 更新管理员个人信息
|
||||
func (s *AdminService) UpdateProfile(id uint, nickname, email, phone string) error {
|
||||
updates := make(map[string]interface{})
|
||||
|
||||
if nickname != "" {
|
||||
updates["nickname"] = nickname
|
||||
}
|
||||
if email != "" {
|
||||
updates["email"] = email
|
||||
}
|
||||
if phone != "" {
|
||||
updates["phone"] = phone
|
||||
}
|
||||
|
||||
return s.adminRepo.Update(id, updates)
|
||||
}
|
||||
125
server/internal/service/aftersale.go
Normal file
125
server/internal/service/aftersale.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AfterSaleService 售后服务
|
||||
type AfterSaleService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewAfterSaleService 创建售后服务
|
||||
func NewAfterSaleService(db *gorm.DB) *AfterSaleService {
|
||||
return &AfterSaleService{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserAfterSales 获取用户售后列表
|
||||
func (s *AfterSaleService) GetUserAfterSales(userID uint, page, pageSize int, status int) ([]model.AfterSale, int64, error) {
|
||||
var afterSales []model.AfterSale
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.AfterSale{}).Where("user_id = ?", userID)
|
||||
|
||||
// 如果指定了状态,添加状态过滤
|
||||
if status > 0 {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询,预加载关联数据
|
||||
offset := (page - 1) * pageSize
|
||||
if err := query.Preload("Order").Preload("OrderItem").Preload("User").
|
||||
Order("created_at DESC").
|
||||
Offset(offset).Limit(pageSize).
|
||||
Find(&afterSales).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return afterSales, total, nil
|
||||
}
|
||||
|
||||
// GetAfterSaleDetail 获取售后详情
|
||||
func (s *AfterSaleService) GetAfterSaleDetail(userID, afterSaleID uint) (*model.AfterSale, error) {
|
||||
var afterSale model.AfterSale
|
||||
|
||||
if err := s.db.Where("id = ? AND user_id = ?", afterSaleID, userID).
|
||||
Preload("Order").Preload("OrderItem").Preload("User").
|
||||
First(&afterSale).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &afterSale, nil
|
||||
}
|
||||
|
||||
// CreateAfterSale 创建售后申请
|
||||
func (s *AfterSaleService) CreateAfterSale(userID uint, req *CreateAfterSaleRequest) (*model.AfterSale, error) {
|
||||
// 验证订单是否存在且属于该用户
|
||||
var order model.Order
|
||||
if err := s.db.Where("id = ? AND user_id = ?", req.OrderID, userID).First(&order).Error; err != nil {
|
||||
return nil, fmt.Errorf("订单不存在或无权限")
|
||||
}
|
||||
|
||||
// 验证订单项是否存在
|
||||
var orderItem model.OrderItem
|
||||
if err := s.db.Where("id = ? AND order_id = ?", req.OrderItemID, req.OrderID).First(&orderItem).Error; err != nil {
|
||||
return nil, fmt.Errorf("订单项不存在")
|
||||
}
|
||||
|
||||
// 创建售后记录
|
||||
afterSale := &model.AfterSale{
|
||||
OrderID: req.OrderID,
|
||||
OrderItemID: req.OrderItemID,
|
||||
UserID: userID,
|
||||
Type: req.Type,
|
||||
Reason: req.Reason,
|
||||
Description: req.Description,
|
||||
Images: req.Images,
|
||||
Status: 1, // 1待审核
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := s.db.Create(afterSale).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
if err := s.db.Preload("Order").Preload("OrderItem").Preload("User").
|
||||
First(afterSale, afterSale.ID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return afterSale, nil
|
||||
}
|
||||
|
||||
// CreateAfterSaleRequest 创建售后申请请求
|
||||
type CreateAfterSaleRequest struct {
|
||||
OrderID uint `json:"order_id" binding:"required"`
|
||||
OrderItemID uint `json:"order_item_id" binding:"required"`
|
||||
Type int `json:"type" binding:"required"` // 1退货,2换货,3维修
|
||||
Reason string `json:"reason" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Images model.JSONSlice `json:"images"`
|
||||
}
|
||||
|
||||
// UpdateAfterSaleStatus 更新售后状态
|
||||
func (s *AfterSaleService) UpdateAfterSaleStatus(afterSaleID uint, status int, adminRemark string) error {
|
||||
return s.db.Model(&model.AfterSale{}).
|
||||
Where("id = ?", afterSaleID).
|
||||
Updates(map[string]interface{}{
|
||||
"status": status,
|
||||
"admin_remark": adminRemark,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
319
server/internal/service/banner.go
Normal file
319
server/internal/service/banner.go
Normal file
@@ -0,0 +1,319 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"dianshang/internal/repository"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BannerService 轮播图服务
|
||||
type BannerService struct {
|
||||
bannerRepo *repository.BannerRepository
|
||||
}
|
||||
|
||||
// NewBannerService 创建轮播图服务
|
||||
func NewBannerService(bannerRepo *repository.BannerRepository) *BannerService {
|
||||
return &BannerService{
|
||||
bannerRepo: bannerRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetActiveBanners 获取有效的轮播图
|
||||
func (s *BannerService) GetActiveBanners() ([]model.Banner, error) {
|
||||
return s.bannerRepo.GetActiveBannersWithTimeRange()
|
||||
}
|
||||
|
||||
// GetBannerList 获取轮播图列表(分页)
|
||||
func (s *BannerService) GetBannerList(page, pageSize int, status *int) ([]model.Banner, int64, error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 || pageSize > 100 {
|
||||
pageSize = 10
|
||||
}
|
||||
|
||||
return s.bannerRepo.GetBannerList(page, pageSize, status)
|
||||
}
|
||||
|
||||
// GetBannerByID 根据ID获取轮播图
|
||||
func (s *BannerService) GetBannerByID(id uint) (*model.Banner, error) {
|
||||
if id == 0 {
|
||||
return nil, errors.New("轮播图ID不能为空")
|
||||
}
|
||||
|
||||
return s.bannerRepo.GetBannerByID(id)
|
||||
}
|
||||
|
||||
// CreateBanner 创建轮播图
|
||||
func (s *BannerService) CreateBanner(banner *model.Banner) error {
|
||||
// 验证必填字段
|
||||
if err := s.validateBanner(banner); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果没有设置排序值,自动设置为最大值+10
|
||||
if banner.Sort == 0 {
|
||||
maxSort, err := s.bannerRepo.GetMaxSort()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取最大排序值失败: %v", err)
|
||||
}
|
||||
banner.Sort = maxSort + 10
|
||||
}
|
||||
|
||||
// 设置默认状态
|
||||
if banner.Status == 0 {
|
||||
banner.Status = 1
|
||||
}
|
||||
|
||||
return s.bannerRepo.CreateBanner(banner)
|
||||
}
|
||||
|
||||
// UpdateBanner 更新轮播图
|
||||
func (s *BannerService) UpdateBanner(id uint, banner *model.Banner) error {
|
||||
if id == 0 {
|
||||
return errors.New("轮播图ID不能为空")
|
||||
}
|
||||
|
||||
// 检查轮播图是否存在
|
||||
exists, err := s.bannerRepo.CheckBannerExists(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("检查轮播图是否存在失败: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
return errors.New("轮播图不存在")
|
||||
}
|
||||
|
||||
// 验证必填字段
|
||||
if err := s.validateBanner(banner); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
banner.ID = id
|
||||
return s.bannerRepo.UpdateBanner(banner)
|
||||
}
|
||||
|
||||
// DeleteBanner 删除轮播图
|
||||
func (s *BannerService) DeleteBanner(id uint) error {
|
||||
if id == 0 {
|
||||
return errors.New("轮播图ID不能为空")
|
||||
}
|
||||
|
||||
// 检查轮播图是否存在
|
||||
exists, err := s.bannerRepo.CheckBannerExists(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("检查轮播图是否存在失败: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
return errors.New("轮播图不存在")
|
||||
}
|
||||
|
||||
return s.bannerRepo.DeleteBanner(id)
|
||||
}
|
||||
|
||||
// BatchDeleteBanners 批量删除轮播图
|
||||
func (s *BannerService) BatchDeleteBanners(ids []uint) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("轮播图ID列表不能为空")
|
||||
}
|
||||
|
||||
return s.bannerRepo.BatchDeleteBanners(ids)
|
||||
}
|
||||
|
||||
// UpdateBannerStatus 更新轮播图状态
|
||||
func (s *BannerService) UpdateBannerStatus(id uint, status int) error {
|
||||
if id == 0 {
|
||||
return errors.New("轮播图ID不能为空")
|
||||
}
|
||||
|
||||
if status < 0 || status > 1 {
|
||||
return errors.New("状态值无效,只能是0或1")
|
||||
}
|
||||
|
||||
// 检查轮播图是否存在
|
||||
exists, err := s.bannerRepo.CheckBannerExists(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("检查轮播图是否存在失败: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
return errors.New("轮播图不存在")
|
||||
}
|
||||
|
||||
return s.bannerRepo.UpdateBannerStatus(id, status)
|
||||
}
|
||||
|
||||
// BatchUpdateBannerStatus 批量更新轮播图状态
|
||||
func (s *BannerService) BatchUpdateBannerStatus(ids []uint, status int) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("轮播图ID列表不能为空")
|
||||
}
|
||||
|
||||
if status < 0 || status > 1 {
|
||||
return errors.New("状态值无效,只能是0或1")
|
||||
}
|
||||
|
||||
return s.bannerRepo.BatchUpdateBannerStatus(ids, status)
|
||||
}
|
||||
|
||||
// UpdateBannerSort 更新轮播图排序
|
||||
func (s *BannerService) UpdateBannerSort(id uint, sort int) error {
|
||||
if id == 0 {
|
||||
return errors.New("轮播图ID不能为空")
|
||||
}
|
||||
|
||||
if sort < 0 {
|
||||
return errors.New("排序值不能为负数")
|
||||
}
|
||||
|
||||
// 检查轮播图是否存在
|
||||
exists, err := s.bannerRepo.CheckBannerExists(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("检查轮播图是否存在失败: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
return errors.New("轮播图不存在")
|
||||
}
|
||||
|
||||
return s.bannerRepo.UpdateBannerSort(id, sort)
|
||||
}
|
||||
|
||||
// BatchUpdateBannerSort 批量更新轮播图排序
|
||||
func (s *BannerService) BatchUpdateBannerSort(sortData []map[string]interface{}) error {
|
||||
if len(sortData) == 0 {
|
||||
return errors.New("排序数据不能为空")
|
||||
}
|
||||
|
||||
// 验证排序数据
|
||||
for _, data := range sortData {
|
||||
id, ok := data["id"]
|
||||
if !ok {
|
||||
return errors.New("排序数据中缺少ID字段")
|
||||
}
|
||||
|
||||
sort, ok := data["sort"]
|
||||
if !ok {
|
||||
return errors.New("排序数据中缺少sort字段")
|
||||
}
|
||||
|
||||
// 类型检查
|
||||
if _, ok := id.(uint); !ok {
|
||||
if idFloat, ok := id.(float64); ok {
|
||||
data["id"] = uint(idFloat)
|
||||
} else {
|
||||
return errors.New("ID字段类型错误")
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := sort.(int); !ok {
|
||||
if sortFloat, ok := sort.(float64); ok {
|
||||
data["sort"] = int(sortFloat)
|
||||
} else {
|
||||
return errors.New("sort字段类型错误")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.bannerRepo.BatchUpdateBannerSort(sortData)
|
||||
}
|
||||
|
||||
// GetBannersByDateRange 根据日期范围获取轮播图
|
||||
func (s *BannerService) GetBannersByDateRange(startDate, endDate time.Time) ([]model.Banner, error) {
|
||||
if startDate.After(endDate) {
|
||||
return nil, errors.New("开始日期不能晚于结束日期")
|
||||
}
|
||||
|
||||
return s.bannerRepo.GetBannersByDateRange(startDate, endDate)
|
||||
}
|
||||
|
||||
// GetBannersByStatus 根据状态获取轮播图
|
||||
func (s *BannerService) GetBannersByStatus(status int) ([]model.Banner, error) {
|
||||
if status < 0 || status > 1 {
|
||||
return nil, errors.New("状态值无效,只能是0或1")
|
||||
}
|
||||
|
||||
return s.bannerRepo.GetBannersByStatus(status)
|
||||
}
|
||||
|
||||
// GetBannerStatistics 获取轮播图统计信息
|
||||
func (s *BannerService) GetBannerStatistics() (map[string]interface{}, error) {
|
||||
// 获取总数
|
||||
total, err := s.bannerRepo.GetBannerCount()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取轮播图总数失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取启用数量
|
||||
activeCount, err := s.bannerRepo.GetBannerCountByStatus(1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取启用轮播图数量失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取禁用数量
|
||||
inactiveCount, err := s.bannerRepo.GetBannerCountByStatus(0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取禁用轮播图数量失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取过期轮播图
|
||||
expiredBanners, err := s.bannerRepo.GetExpiredBanners()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取过期轮播图失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total": total,
|
||||
"active": activeCount,
|
||||
"inactive": inactiveCount,
|
||||
"expired": len(expiredBanners),
|
||||
"expired_list": expiredBanners,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CleanExpiredBanners 清理过期轮播图
|
||||
func (s *BannerService) CleanExpiredBanners() error {
|
||||
expiredBanners, err := s.bannerRepo.GetExpiredBanners()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取过期轮播图失败: %v", err)
|
||||
}
|
||||
|
||||
if len(expiredBanners) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 将过期轮播图状态设置为禁用
|
||||
var ids []uint
|
||||
for _, banner := range expiredBanners {
|
||||
ids = append(ids, banner.ID)
|
||||
}
|
||||
|
||||
return s.bannerRepo.BatchUpdateBannerStatus(ids, 0)
|
||||
}
|
||||
|
||||
// validateBanner 验证轮播图数据
|
||||
func (s *BannerService) validateBanner(banner *model.Banner) error {
|
||||
if banner.Title == "" {
|
||||
return errors.New("轮播图标题不能为空")
|
||||
}
|
||||
|
||||
if banner.Image == "" {
|
||||
return errors.New("轮播图图片不能为空")
|
||||
}
|
||||
|
||||
if banner.LinkType < 1 || banner.LinkType > 4 {
|
||||
return errors.New("链接类型无效,只能是1-4")
|
||||
}
|
||||
|
||||
if banner.Sort < 0 {
|
||||
return errors.New("排序值不能为负数")
|
||||
}
|
||||
|
||||
// 验证时间范围
|
||||
if banner.StartTime != nil && banner.EndTime != nil {
|
||||
if banner.StartTime.After(*banner.EndTime) {
|
||||
return errors.New("开始时间不能晚于结束时间")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
661
server/internal/service/cart.go
Normal file
661
server/internal/service/cart.go
Normal file
@@ -0,0 +1,661 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"dianshang/internal/repository"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CartService 购物车服务
|
||||
type CartService struct {
|
||||
orderRepo *repository.OrderRepository
|
||||
productRepo *repository.ProductRepository
|
||||
userRepo *repository.UserRepository
|
||||
}
|
||||
|
||||
// NewCartService 创建购物车服务
|
||||
func NewCartService(orderRepo *repository.OrderRepository, productRepo *repository.ProductRepository, userRepo *repository.UserRepository) *CartService {
|
||||
return &CartService{
|
||||
orderRepo: orderRepo,
|
||||
productRepo: productRepo,
|
||||
userRepo: userRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetCart 获取购物车
|
||||
func (s *CartService) GetCart(userID uint) ([]model.Cart, error) {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
|
||||
return s.orderRepo.GetCart(userID)
|
||||
}
|
||||
|
||||
// AddToCart 添加到购物车
|
||||
func (s *CartService) AddToCart(userID, productID uint, skuID uint, quantity int) error {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 检查产品是否存在
|
||||
product, err := s.productRepo.GetByID(productID)
|
||||
if err != nil {
|
||||
return errors.New("产品不存在")
|
||||
}
|
||||
|
||||
// 检查产品状态
|
||||
if product.Status != 1 {
|
||||
return errors.New("产品已下架")
|
||||
}
|
||||
|
||||
// 检查库存
|
||||
if product.Stock < quantity {
|
||||
return errors.New("库存不足")
|
||||
}
|
||||
|
||||
// 检查购物车中是否已存在该商品(包括SKU)
|
||||
existingCart, err := s.orderRepo.GetCartItemBySKU(userID, productID, skuID)
|
||||
if err == nil && existingCart != nil {
|
||||
// 已存在,更新数量
|
||||
newQuantity := existingCart.Quantity + quantity
|
||||
if product.Stock < newQuantity {
|
||||
return errors.New("库存不足")
|
||||
}
|
||||
return s.orderRepo.UpdateCartItem(existingCart.ID, newQuantity)
|
||||
}
|
||||
|
||||
// 如果不是记录不存在的错误,返回错误
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
// 不存在,添加新项
|
||||
var skuPtr *uint
|
||||
if skuID != 0 {
|
||||
skuPtr = &skuID
|
||||
}
|
||||
cart := &model.Cart{
|
||||
UserID: userID,
|
||||
ProductID: productID,
|
||||
SKUID: skuPtr,
|
||||
Quantity: quantity,
|
||||
Selected: true,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return s.orderRepo.AddToCart(cart)
|
||||
}
|
||||
|
||||
// UpdateCartItem 更新购物车项
|
||||
func (s *CartService) UpdateCartItem(userID, productID uint, quantity int) error {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 获取该用户该产品的所有购物车项
|
||||
cartItems, err := s.orderRepo.GetCart(userID)
|
||||
if err != nil {
|
||||
return errors.New("获取购物车失败")
|
||||
}
|
||||
|
||||
// 查找匹配的购物车项
|
||||
var targetCartItem *model.Cart
|
||||
for _, item := range cartItems {
|
||||
if item.ProductID == productID {
|
||||
targetCartItem = &item
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if targetCartItem == nil {
|
||||
return errors.New("购物车项不存在")
|
||||
}
|
||||
|
||||
if quantity == 0 {
|
||||
// 数量为0,删除该项
|
||||
return s.orderRepo.RemoveFromCart(userID, productID)
|
||||
}
|
||||
|
||||
// 检查产品库存
|
||||
product, err := s.productRepo.GetByID(productID)
|
||||
if err != nil {
|
||||
return errors.New("产品不存在")
|
||||
}
|
||||
|
||||
if product.Stock < quantity {
|
||||
return errors.New("库存不足")
|
||||
}
|
||||
|
||||
return s.orderRepo.UpdateCartItem(targetCartItem.ID, quantity)
|
||||
}
|
||||
|
||||
// UpdateCartItemBySKU 基于SKU更新购物车项
|
||||
func (s *CartService) UpdateCartItemBySKU(userID, productID, skuID uint, quantity int) error {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 使用SKU查找购物车项
|
||||
cartItem, err := s.orderRepo.GetCartItemBySKU(userID, productID, skuID)
|
||||
if err != nil {
|
||||
return errors.New("购物车项不存在")
|
||||
}
|
||||
|
||||
if quantity == 0 {
|
||||
// 数量为0,删除该项
|
||||
return s.orderRepo.RemoveFromCartBySKU(userID, productID, skuID)
|
||||
}
|
||||
|
||||
// 检查产品库存
|
||||
product, err := s.productRepo.GetByID(productID)
|
||||
if err != nil {
|
||||
return errors.New("产品不存在")
|
||||
}
|
||||
|
||||
if product.Stock < quantity {
|
||||
return errors.New("库存不足")
|
||||
}
|
||||
|
||||
return s.orderRepo.UpdateCartItem(cartItem.ID, quantity)
|
||||
}
|
||||
|
||||
// RemoveFromCart 从购物车移除
|
||||
func (s *CartService) RemoveFromCart(userID, productID uint) error {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 获取该用户该产品的所有购物车项
|
||||
cartItems, err := s.orderRepo.GetCart(userID)
|
||||
if err != nil {
|
||||
return errors.New("获取购物车失败")
|
||||
}
|
||||
|
||||
// 查找匹配的购物车项
|
||||
var found bool
|
||||
for _, item := range cartItems {
|
||||
if item.ProductID == productID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return errors.New("购物车项不存在")
|
||||
}
|
||||
|
||||
return s.orderRepo.RemoveFromCart(userID, productID)
|
||||
}
|
||||
|
||||
// RemoveFromCartBySKU 基于SKU从购物车移除
|
||||
func (s *CartService) RemoveFromCartBySKU(userID, productID, skuID uint) error {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 使用SKU查找购物车项
|
||||
_, err = s.orderRepo.GetCartItemBySKU(userID, productID, skuID)
|
||||
if err != nil {
|
||||
return errors.New("购物车项不存在")
|
||||
}
|
||||
|
||||
return s.orderRepo.RemoveFromCartBySKU(userID, productID, skuID)
|
||||
}
|
||||
|
||||
// ClearCart 清空购物车
|
||||
func (s *CartService) ClearCart(userID uint) error {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
return s.orderRepo.ClearCart(userID)
|
||||
}
|
||||
|
||||
// GetCartCount 获取购物车商品数量
|
||||
func (s *CartService) GetCartCount(userID uint) (int, error) {
|
||||
cart, err := s.GetCart(userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var count int
|
||||
for _, item := range cart {
|
||||
count += int(item.Quantity)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetCartTotal 获取购物车总金额
|
||||
func (s *CartService) GetCartTotal(userID uint) (float64, error) {
|
||||
cart, err := s.GetCart(userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var total float64
|
||||
for _, item := range cart {
|
||||
if item.Product.ID != 0 {
|
||||
// 将价格从分转换为元
|
||||
total += (float64(item.Product.Price) / 100) * float64(item.Quantity)
|
||||
}
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// SelectCartItem 选择/取消选择购物车项
|
||||
func (s *CartService) SelectCartItem(userID, cartID uint, selected bool) error {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
return s.orderRepo.SelectCartItem(userID, cartID, selected)
|
||||
}
|
||||
|
||||
// SelectAllCartItems 全选/取消全选购物车
|
||||
func (s *CartService) SelectAllCartItems(userID uint, selected bool) error {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
return s.orderRepo.SelectAllCartItems(userID, selected)
|
||||
}
|
||||
|
||||
// BatchAddToCart 批量添加到购物车
|
||||
func (s *CartService) BatchAddToCart(userID uint, items []struct {
|
||||
ProductID uint `json:"product_id"`
|
||||
SKUID uint `json:"sku_id"`
|
||||
Quantity int `json:"quantity"`
|
||||
}) error {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 验证所有商品
|
||||
for _, item := range items {
|
||||
product, err := s.productRepo.GetByID(item.ProductID)
|
||||
if err != nil {
|
||||
return errors.New("商品不存在: " + err.Error())
|
||||
}
|
||||
|
||||
if product.Status != 1 {
|
||||
return errors.New("商品已下架")
|
||||
}
|
||||
|
||||
if product.Stock < item.Quantity {
|
||||
return errors.New("商品库存不足")
|
||||
}
|
||||
}
|
||||
|
||||
// 批量添加
|
||||
for _, item := range items {
|
||||
err := s.AddToCart(userID, item.ProductID, item.SKUID, item.Quantity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchRemoveFromCart 批量从购物车移除
|
||||
func (s *CartService) BatchRemoveFromCart(userID uint, cartIDs []uint) error {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
return s.orderRepo.BatchRemoveFromCart(userID, cartIDs)
|
||||
}
|
||||
|
||||
// BatchUpdateCartItems 批量更新购物车项
|
||||
func (s *CartService) BatchUpdateCartItems(userID uint, updates []struct {
|
||||
CartID uint `json:"cart_id"`
|
||||
Quantity int `json:"quantity"`
|
||||
}) error {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 获取购物车项并验证
|
||||
for _, update := range updates {
|
||||
cartItem, err := s.orderRepo.GetCartItem(userID, update.CartID)
|
||||
if err != nil {
|
||||
return errors.New("购物车项不存在")
|
||||
}
|
||||
|
||||
// 检查库存
|
||||
product, err := s.productRepo.GetByID(cartItem.ProductID)
|
||||
if err != nil {
|
||||
return errors.New("商品不存在")
|
||||
}
|
||||
|
||||
if product.Stock < update.Quantity {
|
||||
return errors.New("商品库存不足")
|
||||
}
|
||||
|
||||
// 更新数量
|
||||
if update.Quantity == 0 {
|
||||
err = s.orderRepo.RemoveCartItem(update.CartID)
|
||||
} else {
|
||||
err = s.orderRepo.UpdateCartItem(update.CartID, update.Quantity)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCartWithDetails 获取购物车详细信息(包含商品详情)
|
||||
func (s *CartService) GetCartWithDetails(userID uint) (map[string]interface{}, error) {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
|
||||
cart, err := s.orderRepo.GetCart(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var validItems []model.Cart
|
||||
var invalidItems []model.Cart
|
||||
var totalAmount float64
|
||||
var totalQuantity int
|
||||
var selectedAmount float64
|
||||
var selectedQuantity int
|
||||
|
||||
for _, item := range cart {
|
||||
// 检查商品是否有效
|
||||
product, err := s.productRepo.GetByID(item.ProductID)
|
||||
if err != nil || product.Status != 1 {
|
||||
invalidItems = append(invalidItems, item)
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查库存
|
||||
if product.Stock < item.Quantity {
|
||||
item.Product = *product
|
||||
invalidItems = append(invalidItems, item)
|
||||
continue
|
||||
}
|
||||
|
||||
// 计算价格
|
||||
item.Product = *product
|
||||
itemPrice := float64(product.Price) / 100 * float64(item.Quantity)
|
||||
totalAmount += itemPrice
|
||||
totalQuantity += item.Quantity
|
||||
|
||||
if item.Selected {
|
||||
selectedAmount += itemPrice
|
||||
selectedQuantity += item.Quantity
|
||||
}
|
||||
|
||||
validItems = append(validItems, item)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"valid_items": validItems,
|
||||
"invalid_items": invalidItems,
|
||||
"total_amount": totalAmount,
|
||||
"total_quantity": totalQuantity,
|
||||
"selected_amount": selectedAmount,
|
||||
"selected_quantity": selectedQuantity,
|
||||
"valid_count": len(validItems),
|
||||
"invalid_count": len(invalidItems),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateCartItems 验证购物车商品有效性
|
||||
func (s *CartService) ValidateCartItems(userID uint) (map[string]interface{}, error) {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
|
||||
cart, err := s.orderRepo.GetCart(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var validItems []uint
|
||||
var invalidItems []struct {
|
||||
CartID uint `json:"cart_id"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
for _, item := range cart {
|
||||
// 检查商品是否存在
|
||||
product, err := s.productRepo.GetByID(item.ProductID)
|
||||
if err != nil {
|
||||
invalidItems = append(invalidItems, struct {
|
||||
CartID uint `json:"cart_id"`
|
||||
Reason string `json:"reason"`
|
||||
}{
|
||||
CartID: item.ID,
|
||||
Reason: "商品不存在",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查商品状态
|
||||
if product.Status != 1 {
|
||||
invalidItems = append(invalidItems, struct {
|
||||
CartID uint `json:"cart_id"`
|
||||
Reason string `json:"reason"`
|
||||
}{
|
||||
CartID: item.ID,
|
||||
Reason: "商品已下架",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查库存
|
||||
if product.Stock < item.Quantity {
|
||||
invalidItems = append(invalidItems, struct {
|
||||
CartID uint `json:"cart_id"`
|
||||
Reason string `json:"reason"`
|
||||
}{
|
||||
CartID: item.ID,
|
||||
Reason: "库存不足",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
validItems = append(validItems, item.ID)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"valid_items": validItems,
|
||||
"invalid_items": invalidItems,
|
||||
"valid_count": len(validItems),
|
||||
"invalid_count": len(invalidItems),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CleanInvalidCartItems 清理无效的购物车项
|
||||
func (s *CartService) CleanInvalidCartItems(userID uint) error {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
validation, err := s.ValidateCartItems(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
invalidItems := validation["invalid_items"].([]struct {
|
||||
CartID uint `json:"cart_id"`
|
||||
Reason string `json:"reason"`
|
||||
})
|
||||
|
||||
if len(invalidItems) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cartIDs []uint
|
||||
for _, item := range invalidItems {
|
||||
cartIDs = append(cartIDs, item.CartID)
|
||||
}
|
||||
|
||||
return s.orderRepo.BatchRemoveFromCart(userID, cartIDs)
|
||||
}
|
||||
|
||||
// GetCartSummary 获取购物车摘要信息
|
||||
func (s *CartService) GetCartSummary(userID uint) (map[string]interface{}, error) {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
|
||||
details, err := s.GetCartWithDetails(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_items": details["valid_count"].(int) + details["invalid_count"].(int),
|
||||
"valid_items": details["valid_count"],
|
||||
"invalid_items": details["invalid_count"],
|
||||
"total_amount": details["total_amount"],
|
||||
"selected_amount": details["selected_amount"],
|
||||
"total_quantity": details["total_quantity"],
|
||||
"selected_quantity": details["selected_quantity"],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MergeCart 合并购物车(用于登录后合并游客购物车)
|
||||
func (s *CartService) MergeCart(userID uint, guestCartItems []struct {
|
||||
ProductID uint `json:"product_id"`
|
||||
SKUID uint `json:"sku_id"`
|
||||
Quantity int `json:"quantity"`
|
||||
}) error {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 获取用户现有购物车
|
||||
existingCart, err := s.orderRepo.GetCart(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建现有购物车的映射
|
||||
existingMap := make(map[string]*model.Cart)
|
||||
for i, item := range existingCart {
|
||||
key := fmt.Sprintf("%d_%d", item.ProductID, item.SKUID)
|
||||
if item.SKUID == nil {
|
||||
key = fmt.Sprintf("%d_0", item.ProductID)
|
||||
}
|
||||
existingMap[key] = &existingCart[i]
|
||||
}
|
||||
|
||||
// 合并游客购物车项
|
||||
for _, guestItem := range guestCartItems {
|
||||
key := fmt.Sprintf("%d_%d", guestItem.ProductID, guestItem.SKUID)
|
||||
if guestItem.SKUID == 0 {
|
||||
key = fmt.Sprintf("%d_0", guestItem.ProductID)
|
||||
}
|
||||
|
||||
if existingItem, exists := existingMap[key]; exists {
|
||||
// 已存在,更新数量
|
||||
newQuantity := existingItem.Quantity + guestItem.Quantity
|
||||
err = s.orderRepo.UpdateCartItem(existingItem.ID, newQuantity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// 不存在,添加新项
|
||||
err = s.AddToCart(userID, guestItem.ProductID, guestItem.SKUID, guestItem.Quantity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSelectedCartItems 获取选中的购物车项
|
||||
func (s *CartService) GetSelectedCartItems(userID uint) ([]model.Cart, error) {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
|
||||
return s.orderRepo.GetSelectedCartItems(userID)
|
||||
}
|
||||
|
||||
// CalculateCartDiscount 计算购物车优惠(预留接口)
|
||||
func (s *CartService) CalculateCartDiscount(userID uint, couponID uint) (map[string]interface{}, error) {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
|
||||
selectedItems, err := s.GetSelectedCartItems(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var originalAmount float64
|
||||
for _, item := range selectedItems {
|
||||
product, err := s.productRepo.GetByID(item.ProductID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
originalAmount += float64(product.Price) / 100 * float64(item.Quantity)
|
||||
}
|
||||
|
||||
// 这里可以添加优惠券计算逻辑
|
||||
discountAmount := 0.0
|
||||
finalAmount := originalAmount - discountAmount
|
||||
|
||||
return map[string]interface{}{
|
||||
"original_amount": originalAmount,
|
||||
"discount_amount": discountAmount,
|
||||
"final_amount": finalAmount,
|
||||
"coupon_id": couponID,
|
||||
}, nil
|
||||
}
|
||||
202
server/internal/service/comment.go
Normal file
202
server/internal/service/comment.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"dianshang/internal/model"
|
||||
"dianshang/internal/repository"
|
||||
)
|
||||
|
||||
type CommentService struct {
|
||||
commentRepo *repository.CommentRepository
|
||||
orderRepo *repository.OrderRepository
|
||||
productRepo *repository.ProductRepository
|
||||
}
|
||||
|
||||
func NewCommentService(commentRepo *repository.CommentRepository, orderRepo *repository.OrderRepository, productRepo *repository.ProductRepository) *CommentService {
|
||||
return &CommentService{
|
||||
commentRepo: commentRepo,
|
||||
orderRepo: orderRepo,
|
||||
productRepo: productRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateCommentRequest 创建评论请求
|
||||
type CreateCommentRequest struct {
|
||||
OrderItemID uint `json:"order_item_id" binding:"required"`
|
||||
Rating int `json:"rating" binding:"required,min=1,max=5"`
|
||||
Content string `json:"content"`
|
||||
Images []string `json:"images"`
|
||||
IsAnonymous bool `json:"is_anonymous"`
|
||||
}
|
||||
|
||||
// CreateComment 创建评论
|
||||
func (s *CommentService) CreateComment(userID uint, req *CreateCommentRequest) (*model.Comment, error) {
|
||||
// 1. 验证订单项是否存在且属于该用户
|
||||
orderItem, err := s.orderRepo.GetOrderItemByID(req.OrderItemID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("订单项不存在")
|
||||
}
|
||||
|
||||
// 获取订单信息验证用户权限
|
||||
order, err := s.orderRepo.GetByID(orderItem.OrderID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("订单不存在")
|
||||
}
|
||||
|
||||
if order.UserID != userID {
|
||||
return nil, fmt.Errorf("无权限评论此商品")
|
||||
}
|
||||
|
||||
// 2. 验证订单状态(只有已完成的订单才能评论)
|
||||
if order.Status != model.OrderStatusCompleted {
|
||||
return nil, fmt.Errorf("订单未完成,无法评论")
|
||||
}
|
||||
|
||||
// 3. 检查是否已经评论过
|
||||
if orderItem.IsCommented {
|
||||
return nil, fmt.Errorf("该商品已经评论过了")
|
||||
}
|
||||
|
||||
// 4. 处理图片数据
|
||||
var imagesJSON string
|
||||
if len(req.Images) > 0 {
|
||||
imagesBytes, _ := json.Marshal(req.Images)
|
||||
imagesJSON = string(imagesBytes)
|
||||
}
|
||||
|
||||
// 5. 创建评论
|
||||
comment := &model.Comment{
|
||||
UserID: userID,
|
||||
ProductID: orderItem.ProductID,
|
||||
OrderID: orderItem.OrderID,
|
||||
OrderItemID: req.OrderItemID,
|
||||
Rating: req.Rating,
|
||||
Content: req.Content,
|
||||
Images: imagesJSON,
|
||||
IsAnonymous: req.IsAnonymous,
|
||||
Status: 1,
|
||||
}
|
||||
|
||||
if err := s.commentRepo.Create(comment); err != nil {
|
||||
return nil, fmt.Errorf("创建评论失败: %v", err)
|
||||
}
|
||||
|
||||
// 6. 更新订单项评论状态
|
||||
orderItem.IsCommented = true
|
||||
if err := s.orderRepo.SaveOrderItem(orderItem); err != nil {
|
||||
return nil, fmt.Errorf("更新订单项状态失败: %v", err)
|
||||
}
|
||||
|
||||
// 7. 更新商品评论统计
|
||||
if err := s.commentRepo.UpdateProductStats(orderItem.ProductID); err != nil {
|
||||
return nil, fmt.Errorf("更新商品统计失败: %v", err)
|
||||
}
|
||||
|
||||
return comment, nil
|
||||
}
|
||||
|
||||
// GetProductComments 获取商品评论列表
|
||||
func (s *CommentService) GetProductComments(productID uint, page, pageSize int, rating int) ([]model.Comment, int64, error) {
|
||||
offset := (page - 1) * pageSize
|
||||
return s.commentRepo.GetByProductID(productID, offset, pageSize, rating)
|
||||
}
|
||||
|
||||
// GetUserComments 获取用户评论列表
|
||||
func (s *CommentService) GetUserComments(userID uint, page, pageSize int) ([]model.Comment, int64, error) {
|
||||
offset := (page - 1) * pageSize
|
||||
return s.commentRepo.GetByUserID(userID, offset, pageSize)
|
||||
}
|
||||
|
||||
// GetCommentStats 获取商品评论统计
|
||||
func (s *CommentService) GetCommentStats(productID uint) (*model.CommentStats, error) {
|
||||
return s.commentRepo.GetStats(productID)
|
||||
}
|
||||
|
||||
// GetCommentByID 获取评论详情
|
||||
func (s *CommentService) GetCommentByID(id uint) (*model.Comment, error) {
|
||||
return s.commentRepo.GetByID(id)
|
||||
}
|
||||
|
||||
// CreateReplyRequest 创建回复请求
|
||||
type CreateReplyRequest struct {
|
||||
CommentID uint `json:"comment_id" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
}
|
||||
|
||||
// CreateReply 创建评论回复
|
||||
func (s *CommentService) CreateReply(userID uint, req *CreateReplyRequest, isAdmin bool) (*model.CommentReply, error) {
|
||||
// 验证评论是否存在
|
||||
comment, err := s.commentRepo.GetByID(req.CommentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("评论不存在")
|
||||
}
|
||||
|
||||
if comment.Status != 1 {
|
||||
return nil, fmt.Errorf("评论状态异常,无法回复")
|
||||
}
|
||||
|
||||
// 创建回复
|
||||
reply := &model.CommentReply{
|
||||
CommentID: req.CommentID,
|
||||
UserID: userID,
|
||||
Content: req.Content,
|
||||
IsAdmin: isAdmin,
|
||||
Status: 1,
|
||||
}
|
||||
|
||||
if err := s.commentRepo.CreateReply(reply); err != nil {
|
||||
return nil, fmt.Errorf("创建回复失败: %v", err)
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// LikeComment 点赞评论
|
||||
func (s *CommentService) LikeComment(commentID, userID uint) error {
|
||||
return s.commentRepo.LikeComment(commentID, userID)
|
||||
}
|
||||
|
||||
// UnlikeComment 取消点赞评论
|
||||
func (s *CommentService) UnlikeComment(commentID, userID uint) error {
|
||||
return s.commentRepo.UnlikeComment(commentID, userID)
|
||||
}
|
||||
|
||||
// GetCommentList 获取评论列表(管理端)
|
||||
func (s *CommentService) GetCommentList(page, pageSize int, conditions map[string]interface{}) ([]model.Comment, int64, error) {
|
||||
offset := (page - 1) * pageSize
|
||||
return s.commentRepo.GetList(offset, pageSize, conditions)
|
||||
}
|
||||
|
||||
// UpdateCommentStatus 更新评论状态(管理端)
|
||||
func (s *CommentService) UpdateCommentStatus(id uint, status int) error {
|
||||
comment, err := s.commentRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("评论不存在")
|
||||
}
|
||||
|
||||
comment.Status = status
|
||||
if err := s.commentRepo.Update(comment); err != nil {
|
||||
return fmt.Errorf("更新评论状态失败: %v", err)
|
||||
}
|
||||
|
||||
// 如果是隐藏或删除评论,需要更新商品统计
|
||||
if status != 1 {
|
||||
if err := s.commentRepo.UpdateProductStats(comment.ProductID); err != nil {
|
||||
return fmt.Errorf("更新商品统计失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteComment 删除评论(管理端)
|
||||
func (s *CommentService) DeleteComment(id uint) error {
|
||||
return s.UpdateCommentStatus(id, 3)
|
||||
}
|
||||
|
||||
// GetUncommentedOrderItems 获取用户未评论的订单项
|
||||
func (s *CommentService) GetUncommentedOrderItems(userID uint) ([]model.OrderItem, error) {
|
||||
// 获取用户已完成但未评论的订单项
|
||||
return s.orderRepo.GetUncommentedOrderItems(userID)
|
||||
}
|
||||
226
server/internal/service/coupon.go
Normal file
226
server/internal/service/coupon.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"dianshang/internal/repository"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CouponService 优惠券服务
|
||||
type CouponService struct {
|
||||
couponRepo *repository.CouponRepository
|
||||
}
|
||||
|
||||
// NewCouponService 创建优惠券服务
|
||||
func NewCouponService(couponRepo *repository.CouponRepository) *CouponService {
|
||||
return &CouponService{
|
||||
couponRepo: couponRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAvailableCoupons 获取可用优惠券列表
|
||||
func (s *CouponService) GetAvailableCoupons() ([]model.Coupon, error) {
|
||||
return s.couponRepo.GetAvailableCoupons()
|
||||
}
|
||||
|
||||
// GetAvailableCouponsWithUserStatus 获取可用优惠券列表(包含用户已领取状态)
|
||||
func (s *CouponService) GetAvailableCouponsWithUserStatus(userID uint) ([]map[string]interface{}, error) {
|
||||
// 获取所有可用优惠券
|
||||
coupons, err := s.couponRepo.GetAvailableCoupons()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result []map[string]interface{}
|
||||
|
||||
// 如果用户已登录,检查每个优惠券的领取状态
|
||||
for _, coupon := range coupons {
|
||||
couponData := map[string]interface{}{
|
||||
"id": coupon.ID,
|
||||
"name": coupon.Name,
|
||||
"type": coupon.Type,
|
||||
"value": coupon.Value,
|
||||
"min_amount": coupon.MinAmount,
|
||||
"description": coupon.Description,
|
||||
"start_time": coupon.StartTime,
|
||||
"end_time": coupon.EndTime,
|
||||
"total_count": coupon.TotalCount,
|
||||
"used_count": coupon.UsedCount,
|
||||
"is_received": false, // 默认未领取
|
||||
}
|
||||
|
||||
// 如果用户已登录,检查是否已领取
|
||||
if userID > 0 {
|
||||
exists, err := s.couponRepo.CheckUserCouponExists(userID, coupon.ID)
|
||||
if err == nil {
|
||||
couponData["is_received"] = exists
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, couponData)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetUserCoupons 获取用户优惠券
|
||||
func (s *CouponService) GetUserCoupons(userID uint, status int) ([]model.UserCoupon, error) {
|
||||
return s.couponRepo.GetUserCoupons(userID, status)
|
||||
}
|
||||
|
||||
// ReceiveCoupon 领取优惠券
|
||||
func (s *CouponService) ReceiveCoupon(userID, couponID uint) error {
|
||||
// 检查优惠券是否存在且有效
|
||||
coupon, err := s.couponRepo.GetByID(couponID)
|
||||
if err != nil {
|
||||
return errors.New("优惠券不存在")
|
||||
}
|
||||
|
||||
// 检查是否在有效期内
|
||||
now := time.Now()
|
||||
if now.Before(coupon.StartTime) || now.After(coupon.EndTime) {
|
||||
return errors.New("优惠券不在有效期内")
|
||||
}
|
||||
|
||||
// 检查是否还有库存
|
||||
if coupon.TotalCount > 0 && coupon.UsedCount >= coupon.TotalCount {
|
||||
return errors.New("优惠券已被领完")
|
||||
}
|
||||
|
||||
// 检查用户是否已经领取过
|
||||
exists, err := s.couponRepo.CheckUserCouponExists(userID, couponID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return errors.New("您已经领取过该优惠券")
|
||||
}
|
||||
|
||||
// 创建用户优惠券记录
|
||||
userCoupon := &model.UserCoupon{
|
||||
UserID: userID,
|
||||
CouponID: couponID,
|
||||
Status: 0, // 未使用
|
||||
}
|
||||
|
||||
return s.couponRepo.CreateUserCoupon(userCoupon)
|
||||
}
|
||||
|
||||
// UseCoupon 使用优惠券
|
||||
func (s *CouponService) UseCoupon(userID, userCouponID, orderID uint) error {
|
||||
// 获取用户优惠券
|
||||
userCoupon, err := s.couponRepo.GetUserCouponByID(userCouponID)
|
||||
if err != nil {
|
||||
return errors.New("优惠券不存在")
|
||||
}
|
||||
|
||||
// 检查是否属于该用户
|
||||
if userCoupon.UserID != userID {
|
||||
return errors.New("无权使用该优惠券")
|
||||
}
|
||||
|
||||
// 检查是否已使用
|
||||
if userCoupon.Status != 0 {
|
||||
return errors.New("优惠券已使用或已过期")
|
||||
}
|
||||
|
||||
// 检查优惠券是否在有效期内
|
||||
now := time.Now()
|
||||
if now.Before(userCoupon.Coupon.StartTime) || now.After(userCoupon.Coupon.EndTime) {
|
||||
return errors.New("优惠券不在有效期内")
|
||||
}
|
||||
|
||||
// 更新优惠券状态为已使用
|
||||
return s.couponRepo.UseCoupon(userCouponID, orderID)
|
||||
}
|
||||
|
||||
// ValidateCoupon 验证优惠券是否可用
|
||||
func (s *CouponService) ValidateCoupon(userID, userCouponID uint, orderAmount float64) (*model.UserCoupon, float64, error) {
|
||||
// 获取用户优惠券
|
||||
userCoupon, err := s.couponRepo.GetUserCouponByID(userCouponID)
|
||||
if err != nil {
|
||||
return nil, 0, errors.New("优惠券不存在")
|
||||
}
|
||||
|
||||
// 检查是否属于该用户
|
||||
if userCoupon.UserID != userID {
|
||||
return nil, 0, errors.New("无权使用该优惠券")
|
||||
}
|
||||
|
||||
// 检查是否已使用
|
||||
if userCoupon.Status != 0 {
|
||||
return nil, 0, errors.New("优惠券已使用或已过期")
|
||||
}
|
||||
|
||||
// 检查优惠券是否在有效期内
|
||||
now := time.Now()
|
||||
if now.Before(userCoupon.Coupon.StartTime) || now.After(userCoupon.Coupon.EndTime) {
|
||||
return nil, 0, errors.New("优惠券不在有效期内")
|
||||
}
|
||||
|
||||
// 检查最低消费金额
|
||||
minAmount := float64(userCoupon.Coupon.MinAmount) / 100 // 分转元
|
||||
if orderAmount < minAmount {
|
||||
return nil, 0, errors.New(fmt.Sprintf("订单金额不满足优惠券使用条件,最低需要%.2f元", minAmount))
|
||||
}
|
||||
|
||||
// 计算优惠金额
|
||||
var discountAmount float64
|
||||
switch userCoupon.Coupon.Type {
|
||||
case 1: // 满减券
|
||||
discountAmount = float64(userCoupon.Coupon.Value) / 100 // 分转元
|
||||
case 2: // 折扣券
|
||||
discountRate := float64(userCoupon.Coupon.Value) / 100 // 85 -> 0.85
|
||||
discountAmount = orderAmount * (1 - discountRate)
|
||||
case 3: // 免邮券
|
||||
discountAmount = 0 // 免邮券的优惠金额在运费中体现
|
||||
default:
|
||||
return nil, 0, errors.New("不支持的优惠券类型")
|
||||
}
|
||||
|
||||
// 确保优惠金额不超过订单金额
|
||||
if discountAmount > orderAmount {
|
||||
discountAmount = orderAmount
|
||||
}
|
||||
|
||||
return userCoupon, discountAmount, nil
|
||||
}
|
||||
|
||||
// GetAvailableCouponsForOrder 获取订单可用的优惠券
|
||||
func (s *CouponService) GetAvailableCouponsForOrder(userID uint, orderAmount float64) ([]model.UserCoupon, error) {
|
||||
// 获取用户未使用的优惠券
|
||||
userCoupons, err := s.couponRepo.GetUserCoupons(userID, 1) // 1表示未使用(API状态值)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var availableCoupons []model.UserCoupon
|
||||
now := time.Now()
|
||||
|
||||
for _, userCoupon := range userCoupons {
|
||||
// 严格检查优惠券状态:必须是未使用状态(0)且没有关联订单
|
||||
if userCoupon.Status != 0 || userCoupon.OrderID != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否在有效期内
|
||||
if now.Before(userCoupon.Coupon.StartTime) || now.After(userCoupon.Coupon.EndTime) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查优惠券模板是否可用
|
||||
if userCoupon.Coupon.Status != 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查最低消费金额
|
||||
minAmount := float64(userCoupon.Coupon.MinAmount) / 100 // 分转元
|
||||
if orderAmount >= minAmount {
|
||||
availableCoupons = append(availableCoupons, userCoupon)
|
||||
}
|
||||
}
|
||||
|
||||
return availableCoupons, nil
|
||||
}
|
||||
405
server/internal/service/log.go
Normal file
405
server/internal/service/log.go
Normal file
@@ -0,0 +1,405 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// LogService 日志服务
|
||||
type LogService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewLogService 创建日志服务
|
||||
func NewLogService(db *gorm.DB) *LogService {
|
||||
return &LogService{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateLoginLog 创建登录日志
|
||||
func (s *LogService) CreateLoginLog(userID uint, ip, userAgent string, status int, remark string) error {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
LoginIP: ip,
|
||||
UserAgent: userAgent,
|
||||
LoginTime: time.Now(),
|
||||
Status: status,
|
||||
Remark: remark,
|
||||
}
|
||||
return s.db.Create(log).Error
|
||||
}
|
||||
|
||||
// GetUserLoginLogs 获取用户登录日志
|
||||
func (s *LogService) GetUserLoginLogs(userID uint, page, pageSize int) ([]model.UserLoginLog, map[string]interface{}, error) {
|
||||
var logs []model.UserLoginLog
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.UserLoginLog{}).Where("user_id = ?", userID)
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Offset(offset).Limit(pageSize).Order("login_time DESC").Find(&logs).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建分页信息
|
||||
pagination := map[string]interface{}{
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
|
||||
}
|
||||
|
||||
return logs, pagination, nil
|
||||
}
|
||||
|
||||
// GetLoginLogList 获取登录日志列表(管理后台)
|
||||
func (s *LogService) GetLoginLogList(page, pageSize int, conditions map[string]interface{}) ([]model.UserLoginLog, map[string]interface{}, error) {
|
||||
var logs []model.UserLoginLog
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.UserLoginLog{}).Preload("User")
|
||||
|
||||
// 应用查询条件
|
||||
if userID, ok := conditions["user_id"]; ok && userID != "" {
|
||||
query = query.Where("user_id = ?", userID)
|
||||
}
|
||||
|
||||
if ip, ok := conditions["ip"]; ok && ip != "" {
|
||||
query = query.Where("login_ip LIKE ?", "%"+ip.(string)+"%")
|
||||
}
|
||||
|
||||
if status, ok := conditions["status"]; ok && status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
if startDate, ok := conditions["start_date"]; ok && startDate != "" {
|
||||
query = query.Where("login_time >= ?", startDate)
|
||||
}
|
||||
|
||||
if endDate, ok := conditions["end_date"]; ok && endDate != "" {
|
||||
query = query.Where("login_time <= ?", endDate)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Offset(offset).Limit(pageSize).Order("login_time DESC").Find(&logs).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建分页信息
|
||||
pagination := map[string]interface{}{
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
|
||||
}
|
||||
|
||||
return logs, pagination, nil
|
||||
}
|
||||
|
||||
// CreateOperationLog 创建操作日志
|
||||
func (s *LogService) CreateOperationLog(userID uint, module, action, description, ip, userAgent, requestData string) error {
|
||||
log := &model.UserOperationLog{
|
||||
UserID: userID,
|
||||
Module: module,
|
||||
Action: action,
|
||||
Description: description,
|
||||
IP: ip,
|
||||
UserAgent: userAgent,
|
||||
RequestData: requestData,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
return s.db.Create(log).Error
|
||||
}
|
||||
|
||||
// GetUserOperationLogs 获取用户操作日志
|
||||
func (s *LogService) GetUserOperationLogs(userID uint, page, pageSize int) ([]model.UserOperationLog, map[string]interface{}, error) {
|
||||
var logs []model.UserOperationLog
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.UserOperationLog{}).Where("user_id = ?", userID)
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&logs).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建分页信息
|
||||
pagination := map[string]interface{}{
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
|
||||
}
|
||||
|
||||
return logs, pagination, nil
|
||||
}
|
||||
|
||||
// GetOperationLogList 获取操作日志列表(管理后台)
|
||||
func (s *LogService) GetOperationLogList(page, pageSize int, conditions map[string]interface{}) ([]model.UserOperationLog, map[string]interface{}, error) {
|
||||
var logs []model.UserOperationLog
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.UserOperationLog{}).Preload("User")
|
||||
|
||||
// 应用查询条件
|
||||
if userID, ok := conditions["user_id"]; ok && userID != "" {
|
||||
query = query.Where("user_id = ?", userID)
|
||||
}
|
||||
|
||||
if module, ok := conditions["module"]; ok && module != "" {
|
||||
query = query.Where("module = ?", module)
|
||||
}
|
||||
|
||||
if action, ok := conditions["action"]; ok && action != "" {
|
||||
query = query.Where("action = ?", action)
|
||||
}
|
||||
|
||||
if ip, ok := conditions["ip"]; ok && ip != "" {
|
||||
query = query.Where("ip LIKE ?", "%"+ip.(string)+"%")
|
||||
}
|
||||
|
||||
if startDate, ok := conditions["start_date"]; ok && startDate != "" {
|
||||
query = query.Where("created_at >= ?", startDate)
|
||||
}
|
||||
|
||||
if endDate, ok := conditions["end_date"]; ok && endDate != "" {
|
||||
query = query.Where("created_at <= ?", endDate)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&logs).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建分页信息
|
||||
pagination := map[string]interface{}{
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
|
||||
}
|
||||
|
||||
return logs, pagination, nil
|
||||
}
|
||||
|
||||
// GetLoginStatistics 获取登录统计
|
||||
func (s *LogService) GetLoginStatistics(startDate, endDate string) (map[string]interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// 总登录次数
|
||||
var totalLogins int64
|
||||
query := s.db.Model(&model.UserLoginLog{})
|
||||
if startDate != "" && endDate != "" {
|
||||
query = query.Where("DATE(login_time) BETWEEN ? AND ?", startDate, endDate)
|
||||
}
|
||||
query.Count(&totalLogins)
|
||||
result["total_logins"] = totalLogins
|
||||
|
||||
// 成功登录次数
|
||||
var successLogins int64
|
||||
query = s.db.Model(&model.UserLoginLog{}).Where("status = 1")
|
||||
if startDate != "" && endDate != "" {
|
||||
query = query.Where("DATE(login_time) BETWEEN ? AND ?", startDate, endDate)
|
||||
}
|
||||
query.Count(&successLogins)
|
||||
result["success_logins"] = successLogins
|
||||
|
||||
// 失败登录次数
|
||||
var failedLogins int64
|
||||
query = s.db.Model(&model.UserLoginLog{}).Where("status = 0")
|
||||
if startDate != "" && endDate != "" {
|
||||
query = query.Where("DATE(login_time) BETWEEN ? AND ?", startDate, endDate)
|
||||
}
|
||||
query.Count(&failedLogins)
|
||||
result["failed_logins"] = failedLogins
|
||||
|
||||
// 独立用户数
|
||||
var uniqueUsers int64
|
||||
query = s.db.Model(&model.UserLoginLog{}).Distinct("user_id")
|
||||
if startDate != "" && endDate != "" {
|
||||
query = query.Where("DATE(login_time) BETWEEN ? AND ?", startDate, endDate)
|
||||
}
|
||||
query.Count(&uniqueUsers)
|
||||
result["unique_users"] = uniqueUsers
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetOperationStatistics 获取操作统计
|
||||
func (s *LogService) GetOperationStatistics(startDate, endDate string) (map[string]interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// 总操作次数
|
||||
var totalOperations int64
|
||||
query := s.db.Model(&model.UserOperationLog{})
|
||||
if startDate != "" && endDate != "" {
|
||||
query = query.Where("DATE(created_at) BETWEEN ? AND ?", startDate, endDate)
|
||||
}
|
||||
query.Count(&totalOperations)
|
||||
result["total_operations"] = totalOperations
|
||||
|
||||
// 按模块统计
|
||||
var moduleStats []struct {
|
||||
Module string `json:"module"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
query = s.db.Model(&model.UserOperationLog{}).Select("module, COUNT(*) as count").Group("module")
|
||||
if startDate != "" && endDate != "" {
|
||||
query = query.Where("DATE(created_at) BETWEEN ? AND ?", startDate, endDate)
|
||||
}
|
||||
query.Scan(&moduleStats)
|
||||
result["module_stats"] = moduleStats
|
||||
|
||||
// 按操作统计
|
||||
var actionStats []struct {
|
||||
Action string `json:"action"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
query = s.db.Model(&model.UserOperationLog{}).Select("action, COUNT(*) as count").Group("action")
|
||||
if startDate != "" && endDate != "" {
|
||||
query = query.Where("DATE(created_at) BETWEEN ? AND ?", startDate, endDate)
|
||||
}
|
||||
query.Scan(&actionStats)
|
||||
result["action_stats"] = actionStats
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CleanOldLogs 清理旧日志
|
||||
func (s *LogService) CleanOldLogs(days int) error {
|
||||
cutoffDate := time.Now().AddDate(0, 0, -days)
|
||||
|
||||
// 清理登录日志
|
||||
if err := s.db.Where("login_time < ?", cutoffDate).Delete(&model.UserLoginLog{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 清理操作日志
|
||||
if err := s.db.Where("created_at < ?", cutoffDate).Delete(&model.UserOperationLog{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLoginLogByID 根据ID获取登录日志
|
||||
func (s *LogService) GetLoginLogByID(id uint) (*model.UserLoginLog, error) {
|
||||
var log model.UserLoginLog
|
||||
err := s.db.Preload("User").First(&log, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &log, nil
|
||||
}
|
||||
|
||||
// GetTodayLoginCount 获取今日登录次数
|
||||
func (s *LogService) GetTodayLoginCount() (int64, error) {
|
||||
var count int64
|
||||
today := time.Now().Format("2006-01-02")
|
||||
err := s.db.Model(&model.UserLoginLog{}).Where("DATE(login_time) = ?", today).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetTodayOperationCount 获取今日操作次数
|
||||
func (s *LogService) GetTodayOperationCount() (int64, error) {
|
||||
var count int64
|
||||
today := time.Now().Format("2006-01-02")
|
||||
err := s.db.Model(&model.UserOperationLog{}).Where("DATE(created_at) = ?", today).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetOnlineUserCount 获取在线用户数(最近30分钟有登录记录的用户)
|
||||
func (s *LogService) GetOnlineUserCount() (int64, error) {
|
||||
var count int64
|
||||
thirtyMinutesAgo := time.Now().Add(-30 * time.Minute)
|
||||
err := s.db.Model(&model.UserLoginLog{}).
|
||||
Where("login_time >= ? AND status = 1", thirtyMinutesAgo).
|
||||
Distinct("user_id").
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetLoginTrend 获取登录趋势
|
||||
func (s *LogService) GetLoginTrend(days int) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
for i := days - 1; i >= 0; i-- {
|
||||
date := time.Now().AddDate(0, 0, -i).Format("2006-01-02")
|
||||
var count int64
|
||||
s.db.Model(&model.UserLoginLog{}).
|
||||
Where("DATE(login_time) = ?", date).
|
||||
Count(&count)
|
||||
|
||||
results = append(results, map[string]interface{}{
|
||||
"date": date,
|
||||
"count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetOperationTrend 获取操作趋势
|
||||
func (s *LogService) GetOperationTrend(days int) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
for i := days - 1; i >= 0; i-- {
|
||||
date := time.Now().AddDate(0, 0, -i).Format("2006-01-02")
|
||||
var count int64
|
||||
s.db.Model(&model.UserOperationLog{}).
|
||||
Where("DATE(created_at) = ?", date).
|
||||
Count(&count)
|
||||
|
||||
results = append(results, map[string]interface{}{
|
||||
"date": date,
|
||||
"count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetOperationLogByID 根据ID获取操作日志
|
||||
func (s *LogService) GetOperationLogByID(id uint) (*model.UserOperationLog, error) {
|
||||
var log model.UserOperationLog
|
||||
err := s.db.Preload("User").First(&log, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &log, nil
|
||||
}
|
||||
1581
server/internal/service/order.go
Normal file
1581
server/internal/service/order.go
Normal file
File diff suppressed because it is too large
Load Diff
350
server/internal/service/points.go
Normal file
350
server/internal/service/points.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"dianshang/internal/repository"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PointsService 积分服务
|
||||
type PointsService struct {
|
||||
pointsRepo *repository.PointsRepository
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewPointsService 创建积分服务
|
||||
func NewPointsService(pointsRepo *repository.PointsRepository, db *gorm.DB) *PointsService {
|
||||
return &PointsService{
|
||||
pointsRepo: pointsRepo,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserPoints 获取用户积分
|
||||
func (s *PointsService) GetUserPoints(userID uint) (int, error) {
|
||||
return s.pointsRepo.GetUserPoints(userID)
|
||||
}
|
||||
|
||||
// AddPoints 增加用户积分
|
||||
func (s *PointsService) AddPoints(userID uint, amount int, description string, orderID *uint, orderNo, productName string) error {
|
||||
if amount <= 0 {
|
||||
return errors.New("积分数量必须大于0")
|
||||
}
|
||||
|
||||
// 开启事务
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
// 获取当前积分
|
||||
currentPoints, err := s.pointsRepo.GetUserPoints(userID)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新用户积分
|
||||
newPoints := currentPoints + amount
|
||||
err = s.pointsRepo.UpdateUserPoints(userID, newPoints)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建积分历史记录
|
||||
history := &model.PointsHistory{
|
||||
UserID: userID,
|
||||
Type: 1, // 获得
|
||||
Points: amount,
|
||||
Description: description,
|
||||
OrderID: orderID,
|
||||
OrderNo: orderNo,
|
||||
ProductName: productName,
|
||||
}
|
||||
|
||||
err = s.pointsRepo.CreatePointsHistory(history)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// DeductPoints 扣减用户积分
|
||||
func (s *PointsService) DeductPoints(userID uint, amount int, description string, orderID *uint, orderNo, productName string) error {
|
||||
if amount <= 0 {
|
||||
return errors.New("积分数量必须大于0")
|
||||
}
|
||||
|
||||
// 开启事务
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
// 获取当前积分
|
||||
currentPoints, err := s.pointsRepo.GetUserPoints(userID)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查积分是否足够
|
||||
if currentPoints < amount {
|
||||
tx.Rollback()
|
||||
return errors.New("积分不足")
|
||||
}
|
||||
|
||||
// 更新用户积分
|
||||
newPoints := currentPoints - amount
|
||||
err = s.pointsRepo.UpdateUserPoints(userID, newPoints)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建积分历史记录
|
||||
history := &model.PointsHistory{
|
||||
UserID: userID,
|
||||
Type: 2, // 消费
|
||||
Points: amount,
|
||||
Description: description,
|
||||
OrderID: orderID,
|
||||
OrderNo: orderNo,
|
||||
ProductName: productName,
|
||||
}
|
||||
|
||||
err = s.pointsRepo.CreatePointsHistory(history)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// GetPointsHistory 获取积分历史记录
|
||||
func (s *PointsService) GetPointsHistory(userID uint, page, pageSize int) ([]model.PointsHistory, map[string]interface{}, error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
histories, total, err := s.pointsRepo.GetPointsHistory(userID, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建分页信息
|
||||
totalPages := (int(total) + pageSize - 1) / pageSize
|
||||
pagination := map[string]interface{}{
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"total_pages": totalPages,
|
||||
}
|
||||
|
||||
return histories, pagination, nil
|
||||
}
|
||||
|
||||
// GetPointsRules 获取积分规则列表
|
||||
func (s *PointsService) GetPointsRules() ([]model.PointsRule, error) {
|
||||
return s.pointsRepo.GetPointsRules()
|
||||
}
|
||||
|
||||
// GetPointsExchangeList 获取积分兑换商品列表
|
||||
func (s *PointsService) GetPointsExchangeList() ([]model.PointsExchange, error) {
|
||||
return s.pointsRepo.GetPointsExchangeList()
|
||||
}
|
||||
|
||||
// ExchangePoints 积分兑换
|
||||
func (s *PointsService) ExchangePoints(userID, exchangeID uint) error {
|
||||
// 开启事务
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
// 获取兑换商品信息
|
||||
exchange, err := s.pointsRepo.GetPointsExchangeByID(exchangeID)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return errors.New("兑换商品不存在")
|
||||
}
|
||||
|
||||
// 检查库存
|
||||
if exchange.Stock > 0 && exchange.ExchangeCount >= exchange.Stock {
|
||||
tx.Rollback()
|
||||
return errors.New("商品库存不足")
|
||||
}
|
||||
|
||||
// 获取用户当前积分
|
||||
currentPoints, err := s.pointsRepo.GetUserPoints(userID)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查积分是否足够
|
||||
if currentPoints < exchange.Points {
|
||||
tx.Rollback()
|
||||
return errors.New("积分不足")
|
||||
}
|
||||
|
||||
// 扣减积分
|
||||
description := fmt.Sprintf("兑换商品:%s", exchange.Name)
|
||||
err = s.DeductPoints(userID, exchange.Points, description, nil, "", exchange.Name)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建兑换记录
|
||||
record := &model.PointsExchangeRecord{
|
||||
UserID: userID,
|
||||
PointsExchangeID: exchangeID,
|
||||
Points: exchange.Points,
|
||||
Status: 1, // 已兑换
|
||||
}
|
||||
|
||||
err = s.pointsRepo.CreatePointsExchangeRecord(record)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新兑换次数
|
||||
err = s.pointsRepo.UpdatePointsExchangeCount(exchangeID)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// GetUserExchangeRecords 获取用户兑换记录
|
||||
func (s *PointsService) GetUserExchangeRecords(userID uint, page, pageSize int) ([]model.PointsExchangeRecord, map[string]interface{}, error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
records, total, err := s.pointsRepo.GetUserExchangeRecords(userID, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建分页信息
|
||||
totalPages := (int(total) + pageSize - 1) / pageSize
|
||||
pagination := map[string]interface{}{
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"total_pages": totalPages,
|
||||
}
|
||||
|
||||
return records, pagination, nil
|
||||
}
|
||||
|
||||
// GetPointsOverview 获取积分概览
|
||||
func (s *PointsService) GetPointsOverview(userID uint) (map[string]interface{}, error) {
|
||||
// 获取用户当前积分
|
||||
currentPoints, err := s.pointsRepo.GetUserPoints(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取积分历史统计
|
||||
var totalEarned, totalSpent, thisMonthEarned, thisMonthSpent int64
|
||||
|
||||
// 统计总获得积分
|
||||
s.db.Model(&model.PointsHistory{}).
|
||||
Where("user_id = ? AND type = ?", userID, 1).
|
||||
Select("COALESCE(SUM(points), 0)").
|
||||
Scan(&totalEarned)
|
||||
|
||||
// 统计总消费积分
|
||||
s.db.Model(&model.PointsHistory{}).
|
||||
Where("user_id = ? AND type = ?", userID, 2).
|
||||
Select("COALESCE(SUM(points), 0)").
|
||||
Scan(&totalSpent)
|
||||
|
||||
// 获取本月的开始时间和结束时间
|
||||
now := time.Now()
|
||||
firstDayOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
|
||||
firstDayOfNextMonth := firstDayOfMonth.AddDate(0, 1, 0)
|
||||
|
||||
// 统计本月获得积分
|
||||
s.db.Model(&model.PointsHistory{}).
|
||||
Where("user_id = ? AND type = ? AND created_at >= ? AND created_at < ?",
|
||||
userID, 1, firstDayOfMonth, firstDayOfNextMonth).
|
||||
Select("COALESCE(SUM(points), 0)").
|
||||
Scan(&thisMonthEarned)
|
||||
|
||||
// 统计本月消费积分
|
||||
s.db.Model(&model.PointsHistory{}).
|
||||
Where("user_id = ? AND type = ? AND created_at >= ? AND created_at < ?",
|
||||
userID, 2, firstDayOfMonth, firstDayOfNextMonth).
|
||||
Select("COALESCE(SUM(points), 0)").
|
||||
Scan(&thisMonthSpent)
|
||||
|
||||
overview := map[string]interface{}{
|
||||
"total_points": currentPoints,
|
||||
"available_points": currentPoints,
|
||||
"frozen_points": 0,
|
||||
"total_earned": totalEarned,
|
||||
"total_spent": totalSpent,
|
||||
"this_month_earned": thisMonthEarned,
|
||||
"this_month_spent": thisMonthSpent,
|
||||
}
|
||||
|
||||
return overview, nil
|
||||
}
|
||||
|
||||
// CheckAndGiveDailyLoginPoints 检查并给予每日首次登录积分
|
||||
func (s *PointsService) CheckAndGiveDailyLoginPoints(userID uint) (bool, error) {
|
||||
// 获取今天的开始时间(00:00:00)
|
||||
now := time.Now()
|
||||
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
tomorrow := today.Add(24 * time.Hour)
|
||||
|
||||
// 检查今天是否已经有登录积分记录
|
||||
var count int64
|
||||
err := s.db.Model(&model.PointsHistory{}).
|
||||
Where("user_id = ? AND description = ? AND created_at >= ? AND created_at < ?",
|
||||
userID, "每日首次登录", today, tomorrow).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("检查每日登录记录失败: %v", err)
|
||||
}
|
||||
|
||||
// 如果今天已经有登录积分记录,则不再给予
|
||||
if count > 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 给予每日登录积分(1积分)
|
||||
err = s.AddPoints(userID, 1, "每日首次登录", nil, "", "")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("给予每日登录积分失败: %v", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
869
server/internal/service/product.go
Normal file
869
server/internal/service/product.go
Normal file
@@ -0,0 +1,869 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"dianshang/internal/repository"
|
||||
"dianshang/pkg/utils"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ProductService 产品服务
|
||||
type ProductService struct {
|
||||
productRepo *repository.ProductRepository
|
||||
userRepo *repository.UserRepository
|
||||
}
|
||||
|
||||
// NewProductService 创建产品服务
|
||||
func NewProductService(productRepo *repository.ProductRepository, userRepo *repository.UserRepository) *ProductService {
|
||||
return &ProductService{
|
||||
productRepo: productRepo,
|
||||
userRepo: userRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetProductList 获取产品列表(前端用户)
|
||||
func (s *ProductService) GetProductList(page, pageSize int, categoryID uint, keyword string, minPrice, maxPrice float64, sort, sortType string) ([]model.Product, *utils.Pagination, error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
conditions := make(map[string]interface{})
|
||||
|
||||
if categoryID > 0 {
|
||||
conditions["category_id"] = categoryID
|
||||
}
|
||||
if keyword != "" {
|
||||
conditions["keyword"] = keyword
|
||||
}
|
||||
if minPrice > 0 {
|
||||
conditions["min_price"] = minPrice
|
||||
}
|
||||
if maxPrice > 0 {
|
||||
conditions["max_price"] = maxPrice
|
||||
}
|
||||
if sort != "" {
|
||||
conditions["sort"] = sort
|
||||
}
|
||||
if sortType != "" {
|
||||
conditions["sort_type"] = sortType
|
||||
}
|
||||
|
||||
products, total, err := s.productRepo.GetList(offset, pageSize, conditions)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pagination := utils.NewPagination(page, pageSize)
|
||||
pagination.Total = total
|
||||
return products, pagination, nil
|
||||
}
|
||||
|
||||
// GetProductListForAdmin 获取产品列表(管理系统)
|
||||
func (s *ProductService) GetProductListForAdmin(page, pageSize int, categoryID uint, keyword string, minPrice, maxPrice float64, sort, sortType, status, isHot, isNew, isRecommend string) ([]model.Product, *utils.Pagination, error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
conditions := make(map[string]interface{})
|
||||
|
||||
if categoryID > 0 {
|
||||
conditions["category_id"] = categoryID
|
||||
}
|
||||
if keyword != "" {
|
||||
conditions["keyword"] = keyword
|
||||
}
|
||||
if minPrice > 0 {
|
||||
conditions["min_price"] = minPrice
|
||||
}
|
||||
if maxPrice > 0 {
|
||||
conditions["max_price"] = maxPrice
|
||||
}
|
||||
if sort != "" {
|
||||
conditions["sort"] = sort
|
||||
}
|
||||
if sortType != "" {
|
||||
conditions["sort_type"] = sortType
|
||||
}
|
||||
// 添加状态条件,支持获取所有状态的商品
|
||||
if status != "" {
|
||||
conditions["status"] = status
|
||||
}
|
||||
// 添加热门、新品、推荐筛选条件
|
||||
if isHot != "" {
|
||||
conditions["is_hot"] = isHot
|
||||
}
|
||||
if isNew != "" {
|
||||
conditions["is_new"] = isNew
|
||||
}
|
||||
if isRecommend != "" {
|
||||
conditions["is_recommend"] = isRecommend
|
||||
}
|
||||
|
||||
products, total, err := s.productRepo.GetList(offset, pageSize, conditions)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pagination := utils.NewPagination(page, pageSize)
|
||||
pagination.Total = total
|
||||
return products, pagination, nil
|
||||
}
|
||||
|
||||
// GetProductDetail 获取产品详情
|
||||
func (s *ProductService) GetProductDetail(id uint) (*model.Product, error) {
|
||||
return s.productRepo.GetByID(id)
|
||||
}
|
||||
|
||||
// CreateProduct 创建产品
|
||||
func (s *ProductService) CreateProduct(product *model.Product) error {
|
||||
// 验证分类是否存在
|
||||
if product.CategoryID > 0 {
|
||||
_, err := s.productRepo.GetCategoryByID(product.CategoryID)
|
||||
if err != nil {
|
||||
return errors.New("分类不存在")
|
||||
}
|
||||
}
|
||||
|
||||
return s.productRepo.Create(product)
|
||||
}
|
||||
|
||||
// UpdateProduct 更新产品
|
||||
func (s *ProductService) UpdateProduct(id uint, updates map[string]interface{}) error {
|
||||
// 检查产品是否存在
|
||||
_, err := s.productRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return errors.New("产品不存在")
|
||||
}
|
||||
|
||||
// 如果更新分类,验证分类是否存在
|
||||
if categoryID, ok := updates["category_id"]; ok {
|
||||
var catID uint
|
||||
switch v := categoryID.(type) {
|
||||
case uint:
|
||||
catID = v
|
||||
case float64:
|
||||
catID = uint(v)
|
||||
case int:
|
||||
catID = uint(v)
|
||||
default:
|
||||
return errors.New("分类ID格式错误")
|
||||
}
|
||||
|
||||
if catID > 0 {
|
||||
_, err := s.productRepo.GetCategoryByID(catID)
|
||||
if err != nil {
|
||||
return errors.New("分类不存在")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 detail_images 字段 - 确保正确转换为 JSONSlice 类型
|
||||
if detailImages, ok := updates["detail_images"]; ok {
|
||||
switch v := detailImages.(type) {
|
||||
case []interface{}:
|
||||
// 将 []interface{} 转换为 []string
|
||||
var stringSlice []string
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok {
|
||||
stringSlice = append(stringSlice, str)
|
||||
}
|
||||
}
|
||||
updates["detail_images"] = model.JSONSlice(stringSlice)
|
||||
case []string:
|
||||
updates["detail_images"] = model.JSONSlice(v)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 images 字段 - 确保正确转换为 JSONSlice 类型
|
||||
if images, ok := updates["images"]; ok {
|
||||
switch v := images.(type) {
|
||||
case []interface{}:
|
||||
// 将 []interface{} 转换为 []string
|
||||
var stringSlice []string
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok {
|
||||
stringSlice = append(stringSlice, str)
|
||||
}
|
||||
}
|
||||
updates["images"] = model.JSONSlice(stringSlice)
|
||||
case []string:
|
||||
updates["images"] = model.JSONSlice(v)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理SKU数据
|
||||
var skusData []interface{}
|
||||
if skus, ok := updates["skus"]; ok {
|
||||
skusData, _ = skus.([]interface{})
|
||||
// 从updates中移除skus,避免直接更新到Product表
|
||||
delete(updates, "skus")
|
||||
}
|
||||
|
||||
// 更新商品基本信息
|
||||
if err := s.productRepo.Update(id, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理SKU数据
|
||||
if len(skusData) > 0 {
|
||||
if err := s.handleProductSKUs(id, skusData); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleProductSKUs 处理商品SKU数据
|
||||
func (s *ProductService) handleProductSKUs(productID uint, skusData []interface{}) error {
|
||||
// 获取当前商品的所有现有SKU
|
||||
existingSKUs, err := s.productRepo.GetProductSKUs(productID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 收集前端发送的SKU ID列表
|
||||
var submittedSKUIDs []uint
|
||||
|
||||
// 处理前端发送的SKU数据
|
||||
for _, skuData := range skusData {
|
||||
skuMap, ok := skuData.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// 创建SKU对象
|
||||
sku := &model.ProductSKU{
|
||||
ProductID: productID,
|
||||
}
|
||||
|
||||
// 处理SKU字段
|
||||
if skuCode, ok := skuMap["sku_code"].(string); ok {
|
||||
sku.SKUCode = skuCode
|
||||
}
|
||||
|
||||
if price, ok := skuMap["price"].(float64); ok {
|
||||
sku.Price = price
|
||||
}
|
||||
|
||||
if stock, ok := skuMap["stock"]; ok {
|
||||
switch v := stock.(type) {
|
||||
case float64:
|
||||
sku.Stock = int(v)
|
||||
case int:
|
||||
sku.Stock = v
|
||||
case string:
|
||||
if stockInt, err := strconv.Atoi(v); err == nil {
|
||||
sku.Stock = stockInt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理spec_values
|
||||
if specValues, ok := skuMap["spec_values"]; ok {
|
||||
if specMap, ok := specValues.(map[string]interface{}); ok {
|
||||
sku.SpecValues = model.JSONMap(specMap)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理image字段
|
||||
if image, ok := skuMap["image"].(string); ok {
|
||||
sku.Image = image
|
||||
}
|
||||
|
||||
// 检查是否是更新还是创建
|
||||
var isUpdate bool
|
||||
var skuIDValue uint
|
||||
|
||||
if skuID, ok := skuMap["id"]; ok && skuID != nil {
|
||||
switch v := skuID.(type) {
|
||||
case float64:
|
||||
if v > 0 {
|
||||
isUpdate = true
|
||||
skuIDValue = uint(v)
|
||||
submittedSKUIDs = append(submittedSKUIDs, skuIDValue)
|
||||
}
|
||||
case int:
|
||||
if v > 0 {
|
||||
isUpdate = true
|
||||
skuIDValue = uint(v)
|
||||
submittedSKUIDs = append(submittedSKUIDs, skuIDValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isUpdate {
|
||||
// 更新现有SKU
|
||||
updates := make(map[string]interface{})
|
||||
if sku.SKUCode != "" {
|
||||
updates["sku_code"] = sku.SKUCode
|
||||
}
|
||||
updates["price"] = sku.Price
|
||||
updates["stock"] = sku.Stock
|
||||
// 直接传递JSONMap类型,让GORM处理序列化
|
||||
updates["spec_values"] = sku.SpecValues
|
||||
// 添加image字段的更新
|
||||
if sku.Image != "" {
|
||||
updates["image"] = sku.Image
|
||||
}
|
||||
|
||||
if err := s.productRepo.UpdateSKU(skuIDValue, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// 创建新SKU - 确保不设置ID字段
|
||||
sku.ID = 0 // 明确设置为0,让数据库自动生成
|
||||
if sku.SKUCode == "" {
|
||||
// 生成默认SKU代码
|
||||
sku.SKUCode = fmt.Sprintf("SKU-%d-%d", productID, time.Now().Unix())
|
||||
}
|
||||
if err := s.productRepo.CreateSKU(sku); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不在前端提交列表中的现有SKU
|
||||
for _, existingSKU := range existingSKUs {
|
||||
shouldDelete := true
|
||||
for _, submittedID := range submittedSKUIDs {
|
||||
if existingSKU.ID == submittedID {
|
||||
shouldDelete = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if shouldDelete {
|
||||
if err := s.productRepo.DeleteSKU(existingSKU.ID); err != nil {
|
||||
return fmt.Errorf("删除SKU失败 - SKU ID: %d, 错误: %v", existingSKU.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理完所有SKU后,同步商品库存
|
||||
if err := s.productRepo.SyncProductStockFromSKUs(productID); err != nil {
|
||||
// 记录错误但不阻止操作
|
||||
fmt.Printf("同步商品库存失败 - 商品ID: %d, 错误: %v\n", productID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteProduct 删除产品
|
||||
func (s *ProductService) DeleteProduct(id uint) error {
|
||||
// 检查产品是否存在
|
||||
_, err := s.productRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return errors.New("产品不存在")
|
||||
}
|
||||
|
||||
return s.productRepo.Delete(id)
|
||||
}
|
||||
|
||||
// GetCategories 获取分类列表
|
||||
func (s *ProductService) GetCategories() ([]model.Category, error) {
|
||||
return s.productRepo.GetCategories()
|
||||
}
|
||||
|
||||
// CreateCategory 创建分类
|
||||
func (s *ProductService) CreateCategory(category *model.Category) error {
|
||||
return s.productRepo.CreateCategory(category)
|
||||
}
|
||||
|
||||
// UpdateCategory 更新分类
|
||||
func (s *ProductService) UpdateCategory(id uint, updates map[string]interface{}) error {
|
||||
// 检查分类是否存在
|
||||
_, err := s.productRepo.GetCategoryByID(id)
|
||||
if err != nil {
|
||||
return errors.New("分类不存在")
|
||||
}
|
||||
|
||||
return s.productRepo.UpdateCategory(id, updates)
|
||||
}
|
||||
|
||||
// DeleteCategory 删除分类
|
||||
func (s *ProductService) DeleteCategory(id uint) error {
|
||||
// 检查分类是否存在
|
||||
_, err := s.productRepo.GetCategoryByID(id)
|
||||
if err != nil {
|
||||
return errors.New("分类不存在")
|
||||
}
|
||||
|
||||
// 检查分类下是否有商品
|
||||
productCount, err := s.productRepo.CountProductsByCategory(id)
|
||||
if err != nil {
|
||||
return errors.New("检查分类商品数量失败")
|
||||
}
|
||||
|
||||
if productCount > 0 {
|
||||
return errors.New("该分类下还有商品,无法删除")
|
||||
}
|
||||
|
||||
// 检查是否有子分类
|
||||
var childCategories []model.Category
|
||||
err = s.productRepo.GetDB().Where("parent_id = ?", id).Find(&childCategories).Error
|
||||
if err != nil {
|
||||
return errors.New("检查子分类失败")
|
||||
}
|
||||
|
||||
if len(childCategories) > 0 {
|
||||
return errors.New("该分类下还有子分类,请先删除子分类")
|
||||
}
|
||||
|
||||
return s.productRepo.DeleteCategory(id)
|
||||
}
|
||||
|
||||
// GetProductReviews 获取产品评价列表
|
||||
func (s *ProductService) GetProductReviews(productID uint, page, pageSize int) ([]model.ProductReview, *utils.Pagination, error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
reviews, total, err := s.productRepo.GetReviews(productID, offset, pageSize)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pagination := utils.NewPagination(page, pageSize)
|
||||
pagination.Total = total
|
||||
return reviews, pagination, nil
|
||||
}
|
||||
|
||||
// CreateReview 创建评价
|
||||
func (s *ProductService) CreateReview(userID uint, review *model.ProductReview) error {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 检查产品是否存在
|
||||
_, err = s.productRepo.GetByID(review.ProductID)
|
||||
if err != nil {
|
||||
return errors.New("产品不存在")
|
||||
}
|
||||
|
||||
// 检查是否已经评价过
|
||||
if review.OrderID != nil {
|
||||
existingReview, _ := s.productRepo.GetReviewByOrderID(userID, *review.OrderID)
|
||||
if existingReview != nil {
|
||||
return errors.New("已经评价过该商品")
|
||||
}
|
||||
}
|
||||
|
||||
review.UserID = userID
|
||||
return s.productRepo.CreateReview(review)
|
||||
}
|
||||
|
||||
// GetHotProducts 获取热门产品
|
||||
func (s *ProductService) GetHotProducts(limit int) ([]model.Product, error) {
|
||||
if limit <= 0 || limit > 50 {
|
||||
limit = 10
|
||||
}
|
||||
return s.productRepo.GetHotProducts(limit)
|
||||
}
|
||||
|
||||
// GetRecommendProducts 获取推荐产品
|
||||
func (s *ProductService) GetRecommendProducts(limit int) ([]model.Product, error) {
|
||||
if limit <= 0 || limit > 50 {
|
||||
limit = 10
|
||||
}
|
||||
return s.productRepo.GetRecommendProducts(limit)
|
||||
}
|
||||
|
||||
// SearchProducts 搜索产品(支持价格与排序)
|
||||
func (s *ProductService) SearchProducts(keyword string, page, pageSize int, minPrice, maxPrice float64, sort, sortType string) ([]model.Product, *utils.Pagination, error) {
|
||||
if keyword == "" {
|
||||
return []model.Product{}, utils.NewPagination(page, pageSize), nil
|
||||
}
|
||||
|
||||
return s.GetProductList(page, pageSize, 0, keyword, minPrice, maxPrice, sort, sortType)
|
||||
}
|
||||
|
||||
// UpdateStock 更新库存
|
||||
func (s *ProductService) UpdateStock(id uint, quantity int) error {
|
||||
// 检查产品是否存在
|
||||
product, err := s.productRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return errors.New("产品不存在")
|
||||
}
|
||||
|
||||
// 检查库存是否足够(减库存时)
|
||||
if quantity < 0 && product.Stock < -quantity {
|
||||
return errors.New("库存不足")
|
||||
}
|
||||
|
||||
return s.productRepo.UpdateStock(id, quantity)
|
||||
}
|
||||
|
||||
// GetProductSKUs 获取产品SKU列表
|
||||
func (s *ProductService) GetProductSKUs(productID uint) ([]model.ProductSKU, error) {
|
||||
return s.productRepo.GetProductSKUs(productID)
|
||||
}
|
||||
|
||||
// GetSKUByID 根据SKU ID获取SKU详情
|
||||
func (s *ProductService) GetSKUByID(skuID uint) (*model.ProductSKU, error) {
|
||||
return s.productRepo.GetSKUByID(skuID)
|
||||
}
|
||||
|
||||
// GetProductTags 获取产品标签列表
|
||||
func (s *ProductService) GetProductTags() ([]model.ProductTag, error) {
|
||||
return s.productRepo.GetProductTags()
|
||||
}
|
||||
|
||||
// GetStores 获取店铺列表
|
||||
func (s *ProductService) GetStores() ([]model.Store, error) {
|
||||
return s.productRepo.GetStores()
|
||||
}
|
||||
|
||||
// GetStoreByID 根据ID获取店铺信息
|
||||
func (s *ProductService) GetStoreByID(id uint) (*model.Store, error) {
|
||||
return s.productRepo.GetStoreByID(id)
|
||||
}
|
||||
|
||||
// GetProductReviewCount 获取产品评价统计
|
||||
func (s *ProductService) GetProductReviewCount(productID uint) (map[string]interface{}, error) {
|
||||
// 检查产品是否存在
|
||||
_, err := s.productRepo.GetByID(productID)
|
||||
if err != nil {
|
||||
return nil, errors.New("产品不存在")
|
||||
}
|
||||
|
||||
return s.productRepo.GetReviewCount(productID)
|
||||
}
|
||||
|
||||
// GetProductStatistics 获取产品统计
|
||||
func (s *ProductService) GetProductStatistics() (map[string]interface{}, error) {
|
||||
// 使用ProductRepository的GetProductStatistics方法
|
||||
return s.productRepo.GetProductStatistics()
|
||||
}
|
||||
|
||||
// GetProductSalesRanking 获取产品销售排行
|
||||
func (s *ProductService) GetProductSalesRanking(startDate, endDate, limit string) ([]map[string]interface{}, error) {
|
||||
// 简化实现,返回基础排行数据
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 这里应该根据订单数据统计产品销量,暂时返回模拟数据
|
||||
products, _, err := s.productRepo.GetList(0, 10, map[string]interface{}{"status": 1})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, product := range products {
|
||||
if i >= 10 { // 限制返回数量
|
||||
break
|
||||
}
|
||||
results = append(results, map[string]interface{}{
|
||||
"product_id": product.ID,
|
||||
"product_name": product.Name,
|
||||
"sales_count": 100 - i*5, // 模拟销量数据
|
||||
"sales_amount": float64(1000 - i*50),
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetCategorySalesRanking 获取分类销售排行
|
||||
func (s *ProductService) GetCategorySalesRanking(startDate, endDate, limit string) ([]map[string]interface{}, error) {
|
||||
// 解析limit参数
|
||||
limitInt := 10 // 默认值
|
||||
if limit != "" {
|
||||
if parsedLimit, err := strconv.Atoi(limit); err == nil && parsedLimit > 0 {
|
||||
limitInt = parsedLimit
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有提供日期范围,使用最近30天
|
||||
if startDate == "" || endDate == "" {
|
||||
now := time.Now()
|
||||
endDate = now.Format("2006-01-02")
|
||||
startDate = now.AddDate(0, 0, -30).Format("2006-01-02")
|
||||
}
|
||||
|
||||
// 使用真实的数据库查询
|
||||
return s.productRepo.GetCategorySalesStatistics(startDate, endDate, limitInt)
|
||||
}
|
||||
|
||||
// BatchUpdateProductStatus 批量更新商品状态
|
||||
func (s *ProductService) BatchUpdateProductStatus(ids []uint, status int) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("商品ID列表不能为空")
|
||||
}
|
||||
|
||||
return s.productRepo.BatchUpdateStatus(ids, status)
|
||||
}
|
||||
|
||||
// BatchUpdateProductPrice 批量更新商品价格
|
||||
func (s *ProductService) BatchUpdateProductPrice(updates []map[string]interface{}) error {
|
||||
if len(updates) == 0 {
|
||||
return errors.New("更新数据不能为空")
|
||||
}
|
||||
|
||||
return s.productRepo.BatchUpdatePrice(updates)
|
||||
}
|
||||
|
||||
// BatchDeleteProducts 批量删除商品
|
||||
func (s *ProductService) BatchDeleteProducts(ids []uint) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("商品ID列表不能为空")
|
||||
}
|
||||
|
||||
return s.productRepo.BatchDelete(ids)
|
||||
}
|
||||
|
||||
// CreateProductSKU 创建商品SKU
|
||||
func (s *ProductService) CreateProductSKU(sku *model.ProductSKU) error {
|
||||
// 验证商品是否存在
|
||||
_, err := s.productRepo.GetByID(sku.ProductID)
|
||||
if err != nil {
|
||||
return errors.New("商品不存在")
|
||||
}
|
||||
|
||||
return s.productRepo.CreateSKU(sku)
|
||||
}
|
||||
|
||||
// UpdateProductSKU 更新商品SKU
|
||||
func (s *ProductService) UpdateProductSKU(id uint, updates map[string]interface{}) error {
|
||||
// 检查SKU是否存在
|
||||
_, err := s.productRepo.GetSKUByID(id)
|
||||
if err != nil {
|
||||
return errors.New("SKU不存在")
|
||||
}
|
||||
|
||||
return s.productRepo.UpdateSKU(id, updates)
|
||||
}
|
||||
|
||||
// DeleteProductSKU 删除商品SKU
|
||||
func (s *ProductService) DeleteProductSKU(id uint) error {
|
||||
// 检查SKU是否存在(包括已软删除的)
|
||||
var sku model.ProductSKU
|
||||
err := s.productRepo.GetDB().Where("id = ?", id).First(&sku).Error
|
||||
if err != nil {
|
||||
return errors.New("SKU不存在")
|
||||
}
|
||||
|
||||
// 如果SKU已经被软删除,直接返回成功
|
||||
if sku.Status == 0 {
|
||||
fmt.Printf("SKU ID %d 已经被软删除,无需重复操作\n", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查SKU是否被订单引用
|
||||
var count int64
|
||||
err = s.productRepo.GetDB().Table("order_items").Where("sk_uid = ?", id).Count(&count).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("检查SKU引用关系失败: %v", err)
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
// 如果被订单引用,执行软删除
|
||||
err = s.productRepo.DeleteSKU(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除SKU失败: %v", err)
|
||||
}
|
||||
// 软删除成功,记录日志但不返回错误
|
||||
fmt.Printf("SKU ID %d 已被 %d 个订单引用,已执行软删除(设置为不可用状态)\n", id, count)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 如果没有被引用,执行硬删除
|
||||
err = s.productRepo.DeleteSKU(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除SKU失败: %v", err)
|
||||
}
|
||||
|
||||
// 同步更新商品库存
|
||||
if err := s.productRepo.SyncProductStockFromSKUs(sku.ProductID); err != nil {
|
||||
// 记录错误但不阻止删除操作
|
||||
fmt.Printf("同步商品库存失败 - 商品ID: %d, 错误: %v\n", sku.ProductID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProductImages 获取商品图片列表
|
||||
func (s *ProductService) GetProductImages(productID uint) ([]model.ProductImage, error) {
|
||||
return s.productRepo.GetProductImages(productID)
|
||||
}
|
||||
|
||||
// CreateProductImage 创建商品图片
|
||||
func (s *ProductService) CreateProductImage(image *model.ProductImage) error {
|
||||
// 验证商品是否存在
|
||||
_, err := s.productRepo.GetByID(image.ProductID)
|
||||
if err != nil {
|
||||
return errors.New("商品不存在")
|
||||
}
|
||||
|
||||
return s.productRepo.CreateProductImage(image)
|
||||
}
|
||||
|
||||
// UpdateProductImageSort 更新商品图片排序
|
||||
func (s *ProductService) UpdateProductImageSort(id uint, sort int) error {
|
||||
return s.productRepo.UpdateProductImageSort(id, sort)
|
||||
}
|
||||
|
||||
// DeleteProductImage 删除商品图片
|
||||
func (s *ProductService) DeleteProductImage(id uint) error {
|
||||
return s.productRepo.DeleteProductImage(id)
|
||||
}
|
||||
|
||||
// CreateProductSpec 创建商品规格
|
||||
func (s *ProductService) CreateProductSpec(spec *model.ProductSpec) error {
|
||||
// 验证商品是否存在
|
||||
_, err := s.productRepo.GetByID(spec.ProductID)
|
||||
if err != nil {
|
||||
return errors.New("商品不存在")
|
||||
}
|
||||
|
||||
return s.productRepo.CreateProductSpec(spec)
|
||||
}
|
||||
|
||||
// UpdateProductSpec 更新商品规格
|
||||
func (s *ProductService) UpdateProductSpec(id uint, updates map[string]interface{}) error {
|
||||
return s.productRepo.UpdateProductSpec(id, updates)
|
||||
}
|
||||
|
||||
// DeleteProductSpec 删除商品规格
|
||||
func (s *ProductService) DeleteProductSpec(id uint) error {
|
||||
return s.productRepo.DeleteProductSpec(id)
|
||||
}
|
||||
|
||||
// GetProductSpecs 获取商品规格列表
|
||||
func (s *ProductService) GetProductSpecs(productID uint) ([]model.ProductSpec, error) {
|
||||
return s.productRepo.GetProductSpecs(productID)
|
||||
}
|
||||
|
||||
// CreateProductTag 创建商品标签
|
||||
func (s *ProductService) CreateProductTag(tag *model.ProductTag) error {
|
||||
return s.productRepo.CreateProductTag(tag)
|
||||
}
|
||||
|
||||
// UpdateProductTag 更新商品标签
|
||||
func (s *ProductService) UpdateProductTag(id uint, updates map[string]interface{}) error {
|
||||
return s.productRepo.UpdateProductTag(id, updates)
|
||||
}
|
||||
|
||||
// DeleteProductTag 删除商品标签
|
||||
func (s *ProductService) DeleteProductTag(id uint) error {
|
||||
return s.productRepo.DeleteProductTag(id)
|
||||
}
|
||||
|
||||
// AssignTagsToProduct 为商品分配标签
|
||||
func (s *ProductService) AssignTagsToProduct(productID uint, tagIDs []uint) error {
|
||||
// 验证商品是否存在
|
||||
_, err := s.productRepo.GetByID(productID)
|
||||
if err != nil {
|
||||
return errors.New("商品不存在")
|
||||
}
|
||||
|
||||
return s.productRepo.AssignTagsToProduct(productID, tagIDs)
|
||||
}
|
||||
|
||||
// GetLowStockProducts 获取低库存商品
|
||||
func (s *ProductService) GetLowStockProducts(threshold int) ([]model.Product, error) {
|
||||
if threshold <= 0 {
|
||||
threshold = 10 // 默认阈值
|
||||
}
|
||||
|
||||
return s.productRepo.GetLowStockProducts(threshold)
|
||||
}
|
||||
|
||||
// GetInventoryStatistics 获取库存统计
|
||||
func (s *ProductService) GetInventoryStatistics() (map[string]interface{}, error) {
|
||||
return s.productRepo.GetInventoryStatistics()
|
||||
}
|
||||
|
||||
// ExportProducts 导出商品数据
|
||||
func (s *ProductService) ExportProducts(conditions map[string]interface{}) ([]model.Product, error) {
|
||||
return s.productRepo.GetProductsForExport(conditions)
|
||||
}
|
||||
|
||||
// ImportProducts 导入商品数据
|
||||
func (s *ProductService) ImportProducts(products []model.Product) (map[string]interface{}, error) {
|
||||
successCount := 0
|
||||
failCount := 0
|
||||
var errors []string
|
||||
|
||||
for _, product := range products {
|
||||
// 验证商品数据
|
||||
if product.Name == "" {
|
||||
errors = append(errors, "商品名称不能为空")
|
||||
failCount++
|
||||
continue
|
||||
}
|
||||
|
||||
if product.Price <= 0 {
|
||||
errors = append(errors, "商品价格必须大于0")
|
||||
failCount++
|
||||
continue
|
||||
}
|
||||
|
||||
// 创建商品
|
||||
err := s.productRepo.Create(&product)
|
||||
if err != nil {
|
||||
errors = append(errors, err.Error())
|
||||
failCount++
|
||||
} else {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success_count": successCount,
|
||||
"fail_count": failCount,
|
||||
"errors": errors,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SyncProductStock 同步商品库存(根据SKU库存计算)
|
||||
func (s *ProductService) SyncProductStock(productID uint) error {
|
||||
return s.productRepo.SyncProductStockFromSKUs(productID)
|
||||
}
|
||||
|
||||
// SyncAllProductsStock 同步所有商品库存
|
||||
func (s *ProductService) SyncAllProductsStock() error {
|
||||
// 获取所有有SKU的商品
|
||||
products, _, err := s.productRepo.GetList(0, 0, map[string]interface{}{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var syncErrors []string
|
||||
for _, product := range products {
|
||||
// 检查商品是否有SKU
|
||||
skus, err := s.productRepo.GetProductSKUs(product.ID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(skus) > 0 {
|
||||
// 如果有SKU,同步库存
|
||||
err = s.productRepo.SyncProductStockFromSKUs(product.ID)
|
||||
if err != nil {
|
||||
syncErrors = append(syncErrors, fmt.Sprintf("商品ID %d 同步失败: %v", product.ID, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(syncErrors) > 0 {
|
||||
return fmt.Errorf("部分商品同步失败: %v", syncErrors)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
829
server/internal/service/refund.go
Normal file
829
server/internal/service/refund.go
Normal file
@@ -0,0 +1,829 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"dianshang/internal/model"
|
||||
"dianshang/internal/repository"
|
||||
"dianshang/pkg/logger"
|
||||
"dianshang/pkg/utils"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RefundService struct {
|
||||
refundRepo *repository.RefundRepository
|
||||
orderRepo *repository.OrderRepository
|
||||
wechatPaySvc *WeChatPayService
|
||||
}
|
||||
|
||||
func NewRefundService(refundRepo *repository.RefundRepository, orderRepo *repository.OrderRepository, wechatPaySvc *WeChatPayService) *RefundService {
|
||||
return &RefundService{
|
||||
refundRepo: refundRepo,
|
||||
orderRepo: orderRepo,
|
||||
wechatPaySvc: wechatPaySvc,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateRefund 创建退款申请
|
||||
func (s *RefundService) CreateRefund(ctx context.Context, req *CreateRefundRequest) (*model.Refund, error) {
|
||||
logger.Info("开始创建退款申请",
|
||||
"orderID", req.OrderID,
|
||||
"refundAmount", req.RefundAmount,
|
||||
"refundReason", req.RefundReason,
|
||||
"userID", req.UserID)
|
||||
|
||||
// 1. 验证订单
|
||||
order, err := s.orderRepo.GetByID(req.OrderID)
|
||||
if err != nil {
|
||||
logger.Error("查询订单失败", "error", err, "orderID", req.OrderID)
|
||||
return nil, fmt.Errorf("订单不存在")
|
||||
}
|
||||
|
||||
// 2. 验证订单状态
|
||||
if order.Status != model.OrderStatusPaid {
|
||||
return nil, fmt.Errorf("订单状态不允许退款,当前状态: %s", order.GetStatusText())
|
||||
}
|
||||
|
||||
// 3. 验证用户权限
|
||||
if order.UserID != req.UserID {
|
||||
return nil, fmt.Errorf("无权限操作此订单")
|
||||
}
|
||||
|
||||
// 4. 验证退款金额
|
||||
if req.RefundAmount <= 0 {
|
||||
return nil, fmt.Errorf("退款金额必须大于0")
|
||||
}
|
||||
|
||||
// 将前端传递的元金额转换为分(数据库统一使用分存储)
|
||||
refundAmountInCents := req.RefundAmount * 100
|
||||
|
||||
// 计算已退款金额
|
||||
totalRefunded, err := s.refundRepo.GetTotalRefundedByOrderID(req.OrderID)
|
||||
if err != nil {
|
||||
logger.Error("查询已退款金额失败", "error", err, "orderID", req.OrderID)
|
||||
return nil, fmt.Errorf("查询退款信息失败")
|
||||
}
|
||||
|
||||
// 检查退款金额是否超过可退款金额(订单金额也需要转换为分进行比较)
|
||||
orderAmountInCents := order.TotalAmount * 100
|
||||
availableRefund := orderAmountInCents - totalRefunded
|
||||
if refundAmountInCents > availableRefund {
|
||||
return nil, fmt.Errorf("退款金额超过可退款金额,可退款: %.2f", availableRefund/100.0)
|
||||
}
|
||||
|
||||
// 5. 生成退款记录
|
||||
refund := &model.Refund{
|
||||
RefundNo: utils.GenerateRefundNo(),
|
||||
WechatOutRefundNo: utils.GenerateWechatOutRefundNo(),
|
||||
OrderID: req.OrderID,
|
||||
OrderNo: order.OrderNo, // 设置订单号
|
||||
UserID: req.UserID,
|
||||
RefundAmount: refundAmountInCents, // 存储为分
|
||||
ActualRefundAmount: refundAmountInCents, // 设置实际退款金额,初始等于申请退款金额(分)
|
||||
RefundReason: req.RefundReason,
|
||||
RefundType: req.RefundType,
|
||||
Status: model.RefundStatusPending,
|
||||
WechatRefundStatus: "",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 6. 保存退款记录
|
||||
err = s.refundRepo.Create(refund)
|
||||
if err != nil {
|
||||
logger.Error("创建退款记录失败", "error", err)
|
||||
return nil, fmt.Errorf("创建退款申请失败")
|
||||
}
|
||||
|
||||
// 7. 更新订单状态为退款中
|
||||
if order.Status != model.OrderStatusReturning {
|
||||
orderUpdates := map[string]interface{}{
|
||||
"status": model.OrderStatusReturning,
|
||||
"refund_time": time.Now(),
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
err = s.orderRepo.UpdateByID(order.ID, orderUpdates)
|
||||
if err != nil {
|
||||
logger.Error("更新订单状态为退款中失败", "error", err, "orderID", order.ID)
|
||||
// 不返回错误,因为退款记录已经创建成功
|
||||
} else {
|
||||
logger.Info("订单状态已更新为退款中", "orderID", order.ID, "orderNo", order.OrderNo)
|
||||
}
|
||||
}
|
||||
|
||||
// 8. 创建退款日志
|
||||
statusTo := model.RefundStatusPending
|
||||
userID := req.UserID
|
||||
err = s.createRefundLog(refund.ID, "create", nil, &statusTo, "用户申请退款", &userID)
|
||||
if err != nil {
|
||||
logger.Warn("创建退款日志失败", "error", err, "refundID", refund.ID)
|
||||
}
|
||||
|
||||
logger.Info("退款申请创建成功",
|
||||
"refundID", refund.ID,
|
||||
"refundNo", refund.RefundNo,
|
||||
"orderID", req.OrderID,
|
||||
"refundAmountYuan", req.RefundAmount,
|
||||
"refundAmountCents", refundAmountInCents)
|
||||
|
||||
return refund, nil
|
||||
}
|
||||
|
||||
// ProcessRefund 处理退款(管理员审核通过后调用)
|
||||
func (s *RefundService) ProcessRefund(ctx context.Context, refundID uint, adminID uint, adminRemark string) error {
|
||||
logger.Info("开始处理退款", "refundID", refundID, "adminID", adminID)
|
||||
|
||||
// 1. 查询退款记录
|
||||
refund, err := s.refundRepo.GetByID(refundID)
|
||||
if err != nil {
|
||||
logger.Error("查询退款记录失败", "error", err, "refundID", refundID)
|
||||
return fmt.Errorf("退款记录不存在")
|
||||
}
|
||||
|
||||
// 2. 验证退款状态
|
||||
if refund.Status != model.RefundStatusPending {
|
||||
return fmt.Errorf("退款状态不允许处理,当前状态: %s", refund.GetStatusText())
|
||||
}
|
||||
|
||||
// 3. 查询订单信息
|
||||
order, err := s.orderRepo.GetByID(refund.OrderID)
|
||||
if err != nil {
|
||||
logger.Error("查询订单失败", "error", err, "orderID", refund.OrderID)
|
||||
return fmt.Errorf("订单不存在")
|
||||
}
|
||||
|
||||
// 4. 更新退款状态为处理中
|
||||
err = s.refundRepo.UpdateByID(refundID, map[string]interface{}{
|
||||
"status": model.RefundStatusProcessing,
|
||||
"admin_remark": adminRemark,
|
||||
"audit_time": time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("更新退款状态失败", "error", err, "refundID", refundID)
|
||||
return fmt.Errorf("更新退款状态失败")
|
||||
}
|
||||
|
||||
// 5. 创建退款日志
|
||||
statusFrom := model.RefundStatusPending
|
||||
statusTo := model.RefundStatusProcessing
|
||||
err = s.createRefundLog(refundID, "approve", &statusFrom, &statusTo, fmt.Sprintf("管理员审核通过: %s", adminRemark), &adminID)
|
||||
if err != nil {
|
||||
logger.Warn("创建退款日志失败", "error", err, "refundID", refundID)
|
||||
}
|
||||
|
||||
// 6. 调用微信退款API
|
||||
wechatResp, err := s.wechatPaySvc.CreateRefund(ctx, refund, order)
|
||||
if err != nil {
|
||||
logger.Error("调用微信退款API失败", "error", err, "refundID", refundID)
|
||||
|
||||
// 更新退款状态为失败
|
||||
s.refundRepo.UpdateByID(refundID, map[string]interface{}{
|
||||
"status": model.RefundStatusFailed,
|
||||
"admin_remark": fmt.Sprintf("微信退款失败: %v", err),
|
||||
})
|
||||
statusFrom := model.RefundStatusProcessing
|
||||
statusTo := model.RefundStatusFailed
|
||||
s.createRefundLog(refundID, "fail", &statusFrom, &statusTo, fmt.Sprintf("微信退款失败: %v", err), &adminID)
|
||||
|
||||
return fmt.Errorf("微信退款失败: %v", err)
|
||||
}
|
||||
|
||||
// 7. 更新退款记录的微信信息
|
||||
updates := map[string]interface{}{
|
||||
"wechat_refund_id": wechatResp.Data["refund_id"],
|
||||
"wechat_refund_status": wechatResp.Data["status"],
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
// 如果微信返回了用户收款账户信息
|
||||
if userAccount, ok := wechatResp.Data["user_received_account"].(string); ok && userAccount != "" {
|
||||
updates["wechat_user_received_account"] = userAccount
|
||||
}
|
||||
|
||||
// 如果微信返回了退款账户信息
|
||||
if refundAccount, ok := wechatResp.Data["funds_account"].(string); ok && refundAccount != "" {
|
||||
updates["wechat_refund_account"] = refundAccount
|
||||
}
|
||||
|
||||
// 如果微信退款立即成功
|
||||
if status, ok := wechatResp.Data["status"].(string); ok && status == "SUCCESS" {
|
||||
updates["status"] = model.RefundStatusSuccess
|
||||
if successTime, ok := wechatResp.Data["success_time"].(string); ok && successTime != "" {
|
||||
if parsedTime, err := time.Parse("2006-01-02T15:04:05+08:00", successTime); err == nil {
|
||||
updates["wechat_success_time"] = parsedTime
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = s.refundRepo.UpdateByID(refundID, updates)
|
||||
if err != nil {
|
||||
logger.Error("更新退款微信信息失败", "error", err, "refundID", refundID)
|
||||
}
|
||||
|
||||
// 8. 如果退款成功,更新订单退款信息
|
||||
if status, ok := wechatResp.Data["status"].(string); ok && status == "SUCCESS" {
|
||||
err = s.updateOrderRefundInfo(order, refund)
|
||||
if err != nil {
|
||||
logger.Error("更新订单退款信息失败", "error", err, "orderID", order.ID)
|
||||
}
|
||||
|
||||
// 创建成功日志
|
||||
statusFrom := model.RefundStatusProcessing
|
||||
statusTo := model.RefundStatusSuccess
|
||||
s.createRefundLog(refundID, "success", &statusFrom, &statusTo, "微信退款成功", &adminID)
|
||||
} else {
|
||||
// 创建处理中日志
|
||||
statusFrom := model.RefundStatusProcessing
|
||||
statusTo := model.RefundStatusProcessing
|
||||
s.createRefundLog(refundID, "processing", &statusFrom, &statusTo, "微信退款处理中", &adminID)
|
||||
}
|
||||
|
||||
logger.Info("退款处理完成",
|
||||
"refundID", refundID,
|
||||
"wechatRefundID", wechatResp.Data["refund_id"],
|
||||
"status", wechatResp.Data["status"])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RejectRefund 拒绝退款申请
|
||||
func (s *RefundService) RejectRefund(ctx context.Context, refundID uint, adminID uint, rejectReason string) error {
|
||||
logger.Info("拒绝退款申请", "refundID", refundID, "adminID", adminID, "reason", rejectReason)
|
||||
|
||||
// 1. 查询退款记录
|
||||
refund, err := s.refundRepo.GetByID(refundID)
|
||||
if err != nil {
|
||||
logger.Error("查询退款记录失败", "error", err, "refundID", refundID)
|
||||
return fmt.Errorf("退款记录不存在")
|
||||
}
|
||||
|
||||
// 2. 验证退款状态
|
||||
if refund.Status != model.RefundStatusPending {
|
||||
return fmt.Errorf("退款状态不允许拒绝,当前状态: %s", refund.GetStatusText())
|
||||
}
|
||||
|
||||
// 3. 更新退款状态为已拒绝
|
||||
err = s.refundRepo.UpdateByID(refundID, map[string]interface{}{
|
||||
"status": model.RefundStatusRejected,
|
||||
"reject_reason": rejectReason,
|
||||
"reject_time": time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("更新退款状态失败", "error", err, "refundID", refundID)
|
||||
return fmt.Errorf("更新退款状态失败")
|
||||
}
|
||||
|
||||
// 4. 创建退款日志
|
||||
statusFrom := model.RefundStatusPending
|
||||
statusTo := model.RefundStatusRejected
|
||||
err = s.createRefundLog(refundID, "reject", &statusFrom, &statusTo, fmt.Sprintf("管理员拒绝: %s", rejectReason), &adminID)
|
||||
if err != nil {
|
||||
logger.Warn("创建退款日志失败", "error", err, "refundID", refundID)
|
||||
}
|
||||
|
||||
logger.Info("退款申请已拒绝", "refundID", refundID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleWeChatRefundNotify 处理微信退款回调通知(解析和解密)
|
||||
func (s *RefundService) HandleWeChatRefundNotify(ctx context.Context, body []byte, headers map[string]string) (*WeChatRefundNotify, error) {
|
||||
logger.Info("开始处理微信退款回调通知")
|
||||
|
||||
if s.wechatPaySvc == nil {
|
||||
return nil, fmt.Errorf("微信支付服务未初始化")
|
||||
}
|
||||
|
||||
// 调用微信支付服务解析和解密回调数据
|
||||
notify, err := s.wechatPaySvc.HandleRefundNotify(ctx, body, headers)
|
||||
if err != nil {
|
||||
logger.Error("解析退款回调数据失败", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("成功解析退款回调数据", "eventType", notify.EventType)
|
||||
return notify, nil
|
||||
}
|
||||
|
||||
// HandleRefundCallback 处理微信退款回调
|
||||
func (s *RefundService) HandleRefundCallback(ctx context.Context, notify *WeChatRefundNotify) error {
|
||||
logger.Info("处理微信退款回调", "eventType", notify.EventType)
|
||||
|
||||
if notify.DecryptedData == nil {
|
||||
return fmt.Errorf("回调数据中缺少解密数据")
|
||||
}
|
||||
|
||||
outRefundNo := notify.DecryptedData.OutRefundNo
|
||||
if outRefundNo == "" {
|
||||
return fmt.Errorf("回调数据中缺少退款单号")
|
||||
}
|
||||
|
||||
// 1. 查询退款记录
|
||||
refund, err := s.refundRepo.GetByWechatOutRefundNo(outRefundNo)
|
||||
if err != nil {
|
||||
logger.Error("根据微信退款单号查询退款记录失败", "error", err, "outRefundNo", outRefundNo)
|
||||
return fmt.Errorf("退款记录不存在")
|
||||
}
|
||||
|
||||
// 2. 根据事件类型处理不同的退款状态
|
||||
var newStatus int
|
||||
var logRemark string
|
||||
|
||||
switch notify.EventType {
|
||||
case "REFUND.SUCCESS":
|
||||
// 退款成功
|
||||
if refund.Status == model.RefundStatusSuccess {
|
||||
logger.Info("退款已经是成功状态,跳过处理", "refundID", refund.ID)
|
||||
return nil
|
||||
}
|
||||
newStatus = model.RefundStatusSuccess
|
||||
logRemark = "微信退款回调:退款成功"
|
||||
|
||||
case "REFUND.ABNORMAL":
|
||||
// 退款异常
|
||||
newStatus = model.RefundStatusFailed
|
||||
logRemark = "微信退款回调:退款异常"
|
||||
|
||||
case "REFUND.CLOSED":
|
||||
// 退款关闭
|
||||
newStatus = model.RefundStatusFailed
|
||||
logRemark = "微信退款回调:退款关闭"
|
||||
|
||||
default:
|
||||
logger.Warn("未知的退款回调事件类型", "eventType", notify.EventType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 3. 更新退款状态和微信信息
|
||||
updates := map[string]interface{}{
|
||||
"status": newStatus,
|
||||
"wechat_refund_id": notify.DecryptedData.RefundId,
|
||||
"wechat_refund_status": notify.DecryptedData.RefundStatus,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
// 只有成功时才更新收款账户和成功时间
|
||||
if notify.EventType == "REFUND.SUCCESS" {
|
||||
updates["wechat_user_received_account"] = notify.DecryptedData.UserReceivedAccount
|
||||
|
||||
// 解析成功时间
|
||||
if notify.DecryptedData.SuccessTime != "" {
|
||||
if successTime, err := time.Parse("2006-01-02T15:04:05+08:00", notify.DecryptedData.SuccessTime); err == nil {
|
||||
updates["wechat_success_time"] = successTime
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = s.refundRepo.UpdateByID(refund.ID, updates)
|
||||
if err != nil {
|
||||
logger.Error("更新退款状态失败", "error", err, "refundID", refund.ID)
|
||||
return fmt.Errorf("更新退款状态失败")
|
||||
}
|
||||
|
||||
// 4. 只有退款成功时才更新订单退款信息
|
||||
if newStatus == model.RefundStatusSuccess {
|
||||
order, err := s.orderRepo.GetByID(refund.OrderID)
|
||||
if err != nil {
|
||||
logger.Error("查询订单失败", "error", err, "orderID", refund.OrderID)
|
||||
} else {
|
||||
err = s.updateOrderRefundInfo(order, refund)
|
||||
if err != nil {
|
||||
logger.Error("更新订单退款信息失败", "error", err, "orderID", order.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 创建退款日志
|
||||
statusFrom := refund.Status
|
||||
statusTo := newStatus
|
||||
var operatorID *uint = nil
|
||||
err = s.createRefundLog(refund.ID, "callback", &statusFrom, &statusTo, logRemark, operatorID)
|
||||
if err != nil {
|
||||
logger.Warn("创建退款日志失败", "error", err, "refundID", refund.ID)
|
||||
}
|
||||
|
||||
logger.Info("微信退款回调处理完成", "refundID", refund.ID, "outRefundNo", outRefundNo, "newStatus", newStatus)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRefundsByOrderID 获取订单的退款记录
|
||||
func (s *RefundService) GetRefundsByOrderID(ctx context.Context, orderID uint, userID uint) ([]*model.Refund, error) {
|
||||
// 验证用户权限
|
||||
order, err := s.orderRepo.GetByID(orderID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("订单不存在")
|
||||
}
|
||||
|
||||
if order.UserID != userID {
|
||||
return nil, fmt.Errorf("无权限查看此订单的退款信息")
|
||||
}
|
||||
|
||||
refunds, err := s.refundRepo.GetByOrderID(orderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*model.Refund, len(refunds))
|
||||
for i := range refunds {
|
||||
result[i] = &refunds[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SyncRefundAndOrderStatus 同步退款状态和订单状态
|
||||
// 这个方法用于修复退款状态已成功但订单状态未更新的问题
|
||||
func (s *RefundService) SyncRefundAndOrderStatus(ctx context.Context) error {
|
||||
logger.Info("开始同步退款状态和订单状态")
|
||||
|
||||
// 1. 查询所有状态为成功的退款记录
|
||||
refunds, err := s.refundRepo.GetRefundsByStatus(model.RefundStatusSuccess)
|
||||
if err != nil {
|
||||
logger.Error("查询成功退款记录失败", "error", err)
|
||||
return fmt.Errorf("查询成功退款记录失败: %v", err)
|
||||
}
|
||||
|
||||
// 2. 遍历每个退款记录,检查对应的订单状态
|
||||
for _, refund := range refunds {
|
||||
// 获取订单信息
|
||||
order, err := s.orderRepo.GetByID(refund.OrderID)
|
||||
if err != nil {
|
||||
logger.Error("获取订单信息失败", "error", err, "orderID", refund.OrderID)
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果订单状态不是已退款,则需要更新
|
||||
if order.Status != model.OrderStatusRefunded {
|
||||
// 计算订单总退款金额
|
||||
totalRefunded, err := s.refundRepo.GetTotalRefundedByOrderID(order.ID)
|
||||
if err != nil {
|
||||
logger.Error("计算订单总退款金额失败", "error", err, "orderID", order.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果总退款金额大于等于订单金额,则更新订单状态为已退款
|
||||
if totalRefunded >= order.TotalAmount {
|
||||
updates := map[string]interface{}{
|
||||
"status": model.OrderStatusRefunded,
|
||||
"refunded_at": time.Now(),
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
err = s.orderRepo.UpdateByID(order.ID, updates)
|
||||
if err != nil {
|
||||
logger.Error("更新订单状态为已退款失败", "error", err, "orderID", order.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("订单状态已更新为已退款",
|
||||
"orderID", order.ID,
|
||||
"orderNo", order.OrderNo,
|
||||
"totalAmount", order.TotalAmount,
|
||||
"totalRefunded", totalRefunded)
|
||||
} else if order.Status != model.OrderStatusReturning {
|
||||
// 如果是部分退款且订单状态不是退款中,则更新为退款中
|
||||
updates := map[string]interface{}{
|
||||
"status": model.OrderStatusReturning,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
err = s.orderRepo.UpdateByID(order.ID, updates)
|
||||
if err != nil {
|
||||
logger.Error("更新订单状态为退款中失败", "error", err, "orderID", order.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("订单状态已更新为退款中",
|
||||
"orderID", order.ID,
|
||||
"orderNo", order.OrderNo,
|
||||
"totalAmount", order.TotalAmount,
|
||||
"totalRefunded", totalRefunded)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("同步退款状态和订单状态完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRefundsByUserID 获取用户的退款记录
|
||||
func (s *RefundService) GetRefundsByUserID(ctx context.Context, userID uint, page, pageSize int) ([]*model.Refund, int64, error) {
|
||||
refunds, total, err := s.refundRepo.GetByUserID(userID, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*model.Refund, len(refunds))
|
||||
for i := range refunds {
|
||||
result[i] = &refunds[i]
|
||||
|
||||
// 检查退款状态是否为成功,但订单状态不是已退款
|
||||
if refunds[i].Status == model.RefundStatusSuccess && refunds[i].Order.Status != model.OrderStatusRefunded {
|
||||
// 计算订单总退款金额
|
||||
totalRefunded, err := s.refundRepo.GetTotalRefundedByOrderID(refunds[i].OrderID)
|
||||
if err != nil {
|
||||
logger.Error("计算订单总退款金额失败", "error", err, "orderID", refunds[i].OrderID)
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果总退款金额大于等于订单金额,则更新订单状态为已退款
|
||||
if totalRefunded >= refunds[i].Order.TotalAmount {
|
||||
updates := map[string]interface{}{
|
||||
"status": model.OrderStatusRefunded,
|
||||
"refunded_at": time.Now(),
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
err = s.orderRepo.UpdateByID(refunds[i].OrderID, updates)
|
||||
if err != nil {
|
||||
logger.Error("更新订单状态为已退款失败", "error", err, "orderID", refunds[i].OrderID)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新当前退款记录中的订单状态
|
||||
result[i].Order.Status = model.OrderStatusRefunded
|
||||
|
||||
logger.Info("订单状态已更新为已退款",
|
||||
"orderID", refunds[i].OrderID,
|
||||
"orderNo", refunds[i].OrderNo,
|
||||
"totalAmount", refunds[i].Order.TotalAmount,
|
||||
"totalRefunded", totalRefunded)
|
||||
} else if refunds[i].Order.Status != model.OrderStatusReturning {
|
||||
// 如果是部分退款且订单状态不是退款中,则更新为退款中
|
||||
updates := map[string]interface{}{
|
||||
"status": model.OrderStatusReturning,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
err = s.orderRepo.UpdateByID(refunds[i].OrderID, updates)
|
||||
if err != nil {
|
||||
logger.Error("更新订单状态为退款中失败", "error", err, "orderID", refunds[i].OrderID)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新当前退款记录中的订单状态
|
||||
result[i].Order.Status = model.OrderStatusReturning
|
||||
|
||||
logger.Info("订单状态已更新为退款中",
|
||||
"orderID", refunds[i].OrderID,
|
||||
"orderNo", refunds[i].OrderNo,
|
||||
"totalAmount", refunds[i].Order.TotalAmount,
|
||||
"totalRefunded", totalRefunded)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// GetRefundByID 获取退款详情
|
||||
func (s *RefundService) GetRefundByID(ctx context.Context, refundID uint, userID uint) (*model.Refund, error) {
|
||||
refund, err := s.refundRepo.GetByID(refundID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("退款记录不存在")
|
||||
}
|
||||
|
||||
// 验证用户权限
|
||||
if refund.UserID != userID {
|
||||
return nil, fmt.Errorf("无权限查看此退款记录")
|
||||
}
|
||||
|
||||
return refund, nil
|
||||
}
|
||||
|
||||
// QueryRefundStatus 查询微信退款状态
|
||||
func (s *RefundService) QueryRefundStatus(ctx context.Context, refundID uint) error {
|
||||
logger.Info("查询微信退款状态", "refundID", refundID)
|
||||
|
||||
// 1. 查询退款记录
|
||||
refund, err := s.refundRepo.GetByID(refundID)
|
||||
if err != nil {
|
||||
logger.Error("查询退款记录失败", "error", err, "refundID", refundID)
|
||||
return fmt.Errorf("退款记录不存在")
|
||||
}
|
||||
|
||||
if refund.WechatOutRefundNo == "" {
|
||||
return fmt.Errorf("退款记录没有微信退款单号")
|
||||
}
|
||||
|
||||
// 2. 调用微信查询退款API
|
||||
wechatRefund, err := s.wechatPaySvc.QueryRefund(ctx, refund.WechatOutRefundNo)
|
||||
if err != nil {
|
||||
logger.Error("查询微信退款状态失败", "error", err, "outRefundNo", refund.WechatOutRefundNo)
|
||||
return fmt.Errorf("查询微信退款状态失败: %v", err)
|
||||
}
|
||||
|
||||
// 3. 更新退款记录
|
||||
updates := map[string]interface{}{
|
||||
"wechat_refund_status": wechatRefund.WechatRefundStatus,
|
||||
"wechat_user_received_account": wechatRefund.WechatUserReceivedAccount,
|
||||
"wechat_refund_account": wechatRefund.WechatRefundAccount,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
// 如果微信退款成功,更新本地状态
|
||||
if wechatRefund.WechatRefundStatus == "SUCCESS" {
|
||||
// 无论当前退款状态如何,只要微信退款成功,就更新为成功状态
|
||||
updates["status"] = model.RefundStatusSuccess
|
||||
if wechatRefund.WechatSuccessTime != nil {
|
||||
updates["wechat_success_time"] = *wechatRefund.WechatSuccessTime
|
||||
}
|
||||
|
||||
// 更新订单退款信息
|
||||
order, err := s.orderRepo.GetByID(refund.OrderID)
|
||||
if err == nil {
|
||||
s.updateOrderRefundInfo(order, refund)
|
||||
}
|
||||
|
||||
// 只有当状态发生变化时才创建日志
|
||||
if refund.Status != model.RefundStatusSuccess {
|
||||
statusFrom := refund.Status
|
||||
statusTo := model.RefundStatusSuccess
|
||||
var operatorID *uint = nil
|
||||
s.createRefundLog(refund.ID, "query_success", &statusFrom, &statusTo, "查询确认微信退款成功", operatorID)
|
||||
}
|
||||
}
|
||||
|
||||
err = s.refundRepo.UpdateByID(refund.ID, updates)
|
||||
if err != nil {
|
||||
logger.Error("更新退款状态失败", "error", err, "refundID", refundID)
|
||||
return fmt.Errorf("更新退款状态失败")
|
||||
}
|
||||
|
||||
logger.Info("退款状态查询完成", "refundID", refundID, "status", wechatRefund.WechatRefundStatus)
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateOrderRefundInfo 更新订单退款信息
|
||||
func (s *RefundService) updateOrderRefundInfo(order *model.Order, refund *model.Refund) error {
|
||||
// 计算订单总退款金额
|
||||
totalRefunded, err := s.refundRepo.GetTotalRefundedByOrderID(order.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 计算退款次数
|
||||
refundCount, err := s.refundRepo.GetRefundCountByOrderID(order.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"total_refund_amount": totalRefunded,
|
||||
"refund_count": refundCount,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
// 如果全额退款,更新订单状态为已退款
|
||||
if totalRefunded >= order.TotalAmount {
|
||||
updates["status"] = model.OrderStatusRefunded
|
||||
updates["refunded_at"] = time.Now()
|
||||
logger.Info("更新订单状态为已退款",
|
||||
"orderID", order.ID,
|
||||
"totalAmount", order.TotalAmount,
|
||||
"totalRefunded", totalRefunded,
|
||||
"refundID", refund.ID)
|
||||
} else if order.Status == model.OrderStatusReturning {
|
||||
// 如果是部分退款且当前状态是退款中,保持退款中状态
|
||||
// 这样可以区分部分退款和全额退款的订单
|
||||
updates["status"] = model.OrderStatusReturning
|
||||
logger.Info("保持订单状态为退款中",
|
||||
"orderID", order.ID,
|
||||
"totalAmount", order.TotalAmount,
|
||||
"totalRefunded", totalRefunded,
|
||||
"refundID", refund.ID)
|
||||
}
|
||||
|
||||
err = s.orderRepo.UpdateByID(order.ID, updates)
|
||||
if err != nil {
|
||||
logger.Error("更新订单退款信息失败", "error", err, "orderID", order.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createRefundLog 创建退款日志
|
||||
func (s *RefundService) createRefundLog(refundID uint, action string, statusFrom, statusTo *int, remark string, operatorID *uint) error {
|
||||
log := &model.RefundLog{
|
||||
RefundID: refundID,
|
||||
Action: action,
|
||||
StatusFrom: statusFrom,
|
||||
StatusTo: statusTo,
|
||||
OperatorType: "admin",
|
||||
OperatorID: operatorID,
|
||||
Remark: remark,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return s.refundRepo.CreateLog(log)
|
||||
}
|
||||
|
||||
// GetPendingRefunds 获取待处理的退款申请(管理员)
|
||||
func (s *RefundService) GetPendingRefunds(ctx context.Context, page, pageSize int) ([]*model.Refund, int64, error) {
|
||||
logger.Info("获取待处理退款申请", "page", page, "pageSize", pageSize)
|
||||
|
||||
refunds, total, err := s.refundRepo.GetPendingRefunds(page, pageSize)
|
||||
if err != nil {
|
||||
logger.Error("获取待处理退款申请失败", "error", err)
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*model.Refund, len(refunds))
|
||||
for i := range refunds {
|
||||
result[i] = &refunds[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// GetAllRefunds 获取所有退款记录(管理员)
|
||||
func (s *RefundService) GetAllRefunds(ctx context.Context, page, pageSize int, status, userID string) ([]*model.Refund, int64, error) {
|
||||
logger.Info("获取所有退款记录", "page", page, "pageSize", pageSize, "status", status, "userID", userID)
|
||||
|
||||
// 构建查询条件
|
||||
conditions := make(map[string]interface{})
|
||||
if status != "" {
|
||||
conditions["status"] = status
|
||||
}
|
||||
if userID != "" {
|
||||
conditions["user_id"] = userID
|
||||
}
|
||||
|
||||
refunds, total, err := s.refundRepo.GetAllRefunds(page, pageSize, conditions)
|
||||
if err != nil {
|
||||
logger.Error("获取所有退款记录失败", "error", err)
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*model.Refund, len(refunds))
|
||||
for i := range refunds {
|
||||
result[i] = &refunds[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// GetRefundLogs 获取退款日志(管理员)
|
||||
func (s *RefundService) GetRefundLogs(ctx context.Context, refundID uint) ([]model.RefundLog, error) {
|
||||
logger.Info("获取退款日志", "refundID", refundID)
|
||||
|
||||
logs, err := s.refundRepo.GetRefundLogsByRefundID(refundID)
|
||||
if err != nil {
|
||||
logger.Error("获取退款日志失败", "error", err, "refundID", refundID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
// GetRefundDetailForAdmin 获取退款详情(管理员专用)
|
||||
func (s *RefundService) GetRefundDetailForAdmin(ctx context.Context, refundID uint) (*model.Refund, error) {
|
||||
logger.Info("管理员获取退款详情", "refundID", refundID)
|
||||
|
||||
refund, err := s.refundRepo.GetByID(refundID)
|
||||
if err != nil {
|
||||
logger.Error("获取退款详情失败", "error", err, "refundID", refundID)
|
||||
return nil, fmt.Errorf("退款记录不存在")
|
||||
}
|
||||
|
||||
return refund, nil
|
||||
}
|
||||
|
||||
// GetRefundStatistics 获取退款统计数据
|
||||
func (s *RefundService) GetRefundStatistics(ctx context.Context, startTime, endTime time.Time) (map[string]interface{}, error) {
|
||||
logger.Info("获取退款统计数据", "startTime", startTime, "endTime", endTime)
|
||||
|
||||
stats, err := s.refundRepo.GetRefundStatistics(startTime, endTime)
|
||||
if err != nil {
|
||||
logger.Error("获取退款统计数据失败", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换数据格式以匹配前端期望的格式
|
||||
result := map[string]interface{}{
|
||||
"total_refunds": stats["total_count"],
|
||||
"pending_refunds": stats["pending_count"],
|
||||
"processing_refunds": stats["processing_count"],
|
||||
"total_amount": stats["total_amount"],
|
||||
"success_count": stats["success_count"],
|
||||
"success_amount": stats["success_amount"],
|
||||
"approved_count": stats["approved_count"],
|
||||
"rejected_count": stats["rejected_count"],
|
||||
"failed_count": stats["failed_count"],
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 请求结构体
|
||||
type CreateRefundRequest struct {
|
||||
OrderID uint `json:"order_id" binding:"required"`
|
||||
UserID uint `json:"user_id"` // 由后端从JWT token中获取,不需要前端提供
|
||||
RefundAmount float64 `json:"refund_amount" binding:"required,gt=0"`
|
||||
RefundReason string `json:"refund_reason" binding:"required,max=500"`
|
||||
RefundType int `json:"refund_type" binding:"required,oneof=1 2"` // 1:仅退款 2:退货退款
|
||||
}
|
||||
307
server/internal/service/role.go
Normal file
307
server/internal/service/role.go
Normal file
@@ -0,0 +1,307 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RoleService 角色服务
|
||||
type RoleService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewRoleService 创建角色服务
|
||||
func NewRoleService(db *gorm.DB) *RoleService {
|
||||
return &RoleService{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateRole 创建角色
|
||||
func (s *RoleService) CreateRole(role *model.Role) error {
|
||||
return s.db.Create(role).Error
|
||||
}
|
||||
|
||||
// GetRoleByID 根据ID获取角色
|
||||
func (s *RoleService) GetRoleByID(id uint) (*model.Role, error) {
|
||||
var role model.Role
|
||||
err := s.db.Preload("Permissions").Where("id = ?", id).First(&role).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &role, nil
|
||||
}
|
||||
|
||||
// GetRoleByName 根据名称获取角色
|
||||
func (s *RoleService) GetRoleByName(name string) (*model.Role, error) {
|
||||
var role model.Role
|
||||
err := s.db.Preload("Permissions").Where("name = ?", name).First(&role).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &role, nil
|
||||
}
|
||||
|
||||
// GetRoleList 获取角色列表
|
||||
func (s *RoleService) GetRoleList(page, pageSize int, conditions map[string]interface{}) ([]model.Role, map[string]interface{}, error) {
|
||||
var roles []model.Role
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.Role{})
|
||||
|
||||
// 应用查询条件
|
||||
if keyword, ok := conditions["keyword"]; ok && keyword != "" {
|
||||
query = query.Where("name LIKE ? OR display_name LIKE ?",
|
||||
"%"+keyword.(string)+"%", "%"+keyword.(string)+"%")
|
||||
}
|
||||
|
||||
if status, ok := conditions["status"]; ok && status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Preload("Permissions").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&roles).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建分页信息
|
||||
pagination := map[string]interface{}{
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
|
||||
}
|
||||
|
||||
return roles, pagination, nil
|
||||
}
|
||||
|
||||
// UpdateRole 更新角色
|
||||
func (s *RoleService) UpdateRole(id uint, updates map[string]interface{}) error {
|
||||
return s.db.Model(&model.Role{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteRole 删除角色
|
||||
func (s *RoleService) DeleteRole(id uint) error {
|
||||
// 检查是否有用户使用该角色
|
||||
var count int64
|
||||
s.db.Model(&model.UserRole{}).Where("role_id = ?", id).Count(&count)
|
||||
if count > 0 {
|
||||
return errors.New("该角色正在被用户使用,无法删除")
|
||||
}
|
||||
|
||||
// 删除角色权限关联
|
||||
s.db.Where("role_id = ?", id).Delete(&model.RolePermission{})
|
||||
|
||||
// 删除角色
|
||||
return s.db.Delete(&model.Role{}, id).Error
|
||||
}
|
||||
|
||||
// AssignPermissionsToRole 为角色分配权限
|
||||
func (s *RoleService) AssignPermissionsToRole(roleID uint, permissionIDs []uint) error {
|
||||
// 删除原有权限
|
||||
s.db.Where("role_id = ?", roleID).Delete(&model.RolePermission{})
|
||||
|
||||
// 添加新权限
|
||||
for _, permissionID := range permissionIDs {
|
||||
rolePermission := &model.RolePermission{
|
||||
RoleID: roleID,
|
||||
PermissionID: permissionID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
if err := s.db.Create(rolePermission).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssignRolesToUser 为用户分配角色
|
||||
func (s *RoleService) AssignRolesToUser(userID uint, roleIDs []uint) error {
|
||||
// 删除原有角色
|
||||
s.db.Where("user_id = ?", userID).Delete(&model.UserRole{})
|
||||
|
||||
// 添加新角色
|
||||
for _, roleID := range roleIDs {
|
||||
userRole := &model.UserRole{
|
||||
UserID: userID,
|
||||
RoleID: roleID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
if err := s.db.Create(userRole).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserRoles 获取用户角色
|
||||
func (s *RoleService) GetUserRoles(userID uint) ([]model.Role, error) {
|
||||
var roles []model.Role
|
||||
err := s.db.Table("ai_roles").
|
||||
Joins("JOIN ai_user_roles ON ai_roles.id = ai_user_roles.role_id").
|
||||
Where("ai_user_roles.user_id = ?", userID).
|
||||
Find(&roles).Error
|
||||
return roles, err
|
||||
}
|
||||
|
||||
// GetUserPermissions 获取用户权限
|
||||
func (s *RoleService) GetUserPermissions(userID uint) ([]model.Permission, error) {
|
||||
var permissions []model.Permission
|
||||
err := s.db.Table("ai_permissions").
|
||||
Joins("JOIN ai_role_permissions ON ai_permissions.id = ai_role_permissions.permission_id").
|
||||
Joins("JOIN ai_user_roles ON ai_role_permissions.role_id = ai_user_roles.role_id").
|
||||
Where("ai_user_roles.user_id = ?", userID).
|
||||
Distinct().
|
||||
Find(&permissions).Error
|
||||
return permissions, err
|
||||
}
|
||||
|
||||
// CheckUserPermission 检查用户权限
|
||||
func (s *RoleService) CheckUserPermission(userID uint, module, action string) (bool, error) {
|
||||
var count int64
|
||||
err := s.db.Table("ai_permissions").
|
||||
Joins("JOIN ai_role_permissions ON ai_permissions.id = ai_role_permissions.permission_id").
|
||||
Joins("JOIN ai_user_roles ON ai_role_permissions.role_id = ai_user_roles.role_id").
|
||||
Where("ai_user_roles.user_id = ? AND ai_permissions.module = ? AND ai_permissions.action = ? AND ai_permissions.status = 1",
|
||||
userID, module, action).
|
||||
Count(&count).Error
|
||||
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreatePermission 创建权限
|
||||
func (s *RoleService) CreatePermission(permission *model.Permission) error {
|
||||
return s.db.Create(permission).Error
|
||||
}
|
||||
|
||||
// GetPermissionList 获取权限列表
|
||||
func (s *RoleService) GetPermissionList(page, pageSize int, conditions map[string]interface{}) ([]model.Permission, map[string]interface{}, error) {
|
||||
var permissions []model.Permission
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.Permission{})
|
||||
|
||||
// 应用查询条件
|
||||
if keyword, ok := conditions["keyword"]; ok && keyword != "" {
|
||||
query = query.Where("name LIKE ? OR display_name LIKE ?",
|
||||
"%"+keyword.(string)+"%", "%"+keyword.(string)+"%")
|
||||
}
|
||||
|
||||
if module, ok := conditions["module"]; ok && module != "" {
|
||||
query = query.Where("module = ?", module)
|
||||
}
|
||||
|
||||
if status, ok := conditions["status"]; ok && status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Offset(offset).Limit(pageSize).Order("module ASC, action ASC").Find(&permissions).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建分页信息
|
||||
pagination := map[string]interface{}{
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
|
||||
}
|
||||
|
||||
return permissions, pagination, nil
|
||||
}
|
||||
|
||||
// UpdatePermission 更新权限
|
||||
func (s *RoleService) UpdatePermission(id uint, updates map[string]interface{}) error {
|
||||
return s.db.Model(&model.Permission{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeletePermission 删除权限
|
||||
func (s *RoleService) DeletePermission(id uint) error {
|
||||
// 检查是否有角色使用该权限
|
||||
var count int64
|
||||
s.db.Model(&model.RolePermission{}).Where("permission_id = ?", id).Count(&count)
|
||||
if count > 0 {
|
||||
return errors.New("该权限正在被角色使用,无法删除")
|
||||
}
|
||||
|
||||
return s.db.Delete(&model.Permission{}, id).Error
|
||||
}
|
||||
|
||||
// InitDefaultRolesAndPermissions 初始化默认角色和权限
|
||||
func (s *RoleService) InitDefaultRolesAndPermissions() error {
|
||||
// 创建默认权限
|
||||
permissions := []model.Permission{
|
||||
// 用户管理权限
|
||||
{Name: "user.create", DisplayName: "创建用户", Module: "user", Action: "create", Status: 1},
|
||||
{Name: "user.read", DisplayName: "查看用户", Module: "user", Action: "read", Status: 1},
|
||||
{Name: "user.update", DisplayName: "更新用户", Module: "user", Action: "update", Status: 1},
|
||||
{Name: "user.delete", DisplayName: "删除用户", Module: "user", Action: "delete", Status: 1},
|
||||
|
||||
// 商品管理权限
|
||||
{Name: "product.create", DisplayName: "创建商品", Module: "product", Action: "create", Status: 1},
|
||||
{Name: "product.read", DisplayName: "查看商品", Module: "product", Action: "read", Status: 1},
|
||||
{Name: "product.update", DisplayName: "更新商品", Module: "product", Action: "update", Status: 1},
|
||||
{Name: "product.delete", DisplayName: "删除商品", Module: "product", Action: "delete", Status: 1},
|
||||
|
||||
// 订单管理权限
|
||||
{Name: "order.create", DisplayName: "创建订单", Module: "order", Action: "create", Status: 1},
|
||||
{Name: "order.read", DisplayName: "查看订单", Module: "order", Action: "read", Status: 1},
|
||||
{Name: "order.update", DisplayName: "更新订单", Module: "order", Action: "update", Status: 1},
|
||||
{Name: "order.delete", DisplayName: "删除订单", Module: "order", Action: "delete", Status: 1},
|
||||
|
||||
// 系统管理权限
|
||||
{Name: "system.config", DisplayName: "系统配置", Module: "system", Action: "config", Status: 1},
|
||||
{Name: "system.log", DisplayName: "系统日志", Module: "system", Action: "log", Status: 1},
|
||||
}
|
||||
|
||||
for _, permission := range permissions {
|
||||
var existingPermission model.Permission
|
||||
if err := s.db.Where("name = ?", permission.Name).First(&existingPermission).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
s.db.Create(&permission)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建默认角色
|
||||
roles := []model.Role{
|
||||
{Name: "admin", DisplayName: "超级管理员", Description: "拥有所有权限", Status: 1},
|
||||
{Name: "manager", DisplayName: "管理员", Description: "拥有大部分管理权限", Status: 1},
|
||||
{Name: "user", DisplayName: "普通用户", Description: "基础用户权限", Status: 1},
|
||||
}
|
||||
|
||||
for _, role := range roles {
|
||||
var existingRole model.Role
|
||||
if err := s.db.Where("name = ?", role.Name).First(&existingRole).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
s.db.Create(&role)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
946
server/internal/service/user.go
Normal file
946
server/internal/service/user.go
Normal file
@@ -0,0 +1,946 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"dianshang/internal/repository"
|
||||
"dianshang/pkg/jwt"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UserService 用户服务
|
||||
type UserService struct {
|
||||
userRepo *repository.UserRepository
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserService 创建用户服务
|
||||
func NewUserService(db *gorm.DB) *UserService {
|
||||
return &UserService{
|
||||
userRepo: repository.NewUserRepository(db),
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// WeChatLogin 微信登录
|
||||
func (s *UserService) WeChatLogin(code string) (*model.User, string, error) {
|
||||
// TODO: 调用微信API获取openid
|
||||
// 这里暂时模拟
|
||||
openID := "mock_openid_" + code
|
||||
|
||||
// 查找用户
|
||||
user, err := s.userRepo.GetByOpenID(openID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// 用户不存在,创建新用户
|
||||
user = &model.User{
|
||||
OpenID: openID,
|
||||
Nickname: "微信用户",
|
||||
Status: 1,
|
||||
}
|
||||
if err := s.userRepo.Create(user); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
} else {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if user.Status == 0 {
|
||||
return nil, "", errors.New("用户已被禁用")
|
||||
}
|
||||
|
||||
// 生成JWT token
|
||||
token, err := jwt.GenerateToken(user.ID, "user", 7200)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
// CreateUser 创建用户
|
||||
func (s *UserService) CreateUser(user *model.User) error {
|
||||
// 检查用户是否已存在
|
||||
existingUser, err := s.userRepo.GetByOpenID(user.OpenID)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
if existingUser != nil {
|
||||
return errors.New("用户已存在")
|
||||
}
|
||||
|
||||
return s.userRepo.Create(user)
|
||||
}
|
||||
|
||||
// GetUserByID 根据ID获取用户
|
||||
func (s *UserService) GetUserByID(id uint) (*model.User, error) {
|
||||
return s.userRepo.GetByID(id)
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户信息
|
||||
func (s *UserService) UpdateUser(id uint, updates map[string]interface{}) error {
|
||||
// 更新用户表
|
||||
if err := s.userRepo.Update(id, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 同步更新微信用户信息表
|
||||
wechatUpdates := make(map[string]interface{})
|
||||
if nickname, ok := updates["nickname"]; ok {
|
||||
wechatUpdates["nick_name"] = nickname
|
||||
}
|
||||
if avatar, ok := updates["avatar"]; ok {
|
||||
wechatUpdates["avatar_url"] = avatar
|
||||
}
|
||||
if gender, ok := updates["gender"]; ok {
|
||||
wechatUpdates["gender"] = gender
|
||||
}
|
||||
|
||||
// 如果有需要更新的微信信息字段
|
||||
if len(wechatUpdates) > 0 {
|
||||
wechatUpdates["updated_at"] = time.Now()
|
||||
|
||||
// 更新微信用户信息表
|
||||
if err := s.db.Model(&model.User{}).Where("id = ?", id).Updates(wechatUpdates).Error; err != nil {
|
||||
// 记录错误但不影响主要更新流程
|
||||
fmt.Printf("更新微信用户信息失败: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserAddresses 获取用户地址列表
|
||||
func (s *UserService) GetUserAddresses(userID uint) ([]model.UserAddress, error) {
|
||||
return s.userRepo.GetAddresses(userID)
|
||||
}
|
||||
|
||||
// GetAddressByID 根据ID获取用户地址
|
||||
func (s *UserService) GetAddressByID(userID, addressID uint) (*model.UserAddress, error) {
|
||||
address, err := s.userRepo.GetAddressByID(addressID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查地址是否属于该用户
|
||||
if address.UserID != userID {
|
||||
return nil, errors.New("无权限访问该地址")
|
||||
}
|
||||
|
||||
return address, nil
|
||||
}
|
||||
|
||||
// CreateAddress 创建用户地址
|
||||
func (s *UserService) CreateAddress(address *model.UserAddress) error {
|
||||
// 如果设置为默认地址,先取消其他默认地址
|
||||
if address.IsDefault {
|
||||
if err := s.userRepo.ClearDefaultAddress(address.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return s.userRepo.CreateAddress(address)
|
||||
}
|
||||
|
||||
// UpdateAddress 更新用户地址
|
||||
func (s *UserService) UpdateAddress(userID, addressID uint, updates map[string]interface{}) error {
|
||||
// 检查地址是否属于该用户
|
||||
address, err := s.userRepo.GetAddressByID(addressID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if address.UserID != userID {
|
||||
return errors.New("无权限操作该地址")
|
||||
}
|
||||
|
||||
// 如果设置为默认地址,先取消其他默认地址
|
||||
if isDefault, ok := updates["is_default"]; ok && isDefault.(uint8) == 1 {
|
||||
if err := s.userRepo.ClearDefaultAddress(userID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return s.userRepo.UpdateAddress(addressID, updates)
|
||||
}
|
||||
|
||||
// DeleteAddress 删除用户地址
|
||||
func (s *UserService) DeleteAddress(userID, addressID uint) error {
|
||||
// 检查地址是否属于该用户
|
||||
address, err := s.userRepo.GetAddressByID(addressID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if address.UserID != userID {
|
||||
return errors.New("无权限操作该地址")
|
||||
}
|
||||
|
||||
return s.userRepo.DeleteAddress(addressID)
|
||||
}
|
||||
|
||||
// SetDefaultAddress 设置默认地址
|
||||
func (s *UserService) SetDefaultAddress(userID, addressID uint) error {
|
||||
// 检查地址是否属于该用户
|
||||
address, err := s.userRepo.GetAddressByID(addressID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if address.UserID != userID {
|
||||
return errors.New("无权限操作该地址")
|
||||
}
|
||||
|
||||
// 先取消其他默认地址
|
||||
if err := s.userRepo.ClearDefaultAddress(userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 设置为默认地址
|
||||
return s.userRepo.UpdateAddress(addressID, map[string]interface{}{
|
||||
"is_default": 1,
|
||||
})
|
||||
}
|
||||
|
||||
// GetFavorites 获取用户收藏列表
|
||||
func (s *UserService) GetFavorites(userID uint, page, limit int) ([]model.UserFavorite, int64, error) {
|
||||
offset := (page - 1) * limit
|
||||
return s.userRepo.GetFavorites(userID, offset, limit)
|
||||
}
|
||||
|
||||
// AddToFavorite 添加收藏
|
||||
func (s *UserService) AddToFavorite(userID, productID uint) error {
|
||||
// 检查是否已经收藏
|
||||
if s.userRepo.IsFavorite(userID, productID) {
|
||||
return errors.New("商品已在收藏列表中")
|
||||
}
|
||||
|
||||
favorite := &model.UserFavorite{
|
||||
UserID: userID,
|
||||
ProductID: productID,
|
||||
}
|
||||
return s.userRepo.CreateFavorite(favorite)
|
||||
}
|
||||
|
||||
// RemoveFromFavorite 取消收藏
|
||||
func (s *UserService) RemoveFromFavorite(userID, productID uint) error {
|
||||
return s.userRepo.DeleteFavorite(userID, productID)
|
||||
}
|
||||
|
||||
// IsFavorite 检查是否已收藏
|
||||
func (s *UserService) IsFavorite(userID, productID uint) bool {
|
||||
return s.userRepo.IsFavorite(userID, productID)
|
||||
}
|
||||
|
||||
// GetUserStatistics 获取用户统计
|
||||
func (s *UserService) GetUserStatistics(startDate, endDate string) (map[string]interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// 总用户数
|
||||
var totalUsers int64
|
||||
s.db.Model(&model.User{}).Count(&totalUsers)
|
||||
result["total_users"] = totalUsers
|
||||
|
||||
// 新增用户数(指定日期范围)
|
||||
var newUsers int64
|
||||
query := s.db.Model(&model.User{})
|
||||
if startDate != "" && endDate != "" {
|
||||
query = query.Where("DATE(created_at) BETWEEN ? AND ?", startDate, endDate)
|
||||
}
|
||||
query.Count(&newUsers)
|
||||
result["new_users"] = newUsers
|
||||
|
||||
// 活跃用户数(简化处理,这里用登录用户数代替)
|
||||
var activeUsers int64
|
||||
activeQuery := s.db.Model(&model.User{}).Where("status = ?", 1)
|
||||
if startDate != "" && endDate != "" {
|
||||
activeQuery = activeQuery.Where("DATE(updated_at) BETWEEN ? AND ?", startDate, endDate)
|
||||
}
|
||||
activeQuery.Count(&activeUsers)
|
||||
result["active_users"] = activeUsers
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetDailyUserStatistics 获取每日用户统计
|
||||
func (s *UserService) GetDailyUserStatistics(startDate, endDate string) ([]map[string]interface{}, error) {
|
||||
// 简化实现,返回基础统计数据
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 解析日期
|
||||
start, err := time.Parse("2006-01-02", startDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
end, err := time.Parse("2006-01-02", endDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 遍历日期范围
|
||||
for d := start; !d.After(end); d = d.AddDate(0, 0, 1) {
|
||||
dateStr := d.Format("2006-01-02")
|
||||
|
||||
var newUsers int64
|
||||
s.db.Model(&model.User{}).
|
||||
Where("DATE(created_at) = ?", dateStr).
|
||||
Count(&newUsers)
|
||||
|
||||
results = append(results, map[string]interface{}{
|
||||
"date": dateStr,
|
||||
"new_users": newUsers,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetUserListForAdmin 获取用户列表(管理后台)
|
||||
func (s *UserService) GetUserListForAdmin(page, pageSize int, conditions map[string]interface{}) ([]model.User, map[string]interface{}, error) {
|
||||
var users []model.User
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.User{})
|
||||
|
||||
// 应用查询条件
|
||||
if keyword, ok := conditions["keyword"]; ok && keyword != "" {
|
||||
query = query.Where("nickname LIKE ? OR email LIKE ? OR phone LIKE ?",
|
||||
"%"+keyword.(string)+"%", "%"+keyword.(string)+"%", "%"+keyword.(string)+"%")
|
||||
}
|
||||
|
||||
if status, ok := conditions["status"]; ok && status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
if startDate, ok := conditions["start_date"]; ok && startDate != "" {
|
||||
query = query.Where("created_at >= ?", startDate)
|
||||
}
|
||||
|
||||
if endDate, ok := conditions["end_date"]; ok && endDate != "" {
|
||||
query = query.Where("created_at <= ?", endDate)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&users).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建分页信息
|
||||
pagination := map[string]interface{}{
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
|
||||
}
|
||||
|
||||
return users, pagination, nil
|
||||
}
|
||||
|
||||
// GetUserDetailForAdmin 获取用户详情(管理后台)
|
||||
func (s *UserService) GetUserDetailForAdmin(userID uint) (map[string]interface{}, error) {
|
||||
var user model.User
|
||||
err := s.db.Where("id = ?", userID).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建返回数据
|
||||
result := map[string]interface{}{
|
||||
"id": user.ID,
|
||||
"openid": user.OpenID,
|
||||
"unionid": user.UnionID,
|
||||
"nickname": user.Nickname,
|
||||
"avatar": user.Avatar,
|
||||
"gender": user.Gender,
|
||||
"phone": user.Phone,
|
||||
"email": user.Email,
|
||||
"birthday": user.Birthday,
|
||||
"points": user.Points,
|
||||
"level": user.Level,
|
||||
"status": user.Status,
|
||||
"created_at": user.CreatedAt,
|
||||
"updated_at": user.UpdatedAt,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UpdateUserStatusByAdmin 管理员更新用户状态
|
||||
func (s *UserService) UpdateUserStatusByAdmin(userID uint, status uint8, remark string, adminID uint) error {
|
||||
return s.db.Model(&model.User{}).Where("id = ?", userID).Update("status", status).Error
|
||||
}
|
||||
|
||||
// UpdateUserProfile 更新用户资料
|
||||
func (s *UserService) UpdateUserProfile(userID uint, updates map[string]interface{}) error {
|
||||
// 验证更新字段
|
||||
allowedFields := map[string]bool{
|
||||
"nickname": true,
|
||||
"avatar": true,
|
||||
"gender": true,
|
||||
"phone": true,
|
||||
"email": true,
|
||||
"birthday": true,
|
||||
}
|
||||
|
||||
filteredUpdates := make(map[string]interface{})
|
||||
for key, value := range updates {
|
||||
if allowedFields[key] {
|
||||
filteredUpdates[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
if len(filteredUpdates) == 0 {
|
||||
return fmt.Errorf("没有有效的更新字段")
|
||||
}
|
||||
|
||||
filteredUpdates["updated_at"] = time.Now()
|
||||
|
||||
return s.db.Model(&model.User{}).Where("id = ?", userID).Updates(filteredUpdates).Error
|
||||
}
|
||||
|
||||
// UpdateUserProfileByAdmin 管理员更新用户资料
|
||||
func (s *UserService) UpdateUserProfileByAdmin(userID uint, updates map[string]interface{}, adminID uint) error {
|
||||
// 验证更新字段
|
||||
allowedFields := map[string]bool{
|
||||
"nickname": true,
|
||||
"avatar": true,
|
||||
"gender": true,
|
||||
"phone": true,
|
||||
"email": true,
|
||||
"birthday": true,
|
||||
}
|
||||
|
||||
filteredUpdates := make(map[string]interface{})
|
||||
for key, value := range updates {
|
||||
if allowedFields[key] {
|
||||
filteredUpdates[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
if len(filteredUpdates) == 0 {
|
||||
return fmt.Errorf("没有有效的更新字段")
|
||||
}
|
||||
|
||||
filteredUpdates["updated_at"] = time.Now()
|
||||
|
||||
// TODO: 记录操作日志
|
||||
// 可以在这里添加操作日志记录,记录管理员修改用户资料的操作
|
||||
|
||||
return s.db.Model(&model.User{}).Where("id = ?", userID).Updates(filteredUpdates).Error
|
||||
}
|
||||
|
||||
// ResetUserPassword 重置用户密码(管理员操作)
|
||||
func (s *UserService) ResetUserPassword(userID uint, newPassword string, adminID uint) error {
|
||||
// 这里可以添加密码加密逻辑
|
||||
// 由于是微信小程序,通常不需要密码,这里预留接口
|
||||
updates := map[string]interface{}{
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
return s.db.Model(&model.User{}).Where("id = ?", userID).Updates(updates).Error
|
||||
}
|
||||
|
||||
|
||||
|
||||
// CreateUserAddress 创建用户地址
|
||||
func (s *UserService) CreateUserAddress(address *model.UserAddress) error {
|
||||
// 如果设置为默认地址,先取消其他默认地址
|
||||
if address.IsDefault {
|
||||
s.db.Model(&model.UserAddress{}).Where("user_id = ?", address.UserID).Update("is_default", false)
|
||||
}
|
||||
|
||||
return s.db.Create(address).Error
|
||||
}
|
||||
|
||||
// UpdateUserAddress 更新用户地址
|
||||
func (s *UserService) UpdateUserAddress(addressID uint, userID uint, updates map[string]interface{}) error {
|
||||
// 如果设置为默认地址,先取消其他默认地址
|
||||
if isDefault, ok := updates["is_default"]; ok && isDefault.(bool) {
|
||||
s.db.Model(&model.UserAddress{}).Where("user_id = ?", userID).Update("is_default", false)
|
||||
}
|
||||
|
||||
updates["updated_at"] = time.Now()
|
||||
|
||||
return s.db.Model(&model.UserAddress{}).Where("id = ? AND user_id = ?", addressID, userID).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteUserAddress 删除用户地址
|
||||
func (s *UserService) DeleteUserAddress(addressID uint, userID uint) error {
|
||||
return s.db.Where("id = ? AND user_id = ?", addressID, userID).Delete(&model.UserAddress{}).Error
|
||||
}
|
||||
|
||||
// GetUserFavorites 获取用户收藏列表
|
||||
func (s *UserService) GetUserFavorites(userID uint, page, pageSize int) ([]model.UserFavorite, map[string]interface{}, error) {
|
||||
var favorites []model.UserFavorite
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.UserFavorite{}).Where("user_id = ?", userID).Preload("Product")
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&favorites).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建分页信息
|
||||
pagination := map[string]interface{}{
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
|
||||
}
|
||||
|
||||
return favorites, pagination, nil
|
||||
}
|
||||
|
||||
// AddUserFavorite 添加用户收藏
|
||||
func (s *UserService) AddUserFavorite(userID, productID uint) error {
|
||||
// 检查是否已收藏
|
||||
var count int64
|
||||
s.db.Model(&model.UserFavorite{}).Where("user_id = ? AND product_id = ?", userID, productID).Count(&count)
|
||||
if count > 0 {
|
||||
return fmt.Errorf("商品已收藏")
|
||||
}
|
||||
|
||||
favorite := &model.UserFavorite{
|
||||
UserID: userID,
|
||||
ProductID: productID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return s.db.Create(favorite).Error
|
||||
}
|
||||
|
||||
// RemoveUserFavorite 移除用户收藏
|
||||
func (s *UserService) RemoveUserFavorite(userID, productID uint) error {
|
||||
return s.db.Where("user_id = ? AND product_id = ?", userID, productID).Delete(&model.UserFavorite{}).Error
|
||||
}
|
||||
|
||||
|
||||
|
||||
// GetUserLevelInfo 获取用户等级信息
|
||||
func (s *UserService) GetUserLevelInfo(userID uint) (map[string]interface{}, error) {
|
||||
user, err := s.GetUserByID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 计算用户等级相关信息
|
||||
levelInfo := map[string]interface{}{
|
||||
"current_level": user.Level,
|
||||
"current_points": user.Points,
|
||||
"level_name": s.getLevelName(user.Level),
|
||||
"next_level": user.Level + 1,
|
||||
"next_level_name": s.getLevelName(user.Level + 1),
|
||||
"points_to_next": s.getPointsToNextLevel(user.Level, user.Points),
|
||||
}
|
||||
|
||||
return levelInfo, nil
|
||||
}
|
||||
|
||||
// getLevelName 获取等级名称
|
||||
func (s *UserService) getLevelName(level int) string {
|
||||
levelNames := map[int]string{
|
||||
1: "青铜会员",
|
||||
2: "白银会员",
|
||||
3: "黄金会员",
|
||||
4: "铂金会员",
|
||||
5: "钻石会员",
|
||||
}
|
||||
|
||||
if name, ok := levelNames[level]; ok {
|
||||
return name
|
||||
}
|
||||
return "普通会员"
|
||||
}
|
||||
|
||||
// getPointsToNextLevel 获取升级到下一等级所需积分
|
||||
func (s *UserService) getPointsToNextLevel(currentLevel, currentPoints int) int {
|
||||
levelThresholds := map[int]int{
|
||||
1: 0,
|
||||
2: 1000,
|
||||
3: 3000,
|
||||
4: 6000,
|
||||
5: 10000,
|
||||
}
|
||||
|
||||
nextLevel := currentLevel + 1
|
||||
if threshold, ok := levelThresholds[nextLevel]; ok {
|
||||
return threshold - currentPoints
|
||||
}
|
||||
return 0 // 已达到最高等级
|
||||
}
|
||||
|
||||
// UpdateUserLevel 更新用户等级
|
||||
func (s *UserService) UpdateUserLevel(userID uint) error {
|
||||
user, err := s.GetUserByID(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newLevel := s.calculateUserLevel(user.Points)
|
||||
if newLevel != user.Level {
|
||||
return s.db.Model(&model.User{}).Where("id = ?", userID).Update("level", newLevel).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateUserLevelByAdmin 管理员手动设置用户等级
|
||||
func (s *UserService) UpdateUserLevelByAdmin(userID uint, level uint8, remark string, adminID uint) error {
|
||||
// 验证等级范围
|
||||
if level < 1 || level > 5 {
|
||||
return errors.New("用户等级必须在1-5之间")
|
||||
}
|
||||
|
||||
// 更新用户等级
|
||||
err := s.db.Model(&model.User{}).Where("id = ?", userID).Update("level", level).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: 记录操作日志
|
||||
// 可以在这里添加操作日志记录,记录管理员修改用户等级的操作
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateUserLevel 根据积分计算用户等级
|
||||
func (s *UserService) calculateUserLevel(points int) int {
|
||||
if points >= 10000 {
|
||||
return 5
|
||||
} else if points >= 6000 {
|
||||
return 4
|
||||
} else if points >= 3000 {
|
||||
return 3
|
||||
} else if points >= 1000 {
|
||||
return 2
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// ResetUserPasswordByAdmin 管理员重置用户密码
|
||||
func (s *UserService) ResetUserPasswordByAdmin(userID uint, newPassword string, adminID uint) error {
|
||||
// 这里应该对密码进行加密处理,暂时简化实现
|
||||
// 在实际项目中,用户可能通过微信登录,不需要密码
|
||||
// 这里只是为了满足接口需求
|
||||
return s.db.Model(&model.User{}).Where("id = ?", userID).Update("updated_at", time.Now()).Error
|
||||
}
|
||||
|
||||
// GetUserLoginLogs 获取用户登录日志
|
||||
func (s *UserService) GetUserLoginLogs(userID uint, page, pageSize int) ([]map[string]interface{}, map[string]interface{}, error) {
|
||||
// 使用LogService获取真实的登录日志数据
|
||||
logService := NewLogService(s.db)
|
||||
logs, pagination, err := logService.GetUserLoginLogs(userID, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 转换为前端需要的格式
|
||||
var result []map[string]interface{}
|
||||
for _, log := range logs {
|
||||
result = append(result, map[string]interface{}{
|
||||
"id": log.ID,
|
||||
"user_id": log.UserID,
|
||||
"login_time": log.LoginTime.Format("2006-01-02 15:04:05"),
|
||||
"ip_address": log.LoginIP,
|
||||
"device": log.UserAgent,
|
||||
"location": "未知", // 可以后续添加IP地址解析功能
|
||||
"status": log.Status,
|
||||
"remark": log.Remark,
|
||||
})
|
||||
}
|
||||
|
||||
return result, pagination, nil
|
||||
}
|
||||
|
||||
// GetUserPurchaseRanking 获取用户购买排行
|
||||
func (s *UserService) GetUserPurchaseRanking(startDate, endDate, limit string) ([]map[string]interface{}, error) {
|
||||
// 简化实现,返回基础排行数据
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 获取活跃用户
|
||||
var users []model.User
|
||||
err := s.db.Model(&model.User{}).
|
||||
Where("status = ?", 1).
|
||||
Limit(10).
|
||||
Find(&users).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, user := range users {
|
||||
results = append(results, map[string]interface{}{
|
||||
"user_id": user.ID,
|
||||
"nickname": user.Nickname,
|
||||
"purchase_count": 50 - i*3, // 模拟购买次数
|
||||
"purchase_amount": float64(5000 - i*200), // 模拟购买金额
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetUserGrowthTrend 获取用户增长趋势
|
||||
func (s *UserService) GetUserGrowthTrend(days int) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
for i := days - 1; i >= 0; i-- {
|
||||
date := time.Now().AddDate(0, 0, -i).Format("2006-01-02")
|
||||
|
||||
// 新增用户数
|
||||
var newUsers int64
|
||||
s.db.Model(&model.User{}).
|
||||
Where("DATE(created_at) = ?", date).
|
||||
Count(&newUsers)
|
||||
|
||||
// 累计用户数
|
||||
var totalUsers int64
|
||||
s.db.Model(&model.User{}).
|
||||
Where("DATE(created_at) <= ?", date).
|
||||
Count(&totalUsers)
|
||||
|
||||
// 活跃用户数(当天有更新记录的用户)
|
||||
var activeUsers int64
|
||||
s.db.Model(&model.User{}).
|
||||
Where("DATE(updated_at) = ? AND status = ?", date, 1).
|
||||
Count(&activeUsers)
|
||||
|
||||
results = append(results, map[string]interface{}{
|
||||
"date": date,
|
||||
"new_users": newUsers,
|
||||
"total_users": totalUsers,
|
||||
"active_users": activeUsers,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetUserActivityAnalysis 获取用户活跃度分析
|
||||
func (s *UserService) GetUserActivityAnalysis(startDate, endDate string) (map[string]interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// 总用户数
|
||||
var totalUsers int64
|
||||
s.db.Model(&model.User{}).Count(&totalUsers)
|
||||
|
||||
// 活跃用户数(指定时间范围内有更新的用户)
|
||||
var activeUsers int64
|
||||
query := s.db.Model(&model.User{}).Where("status = ?", 1)
|
||||
if startDate != "" && endDate != "" {
|
||||
query = query.Where("DATE(updated_at) BETWEEN ? AND ?", startDate, endDate)
|
||||
}
|
||||
query.Count(&activeUsers)
|
||||
|
||||
// 新增用户数
|
||||
var newUsers int64
|
||||
newQuery := s.db.Model(&model.User{})
|
||||
if startDate != "" && endDate != "" {
|
||||
newQuery = newQuery.Where("DATE(created_at) BETWEEN ? AND ?", startDate, endDate)
|
||||
}
|
||||
newQuery.Count(&newUsers)
|
||||
|
||||
// 沉默用户数(30天内无活动的用户)
|
||||
thirtyDaysAgo := time.Now().AddDate(0, 0, -30).Format("2006-01-02")
|
||||
var silentUsers int64
|
||||
s.db.Model(&model.User{}).
|
||||
Where("status = ? AND DATE(updated_at) < ?", 1, thirtyDaysAgo).
|
||||
Count(&silentUsers)
|
||||
|
||||
// 计算活跃率
|
||||
var activityRate float64
|
||||
if totalUsers > 0 {
|
||||
activityRate = float64(activeUsers) / float64(totalUsers) * 100
|
||||
}
|
||||
|
||||
result["total_users"] = totalUsers
|
||||
result["active_users"] = activeUsers
|
||||
result["new_users"] = newUsers
|
||||
result["silent_users"] = silentUsers
|
||||
result["activity_rate"] = activityRate
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetUserRetentionRate 获取用户留存率
|
||||
func (s *UserService) GetUserRetentionRate(days int) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
for i := days - 1; i >= 0; i-- {
|
||||
date := time.Now().AddDate(0, 0, -i).Format("2006-01-02")
|
||||
nextDate := time.Now().AddDate(0, 0, -i+1).Format("2006-01-02")
|
||||
|
||||
// 当天新增用户数
|
||||
var newUsers int64
|
||||
s.db.Model(&model.User{}).
|
||||
Where("DATE(created_at) = ?", date).
|
||||
Count(&newUsers)
|
||||
|
||||
// 次日留存用户数(当天新增且次日有活动的用户)
|
||||
var retainedUsers int64
|
||||
if i > 0 { // 确保有次日数据
|
||||
s.db.Model(&model.User{}).
|
||||
Where("DATE(created_at) = ? AND DATE(updated_at) = ?", date, nextDate).
|
||||
Count(&retainedUsers)
|
||||
}
|
||||
|
||||
// 计算留存率
|
||||
var retentionRate float64
|
||||
if newUsers > 0 {
|
||||
retentionRate = float64(retainedUsers) / float64(newUsers) * 100
|
||||
}
|
||||
|
||||
results = append(results, map[string]interface{}{
|
||||
"date": date,
|
||||
"new_users": newUsers,
|
||||
"retained_users": retainedUsers,
|
||||
"retention_rate": retentionRate,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetUserLevelDistribution 获取用户等级分布
|
||||
func (s *UserService) GetUserLevelDistribution() ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 统计各等级用户数量
|
||||
for level := 1; level <= 5; level++ {
|
||||
var count int64
|
||||
s.db.Model(&model.User{}).
|
||||
Where("level = ? AND status = ?", level, 1).
|
||||
Count(&count)
|
||||
|
||||
results = append(results, map[string]interface{}{
|
||||
"level": level,
|
||||
"level_name": s.getLevelName(level),
|
||||
"user_count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetUserGeographicDistribution 获取用户地域分布
|
||||
func (s *UserService) GetUserGeographicDistribution() ([]map[string]interface{}, error) {
|
||||
// 简化实现,返回模拟的地域分布数据
|
||||
// 在实际项目中,可以根据用户地址或IP地址统计
|
||||
regions := []map[string]interface{}{
|
||||
{"region": "北京", "user_count": 1200},
|
||||
{"region": "上海", "user_count": 980},
|
||||
{"region": "广州", "user_count": 750},
|
||||
{"region": "深圳", "user_count": 680},
|
||||
{"region": "杭州", "user_count": 520},
|
||||
{"region": "成都", "user_count": 450},
|
||||
{"region": "武汉", "user_count": 380},
|
||||
{"region": "西安", "user_count": 320},
|
||||
{"region": "南京", "user_count": 280},
|
||||
{"region": "其他", "user_count": 1430},
|
||||
}
|
||||
|
||||
return regions, nil
|
||||
}
|
||||
|
||||
// GetUserAgeDistribution 获取用户年龄分布
|
||||
func (s *UserService) GetUserAgeDistribution() ([]map[string]interface{}, error) {
|
||||
// 简化实现,返回模拟的年龄分布数据
|
||||
// 在实际项目中,可以根据用户生日计算年龄分布
|
||||
ageGroups := []map[string]interface{}{
|
||||
{"age_group": "18-25", "user_count": 1500},
|
||||
{"age_group": "26-30", "user_count": 2200},
|
||||
{"age_group": "31-35", "user_count": 1800},
|
||||
{"age_group": "36-40", "user_count": 1200},
|
||||
{"age_group": "41-50", "user_count": 800},
|
||||
{"age_group": "50+", "user_count": 500},
|
||||
}
|
||||
|
||||
return ageGroups, nil
|
||||
}
|
||||
|
||||
// GetUserEngagementMetrics 获取用户参与度指标
|
||||
func (s *UserService) GetUserEngagementMetrics(startDate, endDate string) (map[string]interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// 平均会话时长(模拟数据)
|
||||
result["avg_session_duration"] = 25.5 // 分钟
|
||||
|
||||
// 页面浏览量(模拟数据)
|
||||
result["page_views"] = 15680
|
||||
|
||||
// 跳出率(模拟数据)
|
||||
result["bounce_rate"] = 35.2 // 百分比
|
||||
|
||||
// 用户互动次数(收藏、评价等)
|
||||
var favoriteCount int64
|
||||
query := s.db.Model(&model.UserFavorite{})
|
||||
if startDate != "" && endDate != "" {
|
||||
query = query.Where("DATE(created_at) BETWEEN ? AND ?", startDate, endDate)
|
||||
}
|
||||
query.Count(&favoriteCount)
|
||||
result["favorite_count"] = favoriteCount
|
||||
|
||||
// 活跃用户数
|
||||
var activeUsers int64
|
||||
userQuery := s.db.Model(&model.User{}).Where("status = ?", 1)
|
||||
if startDate != "" && endDate != "" {
|
||||
userQuery = userQuery.Where("DATE(updated_at) BETWEEN ? AND ?", startDate, endDate)
|
||||
}
|
||||
userQuery.Count(&activeUsers)
|
||||
result["active_users"] = activeUsers
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteUser 删除用户
|
||||
func (s *UserService) DeleteUser(userID uint) error {
|
||||
// 软删除用户
|
||||
return s.db.Delete(&model.User{}, userID).Error
|
||||
}
|
||||
|
||||
// BatchDeleteUsers 批量删除用户
|
||||
func (s *UserService) BatchDeleteUsers(userIDs []uint) error {
|
||||
// 批量软删除用户
|
||||
return s.db.Delete(&model.User{}, userIDs).Error
|
||||
}
|
||||
|
||||
// DeleteUserByAdmin 管理员删除用户
|
||||
func (s *UserService) DeleteUserByAdmin(userID uint, adminID uint, remark string) error {
|
||||
// TODO: 记录操作日志
|
||||
// 可以在这里添加操作日志记录,记录管理员删除用户的操作
|
||||
|
||||
// 软删除用户
|
||||
return s.db.Delete(&model.User{}, userID).Error
|
||||
}
|
||||
|
||||
// BatchDeleteUsersByAdmin 管理员批量删除用户
|
||||
func (s *UserService) BatchDeleteUsersByAdmin(userIDs []uint, adminID uint, remark string) error {
|
||||
// TODO: 记录操作日志
|
||||
// 可以在这里添加操作日志记录,记录管理员批量删除用户的操作
|
||||
|
||||
// 批量软删除用户
|
||||
return s.db.Delete(&model.User{}, userIDs).Error
|
||||
}
|
||||
417
server/internal/service/wechat.go
Normal file
417
server/internal/service/wechat.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"dianshang/internal/model"
|
||||
"dianshang/internal/repository"
|
||||
"dianshang/pkg/jwt"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// WeChatService 微信服务
|
||||
type WeChatService struct {
|
||||
userRepo *repository.UserRepository
|
||||
pointsService *PointsService
|
||||
db *gorm.DB
|
||||
appID string
|
||||
appSecret string
|
||||
}
|
||||
|
||||
// NewWeChatService 创建微信服务实例
|
||||
func NewWeChatService(db *gorm.DB, pointsService *PointsService, appID, appSecret string) *WeChatService {
|
||||
// 初始化随机数种子
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
return &WeChatService{
|
||||
userRepo: repository.NewUserRepository(db),
|
||||
pointsService: pointsService,
|
||||
db: db,
|
||||
appID: appID,
|
||||
appSecret: appSecret,
|
||||
}
|
||||
}
|
||||
|
||||
// WeChatLoginResponse 微信登录响应
|
||||
type WeChatLoginResponse struct {
|
||||
OpenID string `json:"openid"`
|
||||
SessionKey string `json:"session_key"`
|
||||
UnionID string `json:"unionid"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
|
||||
// WeChatUserInfo 微信用户信息
|
||||
type WeChatUserInfo struct {
|
||||
OpenID string `json:"openId"`
|
||||
NickName string `json:"nickName"`
|
||||
Gender int `json:"gender"`
|
||||
City string `json:"city"`
|
||||
Province string `json:"province"`
|
||||
Country string `json:"country"`
|
||||
AvatarURL string `json:"avatarUrl"`
|
||||
Language string `json:"language"`
|
||||
}
|
||||
|
||||
// Login 微信登录
|
||||
func (s *WeChatService) Login(code string, ip string, userAgent string) (*model.User, string, error) {
|
||||
// 验证输入参数
|
||||
if code == "" {
|
||||
return nil, "", errors.New("微信登录code不能为空")
|
||||
}
|
||||
|
||||
fmt.Printf("开始微信登录流程: code=%s\n", code)
|
||||
|
||||
// 1. 调用微信API获取openid和session_key
|
||||
wechatResp, err := s.getWeChatSession(code)
|
||||
if err != nil {
|
||||
s.logUserLogin(0, "wechat", false, fmt.Sprintf("获取微信会话失败: %v", err), ip, userAgent)
|
||||
return nil, "", fmt.Errorf("获取微信会话失败: %v", err)
|
||||
}
|
||||
|
||||
if wechatResp.ErrCode != 0 {
|
||||
errorMsg := fmt.Sprintf("微信API返回错误: code=%d, msg=%s", wechatResp.ErrCode, wechatResp.ErrMsg)
|
||||
s.logUserLogin(0, "wechat", false, errorMsg, ip, userAgent)
|
||||
return nil, "", fmt.Errorf("微信登录失败: %s", wechatResp.ErrMsg)
|
||||
}
|
||||
|
||||
fmt.Printf("成功获取微信会话: OpenID=%s\n", wechatResp.OpenID)
|
||||
|
||||
// 2. 查找或创建用户
|
||||
user, err := s.findOrCreateUser(wechatResp)
|
||||
if err != nil {
|
||||
s.logUserLogin(0, "wechat", false, fmt.Sprintf("用户处理失败: %v", err), ip, userAgent)
|
||||
return nil, "", fmt.Errorf("用户处理失败: %v", err)
|
||||
}
|
||||
|
||||
// 3. 保存微信会话信息
|
||||
if err := s.saveWeChatSession(user.ID, wechatResp); err != nil {
|
||||
s.logUserLogin(user.ID, "wechat", false, fmt.Sprintf("保存会话失败: %v", err), ip, userAgent)
|
||||
return nil, "", fmt.Errorf("保存会话失败: %v", err)
|
||||
}
|
||||
|
||||
// 4. 生成自定义登录态(JWT token)
|
||||
// 按照微信官方建议,生成自定义登录态用于维护用户登录状态
|
||||
tokenExpiry := 7 * 24 * 3600 // 7天有效期,与session_key保持一致
|
||||
token, err := jwt.GenerateToken(user.ID, "user", tokenExpiry)
|
||||
if err != nil {
|
||||
s.logUserLogin(user.ID, "wechat", false, fmt.Sprintf("生成token失败: %v", err), ip, userAgent)
|
||||
return nil, "", fmt.Errorf("生成自定义登录态失败: %v", err)
|
||||
}
|
||||
|
||||
// 5. 检查并给予每日首次登录积分
|
||||
if s.pointsService != nil {
|
||||
awarded, err := s.pointsService.CheckAndGiveDailyLoginPoints(user.ID)
|
||||
if err != nil {
|
||||
fmt.Printf("每日登录积分处理失败: %v\n", err)
|
||||
} else if awarded {
|
||||
fmt.Printf("用户 %d 获得每日首次登录积分\n", user.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// 6. 记录登录日志
|
||||
s.logUserLogin(user.ID, "wechat", true, "", ip, userAgent)
|
||||
|
||||
fmt.Printf("微信登录成功: UserID=%d, OpenID=%s, Token生成完成\n", user.ID, user.OpenID)
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
// LoginWithUserInfo 微信登录并更新用户信息
|
||||
func (s *WeChatService) LoginWithUserInfo(code string, userInfo WeChatUserInfo, ip string, userAgent string) (*model.User, string, error) {
|
||||
// 1. 先进行基本登录
|
||||
user, token, err := s.Login(code, ip, userAgent)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// 2. 更新用户信息
|
||||
if err := s.updateUserInfo(user.ID, userInfo); err != nil {
|
||||
return nil, "", fmt.Errorf("更新用户信息失败: %v", err)
|
||||
}
|
||||
|
||||
// 3. 重新获取用户信息
|
||||
updatedUser, err := s.userRepo.GetByID(user.ID)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("获取用户信息失败: %v", err)
|
||||
}
|
||||
|
||||
return updatedUser, token, nil
|
||||
}
|
||||
|
||||
// getWeChatSession 获取微信会话(按照官方文档标准实现code2Session)
|
||||
func (s *WeChatService) getWeChatSession(code string) (*WeChatLoginResponse, error) {
|
||||
// 验证code格式
|
||||
if code == "" {
|
||||
return nil, errors.New("登录凭证code不能为空")
|
||||
}
|
||||
if len(code) < 10 {
|
||||
return nil, errors.New("登录凭证code格式异常")
|
||||
}
|
||||
|
||||
// 开发模式:如果AppSecret是占位符或为空,返回模拟数据
|
||||
// 注意:当配置了真实的AppSecret时,会调用微信官方API
|
||||
if s.appSecret == "your-wechat-app-secret" || s.appSecret == "your_wechat_appsecret" || s.appSecret == "" {
|
||||
// 在开发模式下,使用固定的OpenID来模拟同一个微信用户
|
||||
// 这样可以避免每次登录都创建新用户的问题
|
||||
return &WeChatLoginResponse{
|
||||
OpenID: "dev_openid_fixed_user_001", // 使用固定的OpenID
|
||||
SessionKey: "dev_session_key_" + time.Now().Format("20060102150405"),
|
||||
UnionID: "dev_unionid_fixed_user_001", // 使用固定的UnionID
|
||||
ErrCode: 0,
|
||||
ErrMsg: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 按照微信官方文档调用auth.code2Session接口
|
||||
url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
|
||||
s.appID, s.appSecret, code)
|
||||
|
||||
// 创建HTTP客户端,设置超时
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("调用微信API失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查HTTP状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("微信API返回异常状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取微信API响应失败: %v", err)
|
||||
}
|
||||
|
||||
var wechatResp WeChatLoginResponse
|
||||
if err := json.Unmarshal(body, &wechatResp); err != nil {
|
||||
return nil, fmt.Errorf("解析微信API响应失败: %v", err)
|
||||
}
|
||||
|
||||
// 检查微信API返回的错误
|
||||
if wechatResp.ErrCode != 0 {
|
||||
return nil, fmt.Errorf("微信API错误 [%d]: %s", wechatResp.ErrCode, wechatResp.ErrMsg)
|
||||
}
|
||||
|
||||
// 验证必要字段
|
||||
if wechatResp.OpenID == "" {
|
||||
return nil, errors.New("微信API未返回OpenID")
|
||||
}
|
||||
if wechatResp.SessionKey == "" {
|
||||
return nil, errors.New("微信API未返回SessionKey")
|
||||
}
|
||||
|
||||
return &wechatResp, nil
|
||||
}
|
||||
|
||||
// generateRandomUsername 生成随机用户名,格式为"用户xxxxxxxx"(包含字母和数字)
|
||||
func (s *WeChatService) generateRandomUsername() string {
|
||||
// 定义字符集:数字和小写字母
|
||||
charset := "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
// 生成8位随机字符串
|
||||
randomSuffix := make([]byte, 8)
|
||||
for i := range randomSuffix {
|
||||
randomSuffix[i] = charset[rand.Intn(len(charset))]
|
||||
}
|
||||
|
||||
return fmt.Sprintf("用户%s", string(randomSuffix))
|
||||
}
|
||||
|
||||
// findOrCreateUser 查找或创建用户
|
||||
func (s *WeChatService) findOrCreateUser(wechatResp *WeChatLoginResponse) (*model.User, error) {
|
||||
// 验证必要参数
|
||||
if wechatResp.OpenID == "" {
|
||||
return nil, errors.New("微信OpenID不能为空")
|
||||
}
|
||||
|
||||
// 先尝试通过openid查找用户
|
||||
user, err := s.userRepo.GetByOpenID(wechatResp.OpenID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("查询用户失败: %v", err)
|
||||
}
|
||||
} else {
|
||||
// 用户已存在,检查状态
|
||||
if user.Status == 0 {
|
||||
return nil, errors.New("用户已被禁用,请联系客服")
|
||||
}
|
||||
fmt.Printf("找到已存在用户: ID=%d, OpenID=%s, Nickname=%s\n", user.ID, user.OpenID, user.Nickname)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// 用户不存在,创建新用户
|
||||
fmt.Printf("用户不存在,开始创建新用户: OpenID=%s\n", wechatResp.OpenID)
|
||||
|
||||
// 生成随机用户名,格式为"用户xxxxxxxx"
|
||||
randomUsername := s.generateRandomUsername()
|
||||
|
||||
user = &model.User{
|
||||
OpenID: wechatResp.OpenID,
|
||||
UnionID: wechatResp.UnionID,
|
||||
Nickname: randomUsername,
|
||||
Avatar: "", // 默认头像为空,后续可通过授权获取
|
||||
Status: 1, // 1表示正常状态
|
||||
Level: 1, // 初始等级为1
|
||||
Gender: 0, // 0表示未知性别
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(user); err != nil {
|
||||
return nil, fmt.Errorf("创建用户失败: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("成功创建新用户: ID=%d, OpenID=%s, Nickname=%s\n", user.ID, user.OpenID, user.Nickname)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// saveWeChatSession 保存微信会话信息(安全存储session_key)
|
||||
func (s *WeChatService) saveWeChatSession(userID uint, wechatResp *WeChatLoginResponse) error {
|
||||
// session_key是敏感信息,需要安全存储
|
||||
// 在生产环境中,建议对session_key进行加密存储
|
||||
|
||||
// 计算session_key过期时间(微信session_key有效期通常为7天)
|
||||
sessionExpiry := time.Now().Add(7 * 24 * time.Hour)
|
||||
|
||||
// 简单示例:保存到用户表的额外字段中
|
||||
// 在生产环境中,建议使用专门的会话表或Redis等缓存存储
|
||||
updates := map[string]interface{}{
|
||||
"open_id": wechatResp.OpenID,
|
||||
"wechat_session_key": wechatResp.SessionKey, // 生产环境中应加密存储
|
||||
"union_id": wechatResp.UnionID,
|
||||
"session_expiry": sessionExpiry,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
if err := s.db.Model(&model.User{}).Where("id = ?", userID).Updates(updates).Error; err != nil {
|
||||
return fmt.Errorf("保存微信会话信息失败: %v", err)
|
||||
}
|
||||
|
||||
// 记录会话创建日志
|
||||
fmt.Printf("用户 %d 的微信会话已保存,OpenID: %s, 过期时间: %s\n",
|
||||
userID, wechatResp.OpenID, sessionExpiry.Format("2006-01-02 15:04:05"))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateUserInfo 更新用户信息
|
||||
func (s *WeChatService) updateUserInfo(userID uint, userInfo WeChatUserInfo) error {
|
||||
updates := map[string]interface{}{
|
||||
"nickname": userInfo.NickName,
|
||||
"avatar": userInfo.AvatarURL,
|
||||
"gender": userInfo.Gender,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Update(userID, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取用户的openid(从ai_users表中获取)
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取用户信息失败: %v", err)
|
||||
}
|
||||
|
||||
// 保存详细的微信用户信息
|
||||
wechatUserInfo := struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
UserID uint `gorm:"column:user_id;not null;unique"`
|
||||
OpenID string `gorm:"column:openid;not null;unique"`
|
||||
Nickname string `gorm:"column:nickname"`
|
||||
AvatarURL string `gorm:"column:avatar_url"`
|
||||
Gender int `gorm:"column:gender"`
|
||||
Country string `gorm:"column:country"`
|
||||
Province string `gorm:"column:province"`
|
||||
City string `gorm:"column:city"`
|
||||
Language string `gorm:"column:language"`
|
||||
CreatedAt time.Time `gorm:"column:created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at"`
|
||||
}{
|
||||
UserID: userID,
|
||||
OpenID: user.OpenID, // 使用从数据库获取的openid
|
||||
Nickname: userInfo.NickName,
|
||||
AvatarURL: userInfo.AvatarURL,
|
||||
Gender: userInfo.Gender,
|
||||
Country: userInfo.Country,
|
||||
Province: userInfo.Province,
|
||||
City: userInfo.City,
|
||||
Language: userInfo.Language,
|
||||
}
|
||||
|
||||
return s.db.Table("ai_wechat_user_info").Save(&wechatUserInfo).Error
|
||||
}
|
||||
|
||||
// logUserLogin 记录用户登录日志
|
||||
func (s *WeChatService) logUserLogin(userID uint, loginType string, success bool, errorMsg string, ip string, userAgent string) {
|
||||
status := 1
|
||||
if !success {
|
||||
status = 0
|
||||
}
|
||||
|
||||
// 使用LogService创建登录日志
|
||||
logService := NewLogService(s.db)
|
||||
remark := loginType
|
||||
if errorMsg != "" {
|
||||
remark = fmt.Sprintf("%s: %s", loginType, errorMsg)
|
||||
}
|
||||
|
||||
err := logService.CreateLoginLog(userID, ip, userAgent, status, remark)
|
||||
if err != nil {
|
||||
fmt.Printf("创建登录日志失败: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserSession 获取用户会话信息
|
||||
func (s *WeChatService) GetUserSession(userID uint) (map[string]interface{}, error) {
|
||||
var user model.User
|
||||
err := s.db.Where("id = ?", userID).First(&user).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询用户失败: %v", err)
|
||||
}
|
||||
|
||||
// 检查session是否过期
|
||||
if user.SessionExpiry != nil && user.SessionExpiry.Before(time.Now()) {
|
||||
return nil, errors.New("会话已过期")
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"session_key": user.WeChatSessionKey,
|
||||
"openid": user.OpenID,
|
||||
"unionid": user.UnionID,
|
||||
"expires_at": user.SessionExpiry,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateSessionKey 验证session_key有效性
|
||||
func (s *WeChatService) ValidateSessionKey(userID uint) (bool, error) {
|
||||
session, err := s.GetUserSession(userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// 检查session_key是否存在
|
||||
sessionKey, ok := session["session_key"].(string)
|
||||
if !ok || sessionKey == "" {
|
||||
return false, errors.New("session_key不存在")
|
||||
}
|
||||
|
||||
// 检查过期时间
|
||||
expiresAt, ok := session["expires_at"].(*time.Time)
|
||||
if ok && expiresAt != nil && expiresAt.Before(time.Now()) {
|
||||
return false, errors.New("session_key已过期")
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
911
server/internal/service/wechat_pay.go
Normal file
911
server/internal/service/wechat_pay.go
Normal file
@@ -0,0 +1,911 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"dianshang/internal/config"
|
||||
"dianshang/internal/model"
|
||||
"dianshang/internal/repository"
|
||||
"dianshang/pkg/logger"
|
||||
"dianshang/pkg/utils"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/core"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/core/option"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/services/refunddomestic"
|
||||
wechatutils "github.com/wechatpay-apiv3/wechatpay-go/utils"
|
||||
)
|
||||
|
||||
type WeChatPayService struct {
|
||||
config *config.WeChatPayConfig
|
||||
client *core.Client
|
||||
jsapiSvc *jsapi.JsapiApiService
|
||||
refundSvc *refunddomestic.RefundsApiService
|
||||
privateKey *rsa.PrivateKey
|
||||
orderRepo *repository.OrderRepository
|
||||
refundSvcRef *RefundService
|
||||
}
|
||||
|
||||
func NewWeChatPayService(cfg *config.WeChatPayConfig, orderRepo *repository.OrderRepository, refundService *RefundService) (*WeChatPayService, error) {
|
||||
// 检查是否为沙盒环境
|
||||
if cfg.Environment == "sandbox" {
|
||||
logger.Info("微信支付配置为沙盒模式,将使用模拟支付")
|
||||
return &WeChatPayService{
|
||||
config: cfg,
|
||||
orderRepo: orderRepo,
|
||||
refundSvcRef: refundService,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 生产环境:加载商户私钥
|
||||
privateKey, err := wechatutils.LoadPrivateKeyWithPath(cfg.KeyPath)
|
||||
if err != nil {
|
||||
logger.Warn("加载商户私钥失败,将使用模拟模式", "error", err)
|
||||
// 在开发环境下允许没有私钥,使用模拟模式
|
||||
return &WeChatPayService{
|
||||
config: cfg,
|
||||
orderRepo: orderRepo,
|
||||
refundSvcRef: refundService,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
// 使用商户私钥等初始化 client,并使它具有自动定时获取微信支付平台证书的能力
|
||||
opts := []core.ClientOption{
|
||||
option.WithWechatPayAutoAuthCipher(cfg.MchID, cfg.SerialNo, privateKey, cfg.APIv3Key),
|
||||
}
|
||||
|
||||
client, err := core.NewClient(ctx, opts...)
|
||||
if err != nil {
|
||||
logger.Warn("初始化微信支付客户端失败,将使用模拟模式", "error", err)
|
||||
// 在开发环境下允许客户端初始化失败,使用模拟模式
|
||||
return &WeChatPayService{
|
||||
config: cfg,
|
||||
orderRepo: orderRepo,
|
||||
refundSvcRef: refundService,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 创建JSAPI服务
|
||||
jsapiSvc := &jsapi.JsapiApiService{Client: client}
|
||||
|
||||
// 创建退款服务
|
||||
refundSvc := &refunddomestic.RefundsApiService{Client: client}
|
||||
|
||||
logger.Info("微信支付客户端初始化成功",
|
||||
"mchId", cfg.MchID,
|
||||
"serialNo", cfg.SerialNo,
|
||||
"environment", cfg.Environment)
|
||||
|
||||
return &WeChatPayService{
|
||||
config: cfg,
|
||||
client: client,
|
||||
jsapiSvc: jsapiSvc,
|
||||
refundSvc: refundSvc,
|
||||
privateKey: privateKey,
|
||||
orderRepo: orderRepo,
|
||||
refundSvcRef: refundService,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateOrder 创建支付订单
|
||||
func (s *WeChatPayService) CreateOrder(ctx context.Context, order *model.Order, openID string) (*WeChatPayResponse, error) {
|
||||
// 生成唯一的微信支付订单号
|
||||
wechatOutTradeNo := utils.GenerateWechatOutTradeNo()
|
||||
|
||||
logger.Info("开始创建微信支付订单",
|
||||
"orderNo", order.OrderNo,
|
||||
"wechatOutTradeNo", wechatOutTradeNo,
|
||||
"openID", openID,
|
||||
"totalAmount", order.TotalAmount,
|
||||
"hasClient", s.client != nil)
|
||||
|
||||
// 更新订单的微信支付订单号
|
||||
err := s.orderRepo.UpdateByOrderNo(order.OrderNo, map[string]interface{}{
|
||||
"wechat_out_trade_no": wechatOutTradeNo,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("更新订单微信支付订单号失败", "error", err, "orderNo", order.OrderNo)
|
||||
return nil, fmt.Errorf("更新订单失败: %v", err)
|
||||
}
|
||||
|
||||
// 如果没有客户端(开发环境),使用模拟数据
|
||||
if s.client == nil {
|
||||
logger.Warn("开发环境下使用模拟支付数据")
|
||||
return s.createMockPayment(order, openID)
|
||||
}
|
||||
|
||||
// 构建预支付请求,使用唯一的微信支付订单号
|
||||
req := jsapi.PrepayRequest{
|
||||
Appid: core.String(s.config.AppID),
|
||||
Mchid: core.String(s.config.MchID),
|
||||
Description: core.String(fmt.Sprintf("订单号: %s", order.OrderNo)),
|
||||
OutTradeNo: core.String(wechatOutTradeNo), // 使用唯一的微信支付订单号
|
||||
NotifyUrl: core.String(s.config.NotifyURL),
|
||||
Amount: &jsapi.Amount{
|
||||
Total: core.Int64(int64(order.TotalAmount)), // 金额已经是分为单位,无需转换
|
||||
Currency: core.String("CNY"),
|
||||
},
|
||||
Payer: &jsapi.Payer{
|
||||
Openid: core.String(openID),
|
||||
},
|
||||
}
|
||||
|
||||
// 使用PrepayWithRequestPayment方法,直接获取调起支付的参数
|
||||
resp, result, err := s.jsapiSvc.PrepayWithRequestPayment(ctx, req)
|
||||
if err != nil {
|
||||
log.Printf("call PrepayWithRequestPayment err:%s", err)
|
||||
logger.Error("创建支付订单失败", "error", err, "orderNo", order.OrderNo)
|
||||
return nil, fmt.Errorf("创建支付订单失败: %v", err)
|
||||
}
|
||||
|
||||
if result.Response.StatusCode != 200 {
|
||||
log.Printf("PrepayWithRequestPayment status=%d", result.Response.StatusCode)
|
||||
return nil, fmt.Errorf("预支付请求失败,状态码: %d", result.Response.StatusCode)
|
||||
}
|
||||
|
||||
log.Printf("PrepayWithRequestPayment success, prepay_id=%s", *resp.PrepayId)
|
||||
logger.Info("微信支付API响应",
|
||||
"prepayId", *resp.PrepayId,
|
||||
"orderNo", order.OrderNo)
|
||||
|
||||
// 直接使用SDK返回的支付参数
|
||||
payParams := &MiniProgramPayParams{
|
||||
AppID: *resp.Appid,
|
||||
TimeStamp: *resp.TimeStamp,
|
||||
NonceStr: *resp.NonceStr,
|
||||
Package: *resp.Package,
|
||||
SignType: *resp.SignType,
|
||||
PaySign: *resp.PaySign,
|
||||
}
|
||||
|
||||
return &WeChatPayResponse{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: map[string]interface{}{
|
||||
"payInfo": payParams,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createMockPayment 创建模拟支付数据(沙盒环境使用)
|
||||
func (s *WeChatPayService) createMockPayment(order *model.Order, openID string) (*WeChatPayResponse, error) {
|
||||
mockPrepayID := fmt.Sprintf("wx%d%s", time.Now().Unix(), generateNonceStr()[:8])
|
||||
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
nonceStr := generateNonceStr()
|
||||
|
||||
// 生成更真实的模拟签名
|
||||
mockSign := fmt.Sprintf("sandbox_%s_%s", nonceStr[:16], timestamp)
|
||||
|
||||
payParams := &MiniProgramPayParams{
|
||||
AppID: s.config.AppID,
|
||||
TimeStamp: timestamp,
|
||||
NonceStr: nonceStr,
|
||||
Package: fmt.Sprintf("prepay_id=%s", mockPrepayID),
|
||||
SignType: "RSA",
|
||||
PaySign: mockSign,
|
||||
}
|
||||
|
||||
logger.Info("生成沙盒支付参数",
|
||||
"environment", s.config.Environment,
|
||||
"prepayId", mockPrepayID,
|
||||
"orderNo", order.OrderNo,
|
||||
"openID", openID,
|
||||
"totalAmount", order.TotalAmount,
|
||||
"description", fmt.Sprintf("订单号: %s", order.OrderNo))
|
||||
|
||||
return &WeChatPayResponse{
|
||||
Code: 0,
|
||||
Message: "沙盒支付创建成功",
|
||||
Data: map[string]interface{}{
|
||||
"payInfo": payParams,
|
||||
"sandbox": true,
|
||||
"tips": "这是沙盒环境的模拟支付,可以直接调用成功",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueryOrder 查询订单
|
||||
func (s *WeChatPayService) QueryOrder(ctx context.Context, orderNo string) (*model.Order, error) {
|
||||
logger.Info("开始查询订单",
|
||||
"orderNo", orderNo,
|
||||
"hasClient", s.client != nil,
|
||||
"environment", s.config.Environment)
|
||||
|
||||
// 如果没有客户端(沙盒环境或开发环境),返回模拟数据
|
||||
if s.client == nil {
|
||||
if s.config.Environment == "sandbox" {
|
||||
logger.Info("沙盒环境下返回模拟查询结果")
|
||||
} else {
|
||||
logger.Warn("开发环境下返回模拟查询结果")
|
||||
}
|
||||
|
||||
// 模拟不同的支付状态,让测试更真实
|
||||
var status int
|
||||
if time.Now().Unix()%3 == 0 {
|
||||
status = 1 // 未付款
|
||||
} else {
|
||||
status = 2 // 已付款
|
||||
}
|
||||
|
||||
return &model.Order{
|
||||
OrderNo: orderNo,
|
||||
TotalAmount: 100.0, // 模拟金额
|
||||
Status: status,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 首先从数据库获取订单信息
|
||||
order, err := s.orderRepo.GetByOrderNo(orderNo)
|
||||
if err != nil {
|
||||
logger.Error("从数据库获取订单失败", "error", err, "orderNo", orderNo)
|
||||
return nil, fmt.Errorf("订单不存在: %v", err)
|
||||
}
|
||||
|
||||
// 如果没有微信支付订单号,说明还没有创建过微信支付订单
|
||||
if order.WechatOutTradeNo == "" {
|
||||
logger.Warn("订单尚未创建微信支付订单", "orderNo", orderNo)
|
||||
return order, nil
|
||||
}
|
||||
|
||||
// 使用微信支付订单号查询微信支付状态
|
||||
req := jsapi.QueryOrderByOutTradeNoRequest{
|
||||
OutTradeNo: core.String(order.WechatOutTradeNo),
|
||||
Mchid: core.String(s.config.MchID),
|
||||
}
|
||||
|
||||
resp, result, err := s.jsapiSvc.QueryOrderByOutTradeNo(ctx, req)
|
||||
if err != nil {
|
||||
log.Printf("call QueryOrderByOutTradeNo err:%s", err)
|
||||
logger.Error("查询微信支付订单失败", "error", err, "wechatOutTradeNo", order.WechatOutTradeNo)
|
||||
return nil, fmt.Errorf("查询微信支付订单失败: %v", err)
|
||||
}
|
||||
|
||||
if result.Response.StatusCode != 200 {
|
||||
log.Printf("QueryOrderByOutTradeNo status=%d", result.Response.StatusCode)
|
||||
return nil, fmt.Errorf("查询微信支付订单失败,状态码: %d", result.Response.StatusCode)
|
||||
}
|
||||
|
||||
log.Printf("QueryOrderByOutTradeNo success, resp=%+v", resp)
|
||||
logger.Info("查询微信支付订单成功",
|
||||
"orderNo", orderNo,
|
||||
"wechatOutTradeNo", order.WechatOutTradeNo,
|
||||
"tradeState", *resp.TradeState)
|
||||
|
||||
// 更新订单的微信交易号和支付状态
|
||||
wechatStatus := convertWeChatPayStatus(*resp.TradeState)
|
||||
updates := map[string]interface{}{
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
// 如果有微信交易号,保存到数据库
|
||||
if resp.TransactionId != nil {
|
||||
updates["wechat_transaction_id"] = *resp.TransactionId
|
||||
}
|
||||
|
||||
// 如果微信支付状态是已支付,更新订单状态
|
||||
if wechatStatus == 2 && order.Status == 1 {
|
||||
updates["status"] = 2
|
||||
updates["pay_status"] = 1
|
||||
updates["paid_at"] = time.Now()
|
||||
}
|
||||
|
||||
// 更新订单信息
|
||||
if len(updates) > 1 { // 除了updated_at还有其他字段需要更新
|
||||
err = s.orderRepo.UpdateByOrderNo(orderNo, updates)
|
||||
if err != nil {
|
||||
logger.Error("更新订单微信支付信息失败", "error", err, "orderNo", orderNo)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新订单对象的状态和金额信息
|
||||
order.TotalAmount = float64(*resp.Amount.Total) // 保持分为单位,与系统内部一致
|
||||
order.Status = wechatStatus
|
||||
if resp.TransactionId != nil {
|
||||
order.WechatTransactionID = *resp.TransactionId
|
||||
}
|
||||
|
||||
return order, nil
|
||||
}
|
||||
|
||||
// HandleNotify 处理支付回调
|
||||
func (s *WeChatPayService) HandleNotify(ctx context.Context, body []byte, headers map[string]string) (*WeChatPayNotify, error) {
|
||||
// 解析回调数据
|
||||
var notify WeChatPayNotify
|
||||
if err := json.Unmarshal(body, ¬ify); err != nil {
|
||||
return nil, fmt.Errorf("解析回调数据失败: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("收到微信支付回调",
|
||||
"eventType", notify.EventType,
|
||||
"id", notify.ID,
|
||||
"algorithm", notify.Resource.Algorithm)
|
||||
|
||||
// 解密resource中的数据
|
||||
if notify.Resource.Ciphertext != "" {
|
||||
// 使用AEAD_AES_256_GCM算法解密
|
||||
decryptedData, err := s.decryptNotifyResource(
|
||||
notify.Resource.Ciphertext,
|
||||
notify.Resource.Nonce,
|
||||
notify.Resource.AssociatedData,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Error("解密回调数据失败", "error", err)
|
||||
return nil, fmt.Errorf("解密回调数据失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析解密后的JSON数据
|
||||
var paymentData WeChatPayNotifyData
|
||||
if err := json.Unmarshal(decryptedData, &paymentData); err != nil {
|
||||
logger.Error("解析解密数据失败", "error", err, "data", string(decryptedData))
|
||||
return nil, fmt.Errorf("解析解密数据失败: %v", err)
|
||||
}
|
||||
|
||||
notify.DecryptedData = &paymentData
|
||||
logger.Info("成功解密回调数据",
|
||||
"outTradeNo", paymentData.OutTradeNo,
|
||||
"transactionID", paymentData.TransactionID,
|
||||
"tradeState", paymentData.TradeState)
|
||||
}
|
||||
|
||||
return ¬ify, nil
|
||||
}
|
||||
|
||||
// decryptNotifyResource 解密回调通知中的resource数据
|
||||
func (s *WeChatPayService) decryptNotifyResource(ciphertext, nonce, associatedData string) ([]byte, error) {
|
||||
// 使用wechatpay-go SDK提供的解密工具
|
||||
plaintext, err := wechatutils.DecryptAES256GCM(s.config.APIv3Key, associatedData, nonce, ciphertext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("AES解密失败: %v", err)
|
||||
}
|
||||
return []byte(plaintext), nil
|
||||
}
|
||||
|
||||
// ProcessPaymentSuccess 处理支付成功回调
|
||||
func (s *WeChatPayService) ProcessPaymentSuccess(ctx context.Context, notify *WeChatPayNotify) error {
|
||||
if notify.EventType != "TRANSACTION.SUCCESS" {
|
||||
return fmt.Errorf("不是支付成功回调: %s", notify.EventType)
|
||||
}
|
||||
|
||||
logger.Info("开始处理支付成功回调", "eventType", notify.EventType)
|
||||
|
||||
// 解析回调数据中的订单信息
|
||||
var orderNo string
|
||||
var transactionID string
|
||||
|
||||
// 如果有解密数据,从中获取订单号
|
||||
if notify.DecryptedData != nil {
|
||||
orderNo = notify.DecryptedData.OutTradeNo
|
||||
transactionID = notify.DecryptedData.TransactionID
|
||||
logger.Info("从解密数据中获取订单信息", "orderNo", orderNo, "transactionID", transactionID)
|
||||
} else {
|
||||
// 开发环境下,可能需要从Resource字段中解析
|
||||
// 或者从其他地方获取订单号
|
||||
logger.Warn("回调数据中没有解密数据,尝试从Resource字段获取")
|
||||
|
||||
// 在开发环境下,我们可以尝试解析Resource中的数据
|
||||
if notify.Resource.Ciphertext != "" {
|
||||
// 这里可以添加解密逻辑,但在开发环境下我们先跳过
|
||||
logger.Info("Resource中有加密数据,但开发环境暂不解密")
|
||||
}
|
||||
|
||||
// 如果无法获取订单号,我们可以从最近的订单中查找
|
||||
// 这是一个临时的开发环境解决方案
|
||||
logger.Warn("无法从回调数据中获取订单号,这可能是开发环境的模拟回调")
|
||||
return fmt.Errorf("无法从回调数据中获取订单号")
|
||||
}
|
||||
|
||||
if orderNo == "" {
|
||||
return fmt.Errorf("回调数据中缺少订单号")
|
||||
}
|
||||
|
||||
logger.Info("处理支付成功回调", "orderNo", orderNo)
|
||||
|
||||
// 查询订单
|
||||
order, err := s.orderRepo.GetOrderByWechatOutTradeNo(orderNo)
|
||||
if err != nil {
|
||||
logger.Error("根据微信订单号查询订单失败", "error", err, "wechatOutTradeNo", orderNo)
|
||||
return fmt.Errorf("订单不存在: %v", err)
|
||||
}
|
||||
|
||||
// 检查订单状态,避免重复处理
|
||||
if order.Status >= 2 {
|
||||
logger.Info("订单已经是已支付状态,跳过处理", "orderNo", order.OrderNo, "status", order.Status)
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("开始更新订单状态", "orderNo", order.OrderNo, "currentStatus", order.Status)
|
||||
|
||||
// 更新订单状态为已支付
|
||||
updates := map[string]interface{}{
|
||||
"status": 2, // 已支付
|
||||
"pay_status": 1, // 已支付
|
||||
"pay_time": time.Now(),
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
// 如果有微信交易号,也保存
|
||||
if transactionID != "" {
|
||||
updates["wechat_transaction_id"] = transactionID
|
||||
logger.Info("保存微信交易号", "transactionID", transactionID)
|
||||
}
|
||||
|
||||
err = s.orderRepo.UpdateByOrderNo(order.OrderNo, updates)
|
||||
if err != nil {
|
||||
logger.Error("更新订单支付状态失败", "error", err, "orderNo", order.OrderNo)
|
||||
return fmt.Errorf("更新订单状态失败: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("订单支付状态更新成功", "orderNo", order.OrderNo, "newStatus", 2)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProcessPaymentSuccessByOrderNo 根据订单号手动处理支付成功(用于测试)
|
||||
func (s *WeChatPayService) ProcessPaymentSuccessByOrderNo(ctx context.Context, orderNo string) error {
|
||||
logger.Info("手动处理支付成功", "orderNo", orderNo)
|
||||
|
||||
// 查询订单
|
||||
order, err := s.orderRepo.GetByOrderNo(orderNo)
|
||||
if err != nil {
|
||||
logger.Error("查询订单失败", "error", err, "orderNo", orderNo)
|
||||
return fmt.Errorf("订单不存在: %v", err)
|
||||
}
|
||||
|
||||
// 检查订单状态,避免重复处理
|
||||
if order.Status >= 2 {
|
||||
logger.Info("订单已经是已支付状态,跳过处理", "orderNo", order.OrderNo, "status", order.Status)
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("开始更新订单状态", "orderNo", order.OrderNo, "currentStatus", order.Status)
|
||||
|
||||
// 更新订单状态为已支付
|
||||
updates := map[string]interface{}{
|
||||
"status": 2, // 已支付
|
||||
"pay_status": 1, // 已支付
|
||||
"pay_time": time.Now(),
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
err = s.orderRepo.UpdateByOrderNo(order.OrderNo, updates)
|
||||
if err != nil {
|
||||
logger.Error("更新订单支付状态失败", "error", err, "orderNo", order.OrderNo)
|
||||
return fmt.Errorf("更新订单状态失败: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("订单支付状态更新成功", "orderNo", order.OrderNo, "newStatus", 2)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
// CreateRefund 创建微信退款
|
||||
func (s *WeChatPayService) CreateRefund(ctx context.Context, refundRecord *model.Refund, order *model.Order) (*WeChatRefundResponse, error) {
|
||||
logger.Info("开始创建微信退款",
|
||||
"refundNo", refundRecord.RefundNo,
|
||||
"orderNo", order.OrderNo,
|
||||
"refundAmount", refundRecord.RefundAmount,
|
||||
"hasClient", s.client != nil)
|
||||
|
||||
// 如果没有客户端(开发环境),使用模拟数据
|
||||
if s.client == nil {
|
||||
logger.Warn("开发环境下使用模拟退款数据")
|
||||
return s.createMockRefund(refundRecord, order)
|
||||
}
|
||||
|
||||
// 构建退款请求
|
||||
req := refunddomestic.CreateRequest{
|
||||
OutTradeNo: core.String(order.WechatOutTradeNo),
|
||||
OutRefundNo: core.String(refundRecord.WechatOutRefundNo),
|
||||
Reason: core.String(refundRecord.RefundReason),
|
||||
FundsAccount: (*refunddomestic.ReqFundsAccount)(core.String("AVAILABLE")), // 可用余额退款
|
||||
Amount: &refunddomestic.AmountReq{
|
||||
Refund: core.Int64(int64(refundRecord.RefundAmount)),
|
||||
Total: core.Int64(int64(order.TotalAmount)),
|
||||
Currency: core.String("CNY"),
|
||||
},
|
||||
}
|
||||
|
||||
// 只有当RefundNotifyURL不为空时才设置NotifyUrl
|
||||
if s.config.RefundNotifyURL != "" {
|
||||
req.NotifyUrl = core.String(s.config.RefundNotifyURL)
|
||||
}
|
||||
|
||||
// 如果有微信交易号,优先使用
|
||||
if order.WechatTransactionID != "" {
|
||||
req.TransactionId = core.String(order.WechatTransactionID)
|
||||
req.OutTradeNo = nil // 使用微信交易号时,不需要商户订单号
|
||||
}
|
||||
|
||||
// 调用微信退款API
|
||||
resp, result, err := s.refundSvc.Create(ctx, req)
|
||||
if err != nil {
|
||||
log.Printf("call CreateRefund err:%s", err)
|
||||
logger.Error("创建微信退款失败", "error", err, "refundNo", refundRecord.RefundNo)
|
||||
return nil, fmt.Errorf("创建微信退款失败: %v", err)
|
||||
}
|
||||
|
||||
if result.Response.StatusCode != 200 {
|
||||
log.Printf("CreateRefund status=%d", result.Response.StatusCode)
|
||||
return nil, fmt.Errorf("微信退款请求失败,状态码: %d", result.Response.StatusCode)
|
||||
}
|
||||
|
||||
log.Printf("CreateRefund success, refund_id=%s", *resp.RefundId)
|
||||
logger.Info("微信退款API响应",
|
||||
"refundId", *resp.RefundId,
|
||||
"refundNo", refundRecord.RefundNo,
|
||||
"status", *resp.Status)
|
||||
|
||||
return &WeChatRefundResponse{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: map[string]interface{}{
|
||||
"refund_id": *resp.RefundId,
|
||||
"out_refund_no": *resp.OutRefundNo,
|
||||
"transaction_id": getStringValue(resp.TransactionId),
|
||||
"out_trade_no": getStringValue(resp.OutTradeNo),
|
||||
"channel": getChannelValue(resp.Channel),
|
||||
"user_received_account": getStringValue(resp.UserReceivedAccount),
|
||||
"success_time": getTimeValue(resp.SuccessTime),
|
||||
"create_time": getTimeValue(resp.CreateTime),
|
||||
"status": getStatusValue(*resp.Status),
|
||||
"funds_account": getFundsAccountValue(resp.FundsAccount),
|
||||
"amount": map[string]interface{}{
|
||||
"total": *resp.Amount.Total,
|
||||
"refund": *resp.Amount.Refund,
|
||||
"payer_total": getInt64Value(resp.Amount.PayerTotal),
|
||||
"payer_refund": getInt64Value(resp.Amount.PayerRefund),
|
||||
"settlement_refund": getInt64Value(resp.Amount.SettlementRefund),
|
||||
"settlement_total": getInt64Value(resp.Amount.SettlementTotal),
|
||||
"discount_refund": getInt64Value(resp.Amount.DiscountRefund),
|
||||
"currency": *resp.Amount.Currency,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createMockRefund 创建模拟退款数据(开发环境使用)
|
||||
func (s *WeChatPayService) createMockRefund(refundRecord *model.Refund, order *model.Order) (*WeChatRefundResponse, error) {
|
||||
logger.Info("创建模拟退款数据", "refundNo", refundRecord.RefundNo)
|
||||
|
||||
// 生成模拟的微信退款ID
|
||||
mockRefundID := fmt.Sprintf("mock_refund_%d", time.Now().Unix())
|
||||
|
||||
return &WeChatRefundResponse{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: map[string]interface{}{
|
||||
"refund_id": mockRefundID,
|
||||
"out_refund_no": refundRecord.WechatOutRefundNo,
|
||||
"transaction_id": order.WechatTransactionID,
|
||||
"out_trade_no": order.WechatOutTradeNo,
|
||||
"channel": "ORIGINAL",
|
||||
"user_received_account": "招商银行信用卡0403",
|
||||
"success_time": time.Now().Format("2006-01-02T15:04:05+08:00"),
|
||||
"create_time": time.Now().Format("2006-01-02T15:04:05+08:00"),
|
||||
"status": "SUCCESS",
|
||||
"funds_account": "AVAILABLE",
|
||||
"amount": map[string]interface{}{
|
||||
"total": int64(order.TotalAmount),
|
||||
"refund": int64(refundRecord.RefundAmount),
|
||||
"payer_total": int64(order.TotalAmount),
|
||||
"payer_refund": int64(refundRecord.RefundAmount),
|
||||
"settlement_refund": int64(refundRecord.RefundAmount),
|
||||
"settlement_total": int64(order.TotalAmount),
|
||||
"discount_refund": int64(0),
|
||||
"currency": "CNY",
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueryRefund 查询微信退款状态
|
||||
func (s *WeChatPayService) QueryRefund(ctx context.Context, outRefundNo string) (*model.Refund, error) {
|
||||
logger.Info("查询微信退款状态", "outRefundNo", outRefundNo)
|
||||
|
||||
// 如果没有客户端(开发环境),返回模拟数据
|
||||
if s.client == nil {
|
||||
logger.Warn("开发环境下使用模拟退款查询")
|
||||
return s.queryMockRefund(outRefundNo)
|
||||
}
|
||||
|
||||
// 构建查询请求
|
||||
req := refunddomestic.QueryByOutRefundNoRequest{
|
||||
OutRefundNo: core.String(outRefundNo),
|
||||
}
|
||||
|
||||
// 调用微信查询退款API
|
||||
resp, result, err := s.refundSvc.QueryByOutRefundNo(ctx, req)
|
||||
if err != nil {
|
||||
log.Printf("call QueryRefund err:%s", err)
|
||||
logger.Error("查询微信退款失败", "error", err, "outRefundNo", outRefundNo)
|
||||
return nil, fmt.Errorf("查询微信退款失败: %v", err)
|
||||
}
|
||||
|
||||
if result.Response.StatusCode != 200 {
|
||||
log.Printf("QueryRefund status=%d", result.Response.StatusCode)
|
||||
return nil, fmt.Errorf("查询微信退款失败,状态码: %d", result.Response.StatusCode)
|
||||
}
|
||||
|
||||
log.Printf("QueryRefund success, resp=%+v", resp)
|
||||
logger.Info("查询微信退款成功",
|
||||
"outRefundNo", outRefundNo,
|
||||
"refundId", *resp.RefundId,
|
||||
"status", *resp.Status)
|
||||
|
||||
// 构建返回的退款记录(这里只是示例,实际应该从数据库获取完整记录)
|
||||
refundRecord := &model.Refund{
|
||||
WechatRefundID: *resp.RefundId,
|
||||
WechatOutRefundNo: *resp.OutRefundNo,
|
||||
WechatRefundStatus: getStatusValue(*resp.Status),
|
||||
WechatUserReceivedAccount: getStringValue(resp.UserReceivedAccount),
|
||||
WechatRefundAccount: getFundsAccountValue(resp.FundsAccount),
|
||||
}
|
||||
|
||||
// 如果退款成功,设置成功时间
|
||||
if getStatusValue(*resp.Status) == "SUCCESS" && resp.SuccessTime != nil {
|
||||
successTime, err := time.Parse("2006-01-02T15:04:05+08:00", getTimeValue(resp.SuccessTime))
|
||||
if err == nil {
|
||||
refundRecord.WechatSuccessTime = &successTime
|
||||
}
|
||||
}
|
||||
|
||||
return refundRecord, nil
|
||||
}
|
||||
|
||||
// queryMockRefund 查询模拟退款数据(开发环境使用)
|
||||
func (s *WeChatPayService) queryMockRefund(outRefundNo string) (*model.Refund, error) {
|
||||
logger.Info("查询模拟退款数据", "outRefundNo", outRefundNo)
|
||||
|
||||
now := time.Now()
|
||||
return &model.Refund{
|
||||
WechatRefundID: fmt.Sprintf("mock_refund_%d", now.Unix()),
|
||||
WechatOutRefundNo: outRefundNo,
|
||||
WechatRefundStatus: "SUCCESS",
|
||||
WechatUserReceivedAccount: "招商银行信用卡0403",
|
||||
WechatRefundAccount: "AVAILABLE",
|
||||
WechatSuccessTime: &now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HandleRefundNotify 处理微信退款回调
|
||||
func (s *WeChatPayService) HandleRefundNotify(ctx context.Context, body []byte, headers map[string]string) (*WeChatRefundNotify, error) {
|
||||
// 解析回调数据
|
||||
var notify WeChatRefundNotify
|
||||
if err := json.Unmarshal(body, ¬ify); err != nil {
|
||||
return nil, fmt.Errorf("解析退款回调数据失败: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("收到微信退款回调",
|
||||
"eventType", notify.EventType,
|
||||
"id", notify.ID,
|
||||
"algorithm", notify.Resource.Algorithm)
|
||||
|
||||
// 解密resource中的数据
|
||||
if notify.Resource.Ciphertext != "" {
|
||||
// 使用AEAD_AES_256_GCM算法解密
|
||||
decryptedData, err := s.decryptNotifyResource(
|
||||
notify.Resource.Ciphertext,
|
||||
notify.Resource.Nonce,
|
||||
notify.Resource.AssociatedData,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Error("解密退款回调数据失败", "error", err)
|
||||
return nil, fmt.Errorf("解密退款回调数据失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析解密后的JSON数据
|
||||
var refundData WeChatRefundNotifyData
|
||||
if err := json.Unmarshal(decryptedData, &refundData); err != nil {
|
||||
logger.Error("解析解密数据失败", "error", err, "data", string(decryptedData))
|
||||
return nil, fmt.Errorf("解析解密数据失败: %v", err)
|
||||
}
|
||||
|
||||
notify.DecryptedData = &refundData
|
||||
logger.Info("成功解密退款回调数据",
|
||||
"outRefundNo", refundData.OutRefundNo,
|
||||
"refundId", refundData.RefundId,
|
||||
"refundStatus", refundData.RefundStatus)
|
||||
}
|
||||
|
||||
return ¬ify, nil
|
||||
}
|
||||
|
||||
// ProcessRefundSuccess 处理退款成功回调
|
||||
func (s *WeChatPayService) ProcessRefundSuccess(ctx context.Context, notify *WeChatRefundNotify) error {
|
||||
if notify.EventType != "REFUND.SUCCESS" {
|
||||
return fmt.Errorf("不是退款成功回调: %s", notify.EventType)
|
||||
}
|
||||
|
||||
logger.Info("开始处理退款成功回调", "eventType", notify.EventType)
|
||||
|
||||
// 解析回调数据中的退款信息
|
||||
var outRefundNo string
|
||||
var refundID string
|
||||
|
||||
// 如果有解密数据,从中获取退款单号
|
||||
if notify.DecryptedData != nil {
|
||||
outRefundNo = notify.DecryptedData.OutRefundNo
|
||||
refundID = notify.DecryptedData.RefundId
|
||||
logger.Info("从解密数据中获取退款信息", "outRefundNo", outRefundNo, "refundId", refundID)
|
||||
} else {
|
||||
logger.Warn("退款回调数据中没有解密数据")
|
||||
return fmt.Errorf("无法从回调数据中获取退款单号")
|
||||
}
|
||||
|
||||
if outRefundNo == "" {
|
||||
return fmt.Errorf("回调数据中缺少退款单号")
|
||||
}
|
||||
|
||||
logger.Info("处理退款成功回调", "outRefundNo", outRefundNo)
|
||||
|
||||
// 这里应该调用退款服务来更新退款状态
|
||||
// 由于这是在微信支付服务中,我们只记录日志,实际更新由退款服务处理
|
||||
logger.Info("退款成功回调处理完成", "outRefundNo", outRefundNo, "refundId", refundID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func getStringValue(ptr *string) string {
|
||||
if ptr == nil {
|
||||
return ""
|
||||
}
|
||||
return *ptr
|
||||
}
|
||||
|
||||
func getInt64Value(ptr *int64) int64 {
|
||||
if ptr == nil {
|
||||
return 0
|
||||
}
|
||||
return *ptr
|
||||
}
|
||||
|
||||
func getChannelValue(ptr *refunddomestic.Channel) string {
|
||||
if ptr == nil {
|
||||
return ""
|
||||
}
|
||||
return string(*ptr)
|
||||
}
|
||||
|
||||
func getFundsAccountValue(ptr *refunddomestic.FundsAccount) string {
|
||||
if ptr == nil {
|
||||
return ""
|
||||
}
|
||||
return string(*ptr)
|
||||
}
|
||||
|
||||
func getTimeValue(ptr *time.Time) string {
|
||||
if ptr == nil {
|
||||
return ""
|
||||
}
|
||||
return ptr.Format("2006-01-02T15:04:05+08:00")
|
||||
}
|
||||
|
||||
func getStatusValue(status refunddomestic.Status) string {
|
||||
return string(status)
|
||||
}
|
||||
|
||||
// generateNonceStr 生成随机字符串用于微信支付
|
||||
func generateNonceStr() string {
|
||||
return utils.GenerateRandomString(32)
|
||||
}
|
||||
|
||||
// convertWeChatPayStatus 将微信支付状态转换为订单状态
|
||||
func convertWeChatPayStatus(wechatStatus string) int {
|
||||
switch wechatStatus {
|
||||
case "SUCCESS":
|
||||
return model.OrderStatusPaid
|
||||
case "REFUND":
|
||||
return model.OrderStatusRefunded
|
||||
case "NOTPAY":
|
||||
return model.OrderStatusPending
|
||||
case "CLOSED":
|
||||
return model.OrderStatusCancelled
|
||||
case "REVOKED":
|
||||
return model.OrderStatusCancelled
|
||||
case "USERPAYING":
|
||||
return model.OrderStatusPending
|
||||
case "PAYERROR":
|
||||
return model.OrderStatusCancelled
|
||||
default:
|
||||
return model.OrderStatusPending
|
||||
}
|
||||
}
|
||||
|
||||
// 微信支付相关数据结构
|
||||
type MiniProgramPayParams struct {
|
||||
AppID string `json:"appId"`
|
||||
TimeStamp string `json:"timeStamp"`
|
||||
NonceStr string `json:"nonceStr"`
|
||||
Package string `json:"package"`
|
||||
SignType string `json:"signType"`
|
||||
PaySign string `json:"paySign"`
|
||||
}
|
||||
|
||||
type WeChatPayResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
|
||||
type WeChatPayNotify struct {
|
||||
ID string `json:"id"`
|
||||
CreateTime string `json:"create_time"`
|
||||
ResourceType string `json:"resource_type"`
|
||||
EventType string `json:"event_type"`
|
||||
Summary string `json:"summary"`
|
||||
Resource struct {
|
||||
OriginalType string `json:"original_type"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
Ciphertext string `json:"ciphertext"`
|
||||
AssociatedData string `json:"associated_data"`
|
||||
Nonce string `json:"nonce"`
|
||||
} `json:"resource"`
|
||||
DecryptedData *WeChatPayNotifyData `json:"decrypted_data,omitempty"`
|
||||
}
|
||||
|
||||
type WeChatPayNotifyData struct {
|
||||
MchID string `json:"mchid"`
|
||||
AppID string `json:"appid"`
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
TransactionID string `json:"transaction_id"`
|
||||
TradeType string `json:"trade_type"`
|
||||
TradeState string `json:"trade_state"`
|
||||
BankType string `json:"bank_type"`
|
||||
SuccessTime string `json:"success_time"`
|
||||
Payer struct {
|
||||
OpenID string `json:"openid"`
|
||||
} `json:"payer"`
|
||||
Amount struct {
|
||||
Total int `json:"total"`
|
||||
PayerTotal int `json:"payer_total"`
|
||||
Currency string `json:"currency"`
|
||||
PayerCurrency string `json:"payer_currency"`
|
||||
} `json:"amount"`
|
||||
}
|
||||
|
||||
type WeChatRefundResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
|
||||
type WeChatRefundNotify struct {
|
||||
ID string `json:"id"`
|
||||
CreateTime string `json:"create_time"`
|
||||
ResourceType string `json:"resource_type"`
|
||||
EventType string `json:"event_type"`
|
||||
Summary string `json:"summary"`
|
||||
Resource struct {
|
||||
OriginalType string `json:"original_type"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
Ciphertext string `json:"ciphertext"`
|
||||
AssociatedData string `json:"associated_data"`
|
||||
Nonce string `json:"nonce"`
|
||||
} `json:"resource"`
|
||||
DecryptedData *WeChatRefundNotifyData `json:"decrypted_data,omitempty"`
|
||||
}
|
||||
|
||||
type WeChatRefundNotifyData struct {
|
||||
MchID string `json:"mchid"`
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
TransactionID string `json:"transaction_id"`
|
||||
OutRefundNo string `json:"out_refund_no"`
|
||||
RefundId string `json:"refund_id"`
|
||||
RefundStatus string `json:"refund_status"`
|
||||
SuccessTime string `json:"success_time"`
|
||||
UserReceivedAccount string `json:"user_received_account"`
|
||||
Amount struct {
|
||||
Total int `json:"total"`
|
||||
Refund int `json:"refund"`
|
||||
PayerTotal int `json:"payer_total"`
|
||||
PayerRefund int `json:"payer_refund"`
|
||||
} `json:"amount"`
|
||||
}
|
||||
Reference in New Issue
Block a user