first commit
This commit is contained in:
59
go_backend/middleware/auth.go
Normal file
59
go_backend/middleware/auth.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"ai_xhs/common"
|
||||
"ai_xhs/utils"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AuthMiddleware JWT认证中间件
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从请求头获取token
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
common.Error(c, common.CodeUnauthorized, "未登录或token为空")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查token格式
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if !(len(parts) == 2 && parts[0] == "Bearer") {
|
||||
common.Error(c, common.CodeUnauthorized, "token格式错误")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 解析token
|
||||
claims, err := utils.ParseToken(parts[1])
|
||||
if err != nil {
|
||||
common.Error(c, common.CodeUnauthorized, "无效的token")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将员工ID存入上下文
|
||||
c.Set("employee_id", claims.EmployeeID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
171
go_backend/middleware/logger.go
Normal file
171
go_backend/middleware/logger.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// responseWriter 包装 gin.ResponseWriter 以捕获响应体
|
||||
type responseWriter struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
}
|
||||
|
||||
func (w responseWriter) Write(b []byte) (int, error) {
|
||||
w.body.Write(b)
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// RequestLogger API请求和响应日志中间件
|
||||
func RequestLogger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 读取请求体
|
||||
var requestBody []byte
|
||||
if c.Request.Body != nil {
|
||||
requestBody, _ = io.ReadAll(c.Request.Body)
|
||||
// 恢复请求体供后续处理使用
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
}
|
||||
|
||||
// 包装 ResponseWriter 以捕获响应
|
||||
blw := &responseWriter{
|
||||
ResponseWriter: c.Writer,
|
||||
body: bytes.NewBufferString(""),
|
||||
}
|
||||
c.Writer = blw
|
||||
|
||||
// 打印请求信息
|
||||
printRequest(c, requestBody)
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 计算请求耗时
|
||||
duration := time.Since(startTime)
|
||||
|
||||
// 打印响应信息
|
||||
printResponse(c, blw.body.Bytes(), duration)
|
||||
}
|
||||
}
|
||||
|
||||
// printRequest 打印请求详情
|
||||
func printRequest(c *gin.Context, body []byte) {
|
||||
fmt.Println("\n" + strings.Repeat("=", 100))
|
||||
fmt.Printf("📥 [REQUEST] %s\n", time.Now().Format("2006-01-02 15:04:05"))
|
||||
fmt.Println(strings.Repeat("=", 100))
|
||||
|
||||
// 请求基本信息
|
||||
fmt.Printf("Method: %s\n", c.Request.Method)
|
||||
fmt.Printf("Path: %s\n", c.Request.URL.Path)
|
||||
fmt.Printf("Full URL: %s\n", c.Request.URL.String())
|
||||
fmt.Printf("Client IP: %s\n", c.ClientIP())
|
||||
fmt.Printf("User-Agent: %s\n", c.Request.UserAgent())
|
||||
|
||||
// 请求头
|
||||
if len(c.Request.Header) > 0 {
|
||||
fmt.Println("\n--- Headers ---")
|
||||
for key, values := range c.Request.Header {
|
||||
// 过滤敏感信息
|
||||
if strings.ToLower(key) == "authorization" || strings.ToLower(key) == "cookie" {
|
||||
fmt.Printf("%s: [HIDDEN]\n", key)
|
||||
} else {
|
||||
fmt.Printf("%s: %s\n", key, strings.Join(values, ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 查询参数
|
||||
if len(c.Request.URL.Query()) > 0 {
|
||||
fmt.Println("\n--- Query Parameters ---")
|
||||
for key, values := range c.Request.URL.Query() {
|
||||
fmt.Printf("%s: %s\n", key, strings.Join(values, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
// 请求体
|
||||
if len(body) > 0 {
|
||||
fmt.Println("\n--- Request Body ---")
|
||||
// 尝试格式化 JSON
|
||||
var prettyJSON bytes.Buffer
|
||||
if err := json.Indent(&prettyJSON, body, "", " "); err == nil {
|
||||
fmt.Println(prettyJSON.String())
|
||||
} else {
|
||||
fmt.Println(string(body))
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println(strings.Repeat("-", 100))
|
||||
}
|
||||
|
||||
// printResponse 打印响应详情
|
||||
func printResponse(c *gin.Context, body []byte, duration time.Duration) {
|
||||
fmt.Println("\n" + strings.Repeat("=", 100))
|
||||
fmt.Printf("📤 [RESPONSE] %s | Duration: %v\n", time.Now().Format("2006-01-02 15:04:05"), duration)
|
||||
fmt.Println(strings.Repeat("=", 100))
|
||||
|
||||
// 响应基本信息
|
||||
fmt.Printf("Status Code: %d %s\n", c.Writer.Status(), getStatusText(c.Writer.Status()))
|
||||
fmt.Printf("Size: %d bytes\n", c.Writer.Size())
|
||||
|
||||
// 响应头
|
||||
if len(c.Writer.Header()) > 0 {
|
||||
fmt.Println("\n--- Response Headers ---")
|
||||
for key, values := range c.Writer.Header() {
|
||||
fmt.Printf("%s: %s\n", key, strings.Join(values, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
// 响应体
|
||||
if len(body) > 0 {
|
||||
fmt.Println("\n--- Response Body ---")
|
||||
// 尝试格式化 JSON
|
||||
var prettyJSON bytes.Buffer
|
||||
if err := json.Indent(&prettyJSON, body, "", " "); err == nil {
|
||||
fmt.Println(prettyJSON.String())
|
||||
} else {
|
||||
fmt.Println(string(body))
|
||||
}
|
||||
}
|
||||
|
||||
// 性能提示
|
||||
if duration > 1*time.Second {
|
||||
fmt.Printf("\n⚠️ WARNING: Request took %.2f seconds (>1s)\n", duration.Seconds())
|
||||
} else if duration > 500*time.Millisecond {
|
||||
fmt.Printf("\n⚡ NOTICE: Request took %.0f milliseconds (>500ms)\n", duration.Milliseconds())
|
||||
}
|
||||
|
||||
fmt.Println(strings.Repeat("=", 100))
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// getStatusText 获取状态码文本
|
||||
func getStatusText(code int) string {
|
||||
switch code {
|
||||
case 200:
|
||||
return "OK"
|
||||
case 201:
|
||||
return "Created"
|
||||
case 204:
|
||||
return "No Content"
|
||||
case 400:
|
||||
return "Bad Request"
|
||||
case 401:
|
||||
return "Unauthorized"
|
||||
case 403:
|
||||
return "Forbidden"
|
||||
case 404:
|
||||
return "Not Found"
|
||||
case 500:
|
||||
return "Internal Server Error"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user