Files
2025-11-17 13:32:54 +08:00

497 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"bytes"
"context"
"dianshang/pkg/logger"
"dianshang/pkg/utils"
"fmt"
"io"
"os"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
// RequestIDMiddleware 请求ID中间件
func RequestIDMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 生成请求ID
requestID := logger.GenerateRequestID()
// 设置到上下文
ctx := context.WithValue(c.Request.Context(), logger.RequestIDKey, requestID)
c.Request = c.Request.WithContext(ctx)
// 设置到响应头
c.Header("X-Request-ID", requestID)
// 设置到Gin上下文方便后续使用
c.Set("request_id", requestID)
c.Next()
}
}
// UserContextMiddleware 用户上下文中间件
func UserContextMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 从JWT或其他认证方式获取用户ID
if userID, exists := c.Get("user_id"); exists {
ctx := context.WithValue(c.Request.Context(), logger.UserIDKey, userID)
c.Request = c.Request.WithContext(ctx)
}
c.Next()
}
}
// LoggerMiddleware 日志中间件
func LoggerMiddleware() gin.HandlerFunc {
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
// 记录请求日志
logger.LogRequest(
param.Method,
param.Path,
utils.GetClientIP(param.ClientIP, param.Request.Header.Get("X-Forwarded-For"), param.Request.Header.Get("X-Real-IP")),
param.Request.UserAgent(),
param.StatusCode,
param.Latency.Milliseconds(),
)
return ""
})
}
// responseBodyWriter 响应体写入器
type responseBodyWriter struct {
gin.ResponseWriter
body *bytes.Buffer
}
func (r responseBodyWriter) Write(b []byte) (int, error) {
r.body.Write(b)
return r.ResponseWriter.Write(b)
}
// EnhancedLoggerMiddleware 增强的日志中间件
func EnhancedLoggerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
raw := c.Request.URL.RawQuery
// 读取请求体
var requestBody []byte
if c.Request.Body != nil {
requestBody, _ = io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
}
// 创建响应体写入器
responseBody := &bytes.Buffer{}
writer := &responseBodyWriter{
ResponseWriter: c.Writer,
body: responseBody,
}
c.Writer = writer
// 处理请求
c.Next()
// 计算请求耗时
latency := time.Since(start)
// 获取客户端IP
clientIP := utils.GetClientIP(
c.ClientIP(),
c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP"),
)
// 构建完整路径
if raw != "" {
path = path + "?" + raw
}
// 获取用户ID
var userID interface{}
if uid, exists := c.Get("user_id"); exists {
userID = uid
}
// 获取请求ID
var requestID interface{}
if rid, exists := c.Get("request_id"); exists {
requestID = rid
}
// 准备请求数据
requestData := map[string]interface{}{
"headers": c.Request.Header,
"query": c.Request.URL.Query(),
}
// 添加请求体(如果不为空且不是敏感信息)
if len(requestBody) > 0 && shouldLogRequestBody(c.Request.Method, path) {
requestData["body"] = string(requestBody)
}
// 准备响应数据
responseData := map[string]interface{}{
"headers": writer.Header(),
}
// 添加响应体(如果不为空且不是敏感信息)
if responseBody.Len() > 0 && shouldLogResponseBody(c.Writer.Status(), path) {
responseData["body"] = responseBody.String()
}
printDetailedConsoleLog(c.Request.Method, path, clientIP, c.Request.UserAgent(),
c.Writer.Status(), latency, requestID, userID, requestData, responseData)
// 记录请求日志到文件(所有环境)
logger.LogRequestWithContext(
c.Request.Context(),
c.Request.Method,
path,
clientIP,
c.Request.UserAgent(),
c.Writer.Status(),
latency.Milliseconds(),
)
// 记录性能日志(如果请求耗时较长)
if latency > 1*time.Second {
logger.LogPerformanceWithContext(
c.Request.Context(),
"http_request",
latency,
map[string]interface{}{
"method": c.Request.Method,
"path": path,
"status": c.Writer.Status(),
"user_id": userID,
"request_id": requestID,
},
)
}
// 如果有错误,记录错误日志
if len(c.Errors) > 0 {
logger.LogErrorWithContext(
c.Request.Context(),
c.Errors[0].Err,
map[string]interface{}{
"method": c.Request.Method,
"path": path,
"ip": clientIP,
"user_id": userID,
"request_id": requestID,
"errors": c.Errors.String(),
},
)
}
}
}
// printDetailedConsoleLog 打印详细的控制台日志(仅开发环境)
func printDetailedConsoleLog(method, path, ip, userAgent string, status int, latency time.Duration,
requestID, userID interface{}, requestData, responseData map[string]interface{}) {
// 获取状态颜色
statusColor := getStatusColor(status)
methodColor := getMethodColor(method)
// 打印分隔线
fmt.Println("\n" + strings.Repeat("=", 80))
// 打印基本信息使用简单的ASCII字符作为分隔符
fmt.Printf("%s%s%s %s%s%s - Status: %s%d%s - Latency: %s%v%s\n",
methodColor, method, "\033[0m",
"\033[36m", path, "\033[0m",
statusColor, status, "\033[0m",
"\033[33m", latency, "\033[0m")
// 打印请求信息
fmt.Printf("IP: %s - Request ID: %v - User ID: %v\n", ip, requestID, userID)
fmt.Printf("User-Agent: %s\n", userAgent)
// 打印请求参数
if query, ok := requestData["query"].(map[string][]string); ok && len(query) > 0 {
fmt.Printf("Query Params:\n")
for k, v := range query {
fmt.Printf(" %s: %v\n", k, v)
}
}
// 打印请求体
if body, ok := requestData["body"].(string); ok && body != "" {
fmt.Printf("Request Body:\n")
if len(body) > 2000 {
fmt.Printf(" %s...(truncated)\n", body[:2000])
} else {
fmt.Printf(" %s\n", body)
}
}
// 打印响应体(限制长度,避免刷屏)
if body, ok := responseData["body"].(string); ok && body != "" {
fmt.Printf("Response Body:\n")
if len(body) > 1000 {
fmt.Printf(" %s...(truncated, total %d chars)\n", body[:1000], len(body))
} else {
fmt.Printf(" %s\n", body)
}
}
// 打印时间戳
fmt.Printf("Time: %s\n", time.Now().Format("2006-01-02 15:04:05"))
// 打印分隔线
fmt.Println(strings.Repeat("=", 80))
}
// getStatusColor 获取状态码颜色
func getStatusColor(status int) string {
switch {
case status >= 200 && status < 300:
return "\033[32m" // 绿色
case status >= 300 && status < 400:
return "\033[33m" // 黄色
case status >= 400 && status < 500:
return "\033[31m" // 红色
case status >= 500:
return "\033[35m" // 紫色
default:
return "\033[37m" // 白色
}
}
// getMethodColor 获取方法颜色
func getMethodColor(method string) string {
switch method {
case "GET":
return "\033[32m" // 绿色
case "POST":
return "\033[34m" // 蓝色
case "PUT":
return "\033[33m" // 黄色
case "DELETE":
return "\033[31m" // 红色
case "PATCH":
return "\033[36m" // 青色
default:
return "\033[37m" // 白色
}
}
// shouldLogRequestBody 判断是否应该记录请求体
func shouldLogRequestBody(method, path string) bool {
// 排除敏感路径
sensitivePatterns := []string{
"/login", "/register", "/password", "/auth",
}
for _, pattern := range sensitivePatterns {
if strings.Contains(path, pattern) {
return false
}
}
// 只记录POST、PUT、PATCH请求的请求体
return method == "POST" || method == "PUT" || method == "PATCH"
}
// shouldLogResponseBody 判断是否应该记录响应体
func shouldLogResponseBody(status int, path string) bool {
// 排除大文件下载等路径
excludePatterns := []string{
"/download", "/file", "/image", "/video",
}
for _, pattern := range excludePatterns {
if strings.Contains(path, pattern) {
return false
}
}
// 只记录成功和客户端错误的响应体
return (status >= 200 && status < 300) || (status >= 400 && status < 500)
}
// CustomLoggerMiddleware 自定义日志中间件(保持向后兼容)
func CustomLoggerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
raw := c.Request.URL.RawQuery
// 处理请求
c.Next()
// 计算请求耗时
latency := time.Since(start)
// 获取客户端IP
clientIP := utils.GetClientIP(
c.ClientIP(),
c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP"),
)
// 构建完整路径
if raw != "" {
path = path + "?" + raw
}
// 记录日志
logger.WithFields(map[string]interface{}{
"status": c.Writer.Status(),
"method": c.Request.Method,
"path": path,
"ip": clientIP,
"user_agent": c.Request.UserAgent(),
"latency": latency.String(),
"size": c.Writer.Size(),
}).Info("HTTP Request")
// 如果有错误,记录错误日志
if len(c.Errors) > 0 {
logger.WithFields(map[string]interface{}{
"method": c.Request.Method,
"path": path,
"ip": clientIP,
"errors": c.Errors.String(),
}).Error("Request Errors")
}
}
}
// OperationLogMiddleware 操作日志中间件
func OperationLogMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
// 记录请求数据(仅对特定操作)
var requestData interface{}
if shouldLogRequestData(c.Request.Method, c.Request.URL.Path) {
// 这里可以根据需要记录请求体,但要注意敏感信息
requestData = map[string]interface{}{
"query_params": c.Request.URL.RawQuery,
"content_type": c.Request.Header.Get("Content-Type"),
}
}
// 处理请求
c.Next()
// 计算请求耗时
duration := time.Since(start)
// 获取用户信息
var userID uint
var userType string = "anonymous"
if uid, exists := c.Get("user_id"); exists {
if id, ok := uid.(uint); ok {
userID = id
userType = "user"
} else if idStr, ok := uid.(string); ok {
if id, err := strconv.ParseUint(idStr, 10, 32); err == nil {
userID = uint(id)
userType = "user"
}
}
}
// 获取客户端IP
clientIP := utils.GetClientIP(
c.ClientIP(),
c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP"),
)
// 记录操作日志
if shouldLogOperation(c.Request.Method, c.Request.URL.Path) {
logger.LogOperationWithContext(
c.Request.Context(),
userID,
userType,
getOperationType(c.Request.Method),
getResourceName(c.Request.URL.Path),
c.Request.Method,
c.Request.URL.Path,
clientIP,
requestData,
nil, // 响应数据可以根据需要添加
c.Writer.Status(),
duration,
)
}
}
}
// shouldLogRequestData 判断是否需要记录请求数据
func shouldLogRequestData(method, path string) bool {
// 对于POST、PUT、PATCH请求记录请求数据
return method == "POST" || method == "PUT" || method == "PATCH"
}
// shouldLogOperation 判断是否需要记录操作日志
func shouldLogOperation(method, path string) bool {
// 排除健康检查等不重要的请求
if path == "/health" || path == "/ping" {
return false
}
// 对于修改操作记录日志
return method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE"
}
// getOperationType 获取操作类型
func getOperationType(method string) string {
switch method {
case "POST":
return "create"
case "PUT", "PATCH":
return "update"
case "DELETE":
return "delete"
case "GET":
return "read"
default:
return "unknown"
}
}
// getResourceName 从路径获取资源名称
func getResourceName(path string) string {
// 简单的路径解析,可以根据实际需要完善
if len(path) > 1 {
parts := strings.Split(path[1:], "/")
if len(parts) > 0 {
return parts[len(parts)-1]
}
}
return "unknown"
}
// getEnvironment 获取当前环境
func getEnvironment() string {
env := os.Getenv("GO_ENV")
if env == "" {
env = os.Getenv("APP_ENV")
}
if env == "" {
env = os.Getenv("ENVIRONMENT")
}
if env == "" {
env = "development" // 默认环境
}
return strings.ToLower(env)
}