package middleware import ( "bytes" "fmt" "io" "encoding/json" "strings" "time" "github.com/gin-gonic/gin" "github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/logger" ) // Logger 保留原始Gin格式化日志(如需) func Logger() gin.HandlerFunc { return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { return fmt.Sprintf("%s - [%s] \"%s %s %s %d %s \"%s\" %s\"\n", param.ClientIP, param.TimeStamp.Format(time.RFC1123), param.Method, param.Path, param.Request.Proto, param.StatusCode, param.Latency, param.Request.UserAgent(), param.ErrorMessage, ) }) } // Recovery 恢复中间件 func Recovery() gin.HandlerFunc { return gin.Recovery() } // bodyLogWriter 用于捕获响应体内容 type bodyLogWriter struct { gin.ResponseWriter body *bytes.Buffer } func (w *bodyLogWriter) Write(b []byte) (int, error) { if w.body != nil { w.body.Write(b) } return w.ResponseWriter.Write(b) } // RequestResponseLogger 记录清晰的请求入参与响应信息 func RequestResponseLogger() gin.HandlerFunc { const maxBodyLogSize = 10000 // 最大记录的body长度,超过则截断(增加到10000) return func(c *gin.Context) { start := time.Now() // 捕获请求体(仅记录JSON,避免记录文件/二进制数据) reqCT := c.GetHeader("Content-Type") var reqBodyStr string if strings.Contains(reqCT, "application/json") { if c.Request.Body != nil { data, _ := io.ReadAll(c.Request.Body) // 复位Body以便后续业务读取 c.Request.Body = io.NopCloser(bytes.NewBuffer(data)) reqBodyStr = maskSensitiveJSON(string(data)) reqBodyStr = truncate(reqBodyStr, maxBodyLogSize) } } else { // 对于表单/多部分/其他类型,不直接记录内容,避免污染日志与隐私泄露 reqBodyStr = "(body skipped for non-JSON content)" } // 包装响应写入器,捕获响应体 blw := &bodyLogWriter{ResponseWriter: c.Writer, body: &bytes.Buffer{}} c.Writer = blw // 执行后续处理 c.Next() latency := time.Since(start) status := c.Writer.Status() respCT := c.Writer.Header().Get("Content-Type") respBodyStr := blw.body.String() if strings.Contains(respCT, "application/json") { respBodyStr = truncate(respBodyStr, maxBodyLogSize) } else if respBodyStr != "" { respBodyStr = fmt.Sprintf("(non-JSON response, %d bytes)", len(respBodyStr)) } else { respBodyStr = "(empty)" } // 路径(优先使用路由匹配的完整路径) path := c.FullPath() if path == "" { path = c.Request.URL.Path } // 头信息(仅记录关键字段,并进行脱敏) auth := c.GetHeader("Authorization") if auth != "" { auth = maskToken(auth) } // 错误聚合 var errMsg string if len(c.Errors) > 0 { errMsg = c.Errors.String() } // 根据状态码选择emoji和日志级别 statusEmoji := getStatusEmoji(status) logLevel := getLogLevel(status) // 获取或生成Request ID requestID := c.GetString("request_id") if requestID == "" { requestID = "N/A" } // 获取User ID(如果有) userID := "N/A" if uid, exists := c.Get("user_id"); exists { userID = fmt.Sprintf("%v", uid) } // 格式化耗时 latencyStr := formatLatency(latency) // 构建美化的控制台日志 separator := "================================================================================" fmt.Printf("\n%s\n", separator) fmt.Printf("🌐 %s %s | %s %d | ⏱️ %s\n", c.Request.Method, path, statusEmoji, status, latencyStr) fmt.Printf("📍 IP: %s | 🆔 Request ID: %s | 👤 User ID: %s\n", c.ClientIP(), requestID, userID) if c.Request.URL.RawQuery != "" { fmt.Printf("🔗 Query: %s\n", c.Request.URL.RawQuery) } fmt.Printf("🔍 User-Agent: %s\n", c.Request.UserAgent()) // 显示请求体(如果有) if reqBodyStr != "" && reqBodyStr != "(body skipped for non-JSON content)" { fmt.Printf("📤 Request Body:\n %s\n", reqBodyStr) } // 显示响应体 if respBodyStr != "" && respBodyStr != "(empty)" { fmt.Printf("📥 Response Body:\n %s\n", respBodyStr) } // 显示错误(如果有) if errMsg != "" { fmt.Printf("❌ Error: %s\n", errMsg) } fmt.Printf("🕐 Time: %s\n", time.Now().Format("2006-01-02 15:04:05")) fmt.Printf("%s\n", separator) // 同时记录结构化日志到文件 logEntry := logger.WithFields(map[string]interface{}{ "type": "http_request", "timestamp": time.Now().Unix(), "method": c.Request.Method, "path": path, "raw_query": c.Request.URL.RawQuery, "ip": c.ClientIP(), "user_agent": truncate(c.Request.UserAgent(), 200), "status": status, "request_id": requestID, "user_id": userID, "duration": latency.Milliseconds(), }) logMsg := fmt.Sprintf("HTTP Request") // 根据状态码选择日志级别 switch logLevel { case "error": logEntry.Error(logMsg) case "warn": logEntry.Warn(logMsg) default: logEntry.Info(logMsg) } } } // formatLatency 格式化延迟时间 func formatLatency(d time.Duration) string { if d < time.Millisecond { return fmt.Sprintf("%dµs", d.Microseconds()) } else if d < time.Second { return fmt.Sprintf("%dms", d.Milliseconds()) } else { return fmt.Sprintf("%.1fs", d.Seconds()) } } // getStatusEmoji 根据HTTP状态码返回对应的emoji func getStatusEmoji(status int) string { switch { case status >= 200 && status < 300: return "✅" // 成功 case status >= 300 && status < 400: return "🔄" // 重定向 case status >= 400 && status < 500: return "⚠️" // 客户端错误 case status >= 500: return "❌" // 服务器错误 default: return "📝" // 其他 } } // getMethodEmoji 根据HTTP方法返回对应的emoji func getMethodEmoji(method string) string { switch method { case "GET": return "📖" // 读取 case "POST": return "📝" // 创建 case "PUT": return "✏️" // 更新 case "DELETE": return "🗑️" // 删除 case "PATCH": return "🔧" // 修补 case "OPTIONS": return "🔍" // 选项 default: return "📌" // 其他 } } // getLogLevel 根据状态码返回日志级别 func getLogLevel(status int) string { switch { case status >= 500: return "error" case status >= 400: return "warn" default: return "info" } } // truncate 截断过长日志内容 func truncate(s string, max int) string { if len(s) <= max { return s } if max <= 3 { return s[:max] } return s[:max-3] + "..." } // maskToken 脱敏令牌或认证头 func maskToken(s string) string { if s == "" { return s } // 只保留前后少量字符 if len(s) <= 10 { return "***" } return s[:6] + "***" + s[len(s)-4:] } // maskSensitiveJSON 尝试解析并递归脱敏JSON中的敏感字段 func maskSensitiveJSON(s string) string { var obj interface{} if err := json.Unmarshal([]byte(s), &obj); err != nil { // 解析失败时直接返回原始字符串(随后由truncate处理) return s } masked := maskRecursive(obj) b, err := json.Marshal(masked) if err != nil { return s } return string(b) } func maskRecursive(v interface{}) interface{} { switch t := v.(type) { case map[string]interface{}: for k, val := range t { lk := strings.ToLower(k) if lk == "password" || lk == "token" || lk == "secret" || lk == "authorization" { t[k] = "***" continue } t[k] = maskRecursive(val) } return t case []interface{}: for i := range t { t[i] = maskRecursive(t[i]) } return t default: return v } }