Initial commit
This commit is contained in:
99
server/internal/middleware/auth.go
Normal file
99
server/internal/middleware/auth.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"dianshang/pkg/jwt"
|
||||
"dianshang/pkg/response"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AuthMiddleware JWT认证中间件
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
response.Unauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 移除 "Bearer " 前缀
|
||||
if strings.HasPrefix(token, "Bearer ") {
|
||||
token = token[7:]
|
||||
}
|
||||
|
||||
claims, err := jwt.ParseToken(token)
|
||||
if err != nil {
|
||||
response.Unauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息存储到上下文中
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("user_type", claims.UserType)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// AdminAuthMiddleware 管理员认证中间件
|
||||
func AdminAuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
response.Unauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 移除 "Bearer " 前缀
|
||||
if strings.HasPrefix(token, "Bearer ") {
|
||||
token = token[7:]
|
||||
}
|
||||
|
||||
// 开发环境:支持模拟token
|
||||
if strings.HasPrefix(token, "mock_admin_token_") {
|
||||
// 模拟管理员用户信息
|
||||
c.Set("user_id", uint(1))
|
||||
c.Set("user_type", "admin")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := jwt.ParseToken(token)
|
||||
if err != nil {
|
||||
response.Unauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户类型是否为管理员
|
||||
if claims.UserType != "admin" {
|
||||
response.Forbidden(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息存储到上下文中
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("user_type", claims.UserType)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// OptionalAuthMiddleware 可选认证中间件(不强制要求登录)
|
||||
func OptionalAuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := c.GetHeader("Authorization")
|
||||
if token != "" {
|
||||
// 移除 "Bearer " 前缀
|
||||
if strings.HasPrefix(token, "Bearer ") {
|
||||
token = token[7:]
|
||||
}
|
||||
|
||||
claims, err := jwt.ParseToken(token)
|
||||
if err == nil {
|
||||
// 将用户信息存储到上下文中
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("user_type", claims.UserType)
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
42
server/internal/middleware/cors.go
Normal file
42
server/internal/middleware/cors.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORSMiddleware 跨域中间件
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
method := c.Request.Method
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
|
||||
// 设置允许的域名
|
||||
if origin != "" {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
} else {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// 设置允许的请求头
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-User-ID")
|
||||
|
||||
// 设置允许的请求方法
|
||||
c.Header("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||
|
||||
// 设置允许携带凭证
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
// 设置预检请求的缓存时间
|
||||
c.Header("Access-Control-Max-Age", "86400")
|
||||
|
||||
// 处理预检请求
|
||||
if method == "OPTIONS" {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
497
server/internal/middleware/logger.go
Normal file
497
server/internal/middleware/logger.go
Normal file
@@ -0,0 +1,497 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"dianshang/pkg/logger"
|
||||
"dianshang/pkg/utils"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RequestIDMiddleware 请求ID中间件
|
||||
func RequestIDMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 生成请求ID
|
||||
requestID := logger.GenerateRequestID()
|
||||
|
||||
// 设置到上下文
|
||||
ctx := context.WithValue(c.Request.Context(), logger.RequestIDKey, requestID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
// 设置到响应头
|
||||
c.Header("X-Request-ID", requestID)
|
||||
|
||||
// 设置到Gin上下文,方便后续使用
|
||||
c.Set("request_id", requestID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// UserContextMiddleware 用户上下文中间件
|
||||
func UserContextMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从JWT或其他认证方式获取用户ID
|
||||
if userID, exists := c.Get("user_id"); exists {
|
||||
ctx := context.WithValue(c.Request.Context(), logger.UserIDKey, userID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// LoggerMiddleware 日志中间件
|
||||
func LoggerMiddleware() gin.HandlerFunc {
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
// 记录请求日志
|
||||
logger.LogRequest(
|
||||
param.Method,
|
||||
param.Path,
|
||||
utils.GetClientIP(param.ClientIP, param.Request.Header.Get("X-Forwarded-For"), param.Request.Header.Get("X-Real-IP")),
|
||||
param.Request.UserAgent(),
|
||||
param.StatusCode,
|
||||
param.Latency.Milliseconds(),
|
||||
)
|
||||
return ""
|
||||
})
|
||||
}
|
||||
|
||||
// responseBodyWriter 响应体写入器
|
||||
type responseBodyWriter struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
}
|
||||
|
||||
func (r responseBodyWriter) Write(b []byte) (int, error) {
|
||||
r.body.Write(b)
|
||||
return r.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// EnhancedLoggerMiddleware 增强的日志中间件
|
||||
func EnhancedLoggerMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
raw := c.Request.URL.RawQuery
|
||||
|
||||
// 读取请求体
|
||||
var requestBody []byte
|
||||
if c.Request.Body != nil {
|
||||
requestBody, _ = io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
}
|
||||
|
||||
// 创建响应体写入器
|
||||
responseBody := &bytes.Buffer{}
|
||||
writer := &responseBodyWriter{
|
||||
ResponseWriter: c.Writer,
|
||||
body: responseBody,
|
||||
}
|
||||
c.Writer = writer
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 计算请求耗时
|
||||
latency := time.Since(start)
|
||||
|
||||
// 获取客户端IP
|
||||
clientIP := utils.GetClientIP(
|
||||
c.ClientIP(),
|
||||
c.GetHeader("X-Forwarded-For"),
|
||||
c.GetHeader("X-Real-IP"),
|
||||
)
|
||||
|
||||
// 构建完整路径
|
||||
if raw != "" {
|
||||
path = path + "?" + raw
|
||||
}
|
||||
|
||||
// 获取用户ID
|
||||
var userID interface{}
|
||||
if uid, exists := c.Get("user_id"); exists {
|
||||
userID = uid
|
||||
}
|
||||
|
||||
// 获取请求ID
|
||||
var requestID interface{}
|
||||
if rid, exists := c.Get("request_id"); exists {
|
||||
requestID = rid
|
||||
}
|
||||
|
||||
// 准备请求数据
|
||||
requestData := map[string]interface{}{
|
||||
"headers": c.Request.Header,
|
||||
"query": c.Request.URL.Query(),
|
||||
}
|
||||
|
||||
// 添加请求体(如果不为空且不是敏感信息)
|
||||
if len(requestBody) > 0 && shouldLogRequestBody(c.Request.Method, path) {
|
||||
requestData["body"] = string(requestBody)
|
||||
}
|
||||
|
||||
// 准备响应数据
|
||||
responseData := map[string]interface{}{
|
||||
"headers": writer.Header(),
|
||||
}
|
||||
|
||||
// 添加响应体(如果不为空且不是敏感信息)
|
||||
if responseBody.Len() > 0 && shouldLogResponseBody(c.Writer.Status(), path) {
|
||||
responseData["body"] = responseBody.String()
|
||||
}
|
||||
|
||||
|
||||
printDetailedConsoleLog(c.Request.Method, path, clientIP, c.Request.UserAgent(),
|
||||
c.Writer.Status(), latency, requestID, userID, requestData, responseData)
|
||||
|
||||
|
||||
|
||||
// 记录请求日志到文件(所有环境)
|
||||
logger.LogRequestWithContext(
|
||||
c.Request.Context(),
|
||||
c.Request.Method,
|
||||
path,
|
||||
clientIP,
|
||||
c.Request.UserAgent(),
|
||||
c.Writer.Status(),
|
||||
latency.Milliseconds(),
|
||||
)
|
||||
|
||||
// 记录性能日志(如果请求耗时较长)
|
||||
if latency > 1*time.Second {
|
||||
logger.LogPerformanceWithContext(
|
||||
c.Request.Context(),
|
||||
"http_request",
|
||||
latency,
|
||||
map[string]interface{}{
|
||||
"method": c.Request.Method,
|
||||
"path": path,
|
||||
"status": c.Writer.Status(),
|
||||
"user_id": userID,
|
||||
"request_id": requestID,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// 如果有错误,记录错误日志
|
||||
if len(c.Errors) > 0 {
|
||||
logger.LogErrorWithContext(
|
||||
c.Request.Context(),
|
||||
c.Errors[0].Err,
|
||||
map[string]interface{}{
|
||||
"method": c.Request.Method,
|
||||
"path": path,
|
||||
"ip": clientIP,
|
||||
"user_id": userID,
|
||||
"request_id": requestID,
|
||||
"errors": c.Errors.String(),
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// printDetailedConsoleLog 打印详细的控制台日志(仅开发环境)
|
||||
func printDetailedConsoleLog(method, path, ip, userAgent string, status int, latency time.Duration,
|
||||
requestID, userID interface{}, requestData, responseData map[string]interface{}) {
|
||||
|
||||
// 获取状态颜色
|
||||
statusColor := getStatusColor(status)
|
||||
methodColor := getMethodColor(method)
|
||||
|
||||
// 打印分隔线
|
||||
fmt.Println("\n" + strings.Repeat("=", 80))
|
||||
|
||||
// 打印基本信息(使用简单的ASCII字符作为分隔符)
|
||||
fmt.Printf("%s%s%s %s%s%s - Status: %s%d%s - Latency: %s%v%s\n",
|
||||
methodColor, method, "\033[0m",
|
||||
"\033[36m", path, "\033[0m",
|
||||
statusColor, status, "\033[0m",
|
||||
"\033[33m", latency, "\033[0m")
|
||||
|
||||
// 打印请求信息
|
||||
fmt.Printf("IP: %s - Request ID: %v - User ID: %v\n", ip, requestID, userID)
|
||||
fmt.Printf("User-Agent: %s\n", userAgent)
|
||||
|
||||
// 打印请求参数
|
||||
if query, ok := requestData["query"].(map[string][]string); ok && len(query) > 0 {
|
||||
fmt.Printf("Query Params:\n")
|
||||
for k, v := range query {
|
||||
fmt.Printf(" %s: %v\n", k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// 打印请求体
|
||||
if body, ok := requestData["body"].(string); ok && body != "" {
|
||||
fmt.Printf("Request Body:\n")
|
||||
if len(body) > 2000 {
|
||||
fmt.Printf(" %s...(truncated)\n", body[:2000])
|
||||
} else {
|
||||
fmt.Printf(" %s\n", body)
|
||||
}
|
||||
}
|
||||
|
||||
// 打印响应体(限制长度,避免刷屏)
|
||||
if body, ok := responseData["body"].(string); ok && body != "" {
|
||||
fmt.Printf("Response Body:\n")
|
||||
if len(body) > 1000 {
|
||||
fmt.Printf(" %s...(truncated, total %d chars)\n", body[:1000], len(body))
|
||||
} else {
|
||||
fmt.Printf(" %s\n", body)
|
||||
}
|
||||
}
|
||||
|
||||
// 打印时间戳
|
||||
fmt.Printf("Time: %s\n", time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
// 打印分隔线
|
||||
fmt.Println(strings.Repeat("=", 80))
|
||||
}
|
||||
|
||||
// getStatusColor 获取状态码颜色
|
||||
func getStatusColor(status int) string {
|
||||
switch {
|
||||
case status >= 200 && status < 300:
|
||||
return "\033[32m" // 绿色
|
||||
case status >= 300 && status < 400:
|
||||
return "\033[33m" // 黄色
|
||||
case status >= 400 && status < 500:
|
||||
return "\033[31m" // 红色
|
||||
case status >= 500:
|
||||
return "\033[35m" // 紫色
|
||||
default:
|
||||
return "\033[37m" // 白色
|
||||
}
|
||||
}
|
||||
|
||||
// getMethodColor 获取方法颜色
|
||||
func getMethodColor(method string) string {
|
||||
switch method {
|
||||
case "GET":
|
||||
return "\033[32m" // 绿色
|
||||
case "POST":
|
||||
return "\033[34m" // 蓝色
|
||||
case "PUT":
|
||||
return "\033[33m" // 黄色
|
||||
case "DELETE":
|
||||
return "\033[31m" // 红色
|
||||
case "PATCH":
|
||||
return "\033[36m" // 青色
|
||||
default:
|
||||
return "\033[37m" // 白色
|
||||
}
|
||||
}
|
||||
|
||||
// shouldLogRequestBody 判断是否应该记录请求体
|
||||
func shouldLogRequestBody(method, path string) bool {
|
||||
// 排除敏感路径
|
||||
sensitivePatterns := []string{
|
||||
"/login", "/register", "/password", "/auth",
|
||||
}
|
||||
|
||||
for _, pattern := range sensitivePatterns {
|
||||
if strings.Contains(path, pattern) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 只记录POST、PUT、PATCH请求的请求体
|
||||
return method == "POST" || method == "PUT" || method == "PATCH"
|
||||
}
|
||||
|
||||
// shouldLogResponseBody 判断是否应该记录响应体
|
||||
func shouldLogResponseBody(status int, path string) bool {
|
||||
// 排除大文件下载等路径
|
||||
excludePatterns := []string{
|
||||
"/download", "/file", "/image", "/video",
|
||||
}
|
||||
|
||||
for _, pattern := range excludePatterns {
|
||||
if strings.Contains(path, pattern) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 只记录成功和客户端错误的响应体
|
||||
return (status >= 200 && status < 300) || (status >= 400 && status < 500)
|
||||
}
|
||||
|
||||
// CustomLoggerMiddleware 自定义日志中间件(保持向后兼容)
|
||||
func CustomLoggerMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
raw := c.Request.URL.RawQuery
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 计算请求耗时
|
||||
latency := time.Since(start)
|
||||
|
||||
// 获取客户端IP
|
||||
clientIP := utils.GetClientIP(
|
||||
c.ClientIP(),
|
||||
c.GetHeader("X-Forwarded-For"),
|
||||
c.GetHeader("X-Real-IP"),
|
||||
)
|
||||
|
||||
// 构建完整路径
|
||||
if raw != "" {
|
||||
path = path + "?" + raw
|
||||
}
|
||||
|
||||
// 记录日志
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"status": c.Writer.Status(),
|
||||
"method": c.Request.Method,
|
||||
"path": path,
|
||||
"ip": clientIP,
|
||||
"user_agent": c.Request.UserAgent(),
|
||||
"latency": latency.String(),
|
||||
"size": c.Writer.Size(),
|
||||
}).Info("HTTP Request")
|
||||
|
||||
// 如果有错误,记录错误日志
|
||||
if len(c.Errors) > 0 {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"method": c.Request.Method,
|
||||
"path": path,
|
||||
"ip": clientIP,
|
||||
"errors": c.Errors.String(),
|
||||
}).Error("Request Errors")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OperationLogMiddleware 操作日志中间件
|
||||
func OperationLogMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
// 记录请求数据(仅对特定操作)
|
||||
var requestData interface{}
|
||||
if shouldLogRequestData(c.Request.Method, c.Request.URL.Path) {
|
||||
// 这里可以根据需要记录请求体,但要注意敏感信息
|
||||
requestData = map[string]interface{}{
|
||||
"query_params": c.Request.URL.RawQuery,
|
||||
"content_type": c.Request.Header.Get("Content-Type"),
|
||||
}
|
||||
}
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 计算请求耗时
|
||||
duration := time.Since(start)
|
||||
|
||||
// 获取用户信息
|
||||
var userID uint
|
||||
var userType string = "anonymous"
|
||||
|
||||
if uid, exists := c.Get("user_id"); exists {
|
||||
if id, ok := uid.(uint); ok {
|
||||
userID = id
|
||||
userType = "user"
|
||||
} else if idStr, ok := uid.(string); ok {
|
||||
if id, err := strconv.ParseUint(idStr, 10, 32); err == nil {
|
||||
userID = uint(id)
|
||||
userType = "user"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取客户端IP
|
||||
clientIP := utils.GetClientIP(
|
||||
c.ClientIP(),
|
||||
c.GetHeader("X-Forwarded-For"),
|
||||
c.GetHeader("X-Real-IP"),
|
||||
)
|
||||
|
||||
// 记录操作日志
|
||||
if shouldLogOperation(c.Request.Method, c.Request.URL.Path) {
|
||||
logger.LogOperationWithContext(
|
||||
c.Request.Context(),
|
||||
userID,
|
||||
userType,
|
||||
getOperationType(c.Request.Method),
|
||||
getResourceName(c.Request.URL.Path),
|
||||
c.Request.Method,
|
||||
c.Request.URL.Path,
|
||||
clientIP,
|
||||
requestData,
|
||||
nil, // 响应数据可以根据需要添加
|
||||
c.Writer.Status(),
|
||||
duration,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shouldLogRequestData 判断是否需要记录请求数据
|
||||
func shouldLogRequestData(method, path string) bool {
|
||||
// 对于POST、PUT、PATCH请求记录请求数据
|
||||
return method == "POST" || method == "PUT" || method == "PATCH"
|
||||
}
|
||||
|
||||
// shouldLogOperation 判断是否需要记录操作日志
|
||||
func shouldLogOperation(method, path string) bool {
|
||||
// 排除健康检查等不重要的请求
|
||||
if path == "/health" || path == "/ping" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 对于修改操作记录日志
|
||||
return method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE"
|
||||
}
|
||||
|
||||
// getOperationType 获取操作类型
|
||||
func getOperationType(method string) string {
|
||||
switch method {
|
||||
case "POST":
|
||||
return "create"
|
||||
case "PUT", "PATCH":
|
||||
return "update"
|
||||
case "DELETE":
|
||||
return "delete"
|
||||
case "GET":
|
||||
return "read"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// getResourceName 从路径获取资源名称
|
||||
func getResourceName(path string) string {
|
||||
// 简单的路径解析,可以根据实际需要完善
|
||||
if len(path) > 1 {
|
||||
parts := strings.Split(path[1:], "/")
|
||||
if len(parts) > 0 {
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// getEnvironment 获取当前环境
|
||||
func getEnvironment() string {
|
||||
env := os.Getenv("GO_ENV")
|
||||
if env == "" {
|
||||
env = os.Getenv("APP_ENV")
|
||||
}
|
||||
if env == "" {
|
||||
env = os.Getenv("ENVIRONMENT")
|
||||
}
|
||||
if env == "" {
|
||||
env = "development" // 默认环境
|
||||
}
|
||||
return strings.ToLower(env)
|
||||
}
|
||||
Reference in New Issue
Block a user