Files
2025-11-28 15:18:10 +08:00

158 lines
4.6 KiB
Go

package repository
import (
"dianshang/internal/model"
"gorm.io/gorm"
)
// UserRepository 用户仓储
type UserRepository struct {
db *gorm.DB
}
// NewUserRepository 创建用户仓储
func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
}
// Create 创建用户
func (r *UserRepository) Create(user *model.User) error {
return r.db.Create(user).Error
}
// GetByID 根据ID获取用户
func (r *UserRepository) GetByID(id uint) (*model.User, error) {
var user model.User
err := r.db.First(&user, id).Error
return &user, err
}
// GetByOpenID 根据OpenID获取用户
func (r *UserRepository) GetByOpenID(openID string) (*model.User, error) {
var user model.User
err := r.db.Where("open_id = ?", openID).First(&user).Error
return &user, err
}
// GetByEmail 根据邮箱获取用户
func (r *UserRepository) GetByEmail(email string) (*model.User, error) {
var user model.User
err := r.db.Where("email = ?", email).First(&user).Error
return &user, err
}
// Update 更新用户
func (r *UserRepository) Update(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.User{}).Where("id = ?", id).Updates(updates).Error
}
// Delete 删除用户(软删除)
func (r *UserRepository) Delete(id uint) error {
return r.db.Delete(&model.User{}, id).Error
}
// BatchDelete 批量删除用户(软删除)
func (r *UserRepository) BatchDelete(ids []uint) error {
if len(ids) == 0 {
return nil
}
return r.db.Delete(&model.User{}, ids).Error
}
// GetList 获取用户列表
func (r *UserRepository) GetList(offset, limit int, conditions map[string]interface{}) ([]model.User, int64, error) {
var users []model.User
var total int64
query := r.db.Model(&model.User{})
// 添加查询条件
for key, value := range conditions {
query = query.Where(key+" = ?", value)
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
err := query.Offset(offset).Limit(limit).Find(&users).Error
return users, total, err
}
// GetAddresses 获取用户地址列表
func (r *UserRepository) GetAddresses(userID uint) ([]model.UserAddress, error) {
var addresses []model.UserAddress
err := r.db.Where("user_id = ?", userID).Order("is_default DESC, created_at DESC").Find(&addresses).Error
return addresses, err
}
// CreateAddress 创建用户地址
func (r *UserRepository) CreateAddress(address *model.UserAddress) error {
return r.db.Create(address).Error
}
// GetAddressByID 根据ID获取地址
func (r *UserRepository) GetAddressByID(id uint) (*model.UserAddress, error) {
var address model.UserAddress
err := r.db.First(&address, id).Error
return &address, err
}
// UpdateAddress 更新用户地址
func (r *UserRepository) UpdateAddress(id uint, updates map[string]interface{}) error {
return r.db.Model(&model.UserAddress{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteAddress 删除用户地址
func (r *UserRepository) DeleteAddress(id uint) error {
return r.db.Delete(&model.UserAddress{}, id).Error
}
// ClearDefaultAddress 清除用户的默认地址
func (r *UserRepository) ClearDefaultAddress(userID uint) error {
return r.db.Model(&model.UserAddress{}).Where("user_id = ?", userID).Update("is_default", 0).Error
}
// GetDefaultAddress 获取用户默认地址
func (r *UserRepository) GetDefaultAddress(userID uint) (*model.UserAddress, error) {
var address model.UserAddress
err := r.db.Where("user_id = ? AND is_default = 1", userID).First(&address).Error
return &address, err
}
// GetFavorites 获取用户收藏列表
func (r *UserRepository) GetFavorites(userID uint, offset, limit int) ([]model.UserFavorite, int64, error) {
var favorites []model.UserFavorite
var total int64
query := r.db.Model(&model.UserFavorite{}).Where("user_id = ?", userID)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表,预加载商品信息
err := query.Preload("Product").Offset(offset).Limit(limit).Order("created_at DESC").Find(&favorites).Error
return favorites, total, err
}
// CreateFavorite 创建收藏
func (r *UserRepository) CreateFavorite(favorite *model.UserFavorite) error {
return r.db.Create(favorite).Error
}
// DeleteFavorite 删除收藏
func (r *UserRepository) DeleteFavorite(userID, productID uint) error {
return r.db.Where("user_id = ? AND product_id = ?", userID, productID).Delete(&model.UserFavorite{}).Error
}
// IsFavorite 检查是否已收藏
func (r *UserRepository) IsFavorite(userID, productID uint) bool {
var count int64
r.db.Model(&model.UserFavorite{}).Where("user_id = ? AND product_id = ?", userID, productID).Count(&count)
return count > 0
}