Initial commit

This commit is contained in:
sjk
2025-11-17 13:32:54 +08:00
commit e788eab6eb
1659 changed files with 171560 additions and 0 deletions

View File

@@ -0,0 +1,99 @@
package middleware
import (
"dianshang/pkg/jwt"
"dianshang/pkg/response"
"strings"
"github.com/gin-gonic/gin"
)
// AuthMiddleware JWT认证中间件
func AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
token := c.GetHeader("Authorization")
if token == "" {
response.Unauthorized(c)
return
}
// 移除 "Bearer " 前缀
if strings.HasPrefix(token, "Bearer ") {
token = token[7:]
}
claims, err := jwt.ParseToken(token)
if err != nil {
response.Unauthorized(c)
return
}
// 将用户信息存储到上下文中
c.Set("user_id", claims.UserID)
c.Set("user_type", claims.UserType)
c.Next()
}
}
// AdminAuthMiddleware 管理员认证中间件
func AdminAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
token := c.GetHeader("Authorization")
if token == "" {
response.Unauthorized(c)
return
}
// 移除 "Bearer " 前缀
if strings.HasPrefix(token, "Bearer ") {
token = token[7:]
}
// 开发环境支持模拟token
if strings.HasPrefix(token, "mock_admin_token_") {
// 模拟管理员用户信息
c.Set("user_id", uint(1))
c.Set("user_type", "admin")
c.Next()
return
}
claims, err := jwt.ParseToken(token)
if err != nil {
response.Unauthorized(c)
return
}
// 检查用户类型是否为管理员
if claims.UserType != "admin" {
response.Forbidden(c)
return
}
// 将用户信息存储到上下文中
c.Set("user_id", claims.UserID)
c.Set("user_type", claims.UserType)
c.Next()
}
}
// OptionalAuthMiddleware 可选认证中间件(不强制要求登录)
func OptionalAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
token := c.GetHeader("Authorization")
if token != "" {
// 移除 "Bearer " 前缀
if strings.HasPrefix(token, "Bearer ") {
token = token[7:]
}
claims, err := jwt.ParseToken(token)
if err == nil {
// 将用户信息存储到上下文中
c.Set("user_id", claims.UserID)
c.Set("user_type", claims.UserType)
}
}
c.Next()
}
}

View File

@@ -0,0 +1,42 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
)
// CORSMiddleware 跨域中间件
func CORSMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
method := c.Request.Method
origin := c.Request.Header.Get("Origin")
// 设置允许的域名
if origin != "" {
c.Header("Access-Control-Allow-Origin", origin)
} else {
c.Header("Access-Control-Allow-Origin", "*")
}
// 设置允许的请求头
c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-User-ID")
// 设置允许的请求方法
c.Header("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 设置允许携带凭证
c.Header("Access-Control-Allow-Credentials", "true")
// 设置预检请求的缓存时间
c.Header("Access-Control-Max-Age", "86400")
// 处理预检请求
if method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}

View File

