Files
ai_english/serve/internal/services/user_service.go
2025-11-17 13:39:05 +08:00

349 lines
9.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package services
import (
"errors"
"fmt"
"time"
"gorm.io/gorm"
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/common"
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/models"
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/utils"
)
// UserService 用户服务
type UserService struct {
db *gorm.DB
}
// NewUserService 创建用户服务实例
func NewUserService(db *gorm.DB) *UserService {
return &UserService{db: db}
}
// CreateUser 创建用户
func (s *UserService) CreateUser(username, email, password string) (*models.User, error) {
// 检查用户名是否已存在
var existingUser models.User
if err := s.db.Where("username = ? OR email = ?", username, email).First(&existingUser).Error; err == nil {
if existingUser.Username == username {
return nil, common.ErrUsernameExists
}
if existingUser.Email == email {
return nil, common.ErrEmailExists
}
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
// 生成密码哈希
passwordHash, err := utils.HashPassword(password)
if err != nil {
return nil, err
}
// 创建用户
user := &models.User{
Username: username,
Email: email,
PasswordHash: passwordHash,
Status: "active",
Timezone: "Asia/Shanghai",
Language: "zh-CN",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.db.Create(user).Error; err != nil {
return nil, err
}
// 创建用户偏好设置
preference := &models.UserPreference{
UserID: user.ID,
DailyGoal: 50,
WeeklyGoal: 350,
ReminderEnabled: true,
DifficultyLevel: "beginner",
LearningMode: "casual",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.db.Create(preference).Error; err != nil {
// 如果创建偏好设置失败,记录日志但不影响用户创建
// 可以在这里添加日志记录
}
return user, nil
}
// GetUserByID 根据ID获取用户
func (s *UserService) GetUserByID(userID int64) (*models.User, error) {
var user models.User
if err := s.db.Preload("Preferences").Preload("SocialLinks").Where("id = ?", userID).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, common.ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// GetUserByEmail 根据邮箱获取用户
func (s *UserService) GetUserByEmail(email string) (*models.User, error) {
var user models.User
if err := s.db.Where("email = ?", email).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, common.ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// GetUserByUsername 根据用户名获取用户
func (s *UserService) GetUserByUsername(username string) (*models.User, error) {
var user models.User
if err := s.db.Where("username = ?", username).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, common.ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// UpdateUser 更新用户信息
func (s *UserService) UpdateUser(userID int64, updates map[string]interface{}) (*models.User, error) {
// 检查用户是否存在
var user models.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, common.ErrUserNotFound
}
return nil, err
}
// 如果更新邮箱,检查邮箱是否已被其他用户使用
if email, ok := updates["email"]; ok {
var existingUser models.User
if err := s.db.Where("email = ? AND id != ?", email, userID).First(&existingUser).Error; err == nil {
return nil, common.ErrEmailExists
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
}
// 如果更新用户名,检查用户名是否已被其他用户使用
if username, ok := updates["username"]; ok {
var existingUser models.User
if err := s.db.Where("username = ? AND id != ?", username, userID).First(&existingUser).Error; err == nil {
return nil, common.ErrUsernameExists
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
}
// 更新时间戳
updates["updated_at"] = time.Now()
// 执行更新
if err := s.db.Model(&user).Updates(updates).Error; err != nil {
return nil, err
}
// 重新获取更新后的用户信息
return s.GetUserByID(userID)
}
// UpdatePassword 更新用户密码
func (s *UserService) UpdatePassword(userID int64, oldPassword, newPassword string) error {
// 获取用户
var user models.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return common.ErrUserNotFound
}
return err
}
// 验证旧密码
if !utils.CheckPasswordHash(oldPassword, user.PasswordHash) {
return common.ErrInvalidPassword
}
// 生成新密码哈希
newPasswordHash, err := utils.HashPassword(newPassword)
if err != nil {
return err
}
// 更新密码
return s.db.Model(&user).Updates(map[string]interface{}{
"password_hash": newPasswordHash,
"updated_at": time.Now(),
}).Error
}
// UpdateLoginInfo 更新登录信息
func (s *UserService) UpdateLoginInfo(userID int64, loginIP string) error {
now := time.Now()
return s.db.Model(&models.User{}).Where("id = ?", userID).Updates(map[string]interface{}{
"last_login_at": &now,
"last_login_ip": loginIP,
"login_count": gorm.Expr("login_count + 1"),
"updated_at": now,
}).Error
}
// VerifyPassword 验证密码
func (s *UserService) VerifyPassword(userID int64, password string) error {
var user models.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return common.ErrUserNotFound
}
return err
}
if !utils.CheckPasswordHash(password, user.PasswordHash) {
return common.ErrInvalidPassword
}
return nil
}
// DeleteUser 删除用户
func (s *UserService) DeleteUser(userID int64) error {
// 检查用户是否存在
var user models.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return common.ErrUserNotFound
}
return err
}
// 软删除用户
return s.db.Delete(&user).Error
}
// GetUserPreferences 获取用户偏好设置
func (s *UserService) GetUserPreferences(userID int64) (*models.UserPreference, error) {
var preference models.UserPreference
if err := s.db.Where("user_id = ?", userID).First(&preference).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, common.ErrUserNotFound
}
return nil, err
}
return &preference, nil
}
// UpdateUserPreferences 更新用户偏好设置
func (s *UserService) UpdateUserPreferences(userID int64, updates map[string]interface{}) (*models.UserPreference, error) {
// 检查偏好设置是否存在
var preference models.UserPreference
if err := s.db.Where("user_id = ?", userID).First(&preference).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, common.ErrUserNotFound
}
return nil, err
}
// 更新时间戳
updates["updated_at"] = time.Now()
// 执行更新
if err := s.db.Model(&preference).Updates(updates).Error; err != nil {
return nil, err
}
// 重新获取更新后的偏好设置
return s.GetUserPreferences(userID)
}
// GetUserLearningProgress 获取用户学习进度
func (s *UserService) GetUserLearningProgress(userID string, page, limit int) ([]map[string]interface{}, int64, error) {
// 初始化为空切片而不是nil避免JSON序列化为null
progressList := []map[string]interface{}{}
var total int64
// 查询用户在各个词汇书的学习进度
query := `
SELECT
vb.id,
vb.name as title,
vb.description as category,
vb.level,
COUNT(DISTINCT vbw.vocabulary_id) as total_words,
COUNT(DISTINCT CASE WHEN uwp.status IN ('learning', 'reviewing', 'mastered') THEN uwp.vocabulary_id END) as learned_words,
MAX(uwp.last_studied_at) as last_study_date
FROM ai_vocabulary_books vb
LEFT JOIN ai_vocabulary_book_words vbw ON vbw.book_id = vb.id
LEFT JOIN ai_user_word_progress uwp ON CAST(uwp.vocabulary_id AS UNSIGNED) = CAST(vbw.vocabulary_id AS UNSIGNED) AND uwp.user_id = ?
WHERE vb.is_system = 1
GROUP BY vb.id, vb.name, vb.description, vb.level
HAVING total_words > 0
ORDER BY last_study_date DESC
`
// 获取总数
countQuery := `
SELECT COUNT(*) FROM (
SELECT vb.id
FROM ai_vocabulary_books vb
LEFT JOIN ai_vocabulary_book_words vbw ON vbw.book_id = vb.id
WHERE vb.is_system = 1
GROUP BY vb.id
HAVING COUNT(DISTINCT vbw.vocabulary_id) > 0
) as subquery
`
if err := s.db.Raw(countQuery).Scan(&total).Error; err != nil {
return nil, 0, err
}
// 分页查询
offset := (page - 1) * limit
rows, err := s.db.Raw(query+" LIMIT ? OFFSET ?", userID, limit, offset).Rows()
if err != nil {
return nil, 0, err
}
defer rows.Close()
for rows.Next() {
var (
id int64
title string
category string
level string
totalWords int
learnedWords int
lastStudyDate *time.Time
)
if err := rows.Scan(&id, &title, &category, &level, &totalWords, &learnedWords, &lastStudyDate); err != nil {
continue
}
progress := float64(0)
if totalWords > 0 {
progress = float64(learnedWords) / float64(totalWords)
}
progressList = append(progressList, map[string]interface{}{
"id": fmt.Sprintf("%d", id),
"title": title,
"category": category,
"level": level,
"total_words": totalWords,
"learned_words": learnedWords,
"progress": progress,
"last_study_date": lastStudyDate,
})
}
return progressList, total, nil
}