Files
2025-11-17 14:09:17 +08:00

183 lines
4.4 KiB
Go

package middleware
import (
"net/http"
"strings"
"time"
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/config"
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/common"
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/utils"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
)
// JWTClaims JWT声明结构
type JWTClaims struct {
UserID int64 `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
Type string `json:"type"` // access, refresh
jwt.RegisteredClaims
}
// GenerateTokens 生成访问令牌和刷新令牌
func GenerateTokens(userID int64, username, email string) (accessToken, refreshToken string, err error) {
cfg := config.GlobalConfig
now := time.Now()
// 生成访问令牌
accessClaims := JWTClaims{
UserID: userID,
Username: username,
Email: email,
Type: "access",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(cfg.JWT.AccessTokenTTL) * time.Second)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
Issuer: cfg.App.Name,
Subject: utils.Int64ToString(userID),
ID: utils.GenerateUUID(),
},
}
accessTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
accessToken, err = accessTokenObj.SignedString([]byte(cfg.JWT.Secret))
if err != nil {
return "", "", err
}
// 生成刷新令牌
refreshClaims := JWTClaims{
UserID: userID,
Username: username,
Email: email,
Type: "refresh",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(cfg.JWT.RefreshTokenTTL) * time.Second)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
Issuer: cfg.App.Name,
Subject: utils.Int64ToString(userID),
ID: utils.GenerateUUID(),
},
}
refreshTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
refreshToken, err = refreshTokenObj.SignedString([]byte(cfg.JWT.Secret))
if err != nil {
return "", "", err
}
return accessToken, refreshToken, nil
}
// ParseToken 解析JWT令牌
func ParseToken(tokenString string) (*JWTClaims, error) {
cfg := config.GlobalConfig
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(cfg.JWT.Secret), nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid {
return claims, nil
}
return nil, jwt.ErrInvalidKey
}
// AuthMiddleware JWT认证中间件
func AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 从请求头获取token
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
common.UnauthorizedResponse(c, "缺少认证令牌")
c.Abort()
return
}
// 检查Bearer前缀
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
common.UnauthorizedResponse(c, "认证令牌格式错误")
c.Abort()
return
}
tokenString := parts[1]
// 解析token
claims, err := ParseToken(tokenString)
if err != nil {
if err == jwt.ErrTokenExpired {
common.ErrorResponse(c, http.StatusUnauthorized, "访问令牌已过期")
} else {
common.UnauthorizedResponse(c, "无效的访问令牌")
}
c.Abort()
return
}
// 检查token类型
if claims.Type != "access" {
common.UnauthorizedResponse(c, "令牌类型错误")
c.Abort()
return
}
// 将用户信息存储到上下文中
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Next()
}
}
// OptionalAuthMiddleware 可选认证中间件(不强制要求登录)
func OptionalAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 从请求头获取token
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.Next()
return
}
// 检查Bearer前缀
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
c.Next()
return
}
tokenString := parts[1]
// 解析token
claims, err := ParseToken(tokenString)
if err != nil {
c.Next()
return
}
// 检查token类型
if claims.Type != "access" {
c.Next()
return
}
// 将用户信息存储到上下文中
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Next()
}
}