|
@@ -2,53 +2,317 @@ package middleware
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
- "github.com/duke-git/lancet/v2/cryptor"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "strings"
|
|
|
+ "time"
|
|
|
+
|
|
|
"github.com/duke-git/lancet/v2/random"
|
|
|
"github.com/gin-gonic/gin"
|
|
|
"github.com/go-nunu/nunu-layout-advanced/pkg/log"
|
|
|
"go.uber.org/zap"
|
|
|
- "io"
|
|
|
- "time"
|
|
|
+)
|
|
|
+
|
|
|
+const (
|
|
|
+ MaxBodySize = 10 * 1024 // 10KB
|
|
|
+ TraceIDKey = "trace_id"
|
|
|
+)
|
|
|
+
|
|
|
+var (
|
|
|
+ // 跳过的路径
|
|
|
+ skipPaths = map[string]bool{
|
|
|
+ "/health": true,
|
|
|
+ "/metrics": true,
|
|
|
+ "/favicon.ico": true,
|
|
|
+ "/ping": true,
|
|
|
+ }
|
|
|
+
|
|
|
+ // 需要记录的请求头
|
|
|
+ importantHeaders = []string{
|
|
|
+ "authorization",
|
|
|
+ "x-request-id",
|
|
|
+ "x-real-ip",
|
|
|
+ "x-forwarded-for",
|
|
|
+ "user-agent",
|
|
|
+ "content-type",
|
|
|
+ }
|
|
|
+
|
|
|
+ // 敏感字段
|
|
|
+ sensitiveFields = []string{
|
|
|
+ "password", "passwd", "pwd",
|
|
|
+ "token", "access_token", "refresh_token",
|
|
|
+ "secret", "api_key", "apikey",
|
|
|
+ "authorization",
|
|
|
+ }
|
|
|
)
|
|
|
|
|
|
func RequestLogMiddleware(logger *log.Logger) gin.HandlerFunc {
|
|
|
return func(ctx *gin.Context) {
|
|
|
- // The configuration is initialized once per request
|
|
|
- uuid, err := random.UUIdV4()
|
|
|
- if err != nil {
|
|
|
+ // 跳过不需要记录的路径
|
|
|
+ if skipPaths[ctx.Request.URL.Path] {
|
|
|
+ ctx.Next()
|
|
|
return
|
|
|
}
|
|
|
- trace := cryptor.Md5String(uuid)
|
|
|
- logger.WithValue(ctx, zap.String("trace", trace))
|
|
|
- logger.WithValue(ctx, zap.String("request_method", ctx.Request.Method))
|
|
|
- logger.WithValue(ctx, zap.Any("request_headers", ctx.Request.Header))
|
|
|
- logger.WithValue(ctx, zap.String("request_url", ctx.Request.URL.String()))
|
|
|
- if ctx.Request.Body != nil {
|
|
|
- bodyBytes, _ := ctx.GetRawData()
|
|
|
- ctx.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 关键点
|
|
|
- logger.WithValue(ctx, zap.String("request_params", string(bodyBytes)))
|
|
|
- }
|
|
|
- logger.WithContext(ctx).Info("Request")
|
|
|
+
|
|
|
+ // 生成简短的追踪ID
|
|
|
+ traceID := generateTraceID()
|
|
|
+ ctx.Set(TraceIDKey, traceID)
|
|
|
+
|
|
|
+ // 构建日志字段
|
|
|
+ fields := []zap.Field{
|
|
|
+ zap.String("trace_id", traceID),
|
|
|
+ zap.String("method", ctx.Request.Method),
|
|
|
+ zap.String("path", ctx.Request.URL.Path),
|
|
|
+ zap.String("client_ip", ctx.ClientIP()),
|
|
|
+ }
|
|
|
+
|
|
|
+ // 记录查询参数
|
|
|
+ if query := ctx.Request.URL.RawQuery; query != "" {
|
|
|
+ fields = append(fields, zap.String("query", query))
|
|
|
+ }
|
|
|
+
|
|
|
+ // 记录重要请求头
|
|
|
+ if headers := getImportantHeaders(ctx); len(headers) > 0 {
|
|
|
+ fields = append(fields, zap.Any("headers", headers))
|
|
|
+ }
|
|
|
+
|
|
|
+ // 记录请求体(仅限特定方法)
|
|
|
+ if shouldLogRequestBody(ctx) {
|
|
|
+ if bodyLog := getRequestBody(ctx); bodyLog != "" {
|
|
|
+ fields = append(fields, zap.String("body", bodyLog))
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 设置日志上下文
|
|
|
+ logger.WithValue(ctx, zap.String("trace_id", traceID))
|
|
|
+
|
|
|
+ // 记录请求开始时间
|
|
|
+ ctx.Set("start_time", time.Now())
|
|
|
+
|
|
|
+ logger.Info("API Request", fields...)
|
|
|
ctx.Next()
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
func ResponseLogMiddleware(logger *log.Logger) gin.HandlerFunc {
|
|
|
return func(ctx *gin.Context) {
|
|
|
- blw := &bodyLogWriter{body: bytes.NewBufferString(""), ResponseWriter: ctx.Writer}
|
|
|
+ // 跳过不需要记录的路径
|
|
|
+ if skipPaths[ctx.Request.URL.Path] {
|
|
|
+ ctx.Next()
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // 包装响应写入器
|
|
|
+ blw := &bodyLogWriter{
|
|
|
+ body: bytes.NewBufferString(""),
|
|
|
+ ResponseWriter: ctx.Writer,
|
|
|
+ }
|
|
|
ctx.Writer = blw
|
|
|
- startTime := time.Now()
|
|
|
+
|
|
|
+ // 执行处理
|
|
|
ctx.Next()
|
|
|
- duration := time.Since(startTime).String()
|
|
|
- logger.WithContext(ctx).Info("Response", zap.Any("response_body", blw.body.String()), zap.Any("time", duration))
|
|
|
+
|
|
|
+ // 计算耗时
|
|
|
+ var duration time.Duration
|
|
|
+ if startTime, exists := ctx.Get("start_time"); exists {
|
|
|
+ if st, ok := startTime.(time.Time); ok {
|
|
|
+ duration = time.Since(st)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 构建响应日志字段
|
|
|
+ fields := []zap.Field{
|
|
|
+ zap.Int("status", ctx.Writer.Status()),
|
|
|
+ zap.String("duration", duration.String()),
|
|
|
+ zap.Int64("duration_ms", duration.Milliseconds()),
|
|
|
+ }
|
|
|
+
|
|
|
+ // 记录响应体(限制大小和内容类型)
|
|
|
+ if shouldLogResponseBody(ctx) {
|
|
|
+ bodyStr := blw.body.String()
|
|
|
+ if len(bodyStr) > MaxBodySize {
|
|
|
+ fields = append(fields, zap.String("body", fmt.Sprintf("[TRUNCATED: %d bytes]", len(bodyStr))))
|
|
|
+ } else if len(bodyStr) > 0 {
|
|
|
+ fields = append(fields, zap.String("body", maskSensitiveData(bodyStr)))
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 记录错误信息
|
|
|
+ if len(ctx.Errors) > 0 {
|
|
|
+ fields = append(fields, zap.Any("errors", ctx.Errors))
|
|
|
+ }
|
|
|
+
|
|
|
+ // 添加状态分类
|
|
|
+ statusCode := ctx.Writer.Status()
|
|
|
+ if statusCode >= 500 {
|
|
|
+ fields = append(fields, zap.String("level", "error"))
|
|
|
+ } else if statusCode >= 400 {
|
|
|
+ fields = append(fields, zap.String("level", "warn"))
|
|
|
+ } else {
|
|
|
+ fields = append(fields, zap.String("level", "info"))
|
|
|
+ }
|
|
|
+
|
|
|
+ // 统一使用 Info 级别,避免堆栈信息
|
|
|
+ logger.WithContext(ctx).Info("API Response", fields...)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// bodyLogWriter 包装响应写入器以捕获响应体
|
|
|
type bodyLogWriter struct {
|
|
|
gin.ResponseWriter
|
|
|
body *bytes.Buffer
|
|
|
}
|
|
|
|
|
|
-func (w bodyLogWriter) Write(b []byte) (int, error) {
|
|
|
+func (w *bodyLogWriter) Write(b []byte) (int, error) {
|
|
|
w.body.Write(b)
|
|
|
return w.ResponseWriter.Write(b)
|
|
|
}
|
|
|
+
|
|
|
+// 辅助函数
|
|
|
+
|
|
|
+func generateTraceID() string {
|
|
|
+ // 使用时间戳 + 4位随机字符串
|
|
|
+ return fmt.Sprintf("%d-%s", time.Now().Unix(), random.RandString(4))
|
|
|
+}
|
|
|
+
|
|
|
+func getImportantHeaders(ctx *gin.Context) map[string]string {
|
|
|
+ headers := make(map[string]string)
|
|
|
+ for _, key := range importantHeaders {
|
|
|
+ if value := ctx.GetHeader(key); value != "" {
|
|
|
+ headers[key] = value
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return headers
|
|
|
+}
|
|
|
+
|
|
|
+func shouldLogRequestBody(ctx *gin.Context) bool {
|
|
|
+ method := ctx.Request.Method
|
|
|
+ return method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE"
|
|
|
+}
|
|
|
+
|
|
|
+func shouldLogResponseBody(ctx *gin.Context) bool {
|
|
|
+ // 检查内容类型
|
|
|
+ contentType := ctx.Writer.Header().Get("Content-Type")
|
|
|
+ return strings.Contains(contentType, "json") ||
|
|
|
+ strings.Contains(contentType, "xml") ||
|
|
|
+ strings.Contains(contentType, "text")
|
|
|
+}
|
|
|
+
|
|
|
+func getRequestBody(ctx *gin.Context) string {
|
|
|
+ if ctx.Request.Body == nil {
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+
|
|
|
+ bodyBytes, err := ctx.GetRawData()
|
|
|
+ if err != nil {
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+
|
|
|
+ // 重置请求体
|
|
|
+ ctx.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
|
+
|
|
|
+ // 检查大小
|
|
|
+ if len(bodyBytes) == 0 {
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+ if len(bodyBytes) > MaxBodySize {
|
|
|
+ return fmt.Sprintf("[TRUNCATED: %d bytes]", len(bodyBytes))
|
|
|
+ }
|
|
|
+
|
|
|
+ // 脱敏处理
|
|
|
+ return maskSensitiveData(string(bodyBytes))
|
|
|
+}
|
|
|
+
|
|
|
+func maskSensitiveData(data string) string {
|
|
|
+ result := data
|
|
|
+ for _, field := range sensitiveFields {
|
|
|
+ // 简单的JSON字段脱敏
|
|
|
+ result = maskJSONField(result, field)
|
|
|
+ // URL参数脱敏
|
|
|
+ result = maskURLParam(result, field)
|
|
|
+ }
|
|
|
+ return result
|
|
|
+}
|
|
|
+
|
|
|
+func maskJSONField(data, field string) string {
|
|
|
+ lowerData := strings.ToLower(data)
|
|
|
+ lowerField := strings.ToLower(field)
|
|
|
+
|
|
|
+ // 查找字段位置(不区分大小写)
|
|
|
+ idx := strings.Index(lowerData, `"`+lowerField+`"`)
|
|
|
+ if idx == -1 {
|
|
|
+ idx = strings.Index(lowerData, `'`+lowerField+`'`)
|
|
|
+ if idx == -1 {
|
|
|
+ return data
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 找到冒号位置
|
|
|
+ colonIdx := strings.Index(data[idx:], ":")
|
|
|
+ if colonIdx == -1 {
|
|
|
+ return data
|
|
|
+ }
|
|
|
+ colonIdx += idx
|
|
|
+
|
|
|
+ // 找到值的开始和结束位置
|
|
|
+ valueStart := colonIdx + 1
|
|
|
+ for valueStart < len(data) && (data[valueStart] == ' ' || data[valueStart] == '\t') {
|
|
|
+ valueStart++
|
|
|
+ }
|
|
|
+
|
|
|
+ if valueStart >= len(data) {
|
|
|
+ return data
|
|
|
+ }
|
|
|
+
|
|
|
+ // 判断值的类型
|
|
|
+ var valueEnd int
|
|
|
+ if data[valueStart] == '"' || data[valueStart] == '\'' {
|
|
|
+ // 字符串值
|
|
|
+ quote := data[valueStart]
|
|
|
+ valueEnd = valueStart + 1
|
|
|
+ for valueEnd < len(data) && data[valueEnd] != quote {
|
|
|
+ if data[valueEnd] == '\\' {
|
|
|
+ valueEnd++ // 跳过转义字符
|
|
|
+ }
|
|
|
+ valueEnd++
|
|
|
+ }
|
|
|
+ if valueEnd < len(data) {
|
|
|
+ valueEnd++ // 包含结束引号
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // 非字符串值(数字、布尔值等)
|
|
|
+ valueEnd = valueStart
|
|
|
+ for valueEnd < len(data) && data[valueEnd] != ',' && data[valueEnd] != '}' && data[valueEnd] != ']' && data[valueEnd] != '\n' && data[valueEnd] != '\r' {
|
|
|
+ valueEnd++
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 替换为脱敏值
|
|
|
+ return data[:valueStart] + `"***"` + data[valueEnd:]
|
|
|
+}
|
|
|
+
|
|
|
+func maskURLParam(data, param string) string {
|
|
|
+ lowerData := strings.ToLower(data)
|
|
|
+ lowerParam := strings.ToLower(param)
|
|
|
+
|
|
|
+ // 查找参数位置
|
|
|
+ idx := strings.Index(lowerData, lowerParam+"=")
|
|
|
+ if idx == -1 {
|
|
|
+ return data
|
|
|
+ }
|
|
|
+
|
|
|
+ // 确保是参数开始位置(前面是?或&)
|
|
|
+ if idx > 0 && data[idx-1] != '?' && data[idx-1] != '&' && data[idx-1] != ' ' && data[idx-1] != '\n' {
|
|
|
+ return data
|
|
|
+ }
|
|
|
+
|
|
|
+ // 找到参数值结束位置
|
|
|
+ valueStart := idx + len(param) + 1
|
|
|
+ valueEnd := valueStart
|
|
|
+ for valueEnd < len(data) && data[valueEnd] != '&' && data[valueEnd] != ' ' && data[valueEnd] != '\n' && data[valueEnd] != '\r' {
|
|
|
+ valueEnd++
|
|
|
+ }
|
|
|
+
|
|
|
+ // 替换为脱敏值
|
|
|
+ return data[:valueStart] + "***" + data[valueEnd:]
|
|
|
+}
|