package middleware import ( "bytes" "context" "dianshang/pkg/logger" "dianshang/pkg/utils" "fmt" "io" "os" "strconv" "strings" "time" "github.com/gin-gonic/gin" ) // RequestIDMiddleware 请求ID中间件 func RequestIDMiddleware() gin.HandlerFunc { return func(c *gin.Context) { // 生成请求ID requestID := logger.GenerateRequestID() // 设置到上下文 ctx := context.WithValue(c.Request.Context(), logger.RequestIDKey, requestID) c.Request = c.Request.WithContext(ctx) // 设置到响应头 c.Header("X-Request-ID", requestID) // 设置到Gin上下文,方便后续使用 c.Set("request_id", requestID) c.Next() } } // UserContextMiddleware 用户上下文中间件 func UserContextMiddleware() gin.HandlerFunc { return func(c *gin.Context) { // 从JWT或其他认证方式获取用户ID if userID, exists := c.Get("user_id"); exists { ctx := context.WithValue(c.Request.Context(), logger.UserIDKey, userID) c.Request = c.Request.WithContext(ctx) } c.Next() } } // LoggerMiddleware 日志中间件 func LoggerMiddleware() gin.HandlerFunc { return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { // 记录请求日志 logger.LogRequest( param.Method, param.Path, utils.GetClientIP(param.ClientIP, param.Request.Header.Get("X-Forwarded-For"), param.Request.Header.Get("X-Real-IP")), param.Request.UserAgent(), param.StatusCode, param.Latency.Milliseconds(), ) return "" }) } // responseBodyWriter 响应体写入器 type responseBodyWriter struct { gin.ResponseWriter body *bytes.Buffer } func (r responseBodyWriter) Write(b []byte) (int, error) { r.body.Write(b) return r.ResponseWriter.Write(b) } // EnhancedLoggerMiddleware 增强的日志中间件 func EnhancedLoggerMiddleware() gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() path := c.Request.URL.Path raw := c.Request.URL.RawQuery // 读取请求体 var requestBody []byte if c.Request.Body != nil { requestBody, _ = io.ReadAll(c.Request.Body) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) } // 创建响应体写入器 responseBody := &bytes.Buffer{} writer := &responseBodyWriter{ ResponseWriter: c.Writer, body: responseBody, } c.Writer = writer // 处理请求 c.Next() // 计算请求耗时 latency := time.Since(start) // 获取客户端IP clientIP := utils.GetClientIP( c.ClientIP(), c.GetHeader("X-Forwarded-For"), c.GetHeader("X-Real-IP"), ) // 构建完整路径 if raw != "" { path = path + "?" + raw } // 获取用户ID var userID interface{} if uid, exists := c.Get("user_id"); exists { userID = uid } // 获取请求ID var requestID interface{} if rid, exists := c.Get("request_id"); exists { requestID = rid } // 准备请求数据 requestData := map[string]interface{}{ "headers": c.Request.Header, "query": c.Request.URL.Query(), } // 添加请求体(如果不为空且不是敏感信息) if len(requestBody) > 0 && shouldLogRequestBody(c.Request.Method, path) { requestData["body"] = string(requestBody) } // 准备响应数据 responseData := map[string]interface{}{ "headers": writer.Header(), } // 添加响应体(如果不为空且不是敏感信息) if responseBody.Len() > 0 && shouldLogResponseBody(c.Writer.Status(), path) { responseData["body"] = responseBody.String() } printDetailedConsoleLog(c.Request.Method, path, clientIP, c.Request.UserAgent(), c.Writer.Status(), latency, requestID, userID, requestData, responseData) // 记录请求日志到文件(所有环境) logger.LogRequestWithContext( c.Request.Context(), c.Request.Method, path, clientIP, c.Request.UserAgent(), c.Writer.Status(), latency.Milliseconds(), ) // 记录性能日志(如果请求耗时较长) if latency > 1*time.Second { logger.LogPerformanceWithContext( c.Request.Context(), "http_request", latency, map[string]interface{}{ "method": c.Request.Method, "path": path, "status": c.Writer.Status(), "user_id": userID, "request_id": requestID, }, ) } // 如果有错误,记录错误日志 if len(c.Errors) > 0 { logger.LogErrorWithContext( c.Request.Context(), c.Errors[0].Err, map[string]interface{}{ "method": c.Request.Method, "path": path, "ip": clientIP, "user_id": userID, "request_id": requestID, "errors": c.Errors.String(), }, ) } } } // printDetailedConsoleLog 打印详细的控制台日志(仅开发环境) func printDetailedConsoleLog(method, path, ip, userAgent string, status int, latency time.Duration, requestID, userID interface{}, requestData, responseData map[string]interface{}) { // 获取状态颜色 statusColor := getStatusColor(status) methodColor := getMethodColor(method) // 打印分隔线 fmt.Println("\n" + strings.Repeat("=", 80)) // 打印基本信息(使用简单的ASCII字符作为分隔符) fmt.Printf("%s%s%s %s%s%s - Status: %s%d%s - Latency: %s%v%s\n", methodColor, method, "\033[0m", "\033[36m", path, "\033[0m", statusColor, status, "\033[0m", "\033[33m", latency, "\033[0m") // 打印请求信息 fmt.Printf("IP: %s - Request ID: %v - User ID: %v\n", ip, requestID, userID) fmt.Printf("User-Agent: %s\n", userAgent) // 打印请求参数 if query, ok := requestData["query"].(map[string][]string); ok && len(query) > 0 { fmt.Printf("Query Params:\n") for k, v := range query { fmt.Printf(" %s: %v\n", k, v) } } // 打印请求体 if body, ok := requestData["body"].(string); ok && body != "" { fmt.Printf("Request Body:\n") if len(body) > 2000 { fmt.Printf(" %s...(truncated)\n", body[:2000]) } else { fmt.Printf(" %s\n", body) } } // 打印响应体(限制长度,避免刷屏) if body, ok := responseData["body"].(string); ok && body != "" { fmt.Printf("Response Body:\n") if len(body) > 1000 { fmt.Printf(" %s...(truncated, total %d chars)\n", body[:1000], len(body)) } else { fmt.Printf(" %s\n", body) } } // 打印时间戳 fmt.Printf("Time: %s\n", time.Now().Format("2006-01-02 15:04:05")) // 打印分隔线 fmt.Println(strings.Repeat("=", 80)) } // getStatusColor 获取状态码颜色 func getStatusColor(status int) string { switch { case status >= 200 && status < 300: return "\033[32m" // 绿色 case status >= 300 && status < 400: return "\033[33m" // 黄色 case status >= 400 && status < 500: return "\033[31m" // 红色 case status >= 500: return "\033[35m" // 紫色 default: return "\033[37m" // 白色 } } // getMethodColor 获取方法颜色 func getMethodColor(method string) string { switch method { case "GET": return "\033[32m" // 绿色 case "POST": return "\033[34m" // 蓝色 case "PUT": return "\033[33m" // 黄色 case "DELETE": return "\033[31m" // 红色 case "PATCH": return "\033[36m" // 青色 default: return "\033[37m" // 白色 } } // shouldLogRequestBody 判断是否应该记录请求体 func shouldLogRequestBody(method, path string) bool { // 排除敏感路径 sensitivePatterns := []string{ "/login", "/register", "/password", "/auth", } for _, pattern := range sensitivePatterns { if strings.Contains(path, pattern) { return false } } // 只记录POST、PUT、PATCH请求的请求体 return method == "POST" || method == "PUT" || method == "PATCH" } // shouldLogResponseBody 判断是否应该记录响应体 func shouldLogResponseBody(status int, path string) bool { // 排除大文件下载等路径 excludePatterns := []string{ "/download", "/file", "/image", "/video", } for _, pattern := range excludePatterns { if strings.Contains(path, pattern) { return false } } // 只记录成功和客户端错误的响应体 return (status >= 200 && status < 300) || (status >= 400 && status < 500) } // CustomLoggerMiddleware 自定义日志中间件(保持向后兼容) func CustomLoggerMiddleware() gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() path := c.Request.URL.Path raw := c.Request.URL.RawQuery // 处理请求 c.Next() // 计算请求耗时 latency := time.Since(start) // 获取客户端IP clientIP := utils.GetClientIP( c.ClientIP(), c.GetHeader("X-Forwarded-For"), c.GetHeader("X-Real-IP"), ) // 构建完整路径 if raw != "" { path = path + "?" + raw } // 记录日志 logger.WithFields(map[string]interface{}{ "status": c.Writer.Status(), "method": c.Request.Method, "path": path, "ip": clientIP, "user_agent": c.Request.UserAgent(), "latency": latency.String(), "size": c.Writer.Size(), }).Info("HTTP Request") // 如果有错误,记录错误日志 if len(c.Errors) > 0 { logger.WithFields(map[string]interface{}{ "method": c.Request.Method, "path": path, "ip": clientIP, "errors": c.Errors.String(), }).Error("Request Errors") } } } // OperationLogMiddleware 操作日志中间件 func OperationLogMiddleware() gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() // 记录请求数据(仅对特定操作) var requestData interface{} if shouldLogRequestData(c.Request.Method, c.Request.URL.Path) { // 这里可以根据需要记录请求体,但要注意敏感信息 requestData = map[string]interface{}{ "query_params": c.Request.URL.RawQuery, "content_type": c.Request.Header.Get("Content-Type"), } } // 处理请求 c.Next() // 计算请求耗时 duration := time.Since(start) // 获取用户信息 var userID uint var userType string = "anonymous" if uid, exists := c.Get("user_id"); exists { if id, ok := uid.(uint); ok { userID = id userType = "user" } else if idStr, ok := uid.(string); ok { if id, err := strconv.ParseUint(idStr, 10, 32); err == nil { userID = uint(id) userType = "user" } } } // 获取客户端IP clientIP := utils.GetClientIP( c.ClientIP(), c.GetHeader("X-Forwarded-For"), c.GetHeader("X-Real-IP"), ) // 记录操作日志 if shouldLogOperation(c.Request.Method, c.Request.URL.Path) { logger.LogOperationWithContext( c.Request.Context(), userID, userType, getOperationType(c.Request.Method), getResourceName(c.Request.URL.Path), c.Request.Method, c.Request.URL.Path, clientIP, requestData, nil, // 响应数据可以根据需要添加 c.Writer.Status(), duration, ) } } } // shouldLogRequestData 判断是否需要记录请求数据 func shouldLogRequestData(method, path string) bool { // 对于POST、PUT、PATCH请求记录请求数据 return method == "POST" || method == "PUT" || method == "PATCH" } // shouldLogOperation 判断是否需要记录操作日志 func shouldLogOperation(method, path string) bool { // 排除健康检查等不重要的请求 if path == "/health" || path == "/ping" { return false } // 对于修改操作记录日志 return method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" } // getOperationType 获取操作类型 func getOperationType(method string) string { switch method { case "POST": return "create" case "PUT", "PATCH": return "update" case "DELETE": return "delete" case "GET": return "read" default: return "unknown" } } // getResourceName 从路径获取资源名称 func getResourceName(path string) string { // 简单的路径解析,可以根据实际需要完善 if len(path) > 1 { parts := strings.Split(path[1:], "/") if len(parts) > 0 { return parts[len(parts)-1] } } return "unknown" } // getEnvironment 获取当前环境 func getEnvironment() string { env := os.Getenv("GO_ENV") if env == "" { env = os.Getenv("APP_ENV") } if env == "" { env = os.Getenv("ENVIRONMENT") } if env == "" { env = "development" // 默认环境 } return strings.ToLower(env) }