497 lines
12 KiB
Go
497 lines
12 KiB
Go
package middleware
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"dianshang/pkg/logger"
|
||
"dianshang/pkg/utils"
|
||
"fmt"
|
||
"io"
|
||
"os"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// RequestIDMiddleware 请求ID中间件
|
||
func RequestIDMiddleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 生成请求ID
|
||
requestID := logger.GenerateRequestID()
|
||
|
||
// 设置到上下文
|
||
ctx := context.WithValue(c.Request.Context(), logger.RequestIDKey, requestID)
|
||
c.Request = c.Request.WithContext(ctx)
|
||
|
||
// 设置到响应头
|
||
c.Header("X-Request-ID", requestID)
|
||
|
||
// 设置到Gin上下文,方便后续使用
|
||
c.Set("request_id", requestID)
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// UserContextMiddleware 用户上下文中间件
|
||
func UserContextMiddleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 从JWT或其他认证方式获取用户ID
|
||
if userID, exists := c.Get("user_id"); exists {
|
||
ctx := context.WithValue(c.Request.Context(), logger.UserIDKey, userID)
|
||
c.Request = c.Request.WithContext(ctx)
|
||
}
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// LoggerMiddleware 日志中间件
|
||
func LoggerMiddleware() gin.HandlerFunc {
|
||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||
// 记录请求日志
|
||
logger.LogRequest(
|
||
param.Method,
|
||
param.Path,
|
||
utils.GetClientIP(param.ClientIP, param.Request.Header.Get("X-Forwarded-For"), param.Request.Header.Get("X-Real-IP")),
|
||
param.Request.UserAgent(),
|
||
param.StatusCode,
|
||
param.Latency.Milliseconds(),
|
||
)
|
||
return ""
|
||
})
|
||
}
|
||
|
||
// responseBodyWriter 响应体写入器
|
||
type responseBodyWriter struct {
|
||
gin.ResponseWriter
|
||
body *bytes.Buffer
|
||
}
|
||
|
||
func (r responseBodyWriter) Write(b []byte) (int, error) {
|
||
r.body.Write(b)
|
||
return r.ResponseWriter.Write(b)
|
||
}
|
||
|
||
// EnhancedLoggerMiddleware 增强的日志中间件
|
||
func EnhancedLoggerMiddleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
start := time.Now()
|
||
path := c.Request.URL.Path
|
||
raw := c.Request.URL.RawQuery
|
||
|
||
// 读取请求体
|
||
var requestBody []byte
|
||
if c.Request.Body != nil {
|
||
requestBody, _ = io.ReadAll(c.Request.Body)
|
||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||
}
|
||
|
||
// 创建响应体写入器
|
||
responseBody := &bytes.Buffer{}
|
||
writer := &responseBodyWriter{
|
||
ResponseWriter: c.Writer,
|
||
body: responseBody,
|
||
}
|
||
c.Writer = writer
|
||
|
||
// 处理请求
|
||
c.Next()
|
||
|
||
// 计算请求耗时
|
||
latency := time.Since(start)
|
||
|
||
// 获取客户端IP
|
||
clientIP := utils.GetClientIP(
|
||
c.ClientIP(),
|
||
c.GetHeader("X-Forwarded-For"),
|
||
c.GetHeader("X-Real-IP"),
|
||
)
|
||
|
||
// 构建完整路径
|
||
if raw != "" {
|
||
path = path + "?" + raw
|
||
}
|
||
|
||
// 获取用户ID
|
||
var userID interface{}
|
||
if uid, exists := c.Get("user_id"); exists {
|
||
userID = uid
|
||
}
|
||
|
||
// 获取请求ID
|
||
var requestID interface{}
|
||
if rid, exists := c.Get("request_id"); exists {
|
||
requestID = rid
|
||
}
|
||
|
||
// 准备请求数据
|
||
requestData := map[string]interface{}{
|
||
"headers": c.Request.Header,
|
||
"query": c.Request.URL.Query(),
|
||
}
|
||
|
||
// 添加请求体(如果不为空且不是敏感信息)
|
||
if len(requestBody) > 0 && shouldLogRequestBody(c.Request.Method, path) {
|
||
requestData["body"] = string(requestBody)
|
||
}
|
||
|
||
// 准备响应数据
|
||
responseData := map[string]interface{}{
|
||
"headers": writer.Header(),
|
||
}
|
||
|
||
// 添加响应体(如果不为空且不是敏感信息)
|
||
if responseBody.Len() > 0 && shouldLogResponseBody(c.Writer.Status(), path) {
|
||
responseData["body"] = responseBody.String()
|
||
}
|
||
|
||
|
||
printDetailedConsoleLog(c.Request.Method, path, clientIP, c.Request.UserAgent(),
|
||
c.Writer.Status(), latency, requestID, userID, requestData, responseData)
|
||
|
||
|
||
|
||
// 记录请求日志到文件(所有环境)
|
||
logger.LogRequestWithContext(
|
||
c.Request.Context(),
|
||
c.Request.Method,
|
||
path,
|
||
clientIP,
|
||
c.Request.UserAgent(),
|
||
c.Writer.Status(),
|
||
latency.Milliseconds(),
|
||
)
|
||
|
||
// 记录性能日志(如果请求耗时较长)
|
||
if latency > 1*time.Second {
|
||
logger.LogPerformanceWithContext(
|
||
c.Request.Context(),
|
||
"http_request",
|
||
latency,
|
||
map[string]interface{}{
|
||
"method": c.Request.Method,
|
||
"path": path,
|
||
"status": c.Writer.Status(),
|
||
"user_id": userID,
|
||
"request_id": requestID,
|
||
},
|
||
)
|
||
}
|
||
|
||
// 如果有错误,记录错误日志
|
||
if len(c.Errors) > 0 {
|
||
logger.LogErrorWithContext(
|
||
c.Request.Context(),
|
||
c.Errors[0].Err,
|
||
map[string]interface{}{
|
||
"method": c.Request.Method,
|
||
"path": path,
|
||
"ip": clientIP,
|
||
"user_id": userID,
|
||
"request_id": requestID,
|
||
"errors": c.Errors.String(),
|
||
},
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
// printDetailedConsoleLog 打印详细的控制台日志(仅开发环境)
|
||
func printDetailedConsoleLog(method, path, ip, userAgent string, status int, latency time.Duration,
|
||
requestID, userID interface{}, requestData, responseData map[string]interface{}) {
|
||
|
||
// 获取状态颜色
|
||
statusColor := getStatusColor(status)
|
||
methodColor := getMethodColor(method)
|
||
|
||
// 打印分隔线
|
||
fmt.Println("\n" + strings.Repeat("=", 80))
|
||
|
||
// 打印基本信息(使用简单的ASCII字符作为分隔符)
|
||
fmt.Printf("%s%s%s %s%s%s - Status: %s%d%s - Latency: %s%v%s\n",
|
||
methodColor, method, "\033[0m",
|
||
"\033[36m", path, "\033[0m",
|
||
statusColor, status, "\033[0m",
|
||
"\033[33m", latency, "\033[0m")
|
||
|
||
// 打印请求信息
|
||
fmt.Printf("IP: %s - Request ID: %v - User ID: %v\n", ip, requestID, userID)
|
||
fmt.Printf("User-Agent: %s\n", userAgent)
|
||
|
||
// 打印请求参数
|
||
if query, ok := requestData["query"].(map[string][]string); ok && len(query) > 0 {
|
||
fmt.Printf("Query Params:\n")
|
||
for k, v := range query {
|
||
fmt.Printf(" %s: %v\n", k, v)
|
||
}
|
||
}
|
||
|
||
// 打印请求体
|
||
if body, ok := requestData["body"].(string); ok && body != "" {
|
||
fmt.Printf("Request Body:\n")
|
||
if len(body) > 2000 {
|
||
fmt.Printf(" %s...(truncated)\n", body[:2000])
|
||
} else {
|
||
fmt.Printf(" %s\n", body)
|
||
}
|
||
}
|
||
|
||
// 打印响应体(限制长度,避免刷屏)
|
||
if body, ok := responseData["body"].(string); ok && body != "" {
|
||
fmt.Printf("Response Body:\n")
|
||
if len(body) > 1000 {
|
||
fmt.Printf(" %s...(truncated, total %d chars)\n", body[:1000], len(body))
|
||
} else {
|
||
fmt.Printf(" %s\n", body)
|
||
}
|
||
}
|
||
|
||
// 打印时间戳
|
||
fmt.Printf("Time: %s\n", time.Now().Format("2006-01-02 15:04:05"))
|
||
|
||
// 打印分隔线
|
||
fmt.Println(strings.Repeat("=", 80))
|
||
}
|
||
|
||
// getStatusColor 获取状态码颜色
|
||
func getStatusColor(status int) string {
|
||
switch {
|
||
case status >= 200 && status < 300:
|
||
return "\033[32m" // 绿色
|
||
case status >= 300 && status < 400:
|
||
return "\033[33m" // 黄色
|
||
case status >= 400 && status < 500:
|
||
return "\033[31m" // 红色
|
||
case status >= 500:
|
||
return "\033[35m" // 紫色
|
||
default:
|
||
return "\033[37m" // 白色
|
||
}
|
||
}
|
||
|
||
// getMethodColor 获取方法颜色
|
||
func getMethodColor(method string) string {
|
||
switch method {
|
||
case "GET":
|
||
return "\033[32m" // 绿色
|
||
case "POST":
|
||
return "\033[34m" // 蓝色
|
||
case "PUT":
|
||
return "\033[33m" // 黄色
|
||
case "DELETE":
|
||
return "\033[31m" // 红色
|
||
case "PATCH":
|
||
return "\033[36m" // 青色
|
||
default:
|
||
return "\033[37m" // 白色
|
||
}
|
||
}
|
||
|
||
// shouldLogRequestBody 判断是否应该记录请求体
|
||
func shouldLogRequestBody(method, path string) bool {
|
||
// 排除敏感路径
|
||
sensitivePatterns := []string{
|
||
"/login", "/register", "/password", "/auth",
|
||
}
|
||
|
||
for _, pattern := range sensitivePatterns {
|
||
if strings.Contains(path, pattern) {
|
||
return false
|
||
}
|
||
}
|
||
|
||
// 只记录POST、PUT、PATCH请求的请求体
|
||
return method == "POST" || method == "PUT" || method == "PATCH"
|
||
}
|
||
|
||
// shouldLogResponseBody 判断是否应该记录响应体
|
||
func shouldLogResponseBody(status int, path string) bool {
|
||
// 排除大文件下载等路径
|
||
excludePatterns := []string{
|
||
"/download", "/file", "/image", "/video",
|
||
}
|
||
|
||
for _, pattern := range excludePatterns {
|
||
if strings.Contains(path, pattern) {
|
||
return false
|
||
}
|
||
}
|
||
|
||
// 只记录成功和客户端错误的响应体
|
||
return (status >= 200 && status < 300) || (status >= 400 && status < 500)
|
||
}
|
||
|
||
// CustomLoggerMiddleware 自定义日志中间件(保持向后兼容)
|
||
func CustomLoggerMiddleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
start := time.Now()
|
||
path := c.Request.URL.Path
|
||
raw := c.Request.URL.RawQuery
|
||
|
||
// 处理请求
|
||
c.Next()
|
||
|
||
// 计算请求耗时
|
||
latency := time.Since(start)
|
||
|
||
// 获取客户端IP
|
||
clientIP := utils.GetClientIP(
|
||
c.ClientIP(),
|
||
c.GetHeader("X-Forwarded-For"),
|
||
c.GetHeader("X-Real-IP"),
|
||
)
|
||
|
||
// 构建完整路径
|
||
if raw != "" {
|
||
path = path + "?" + raw
|
||
}
|
||
|
||
// 记录日志
|
||
logger.WithFields(map[string]interface{}{
|
||
"status": c.Writer.Status(),
|
||
"method": c.Request.Method,
|
||
"path": path,
|
||
"ip": clientIP,
|
||
"user_agent": c.Request.UserAgent(),
|
||
"latency": latency.String(),
|
||
"size": c.Writer.Size(),
|
||
}).Info("HTTP Request")
|
||
|
||
// 如果有错误,记录错误日志
|
||
if len(c.Errors) > 0 {
|
||
logger.WithFields(map[string]interface{}{
|
||
"method": c.Request.Method,
|
||
"path": path,
|
||
"ip": clientIP,
|
||
"errors": c.Errors.String(),
|
||
}).Error("Request Errors")
|
||
}
|
||
}
|
||
}
|
||
|
||
// OperationLogMiddleware 操作日志中间件
|
||
func OperationLogMiddleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
start := time.Now()
|
||
|
||
// 记录请求数据(仅对特定操作)
|
||
var requestData interface{}
|
||
if shouldLogRequestData(c.Request.Method, c.Request.URL.Path) {
|
||
// 这里可以根据需要记录请求体,但要注意敏感信息
|
||
requestData = map[string]interface{}{
|
||
"query_params": c.Request.URL.RawQuery,
|
||
"content_type": c.Request.Header.Get("Content-Type"),
|
||
}
|
||
}
|
||
|
||
// 处理请求
|
||
c.Next()
|
||
|
||
// 计算请求耗时
|
||
duration := time.Since(start)
|
||
|
||
// 获取用户信息
|
||
var userID uint
|
||
var userType string = "anonymous"
|
||
|
||
if uid, exists := c.Get("user_id"); exists {
|
||
if id, ok := uid.(uint); ok {
|
||
userID = id
|
||
userType = "user"
|
||
} else if idStr, ok := uid.(string); ok {
|
||
if id, err := strconv.ParseUint(idStr, 10, 32); err == nil {
|
||
userID = uint(id)
|
||
userType = "user"
|
||
}
|
||
}
|
||
}
|
||
|
||
// 获取客户端IP
|
||
clientIP := utils.GetClientIP(
|
||
c.ClientIP(),
|
||
c.GetHeader("X-Forwarded-For"),
|
||
c.GetHeader("X-Real-IP"),
|
||
)
|
||
|
||
// 记录操作日志
|
||
if shouldLogOperation(c.Request.Method, c.Request.URL.Path) {
|
||
logger.LogOperationWithContext(
|
||
c.Request.Context(),
|
||
userID,
|
||
userType,
|
||
getOperationType(c.Request.Method),
|
||
getResourceName(c.Request.URL.Path),
|
||
c.Request.Method,
|
||
c.Request.URL.Path,
|
||
clientIP,
|
||
requestData,
|
||
nil, // 响应数据可以根据需要添加
|
||
c.Writer.Status(),
|
||
duration,
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
// shouldLogRequestData 判断是否需要记录请求数据
|
||
func shouldLogRequestData(method, path string) bool {
|
||
// 对于POST、PUT、PATCH请求记录请求数据
|
||
return method == "POST" || method == "PUT" || method == "PATCH"
|
||
}
|
||
|
||
// shouldLogOperation 判断是否需要记录操作日志
|
||
func shouldLogOperation(method, path string) bool {
|
||
// 排除健康检查等不重要的请求
|
||
if path == "/health" || path == "/ping" {
|
||
return false
|
||
}
|
||
|
||
// 对于修改操作记录日志
|
||
return method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE"
|
||
}
|
||
|
||
// getOperationType 获取操作类型
|
||
func getOperationType(method string) string {
|
||
switch method {
|
||
case "POST":
|
||
return "create"
|
||
case "PUT", "PATCH":
|
||
return "update"
|
||
case "DELETE":
|
||
return "delete"
|
||
case "GET":
|
||
return "read"
|
||
default:
|
||
return "unknown"
|
||
}
|
||
}
|
||
|
||
// getResourceName 从路径获取资源名称
|
||
func getResourceName(path string) string {
|
||
// 简单的路径解析,可以根据实际需要完善
|
||
if len(path) > 1 {
|
||
parts := strings.Split(path[1:], "/")
|
||
if len(parts) > 0 {
|
||
return parts[len(parts)-1]
|
||
}
|
||
}
|
||
return "unknown"
|
||
}
|
||
|
||
// getEnvironment 获取当前环境
|
||
func getEnvironment() string {
|
||
env := os.Getenv("GO_ENV")
|
||
if env == "" {
|
||
env = os.Getenv("APP_ENV")
|
||
}
|
||
if env == "" {
|
||
env = os.Getenv("ENVIRONMENT")
|
||
}
|
||
if env == "" {
|
||
env = "development" // 默认环境
|
||
}
|
||
return strings.ToLower(env)
|
||
} |