357 lines
8.8 KiB
Go
357 lines
8.8 KiB
Go
|
|
package handlers
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"net/http"
|
|||
|
|
|
|||
|
|
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/common"
|
|||
|
|
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/middleware"
|
|||
|
|
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/models"
|
|||
|
|
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/services"
|
|||
|
|
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/utils"
|
|||
|
|
"github.com/gin-gonic/gin"
|
|||
|
|
"github.com/go-playground/validator/v10"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
type AuthHandler struct {
|
|||
|
|
userService *services.UserService
|
|||
|
|
validator *validator.Validate
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func NewAuthHandler(userService *services.UserService) *AuthHandler {
|
|||
|
|
return &AuthHandler{
|
|||
|
|
userService: userService,
|
|||
|
|
validator: validator.New(),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// RegisterRequest 注册请求结构
|
|||
|
|
type RegisterRequest struct {
|
|||
|
|
Username string `json:"username" validate:"required,min=3,max=20"`
|
|||
|
|
Email string `json:"email" validate:"required,email"`
|
|||
|
|
Password string `json:"password" validate:"required,min=6"`
|
|||
|
|
Nickname string `json:"nickname" validate:"required,min=1,max=50"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// LoginRequest 登录请求结构
|
|||
|
|
type LoginRequest struct {
|
|||
|
|
Account string `json:"account" validate:"required"` // 用户名或邮箱
|
|||
|
|
Password string `json:"password" validate:"required"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// RefreshTokenRequest 刷新令牌请求结构
|
|||
|
|
type RefreshTokenRequest struct {
|
|||
|
|
RefreshToken string `json:"refresh_token" validate:"required"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ChangePasswordRequest 修改密码请求结构
|
|||
|
|
type ChangePasswordRequest struct {
|
|||
|
|
OldPassword string `json:"old_password" validate:"required"`
|
|||
|
|
NewPassword string `json:"new_password" validate:"required,min=6"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AuthResponse 认证响应结构
|
|||
|
|
type AuthResponse struct {
|
|||
|
|
User *UserInfo `json:"user"`
|
|||
|
|
AccessToken string `json:"access_token"`
|
|||
|
|
RefreshToken string `json:"refresh_token"`
|
|||
|
|
ExpiresIn int64 `json:"expires_in"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UserInfo 用户信息结构
|
|||
|
|
type UserInfo struct {
|
|||
|
|
ID int64 `json:"id"`
|
|||
|
|
Username string `json:"username"`
|
|||
|
|
Email string `json:"email"`
|
|||
|
|
Nickname string `json:"nickname"`
|
|||
|
|
Avatar string `json:"avatar"`
|
|||
|
|
Level string `json:"level"`
|
|||
|
|
Status string `json:"status"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Register 用户注册
|
|||
|
|
func (h *AuthHandler) Register(c *gin.Context) {
|
|||
|
|
var req RegisterRequest
|
|||
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|||
|
|
common.ValidationErrorResponse(c, err)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证请求参数
|
|||
|
|
if err := h.validator.Struct(&req); err != nil {
|
|||
|
|
common.ValidationErrorResponse(c, err)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证邮箱格式
|
|||
|
|
if !utils.IsValidEmail(req.Email) {
|
|||
|
|
common.BadRequestResponse(c, "邮箱格式不正确")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证密码强度
|
|||
|
|
if !utils.IsStrongPassword(req.Password) {
|
|||
|
|
common.BadRequestResponse(c, "密码强度不足,至少8位且包含大小写字母、数字和特殊字符")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 创建用户
|
|||
|
|
user, err := h.userService.CreateUser(req.Username, req.Email, req.Password)
|
|||
|
|
if err != nil {
|
|||
|
|
if businessErr, ok := err.(*common.BusinessError); ok {
|
|||
|
|
common.ErrorResponse(c, http.StatusBadRequest, businessErr.Message)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
common.InternalServerErrorResponse(c, "用户创建失败")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 生成令牌
|
|||
|
|
accessToken, refreshToken, err := middleware.GenerateTokens(user.ID, user.Username, user.Email)
|
|||
|
|
if err != nil {
|
|||
|
|
common.InternalServerErrorResponse(c, "令牌生成失败")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 更新用户登录信息
|
|||
|
|
h.userService.UpdateLoginInfo(user.ID, utils.GetClientIP(c))
|
|||
|
|
|
|||
|
|
// 构造响应
|
|||
|
|
nickname := ""
|
|||
|
|
if user.Nickname != nil {
|
|||
|
|
nickname = *user.Nickname
|
|||
|
|
}
|
|||
|
|
avatar := ""
|
|||
|
|
if user.Avatar != nil {
|
|||
|
|
avatar = *user.Avatar
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
userInfo := &UserInfo{
|
|||
|
|
ID: user.ID,
|
|||
|
|
Username: user.Username,
|
|||
|
|
Email: user.Email,
|
|||
|
|
Nickname: nickname,
|
|||
|
|
Avatar: avatar,
|
|||
|
|
Level: "beginner",
|
|||
|
|
Status: user.Status,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
response := &AuthResponse{
|
|||
|
|
User: userInfo,
|
|||
|
|
AccessToken: accessToken,
|
|||
|
|
RefreshToken: refreshToken,
|
|||
|
|
ExpiresIn: 7200, // 2小时
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
common.SuccessResponse(c, response)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Login 用户登录
|
|||
|
|
func (h *AuthHandler) Login(c *gin.Context) {
|
|||
|
|
var req LoginRequest
|
|||
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|||
|
|
common.ValidationErrorResponse(c, err)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证请求参数
|
|||
|
|
if err := h.validator.Struct(&req); err != nil {
|
|||
|
|
common.ValidationErrorResponse(c, err)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 根据账号类型获取用户
|
|||
|
|
var user *models.User
|
|||
|
|
var err error
|
|||
|
|
|
|||
|
|
if utils.IsValidEmail(req.Account) {
|
|||
|
|
user, err = h.userService.GetUserByEmail(req.Account)
|
|||
|
|
} else {
|
|||
|
|
user, err = h.userService.GetUserByUsername(req.Account)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err != nil {
|
|||
|
|
if businessErr, ok := err.(*common.BusinessError); ok {
|
|||
|
|
common.ErrorResponse(c, http.StatusBadRequest, businessErr.Message)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
common.InternalServerErrorResponse(c, "用户查询失败")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证密码
|
|||
|
|
if !utils.CheckPasswordHash(req.Password, user.PasswordHash) {
|
|||
|
|
common.BadRequestResponse(c, "密码错误")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查用户状态
|
|||
|
|
if user.Status != "active" {
|
|||
|
|
common.BadRequestResponse(c, "用户已被禁用")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 生成令牌
|
|||
|
|
accessToken, refreshToken, err := middleware.GenerateTokens(user.ID, user.Username, user.Email)
|
|||
|
|
if err != nil {
|
|||
|
|
common.InternalServerErrorResponse(c, "令牌生成失败")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 更新用户登录信息
|
|||
|
|
h.userService.UpdateLoginInfo(user.ID, utils.GetClientIP(c))
|
|||
|
|
|
|||
|
|
// 构造响应
|
|||
|
|
nickname := ""
|
|||
|
|
if user.Nickname != nil {
|
|||
|
|
nickname = *user.Nickname
|
|||
|
|
}
|
|||
|
|
avatar := ""
|
|||
|
|
if user.Avatar != nil {
|
|||
|
|
avatar = *user.Avatar
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
userInfo := &UserInfo{
|
|||
|
|
ID: user.ID,
|
|||
|
|
Username: user.Username,
|
|||
|
|
Email: user.Email,
|
|||
|
|
Nickname: nickname,
|
|||
|
|
Avatar: avatar,
|
|||
|
|
Level: "beginner",
|
|||
|
|
Status: user.Status,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
response := &AuthResponse{
|
|||
|
|
User: userInfo,
|
|||
|
|
AccessToken: accessToken,
|
|||
|
|
RefreshToken: refreshToken,
|
|||
|
|
ExpiresIn: 7200, // 2小时
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
common.SuccessResponse(c, response)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// RefreshToken 刷新访问令牌
|
|||
|
|
func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
|||
|
|
var req RefreshTokenRequest
|
|||
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|||
|
|
common.ValidationErrorResponse(c, err)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证刷新令牌
|
|||
|
|
claims, err := middleware.ParseToken(req.RefreshToken)
|
|||
|
|
if err != nil {
|
|||
|
|
common.BadRequestResponse(c, "无效的刷新令牌")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查令牌类型
|
|||
|
|
if claims.Type != "refresh" {
|
|||
|
|
common.BadRequestResponse(c, "令牌类型错误")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 生成新的令牌
|
|||
|
|
accessToken, newRefreshToken, err := middleware.GenerateTokens(claims.UserID, claims.Username, claims.Email)
|
|||
|
|
if err != nil {
|
|||
|
|
common.InternalServerErrorResponse(c, "令牌生成失败")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
response := map[string]interface{}{
|
|||
|
|
"access_token": accessToken,
|
|||
|
|
"refresh_token": newRefreshToken,
|
|||
|
|
"expires_in": 7200,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
common.SuccessResponse(c, response)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Logout 用户登出
|
|||
|
|
func (h *AuthHandler) Logout(c *gin.Context) {
|
|||
|
|
// 这里可以实现令牌黑名单机制
|
|||
|
|
// 目前简单返回成功
|
|||
|
|
common.SuccessResponse(c, map[string]string{"message": "登出成功"})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetProfile 获取用户资料
|
|||
|
|
func (h *AuthHandler) GetProfile(c *gin.Context) {
|
|||
|
|
// 获取当前用户ID
|
|||
|
|
userID, exists := utils.GetUserIDFromContext(c)
|
|||
|
|
if !exists {
|
|||
|
|
common.BadRequestResponse(c, "请先登录")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
user, err := h.userService.GetUserByID(userID)
|
|||
|
|
if err != nil {
|
|||
|
|
if businessErr, ok := err.(*common.BusinessError); ok {
|
|||
|
|
common.ErrorResponse(c, http.StatusBadRequest, businessErr.Message)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
common.InternalServerErrorResponse(c, "用户查询失败")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
nickname := ""
|
|||
|
|
if user.Nickname != nil {
|
|||
|
|
nickname = *user.Nickname
|
|||
|
|
}
|
|||
|
|
avatar := ""
|
|||
|
|
if user.Avatar != nil {
|
|||
|
|
avatar = *user.Avatar
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
userInfo := &UserInfo{
|
|||
|
|
ID: user.ID,
|
|||
|
|
Username: user.Username,
|
|||
|
|
Email: user.Email,
|
|||
|
|
Nickname: nickname,
|
|||
|
|
Avatar: avatar,
|
|||
|
|
Level: "beginner",
|
|||
|
|
Status: user.Status,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
common.SuccessResponse(c, userInfo)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ChangePassword 修改密码
|
|||
|
|
func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
|||
|
|
// 获取当前用户ID
|
|||
|
|
userID, exists := utils.GetUserIDFromContext(c)
|
|||
|
|
if !exists {
|
|||
|
|
common.BadRequestResponse(c, "请先登录")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var req ChangePasswordRequest
|
|||
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|||
|
|
common.ValidationErrorResponse(c, err)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证请求参数
|
|||
|
|
if err := h.validator.Struct(&req); err != nil {
|
|||
|
|
common.ValidationErrorResponse(c, err)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证新密码强度
|
|||
|
|
if !utils.IsStrongPassword(req.NewPassword) {
|
|||
|
|
common.BadRequestResponse(c, "新密码强度不够")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 更新密码
|
|||
|
|
err := h.userService.UpdatePassword(userID, req.OldPassword, req.NewPassword)
|
|||
|
|
if err != nil {
|
|||
|
|
if businessErr, ok := err.(*common.BusinessError); ok {
|
|||
|
|
common.ErrorResponse(c, http.StatusBadRequest, businessErr.Message)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
common.InternalServerErrorResponse(c, "密码更新失败")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
common.SuccessResponse(c, map[string]string{"message": "密码修改成功"})
|
|||
|
|
}
|