Files
2025-11-17 13:39:05 +08:00

304 lines
8.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"bytes"
"fmt"
"io"
"encoding/json"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/logger"
)
// Logger 保留原始Gin格式化日志如需
func Logger() gin.HandlerFunc {
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
return fmt.Sprintf("%s - [%s] \"%s %s %s %d %s \"%s\" %s\"\n",
param.ClientIP,
param.TimeStamp.Format(time.RFC1123),
param.Method,
param.Path,
param.Request.Proto,
param.StatusCode,
param.Latency,
param.Request.UserAgent(),
param.ErrorMessage,
)
})
}
// Recovery 恢复中间件
func Recovery() gin.HandlerFunc {
return gin.Recovery()
}
// bodyLogWriter 用于捕获响应体内容
type bodyLogWriter struct {
gin.ResponseWriter
body *bytes.Buffer
}
func (w *bodyLogWriter) Write(b []byte) (int, error) {
if w.body != nil {
w.body.Write(b)
}
return w.ResponseWriter.Write(b)
}
// RequestResponseLogger 记录清晰的请求入参与响应信息
func RequestResponseLogger() gin.HandlerFunc {
const maxBodyLogSize = 10000 // 最大记录的body长度超过则截断增加到10000
return func(c *gin.Context) {
start := time.Now()
// 捕获请求体仅记录JSON避免记录文件/二进制数据)
reqCT := c.GetHeader("Content-Type")
var reqBodyStr string
if strings.Contains(reqCT, "application/json") {
if c.Request.Body != nil {
data, _ := io.ReadAll(c.Request.Body)
// 复位Body以便后续业务读取
c.Request.Body = io.NopCloser(bytes.NewBuffer(data))
reqBodyStr = maskSensitiveJSON(string(data))
reqBodyStr = truncate(reqBodyStr, maxBodyLogSize)
}
} else {
// 对于表单/多部分/其他类型,不直接记录内容,避免污染日志与隐私泄露
reqBodyStr = "(body skipped for non-JSON content)"
}
// 包装响应写入器,捕获响应体
blw := &bodyLogWriter{ResponseWriter: c.Writer, body: &bytes.Buffer{}}
c.Writer = blw
// 执行后续处理
c.Next()
latency := time.Since(start)
status := c.Writer.Status()
respCT := c.Writer.Header().Get("Content-Type")
respBodyStr := blw.body.String()
if strings.Contains(respCT, "application/json") {
respBodyStr = truncate(respBodyStr, maxBodyLogSize)
} else if respBodyStr != "" {
respBodyStr = fmt.Sprintf("(non-JSON response, %d bytes)", len(respBodyStr))
} else {
respBodyStr = "(empty)"
}
// 路径(优先使用路由匹配的完整路径)
path := c.FullPath()
if path == "" {
path = c.Request.URL.Path
}
// 头信息(仅记录关键字段,并进行脱敏)
auth := c.GetHeader("Authorization")
if auth != "" {
auth = maskToken(auth)
}
// 错误聚合
var errMsg string
if len(c.Errors) > 0 {
errMsg = c.Errors.String()
}
// 根据状态码选择emoji和日志级别
statusEmoji := getStatusEmoji(status)
logLevel := getLogLevel(status)
// 获取或生成Request ID
requestID := c.GetString("request_id")
if requestID == "" {
requestID = "N/A"
}
// 获取User ID如果有
userID := "N/A"
if uid, exists := c.Get("user_id"); exists {
userID = fmt.Sprintf("%v", uid)
}
// 格式化耗时
latencyStr := formatLatency(latency)
// 构建美化的控制台日志
separator := "================================================================================"
fmt.Printf("\n%s\n", separator)
fmt.Printf("🌐 %s %s | %s %d | ⏱️ %s\n", c.Request.Method, path, statusEmoji, status, latencyStr)
fmt.Printf("📍 IP: %s | 🆔 Request ID: %s | 👤 User ID: %s\n", c.ClientIP(), requestID, userID)
if c.Request.URL.RawQuery != "" {
fmt.Printf("🔗 Query: %s\n", c.Request.URL.RawQuery)
}
fmt.Printf("🔍 User-Agent: %s\n", c.Request.UserAgent())
// 显示请求体(如果有)
if reqBodyStr != "" && reqBodyStr != "(body skipped for non-JSON content)" {
fmt.Printf("📤 Request Body:\n %s\n", reqBodyStr)
}
// 显示响应体
if respBodyStr != "" && respBodyStr != "(empty)" {
fmt.Printf("📥 Response Body:\n %s\n", respBodyStr)
}
// 显示错误(如果有)
if errMsg != "" {
fmt.Printf("❌ Error: %s\n", errMsg)
}
fmt.Printf("🕐 Time: %s\n", time.Now().Format("2006-01-02 15:04:05"))
fmt.Printf("%s\n", separator)
// 同时记录结构化日志到文件
logEntry := logger.WithFields(map[string]interface{}{
"type": "http_request",
"timestamp": time.Now().Unix(),
"method": c.Request.Method,
"path": path,
"raw_query": c.Request.URL.RawQuery,
"ip": c.ClientIP(),
"user_agent": truncate(c.Request.UserAgent(), 200),
"status": status,
"request_id": requestID,
"user_id": userID,
"duration": latency.Milliseconds(),
})
logMsg := fmt.Sprintf("HTTP Request")
// 根据状态码选择日志级别
switch logLevel {
case "error":
logEntry.Error(logMsg)
case "warn":
logEntry.Warn(logMsg)
default:
logEntry.Info(logMsg)
}
}
}
// formatLatency 格式化延迟时间
func formatLatency(d time.Duration) string {
if d < time.Millisecond {
return fmt.Sprintf("%dµs", d.Microseconds())
} else if d < time.Second {
return fmt.Sprintf("%dms", d.Milliseconds())
} else {
return fmt.Sprintf("%.1fs", d.Seconds())
}
}
// getStatusEmoji 根据HTTP状态码返回对应的emoji
func getStatusEmoji(status int) string {
switch {
case status >= 200 && status < 300:
return "✅" // 成功
case status >= 300 && status < 400:
return "🔄" // 重定向
case status >= 400 && status < 500:
return "⚠️" // 客户端错误
case status >= 500:
return "❌" // 服务器错误
default:
return "📝" // 其他
}
}
// getMethodEmoji 根据HTTP方法返回对应的emoji
func getMethodEmoji(method string) string {
switch method {
case "GET":
return "📖" // 读取
case "POST":
return "📝" // 创建
case "PUT":
return "✏️" // 更新
case "DELETE":
return "🗑️" // 删除
case "PATCH":
return "🔧" // 修补
case "OPTIONS":
return "🔍" // 选项
default:
return "📌" // 其他
}
}
// getLogLevel 根据状态码返回日志级别
func getLogLevel(status int) string {
switch {
case status >= 500:
return "error"
case status >= 400:
return "warn"
default:
return "info"
}
}
// truncate 截断过长日志内容
func truncate(s string, max int) string {
if len(s) <= max {
return s
}
if max <= 3 {
return s[:max]
}
return s[:max-3] + "..."
}
// maskToken 脱敏令牌或认证头
func maskToken(s string) string {
if s == "" {
return s
}
// 只保留前后少量字符
if len(s) <= 10 {
return "***"
}
return s[:6] + "***" + s[len(s)-4:]
}
// maskSensitiveJSON 尝试解析并递归脱敏JSON中的敏感字段
func maskSensitiveJSON(s string) string {
var obj interface{}
if err := json.Unmarshal([]byte(s), &obj); err != nil {
// 解析失败时直接返回原始字符串随后由truncate处理
return s
}
masked := maskRecursive(obj)
b, err := json.Marshal(masked)
if err != nil {
return s
}
return string(b)
}
func maskRecursive(v interface{}) interface{} {
switch t := v.(type) {
case map[string]interface{}:
for k, val := range t {
lk := strings.ToLower(k)
if lk == "password" || lk == "token" || lk == "secret" || lk == "authorization" {
t[k] = "***"
continue
}
t[k] = maskRecursive(val)
}
return t
case []interface{}:
for i := range t {
t[i] = maskRecursive(t[i])
}
return t
default:
return v
}
}