@@ -0,0 +1,497 @@
package middleware
import (
"bytes"
"context"
"dianshang/pkg/logger"
"dianshang/pkg/utils"
"fmt"
"io"
"os"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
// RequestIDMiddleware 请求ID中间件
func RequestIDMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 生成请求ID
requestID := logger.GenerateRequestID()
// 设置到上下文
ctx := context.WithValue(c.Request.Context(), logger.RequestIDKey, requestID)
c.Request = c.Request.WithContext(ctx)
// 设置到响应头
c.Header("X-Request-ID", requestID)
// 设置到Gin上下文方便后续使用
c.Set("request_id", requestID)
c.Next()
}
}
// UserContextMiddleware 用户上下文中间件
func UserContextMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 从JWT或其他认证方式获取用户ID
if userID, exists := c.Get("user_id"); exists {
ctx := context.WithValue(c.Request.Context(), logger.UserIDKey, userID)
c.Request = c.Request.WithContext(ctx)
}
c.Next()
}
}
// LoggerMiddleware 日志中间件
func LoggerMiddleware() gin.HandlerFunc {
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
// 记录请求日志
logger.LogRequest(
param.Method,
param.Path,
utils.GetClientIP(param.ClientIP, param.Request.Header.Get("X-Forwarded-For"), param.Request.Header.Get("X-Real-IP")),
param.Request.UserAgent(),
param.StatusCode,
param.Latency.Milliseconds(),
)
return ""
})
}
// responseBodyWriter 响应体写入器
type responseBodyWriter struct {
gin.ResponseWriter
body *bytes.Buffer
}
func (r responseBodyWriter) Write(b []byte) (int, error) {
r.body.Write(b)
return r.ResponseWriter.Write(b)
}
// EnhancedLoggerMiddleware 增强的日志中间件
func EnhancedLoggerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
raw := c.Request.URL.RawQuery
// 读取请求体
var requestBody []byte
if c.Request.Body != nil {
requestBody, _ = io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
}
// 创建响应体写入器
responseBody := &bytes.Buffer{}
writer := &responseBodyWriter{
ResponseWriter: c.Writer,
body: responseBody,
}
c.Writer = writer
// 处理请求
c.Next()
// 计算请求耗时
latency := time.Since(start)
// 获取客户端IP
clientIP := utils.GetClientIP(
c.ClientIP(),
c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP"),
)
// 构建完整路径
if raw != "" {
path = path + "?" + raw
}
// 获取用户ID
var userID interface{}
if uid, exists := c.Get("user_id"); exists {
userID = uid
}
// 获取请求ID
var requestID interface{}
if rid, exists := c.Get("request_id"); exists {
requestID = rid
}
// 准备请求数据
requestData := map[string]interface{}{
"headers": c.Request.Header,
"query": c.Request.URL.Query(),
}
// 添加请求体(如果不为空且不是敏感信息)
if len(requestBody) > 0 && shouldLogRequestBody(c.Request.Method, path) {
requestData["body"] = string(requestBody)
}
// 准备响应数据
responseData := map[string]interface{}{
"headers": writer.Header(),
}
// 添加响应体(如果不为空且不是敏感信息)
if responseBody.Len() > 0 && shouldLogResponseBody(c.Writer.Status(), path) {
responseData["body"] = responseBody.String()
}
printDetailedConsoleLog(c.Request.Method, path, clientIP, c.Request.UserAgent(),
c.Writer.Status(), latency, requestID, userID, requestData, responseData)
// 记录请求日志到文件(所有环境)
logger.LogRequestWithContext(
c.Request.Context(),
c.Request.Method,
path,
clientIP,
c.Request.UserAgent(),
c.Writer.Status(),
latency.Milliseconds(),
)
// 记录性能日志(如果请求耗时较长)
if latency > 1*time.Second {
logger.LogPerformanceWithContext(
c.Request.Context(),
"http_request",
latency,
map[string]interface{}{
"method": c.Request.Method,
"path": path,
"status": c.Writer.Status(),
"user_id": userID,
"request_id": requestID,
},
)
}
// 如果有错误,记录错误日志
if len(c.Errors) > 0 {
logger.LogErrorWithContext(
c.Request.Context(),
c.Errors[0].Err,
map[string]interface{}{
"method": c.Request.Method,
"path": path,
"ip": clientIP,
"user_id": userID,
"request_id": requestID,
"errors": c.Errors.String(),
},
)
}
}
}
// printDetailedConsoleLog 打印详细的控制台日志(仅开发环境)
func printDetailedConsoleLog(method, path, ip, userAgent string, status int, latency time.Duration,
requestID, userID interface{}, requestData, responseData map[string]interface{}) {
// 获取状态颜色
statusColor := getStatusColor(status)
methodColor := getMethodColor(method)
// 打印分隔线
fmt.Println("\n" + strings.Repeat("=", 80))
// 打印基本信息使用简单的ASCII字符作为分隔符
fmt.Printf("%s%s%s %s%s%s - Status: %s%d%s - Latency: %s%v%s\n",
methodColor, method, "\033[0m",
"\033[36m", path, "\033[0m",
statusColor, status, "\033[0m",
"\033[33m", latency, "\033[0m")
// 打印请求信息
fmt.Printf("IP: %s - Request ID: %v - User ID: %v\n", ip, requestID, userID)
fmt.Printf("User-Agent: %s\n", userAgent)
// 打印请求参数
if query, ok := requestData["query"].(map[string][]string); ok && len(query) > 0 {
fmt.Printf("Query Params:\n")
for k, v := range query {
fmt.Printf(" %s: %v\n", k, v)
}
}
// 打印请求体
if body, ok := requestData["body"].(string); ok && body != "" {
fmt.Printf("Request Body:\n")
if len(body) > 2000 {
fmt.Printf(" %s...(truncated)\n", body[:2000])
} else {
fmt.Printf(" %s\n", body)
}
}
// 打印响应体(限制长度,避免刷屏)
if body, ok := responseData["body"].(string); ok && body != "" {
fmt.Printf("Response Body:\n")
if len(body) > 1000 {
fmt.Printf(" %s...(truncated, total %d chars)\n", body[:1000], len(body))
} else {
fmt.Printf(" %s\n", body)
}
}
// 打印时间戳
fmt.Printf("Time: %s\n", time.Now().Format("2006-01-02 15:04:05"))
// 打印分隔线
fmt.Println(strings.Repeat("=", 80))
}
// getStatusColor 获取状态码颜色
func getStatusColor(status int) string {
switch {
case status >= 200 && status < 300:
return "\033[32m" // 绿色
case status >= 300 && status < 400:
return "\033[33m" // 黄色
case status >= 400 && status < 500:
return "\033[31m" // 红色
case status >= 500:
return "\033[35m" // 紫色
default:
return "\033[37m" // 白色
}
}
// getMethodColor 获取方法颜色
func getMethodColor(method string) string {
switch method {
case "GET":
return "\033[32m" // 绿色
case "POST":
return "\033[34m" // 蓝色
case "PUT":
return "\033[33m" // 黄色
case "DELETE":
return "\033[31m" // 红色
case "PATCH":
return "\033[36m" // 青色
default:
return "\033[37m" // 白色
}
}
// shouldLogRequestBody 判断是否应该记录请求体
func shouldLogRequestBody(method, path string) bool {
// 排除敏感路径
sensitivePatterns := []string{
"/login", "/register", "/password", "/auth",
}
for _, pattern := range sensitivePatterns {
if strings.Contains(path, pattern) {
return false
}
}
// 只记录POST、PUT、PATCH请求的请求体
return method == "POST" || method == "PUT" || method == "PATCH"
}
// shouldLogResponseBody 判断是否应该记录响应体
func shouldLogResponseBody(status int, path string) bool {
// 排除大文件下载等路径
excludePatterns := []string{
"/download", "/file", "/image", "/video",
}
for _, pattern := range excludePatterns {
if strings.Contains(path, pattern) {
return false
}
}
// 只记录成功和客户端错误的响应体
return (status >= 200 && status < 300) || (status >= 400 && status < 500)
}
// CustomLoggerMiddleware 自定义日志中间件(保持向后兼容)
func CustomLoggerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
raw := c.Request.URL.RawQuery
// 处理请求
c.Next()
// 计算请求耗时
latency := time.Since(start)
// 获取客户端IP
clientIP := utils.GetClientIP(
c.ClientIP(),
c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP"),
)
// 构建完整路径
if raw != "" {
path = path + "?" + raw
}
// 记录日志
logger.WithFields(map[string]interface{}{
"status": c.Writer.Status(),
"method": c.Request.Method,
"path": path,
"ip": clientIP,
"user_agent": c.Request.UserAgent(),
"latency": latency.String(),
"size": c.Writer.Size(),
}).Info("HTTP Request")
// 如果有错误,记录错误日志
if len(c.Errors) > 0 {
logger.WithFields(map[string]interface{}{
"method": c.Request.Method,
"path": path,
"ip": clientIP,
"errors": c.Errors.String(),
}).Error("Request Errors")
}
}
}
// OperationLogMiddleware 操作日志中间件
func OperationLogMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
// 记录请求数据(仅对特定操作)
var requestData interface{}
if shouldLogRequestData(c.Request.Method, c.Request.URL.Path) {
// 这里可以根据需要记录请求体,但要注意敏感信息
requestData = map[string]interface{}{
"query_params": c.Request.URL.RawQuery,
"content_type": c.Request.Header.Get("Content-Type"),
}
}
// 处理请求
c.Next()
// 计算请求耗时
duration := time.Since(start)
// 获取用户信息
var userID uint
var userType string = "anonymous"
if uid, exists := c.Get("user_id"); exists {
if id, ok := uid.(uint); ok {
userID = id
userType = "user"
} else if idStr, ok := uid.(string); ok {
if id, err := strconv.ParseUint(idStr, 10, 32); err == nil {
userID = uint(id)
userType = "user"
}
}
}
// 获取客户端IP
clientIP := utils.GetClientIP(
c.ClientIP(),
c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP"),
)
// 记录操作日志
if shouldLogOperation(c.Request.Method, c.Request.URL.Path) {
logger.LogOperationWithContext(
c.Request.Context(),
userID,
userType,
getOperationType(c.Request.Method),
getResourceName(c.Request.URL.Path),
c.Request.Method,
c.Request.URL.Path,
clientIP,
requestData,
nil, // 响应数据可以根据需要添加
c.Writer.Status(),
duration,
)
}
}
}
// shouldLogRequestData 判断是否需要记录请求数据
func shouldLogRequestData(method, path string) bool {
// 对于POST、PUT、PATCH请求记录请求数据
return method == "POST" || method == "PUT" || method == "PATCH"
}
// shouldLogOperation 判断是否需要记录操作日志
func shouldLogOperation(method, path string) bool {
// 排除健康检查等不重要的请求
if path == "/health" || path == "/ping" {
return false
}
// 对于修改操作记录日志
return method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE"
}
// getOperationType 获取操作类型
func getOperationType(method string) string {
switch method {
case "POST":
return "create"
case "PUT", "PATCH":
return "update"
case "DELETE":
return "delete"
case "GET":
return "read"
default:
return "unknown"
}
}
// getResourceName 从路径获取资源名称
func getResourceName(path string) string {
// 简单的路径解析,可以根据实际需要完善
if len(path) > 1 {
parts := strings.Split(path[1:], "/")
if len(parts) > 0 {
return parts[len(parts)-1]
}
}
return "unknown"
}
// getEnvironment 获取当前环境
func getEnvironment() string {
env := os.Getenv("GO_ENV")
if env == "" {
env = os.Getenv("APP_ENV")
}
if env == "" {
env = os.Getenv("ENVIRONMENT")
}
if env == "" {
env = "development" // 默认环境
}
return strings.ToLower(env)
}