183 lines
4.4 KiB
Go
183 lines
4.4 KiB
Go
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net/http"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/config"
|
||
|
|
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/common"
|
||
|
|
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/utils"
|
||
|
|
"github.com/gin-gonic/gin"
|
||
|
|
"github.com/golang-jwt/jwt/v5"
|
||
|
|
)
|
||
|
|
|
||
|
|
// JWTClaims JWT声明结构
|
||
|
|
type JWTClaims struct {
|
||
|
|
UserID int64 `json:"user_id"`
|
||
|
|
Username string `json:"username"`
|
||
|
|
Email string `json:"email"`
|
||
|
|
Type string `json:"type"` // access, refresh
|
||
|
|
jwt.RegisteredClaims
|
||
|
|
}
|
||
|
|
|
||
|
|
// GenerateTokens 生成访问令牌和刷新令牌
|
||
|
|
func GenerateTokens(userID int64, username, email string) (accessToken, refreshToken string, err error) {
|
||
|
|
cfg := config.GlobalConfig
|
||
|
|
now := time.Now()
|
||
|
|
|
||
|
|
// 生成访问令牌
|
||
|
|
accessClaims := JWTClaims{
|
||
|
|
UserID: userID,
|
||
|
|
Username: username,
|
||
|
|
Email: email,
|
||
|
|
Type: "access",
|
||
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
||
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(cfg.JWT.AccessTokenTTL) * time.Second)),
|
||
|
|
IssuedAt: jwt.NewNumericDate(now),
|
||
|
|
NotBefore: jwt.NewNumericDate(now),
|
||
|
|
Issuer: cfg.App.Name,
|
||
|
|
Subject: utils.Int64ToString(userID),
|
||
|
|
ID: utils.GenerateUUID(),
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
accessTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
||
|
|
accessToken, err = accessTokenObj.SignedString([]byte(cfg.JWT.Secret))
|
||
|
|
if err != nil {
|
||
|
|
return "", "", err
|
||
|
|
}
|
||
|
|
|
||
|
|
// 生成刷新令牌
|
||
|
|
refreshClaims := JWTClaims{
|
||
|
|
UserID: userID,
|
||
|
|
Username: username,
|
||
|
|
Email: email,
|
||
|
|
Type: "refresh",
|
||
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
||
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(cfg.JWT.RefreshTokenTTL) * time.Second)),
|
||
|
|
IssuedAt: jwt.NewNumericDate(now),
|
||
|
|
NotBefore: jwt.NewNumericDate(now),
|
||
|
|
Issuer: cfg.App.Name,
|
||
|
|
Subject: utils.Int64ToString(userID),
|
||
|
|
ID: utils.GenerateUUID(),
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
refreshTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
|
||
|
|
refreshToken, err = refreshTokenObj.SignedString([]byte(cfg.JWT.Secret))
|
||
|
|
if err != nil {
|
||
|
|
return "", "", err
|
||
|
|
}
|
||
|
|
|
||
|
|
return accessToken, refreshToken, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// ParseToken 解析JWT令牌
|
||
|
|
func ParseToken(tokenString string) (*JWTClaims, error) {
|
||
|
|
cfg := config.GlobalConfig
|
||
|
|
|
||
|
|
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||
|
|
return []byte(cfg.JWT.Secret), nil
|
||
|
|
})
|
||
|
|
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid {
|
||
|
|
return claims, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil, jwt.ErrInvalidKey
|
||
|
|
}
|
||
|
|
|
||
|
|
// AuthMiddleware JWT认证中间件
|
||
|
|
func AuthMiddleware() gin.HandlerFunc {
|
||
|
|
return func(c *gin.Context) {
|
||
|
|
// 从请求头获取token
|
||
|
|
authHeader := c.GetHeader("Authorization")
|
||
|
|
if authHeader == "" {
|
||
|
|
common.UnauthorizedResponse(c, "缺少认证令牌")
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// 检查Bearer前缀
|
||
|
|
parts := strings.SplitN(authHeader, " ", 2)
|
||
|
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
||
|
|
common.UnauthorizedResponse(c, "认证令牌格式错误")
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
tokenString := parts[1]
|
||
|
|
|
||
|
|
// 解析token
|
||
|
|
claims, err := ParseToken(tokenString)
|
||
|
|
if err != nil {
|
||
|
|
if err == jwt.ErrTokenExpired {
|
||
|
|
common.ErrorResponse(c, http.StatusUnauthorized, "访问令牌已过期")
|
||
|
|
} else {
|
||
|
|
common.UnauthorizedResponse(c, "无效的访问令牌")
|
||
|
|
}
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// 检查token类型
|
||
|
|
if claims.Type != "access" {
|
||
|
|
common.UnauthorizedResponse(c, "令牌类型错误")
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// 将用户信息存储到上下文中
|
||
|
|
c.Set("user_id", claims.UserID)
|
||
|
|
c.Set("username", claims.Username)
|
||
|
|
c.Set("email", claims.Email)
|
||
|
|
|
||
|
|
c.Next()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// OptionalAuthMiddleware 可选认证中间件(不强制要求登录)
|
||
|
|
func OptionalAuthMiddleware() gin.HandlerFunc {
|
||
|
|
return func(c *gin.Context) {
|
||
|
|
// 从请求头获取token
|
||
|
|
authHeader := c.GetHeader("Authorization")
|
||
|
|
if authHeader == "" {
|
||
|
|
c.Next()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// 检查Bearer前缀
|
||
|
|
parts := strings.SplitN(authHeader, " ", 2)
|
||
|
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
||
|
|
c.Next()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
tokenString := parts[1]
|
||
|
|
|
||
|
|
// 解析token
|
||
|
|
claims, err := ParseToken(tokenString)
|
||
|
|
if err != nil {
|
||
|
|
c.Next()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// 检查token类型
|
||
|
|
if claims.Type != "access" {
|
||
|
|
c.Next()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// 将用户信息存储到上下文中
|
||
|
|
c.Set("user_id", claims.UserID)
|
||
|
|
c.Set("username", claims.Username)
|
||
|
|
c.Set("email", claims.Email)
|
||
|
|
|
||
|
|
c.Next()
|
||
|
|
}
|
||
|
|
}
|