package middleware import ( "bytes" "fmt" "io" "strings" "time" "github.com/duke-git/lancet/v2/random" "github.com/gin-gonic/gin" "github.com/go-nunu/nunu-layout-advanced/pkg/log" "go.uber.org/zap" ) const ( MaxBodySize = 10 * 1024 // 10KB TraceIDKey = "trace_id" ) var ( // 跳过的路径 skipPaths = map[string]bool{ "/health": true, "/metrics": true, "/favicon.ico": true, "/ping": true, } // 需要记录的请求头 importantHeaders = []string{ "authorization", "x-request-id", "x-real-ip", "x-forwarded-for", "user-agent", "content-type", } // 敏感字段 sensitiveFields = []string{ "password", "passwd", "pwd", "token", "access_token", "refresh_token", "secret", "api_key", "apikey", "authorization", } ) func RequestLogMiddleware(logger *log.Logger) gin.HandlerFunc { return func(ctx *gin.Context) { // 跳过不需要记录的路径 if skipPaths[ctx.Request.URL.Path] { ctx.Next() return } // 生成简短的追踪ID traceID := generateTraceID() ctx.Set(TraceIDKey, traceID) // 构建日志字段 fields := []zap.Field{ zap.String("trace_id", traceID), zap.String("method", ctx.Request.Method), zap.String("path", ctx.Request.URL.Path), zap.String("client_ip", ctx.ClientIP()), } // 记录查询参数 if query := ctx.Request.URL.RawQuery; query != "" { fields = append(fields, zap.String("query", query)) } // 记录重要请求头 if headers := getImportantHeaders(ctx); len(headers) > 0 { fields = append(fields, zap.Any("headers", headers)) } // 记录请求体(仅限特定方法) if shouldLogRequestBody(ctx) { if bodyLog := getRequestBody(ctx); bodyLog != "" { fields = append(fields, zap.String("body", bodyLog)) } } // 设置日志上下文 logger.WithValue(ctx, zap.String("trace_id", traceID)) // 记录请求开始时间 ctx.Set("start_time", time.Now()) logger.Info("API Request", fields...) ctx.Next() } } func ResponseLogMiddleware(logger *log.Logger) gin.HandlerFunc { return func(ctx *gin.Context) { // 跳过不需要记录的路径 if skipPaths[ctx.Request.URL.Path] { ctx.Next() return } // 包装响应写入器 blw := &bodyLogWriter{ body: bytes.NewBufferString(""), ResponseWriter: ctx.Writer, } ctx.Writer = blw // 执行处理 ctx.Next() // 计算耗时 var duration time.Duration if startTime, exists := ctx.Get("start_time"); exists { if st, ok := startTime.(time.Time); ok { duration = time.Since(st) } } // 构建响应日志字段 fields := []zap.Field{ zap.Int("status", ctx.Writer.Status()), zap.String("duration", duration.String()), zap.Int64("duration_ms", duration.Milliseconds()), } // 记录响应体(限制大小和内容类型) if shouldLogResponseBody(ctx) { bodyStr := blw.body.String() if len(bodyStr) > MaxBodySize { fields = append(fields, zap.String("body", fmt.Sprintf("[TRUNCATED: %d bytes]", len(bodyStr)))) } else if len(bodyStr) > 0 { fields = append(fields, zap.String("body", maskSensitiveData(bodyStr))) } } // 记录错误信息 if len(ctx.Errors) > 0 { fields = append(fields, zap.Any("errors", ctx.Errors)) } // 添加状态分类 statusCode := ctx.Writer.Status() if statusCode >= 500 { fields = append(fields, zap.String("level", "error")) } else if statusCode >= 400 { fields = append(fields, zap.String("level", "warn")) } else { fields = append(fields, zap.String("level", "info")) } // 统一使用 Info 级别,避免堆栈信息 logger.WithContext(ctx).Info("API Response", fields...) } } // bodyLogWriter 包装响应写入器以捕获响应体 type bodyLogWriter struct { gin.ResponseWriter body *bytes.Buffer } func (w *bodyLogWriter) Write(b []byte) (int, error) { w.body.Write(b) return w.ResponseWriter.Write(b) } // 辅助函数 func generateTraceID() string { // 使用时间戳 + 4位随机字符串 return fmt.Sprintf("%d-%s", time.Now().Unix(), random.RandString(4)) } func getImportantHeaders(ctx *gin.Context) map[string]string { headers := make(map[string]string) for _, key := range importantHeaders { if value := ctx.GetHeader(key); value != "" { headers[key] = value } } return headers } func shouldLogRequestBody(ctx *gin.Context) bool { method := ctx.Request.Method return method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" } func shouldLogResponseBody(ctx *gin.Context) bool { // 检查内容类型 contentType := ctx.Writer.Header().Get("Content-Type") return strings.Contains(contentType, "json") || strings.Contains(contentType, "xml") || strings.Contains(contentType, "text") } func getRequestBody(ctx *gin.Context) string { if ctx.Request.Body == nil { return "" } bodyBytes, err := ctx.GetRawData() if err != nil { return "" } // 重置请求体 ctx.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 检查大小 if len(bodyBytes) == 0 { return "" } if len(bodyBytes) > MaxBodySize { return fmt.Sprintf("[TRUNCATED: %d bytes]", len(bodyBytes)) } // 脱敏处理 return maskSensitiveData(string(bodyBytes)) } func maskSensitiveData(data string) string { result := data for _, field := range sensitiveFields { // 简单的JSON字段脱敏 result = maskJSONField(result, field) // URL参数脱敏 result = maskURLParam(result, field) } return result } func maskJSONField(data, field string) string { lowerData := strings.ToLower(data) lowerField := strings.ToLower(field) // 查找字段位置(不区分大小写) idx := strings.Index(lowerData, `"`+lowerField+`"`) if idx == -1 { idx = strings.Index(lowerData, `'`+lowerField+`'`) if idx == -1 { return data } } // 找到冒号位置 colonIdx := strings.Index(data[idx:], ":") if colonIdx == -1 { return data } colonIdx += idx // 找到值的开始和结束位置 valueStart := colonIdx + 1 for valueStart < len(data) && (data[valueStart] == ' ' || data[valueStart] == '\t') { valueStart++ } if valueStart >= len(data) { return data } // 判断值的类型 var valueEnd int if data[valueStart] == '"' || data[valueStart] == '\'' { // 字符串值 quote := data[valueStart] valueEnd = valueStart + 1 for valueEnd < len(data) && data[valueEnd] != quote { if data[valueEnd] == '\\' { valueEnd++ // 跳过转义字符 } valueEnd++ } if valueEnd < len(data) { valueEnd++ // 包含结束引号 } } else { // 非字符串值(数字、布尔值等) valueEnd = valueStart for valueEnd < len(data) && data[valueEnd] != ',' && data[valueEnd] != '}' && data[valueEnd] != ']' && data[valueEnd] != '\n' && data[valueEnd] != '\r' { valueEnd++ } } // 替换为脱敏值 return data[:valueStart] + `"***"` + data[valueEnd:] } func maskURLParam(data, param string) string { lowerData := strings.ToLower(data) lowerParam := strings.ToLower(param) // 查找参数位置 idx := strings.Index(lowerData, lowerParam+"=") if idx == -1 { return data } // 确保是参数开始位置(前面是?或&) if idx > 0 && data[idx-1] != '?' && data[idx-1] != '&' && data[idx-1] != ' ' && data[idx-1] != '\n' { return data } // 找到参数值结束位置 valueStart := idx + len(param) + 1 valueEnd := valueStart for valueEnd < len(data) && data[valueEnd] != '&' && data[valueEnd] != ' ' && data[valueEnd] != '\n' && data[valueEnd] != '\r' { valueEnd++ } // 替换为脱敏值 return data[:valueStart] + "***" + data[valueEnd:] }