package middleware import ( "bytes" "encoding/json" "fmt" "io" "mime/multipart" "strings" "time" "github.com/duke-git/lancet/v2/random" "github.com/gin-gonic/gin" "github.com/go-nunu/nunu-layout-advanced/pkg/log" "go.uber.org/zap" ) const ( MaxBodySize = 10 * 1024 // 10KB TraceIDKey = "trace_id" ) var ( // 跳过的路径 skipPaths = map[string]bool{ "/health": true, "/metrics": true, "/favicon.ico": true, "/ping": true, } // 需要记录的请求头 importantHeaders = []string{ "authorization", "x-request-id", "x-real-ip", "x-forwarded-for", "user-agent", "content-type", } ) func RequestLogMiddleware(logger *log.Logger) gin.HandlerFunc { return func(ctx *gin.Context) { // 跳过不需要记录的路径 if skipPaths[ctx.Request.URL.Path] { ctx.Next() return } // 生成简短的追踪ID traceID := generateTraceID() ctx.Set(TraceIDKey, traceID) // 构建日志字段 fields := []zap.Field{ zap.String("trace_id", traceID), zap.String("method", ctx.Request.Method), zap.String("path", ctx.Request.URL.Path), zap.String("client_ip", ctx.ClientIP()), } // 记录查询参数 if query := ctx.Request.URL.RawQuery; query != "" { fields = append(fields, zap.String("query", query)) } // 记录重要请求头 if headers := getImportantHeaders(ctx); len(headers) > 0 { fields = append(fields, zap.Any("headers", headers)) } // 记录请求体(仅限特定方法) if shouldLogRequestBody(ctx) { contentType := ctx.GetHeader("Content-Type") // 特殊处理 multipart/form-data if strings.Contains(contentType, "multipart/form-data") { if formData := parseMultipartData(ctx); formData != nil { fields = append(fields, zap.Any("form_data", formData)) } } else if strings.Contains(contentType, "application/json") { // 处理 JSON 请求体 if bodyData := getJSONBody(ctx); bodyData != nil { fields = append(fields, zap.Any("body", bodyData)) } } else { // 处理其他类型的请求体 if bodyLog := getRequestBody(ctx); bodyLog != "" { fields = append(fields, zap.String("body", bodyLog)) } } } // 设置日志上下文 logger.WithValue(ctx, zap.String(TraceIDKey, traceID)) // 记录请求开始时间 ctx.Set("start_time", time.Now()) logger.Info("API Request", fields...) ctx.Next() } } func ResponseLogMiddleware(logger *log.Logger) gin.HandlerFunc { return func(ctx *gin.Context) { // 跳过不需要记录的路径 if skipPaths[ctx.Request.URL.Path] { ctx.Next() return } // 包装响应写入器 blw := &bodyLogWriter{ body: bytes.NewBufferString(""), ResponseWriter: ctx.Writer, } ctx.Writer = blw // 执行处理 ctx.Next() // 计算耗时 var duration time.Duration if startTime, exists := ctx.Get("start_time"); exists { if st, ok := startTime.(time.Time); ok { duration = time.Since(st) } } // 构建响应日志字段 fields := []zap.Field{ zap.Int("status", ctx.Writer.Status()), zap.String("duration", duration.String()), zap.Int64("duration_ms", duration.Milliseconds()), } // 记录响应体(限制大小和内容类型) if shouldLogResponseBody(ctx) { bodyStr := blw.body.String() if len(bodyStr) > MaxBodySize { fields = append(fields, zap.String("body", fmt.Sprintf("[TRUNCATED: %d bytes]", len(bodyStr)))) } else if len(bodyStr) > 0 { // 尝试解析 JSON 响应 if json.Valid([]byte(bodyStr)) { var jsonData interface{} if err := json.Unmarshal([]byte(bodyStr), &jsonData); err == nil { fields = append(fields, zap.Any("body", jsonData)) } else { fields = append(fields, zap.String("body", bodyStr)) } } else { fields = append(fields, zap.String("body", bodyStr)) } } } // 记录错误信息 if len(ctx.Errors) > 0 { fields = append(fields, zap.Any("errors", ctx.Errors)) } // 添加状态分类 statusCode := ctx.Writer.Status() if statusCode >= 500 { fields = append(fields, zap.String("level", "error")) } else if statusCode >= 400 { fields = append(fields, zap.String("level", "warn")) } else { fields = append(fields, zap.String("level", "info")) } // 统一使用 Info 级别,避免堆栈信息 logger.WithContext(ctx).Info("API Response", fields...) } } // 获取 JSON 请求体并解析为 map func getJSONBody(ctx *gin.Context) interface{} { if ctx.Request.Body == nil { return nil } bodyBytes, err := ctx.GetRawData() if err != nil { return nil } // 重置请求体 ctx.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 检查大小 if len(bodyBytes) == 0 { return nil } if len(bodyBytes) > MaxBodySize { return fmt.Sprintf("[TRUNCATED: %d bytes]", len(bodyBytes)) } // 解析 JSON var data interface{} if err := json.Unmarshal(bodyBytes, &data); err != nil { // 如果解析失败,返回原始字符串 return string(bodyBytes) } return data } // 解析 multipart/form-data func parseMultipartData(ctx *gin.Context) map[string]interface{} { // 保存原始请求体 bodyBytes, err := ctx.GetRawData() if err != nil { return nil } // 重置请求体 ctx.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 创建新的 reader reader := multipart.NewReader(bytes.NewReader(bodyBytes), extractBoundary(ctx.GetHeader("Content-Type"))) if reader == nil { return nil } formData := make(map[string]interface{}) for { part, err := reader.NextPart() if err == io.EOF { break } if err != nil { return nil } name := part.FormName() if name == "" { continue } // 读取内容 value, err := io.ReadAll(part) if err != nil { continue } valueStr := string(value) // 尝试解析为 JSON if json.Valid(value) { var jsonData interface{} if err := json.Unmarshal(value, &jsonData); err == nil { formData[name] = jsonData } else { formData[name] = valueStr } } else { formData[name] = valueStr } part.Close() } return formData } // 提取 boundary func extractBoundary(contentType string) string { if !strings.Contains(contentType, "boundary=") { return "" } parts := strings.Split(contentType, "boundary=") if len(parts) < 2 { return "" } boundary := parts[1] // 移除可能的引号 boundary = strings.Trim(boundary, `"`) return boundary } // bodyLogWriter 包装响应写入器以捕获响应体 type bodyLogWriter struct { gin.ResponseWriter body *bytes.Buffer } func (w *bodyLogWriter) Write(b []byte) (int, error) { w.body.Write(b) return w.ResponseWriter.Write(b) } // 辅助函数 func generateTraceID() string { // 使用时间戳 + 4位随机字符串 return fmt.Sprintf("%d-%s", time.Now().Unix(), random.RandString(4)) } func getImportantHeaders(ctx *gin.Context) map[string]string { headers := make(map[string]string) for _, key := range importantHeaders { if value := ctx.GetHeader(key); value != "" { headers[key] = value } } return headers } func shouldLogRequestBody(ctx *gin.Context) bool { method := ctx.Request.Method return method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" } func shouldLogResponseBody(ctx *gin.Context) bool { // 检查内容类型 contentType := ctx.Writer.Header().Get("Content-Type") return strings.Contains(contentType, "json") || strings.Contains(contentType, "xml") || strings.Contains(contentType, "text") } func getRequestBody(ctx *gin.Context) string { if ctx.Request.Body == nil { return "" } bodyBytes, err := ctx.GetRawData() if err != nil { return "" } // 重置请求体 ctx.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 检查大小 if len(bodyBytes) == 0 { return "" } if len(bodyBytes) > MaxBodySize { return fmt.Sprintf("[TRUNCATED: %d bytes]", len(bodyBytes)) } return string(bodyBytes) }