package services import ( "errors" "fmt" "time" "gorm.io/gorm" "github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/common" "github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/models" "github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/utils" ) // UserService 用户服务 type UserService struct { db *gorm.DB } // NewUserService 创建用户服务实例 func NewUserService(db *gorm.DB) *UserService { return &UserService{db: db} } // CreateUser 创建用户 func (s *UserService) CreateUser(username, email, password string) (*models.User, error) { // 检查用户名是否已存在 var existingUser models.User if err := s.db.Where("username = ? OR email = ?", username, email).First(&existingUser).Error; err == nil { if existingUser.Username == username { return nil, common.ErrUsernameExists } if existingUser.Email == email { return nil, common.ErrEmailExists } } else if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } // 生成密码哈希 passwordHash, err := utils.HashPassword(password) if err != nil { return nil, err } // 创建用户 user := &models.User{ Username: username, Email: email, PasswordHash: passwordHash, Status: "active", Timezone: "Asia/Shanghai", Language: "zh-CN", CreatedAt: time.Now(), UpdatedAt: time.Now(), } if err := s.db.Create(user).Error; err != nil { return nil, err } // 创建用户偏好设置 preference := &models.UserPreference{ UserID: user.ID, DailyGoal: 50, WeeklyGoal: 350, ReminderEnabled: true, DifficultyLevel: "beginner", LearningMode: "casual", CreatedAt: time.Now(), UpdatedAt: time.Now(), } if err := s.db.Create(preference).Error; err != nil { // 如果创建偏好设置失败,记录日志但不影响用户创建 // 可以在这里添加日志记录 } return user, nil } // GetUserByID 根据ID获取用户 func (s *UserService) GetUserByID(userID int64) (*models.User, error) { var user models.User if err := s.db.Preload("Preferences").Preload("SocialLinks").Where("id = ?", userID).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, common.ErrUserNotFound } return nil, err } return &user, nil } // GetUserByEmail 根据邮箱获取用户 func (s *UserService) GetUserByEmail(email string) (*models.User, error) { var user models.User if err := s.db.Where("email = ?", email).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, common.ErrUserNotFound } return nil, err } return &user, nil } // GetUserByUsername 根据用户名获取用户 func (s *UserService) GetUserByUsername(username string) (*models.User, error) { var user models.User if err := s.db.Where("username = ?", username).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, common.ErrUserNotFound } return nil, err } return &user, nil } // UpdateUser 更新用户信息 func (s *UserService) UpdateUser(userID int64, updates map[string]interface{}) (*models.User, error) { // 检查用户是否存在 var user models.User if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, common.ErrUserNotFound } return nil, err } // 如果更新邮箱,检查邮箱是否已被其他用户使用 if email, ok := updates["email"]; ok { var existingUser models.User if err := s.db.Where("email = ? AND id != ?", email, userID).First(&existingUser).Error; err == nil { return nil, common.ErrEmailExists } else if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } } // 如果更新用户名,检查用户名是否已被其他用户使用 if username, ok := updates["username"]; ok { var existingUser models.User if err := s.db.Where("username = ? AND id != ?", username, userID).First(&existingUser).Error; err == nil { return nil, common.ErrUsernameExists } else if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } } // 更新时间戳 updates["updated_at"] = time.Now() // 执行更新 if err := s.db.Model(&user).Updates(updates).Error; err != nil { return nil, err } // 重新获取更新后的用户信息 return s.GetUserByID(userID) } // UpdatePassword 更新用户密码 func (s *UserService) UpdatePassword(userID int64, oldPassword, newPassword string) error { // 获取用户 var user models.User if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return common.ErrUserNotFound } return err } // 验证旧密码 if !utils.CheckPasswordHash(oldPassword, user.PasswordHash) { return common.ErrInvalidPassword } // 生成新密码哈希 newPasswordHash, err := utils.HashPassword(newPassword) if err != nil { return err } // 更新密码 return s.db.Model(&user).Updates(map[string]interface{}{ "password_hash": newPasswordHash, "updated_at": time.Now(), }).Error } // UpdateLoginInfo 更新登录信息 func (s *UserService) UpdateLoginInfo(userID int64, loginIP string) error { now := time.Now() return s.db.Model(&models.User{}).Where("id = ?", userID).Updates(map[string]interface{}{ "last_login_at": &now, "last_login_ip": loginIP, "login_count": gorm.Expr("login_count + 1"), "updated_at": now, }).Error } // VerifyPassword 验证密码 func (s *UserService) VerifyPassword(userID int64, password string) error { var user models.User if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return common.ErrUserNotFound } return err } if !utils.CheckPasswordHash(password, user.PasswordHash) { return common.ErrInvalidPassword } return nil } // DeleteUser 删除用户 func (s *UserService) DeleteUser(userID int64) error { // 检查用户是否存在 var user models.User if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return common.ErrUserNotFound } return err } // 软删除用户 return s.db.Delete(&user).Error } // GetUserPreferences 获取用户偏好设置 func (s *UserService) GetUserPreferences(userID int64) (*models.UserPreference, error) { var preference models.UserPreference if err := s.db.Where("user_id = ?", userID).First(&preference).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, common.ErrUserNotFound } return nil, err } return &preference, nil } // UpdateUserPreferences 更新用户偏好设置 func (s *UserService) UpdateUserPreferences(userID int64, updates map[string]interface{}) (*models.UserPreference, error) { // 检查偏好设置是否存在 var preference models.UserPreference if err := s.db.Where("user_id = ?", userID).First(&preference).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, common.ErrUserNotFound } return nil, err } // 更新时间戳 updates["updated_at"] = time.Now() // 执行更新 if err := s.db.Model(&preference).Updates(updates).Error; err != nil { return nil, err } // 重新获取更新后的偏好设置 return s.GetUserPreferences(userID) } // GetUserLearningProgress 获取用户学习进度 func (s *UserService) GetUserLearningProgress(userID string, page, limit int) ([]map[string]interface{}, int64, error) { // 初始化为空切片而不是nil,避免JSON序列化为null progressList := []map[string]interface{}{} var total int64 // 查询用户在各个词汇书的学习进度 query := ` SELECT vb.id, vb.name as title, vb.description as category, vb.level, COUNT(DISTINCT vbw.vocabulary_id) as total_words, COUNT(DISTINCT CASE WHEN uwp.status IN ('learning', 'reviewing', 'mastered') THEN uwp.vocabulary_id END) as learned_words, MAX(uwp.last_studied_at) as last_study_date FROM ai_vocabulary_books vb LEFT JOIN ai_vocabulary_book_words vbw ON vbw.book_id = vb.id LEFT JOIN ai_user_word_progress uwp ON CAST(uwp.vocabulary_id AS UNSIGNED) = CAST(vbw.vocabulary_id AS UNSIGNED) AND uwp.user_id = ? WHERE vb.is_system = 1 GROUP BY vb.id, vb.name, vb.description, vb.level HAVING total_words > 0 ORDER BY last_study_date DESC ` // 获取总数 countQuery := ` SELECT COUNT(*) FROM ( SELECT vb.id FROM ai_vocabulary_books vb LEFT JOIN ai_vocabulary_book_words vbw ON vbw.book_id = vb.id WHERE vb.is_system = 1 GROUP BY vb.id HAVING COUNT(DISTINCT vbw.vocabulary_id) > 0 ) as subquery ` if err := s.db.Raw(countQuery).Scan(&total).Error; err != nil { return nil, 0, err } // 分页查询 offset := (page - 1) * limit rows, err := s.db.Raw(query+" LIMIT ? OFFSET ?", userID, limit, offset).Rows() if err != nil { return nil, 0, err } defer rows.Close() for rows.Next() { var ( id int64 title string category string level string totalWords int learnedWords int lastStudyDate *time.Time ) if err := rows.Scan(&id, &title, &category, &level, &totalWords, &learnedWords, &lastStudyDate); err != nil { continue } progress := float64(0) if totalWords > 0 { progress = float64(learnedWords) / float64(totalWords) } progressList = append(progressList, map[string]interface{}{ "id": fmt.Sprintf("%d", id), "title": title, "category": category, "level": level, "total_words": totalWords, "learned_words": learnedWords, "progress": progress, "last_study_date": lastStudyDate, }) } return progressList, total, nil }