This commit is contained in:
sjk
2025-11-17 14:09:17 +08:00
commit 31e46c5bf6
479 changed files with 109324 additions and 0 deletions

View File

@@ -0,0 +1,304 @@
package middleware
import (
"bytes"
"fmt"
"io"
"encoding/json"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/logger"
)
// Logger 保留原始Gin格式化日志如需
func Logger() gin.HandlerFunc {
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
return fmt.Sprintf("%s - [%s] \"%s %s %s %d %s \"%s\" %s\"\n",
param.ClientIP,
param.TimeStamp.Format(time.RFC1123),
param.Method,
param.Path,
param.Request.Proto,
param.StatusCode,
param.Latency,
param.Request.UserAgent(),
param.ErrorMessage,
)
})
}
// Recovery 恢复中间件
func Recovery() gin.HandlerFunc {
return gin.Recovery()
}
// bodyLogWriter 用于捕获响应体内容
type bodyLogWriter struct {
gin.ResponseWriter
body *bytes.Buffer
}
func (w *bodyLogWriter) Write(b []byte) (int, error) {
if w.body != nil {
w.body.Write(b)
}
return w.ResponseWriter.Write(b)
}
// RequestResponseLogger 记录清晰的请求入参与响应信息
func RequestResponseLogger() gin.HandlerFunc {
const maxBodyLogSize = 10000 // 最大记录的body长度超过则截断增加到10000
return func(c *gin.Context) {
start := time.Now()
// 捕获请求体仅记录JSON避免记录文件/二进制数据)
reqCT := c.GetHeader("Content-Type")
var reqBodyStr string
if strings.Contains(reqCT, "application/json") {
if c.Request.Body != nil {
data, _ := io.ReadAll(c.Request.Body)
// 复位Body以便后续业务读取
c.Request.Body = io.NopCloser(bytes.NewBuffer(data))
reqBodyStr = maskSensitiveJSON(string(data))
reqBodyStr = truncate(reqBodyStr, maxBodyLogSize)
}
} else {
// 对于表单/多部分/其他类型,不直接记录内容,避免污染日志与隐私泄露
reqBodyStr = "(body skipped for non-JSON content)"
}
// 包装响应写入器,捕获响应体
blw := &bodyLogWriter{ResponseWriter: c.Writer, body: &bytes.Buffer{}}
c.Writer = blw
// 执行后续处理
c.Next()
latency := time.Since(start)
status := c.Writer.Status()
respCT := c.Writer.Header().Get("Content-Type")
respBodyStr := blw.body.String()
if strings.Contains(respCT, "application/json") {
respBodyStr = truncate(respBodyStr, maxBodyLogSize)
} else if respBodyStr != "" {
respBodyStr = fmt.Sprintf("(non-JSON response, %d bytes)", len(respBodyStr))
} else {
respBodyStr = "(empty)"
}
// 路径(优先使用路由匹配的完整路径)
path := c.FullPath()
if path == "" {
path = c.Request.URL.Path
}
// 头信息(仅记录关键字段,并进行脱敏)
auth := c.GetHeader("Authorization")
if auth != "" {
auth = maskToken(auth)
}
// 错误聚合
var errMsg string
if len(c.Errors) > 0 {
errMsg = c.Errors.String()
}
// 根据状态码选择emoji和日志级别
statusEmoji := getStatusEmoji(status)
logLevel := getLogLevel(status)
// 获取或生成Request ID
requestID := c.GetString("request_id")
if requestID == "" {
requestID = "N/A"
}
// 获取User ID如果有
userID := "N/A"
if uid, exists := c.Get("user_id"); exists {
userID = fmt.Sprintf("%v", uid)
}
// 格式化耗时
latencyStr := formatLatency(latency)
// 构建美化的控制台日志
separator := "================================================================================"
fmt.Printf("\n%s\n", separator)
fmt.Printf("🌐 %s %s | %s %d | ⏱️ %s\n", c.Request.Method, path, statusEmoji, status, latencyStr)
fmt.Printf("📍 IP: %s | 🆔 Request ID: %s | 👤 User ID: %s\n", c.ClientIP(), requestID, userID)
if c.Request.URL.RawQuery != "" {
fmt.Printf("🔗 Query: %s\n", c.Request.URL.RawQuery)
}
fmt.Printf("🔍 User-Agent: %s\n", c.Request.UserAgent())
// 显示请求体(如果有)
if reqBodyStr != "" && reqBodyStr != "(body skipped for non-JSON content)" {
fmt.Printf("📤 Request Body:\n %s\n", reqBodyStr)
}
// 显示响应体
if respBodyStr != "" && respBodyStr != "(empty)" {
fmt.Printf("📥 Response Body:\n %s\n", respBodyStr)
}
// 显示错误(如果有)
if errMsg != "" {
fmt.Printf("❌ Error: %s\n", errMsg)
}
fmt.Printf("🕐 Time: %s\n", time.Now().Format("2006-01-02 15:04:05"))
fmt.Printf("%s\n", separator)
// 同时记录结构化日志到文件
logEntry := logger.WithFields(map[string]interface{}{
"type": "http_request",
"timestamp": time.Now().Unix(),
"method": c.Request.Method,
"path": path,
"raw_query": c.Request.URL.RawQuery,
"ip": c.ClientIP(),
"user_agent": truncate(c.Request.UserAgent(), 200),
"status": status,
"request_id": requestID,
"user_id": userID,
"duration": latency.Milliseconds(),
})
logMsg := fmt.Sprintf("HTTP Request")
// 根据状态码选择日志级别
switch logLevel {
case "error":
logEntry.Error(logMsg)
case "warn":
logEntry.Warn(logMsg)
default:
logEntry.Info(logMsg)
}
}
}
// formatLatency 格式化延迟时间
func formatLatency(d time.Duration) string {
if d < time.Millisecond {
return fmt.Sprintf("%dµs", d.Microseconds())
} else if d < time.Second {
return fmt.Sprintf("%dms", d.Milliseconds())
} else {
return fmt.Sprintf("%.1fs", d.Seconds())
}
}
// getStatusEmoji 根据HTTP状态码返回对应的emoji
func getStatusEmoji(status int) string {
switch {
case status >= 200 && status < 300:
return "✅" // 成功
case status >= 300 && status < 400:
return "🔄" // 重定向
case status >= 400 && status < 500:
return "⚠️" // 客户端错误
case status >= 500:
return "❌" // 服务器错误
default:
return "📝" // 其他
}
}
// getMethodEmoji 根据HTTP方法返回对应的emoji
func getMethodEmoji(method string) string {
switch method {
case "GET":
return "📖" // 读取
case "POST":
return "📝" // 创建
case "PUT":
return "✏️" // 更新
case "DELETE":
return "🗑️" // 删除
case "PATCH":
return "🔧" // 修补
case "OPTIONS":
return "🔍" // 选项
default:
return "📌" // 其他
}
}
// getLogLevel 根据状态码返回日志级别
func getLogLevel(status int) string {
switch {
case status >= 500:
return "error"
case status >= 400:
return "warn"
default:
return "info"
}
}
// truncate 截断过长日志内容
func truncate(s string, max int) string {
if len(s) <= max {
return s
}
if max <= 3 {
return s[:max]
}
return s[:max-3] + "..."
}
// maskToken 脱敏令牌或认证头
func maskToken(s string) string {
if s == "" {
return s
}
// 只保留前后少量字符
if len(s) <= 10 {
return "***"
}
return s[:6] + "***" + s[len(s)-4:]
}
// maskSensitiveJSON 尝试解析并递归脱敏JSON中的敏感字段
func maskSensitiveJSON(s string) string {
var obj interface{}
if err := json.Unmarshal([]byte(s), &obj); err != nil {
// 解析失败时直接返回原始字符串随后由truncate处理
return s
}
masked := maskRecursive(obj)
b, err := json.Marshal(masked)
if err != nil {
return s
}
return string(b)
}
func maskRecursive(v interface{}) interface{} {
switch t := v.(type) {
case map[string]interface{}:
for k, val := range t {
lk := strings.ToLower(k)
if lk == "password" || lk == "token" || lk == "secret" || lk == "authorization" {
t[k] = "***"
continue
}
t[k] = maskRecursive(val)
}
return t
case []interface{}:
for i := range t {
t[i] = maskRecursive(t[i])
}
return t
default:
return v
}
}