init
This commit is contained in:
183
serve/internal/middleware/auth.go
Normal file
183
serve/internal/middleware/auth.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/config"
|
||||
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/common"
|
||||
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// JWTClaims JWT声明结构
|
||||
type JWTClaims struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Type string `json:"type"` // access, refresh
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateTokens 生成访问令牌和刷新令牌
|
||||
func GenerateTokens(userID int64, username, email string) (accessToken, refreshToken string, err error) {
|
||||
cfg := config.GlobalConfig
|
||||
now := time.Now()
|
||||
|
||||
// 生成访问令牌
|
||||
accessClaims := JWTClaims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Email: email,
|
||||
Type: "access",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(cfg.JWT.AccessTokenTTL) * time.Second)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Issuer: cfg.App.Name,
|
||||
Subject: utils.Int64ToString(userID),
|
||||
ID: utils.GenerateUUID(),
|
||||
},
|
||||
}
|
||||
|
||||
accessTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
||||
accessToken, err = accessTokenObj.SignedString([]byte(cfg.JWT.Secret))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// 生成刷新令牌
|
||||
refreshClaims := JWTClaims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Email: email,
|
||||
Type: "refresh",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(cfg.JWT.RefreshTokenTTL) * time.Second)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Issuer: cfg.App.Name,
|
||||
Subject: utils.Int64ToString(userID),
|
||||
ID: utils.GenerateUUID(),
|
||||
},
|
||||
}
|
||||
|
||||
refreshTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
|
||||
refreshToken, err = refreshTokenObj.SignedString([]byte(cfg.JWT.Secret))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
// ParseToken 解析JWT令牌
|
||||
func ParseToken(tokenString string) (*JWTClaims, error) {
|
||||
cfg := config.GlobalConfig
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(cfg.JWT.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, jwt.ErrInvalidKey
|
||||
}
|
||||
|
||||
// AuthMiddleware JWT认证中间件
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从请求头获取token
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
common.UnauthorizedResponse(c, "缺少认证令牌")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查Bearer前缀
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
common.UnauthorizedResponse(c, "认证令牌格式错误")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
|
||||
// 解析token
|
||||
claims, err := ParseToken(tokenString)
|
||||
if err != nil {
|
||||
if err == jwt.ErrTokenExpired {
|
||||
common.ErrorResponse(c, http.StatusUnauthorized, "访问令牌已过期")
|
||||
} else {
|
||||
common.UnauthorizedResponse(c, "无效的访问令牌")
|
||||
}
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查token类型
|
||||
if claims.Type != "access" {
|
||||
common.UnauthorizedResponse(c, "令牌类型错误")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息存储到上下文中
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("email", claims.Email)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// OptionalAuthMiddleware 可选认证中间件(不强制要求登录)
|
||||
func OptionalAuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从请求头获取token
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查Bearer前缀
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
|
||||
// 解析token
|
||||
claims, err := ParseToken(tokenString)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查token类型
|
||||
if claims.Type != "access" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息存储到上下文中
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("email", claims.Email)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user