浏览代码

feat(middleware): 重构日志中间件并添加响应日志记录功能

- 重构请求日志中间件,增加更多日志字段和脱敏处理
- 新增响应日志中间件,记录响应状态、耗时和身体内容- 优化数据库日志配置,支持不同级别的日志记录
fusu 2 月之前
父节点
当前提交
0d83ec7212
共有 2 个文件被更改,包括 321 次插入27 次删除
  1. 286 22
      internal/middleware/log.go
  2. 35 5
      internal/repository/repository.go

+ 286 - 22
internal/middleware/log.go

@@ -2,53 +2,317 @@ package middleware
 
 import (
 	"bytes"
-	"github.com/duke-git/lancet/v2/cryptor"
+	"fmt"
+	"io"
+	"strings"
+	"time"
+
 	"github.com/duke-git/lancet/v2/random"
 	"github.com/gin-gonic/gin"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/log"
 	"go.uber.org/zap"
-	"io"
-	"time"
+)
+
+const (
+	MaxBodySize = 10 * 1024 // 10KB
+	TraceIDKey  = "trace_id"
+)
+
+var (
+	// 跳过的路径
+	skipPaths = map[string]bool{
+		"/health":      true,
+		"/metrics":     true,
+		"/favicon.ico": true,
+		"/ping":        true,
+	}
+
+	// 需要记录的请求头
+	importantHeaders = []string{
+		"authorization",
+		"x-request-id",
+		"x-real-ip",
+		"x-forwarded-for",
+		"user-agent",
+		"content-type",
+	}
+
+	// 敏感字段
+	sensitiveFields = []string{
+		"password", "passwd", "pwd",
+		"token", "access_token", "refresh_token",
+		"secret", "api_key", "apikey",
+		"authorization",
+	}
 )
 
 func RequestLogMiddleware(logger *log.Logger) gin.HandlerFunc {
 	return func(ctx *gin.Context) {
-		// The configuration is initialized once per request
-		uuid, err := random.UUIdV4()
-		if err != nil {
+		// 跳过不需要记录的路径
+		if skipPaths[ctx.Request.URL.Path] {
+			ctx.Next()
 			return
 		}
-		trace := cryptor.Md5String(uuid)
-		logger.WithValue(ctx, zap.String("trace", trace))
-		logger.WithValue(ctx, zap.String("request_method", ctx.Request.Method))
-		logger.WithValue(ctx, zap.Any("request_headers", ctx.Request.Header))
-		logger.WithValue(ctx, zap.String("request_url", ctx.Request.URL.String()))
-		if ctx.Request.Body != nil {
-			bodyBytes, _ := ctx.GetRawData()
-			ctx.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 关键点
-			logger.WithValue(ctx, zap.String("request_params", string(bodyBytes)))
-		}
-		logger.WithContext(ctx).Info("Request")
+
+		// 生成简短的追踪ID
+		traceID := generateTraceID()
+		ctx.Set(TraceIDKey, traceID)
+
+		// 构建日志字段
+		fields := []zap.Field{
+			zap.String("trace_id", traceID),
+			zap.String("method", ctx.Request.Method),
+			zap.String("path", ctx.Request.URL.Path),
+			zap.String("client_ip", ctx.ClientIP()),
+		}
+
+		// 记录查询参数
+		if query := ctx.Request.URL.RawQuery; query != "" {
+			fields = append(fields, zap.String("query", query))
+		}
+
+		// 记录重要请求头
+		if headers := getImportantHeaders(ctx); len(headers) > 0 {
+			fields = append(fields, zap.Any("headers", headers))
+		}
+
+		// 记录请求体(仅限特定方法)
+		if shouldLogRequestBody(ctx) {
+			if bodyLog := getRequestBody(ctx); bodyLog != "" {
+				fields = append(fields, zap.String("body", bodyLog))
+			}
+		}
+
+		// 设置日志上下文
+		logger.WithValue(ctx, zap.String("trace_id", traceID))
+
+		// 记录请求开始时间
+		ctx.Set("start_time", time.Now())
+
+		logger.Info("API Request", fields...)
 		ctx.Next()
 	}
 }
+
 func ResponseLogMiddleware(logger *log.Logger) gin.HandlerFunc {
 	return func(ctx *gin.Context) {
-		blw := &bodyLogWriter{body: bytes.NewBufferString(""), ResponseWriter: ctx.Writer}
+		// 跳过不需要记录的路径
+		if skipPaths[ctx.Request.URL.Path] {
+			ctx.Next()
+			return
+		}
+
+		// 包装响应写入器
+		blw := &bodyLogWriter{
+			body:           bytes.NewBufferString(""),
+			ResponseWriter: ctx.Writer,
+		}
 		ctx.Writer = blw
-		startTime := time.Now()
+
+		// 执行处理
 		ctx.Next()
-		duration := time.Since(startTime).String()
-		logger.WithContext(ctx).Info("Response", zap.Any("response_body", blw.body.String()), zap.Any("time", duration))
+
+		// 计算耗时
+		var duration time.Duration
+		if startTime, exists := ctx.Get("start_time"); exists {
+			if st, ok := startTime.(time.Time); ok {
+				duration = time.Since(st)
+			}
+		}
+
+		// 构建响应日志字段
+		fields := []zap.Field{
+			zap.Int("status", ctx.Writer.Status()),
+			zap.String("duration", duration.String()),
+			zap.Int64("duration_ms", duration.Milliseconds()),
+		}
+
+		// 记录响应体(限制大小和内容类型)
+		if shouldLogResponseBody(ctx) {
+			bodyStr := blw.body.String()
+			if len(bodyStr) > MaxBodySize {
+				fields = append(fields, zap.String("body", fmt.Sprintf("[TRUNCATED: %d bytes]", len(bodyStr))))
+			} else if len(bodyStr) > 0 {
+				fields = append(fields, zap.String("body", maskSensitiveData(bodyStr)))
+			}
+		}
+
+		// 记录错误信息
+		if len(ctx.Errors) > 0 {
+			fields = append(fields, zap.Any("errors", ctx.Errors))
+		}
+
+		// 添加状态分类
+		statusCode := ctx.Writer.Status()
+		if statusCode >= 500 {
+			fields = append(fields, zap.String("level", "error"))
+		} else if statusCode >= 400 {
+			fields = append(fields, zap.String("level", "warn"))
+		} else {
+			fields = append(fields, zap.String("level", "info"))
+		}
+
+		// 统一使用 Info 级别,避免堆栈信息
+		logger.WithContext(ctx).Info("API Response", fields...)
 	}
 }
 
+// bodyLogWriter 包装响应写入器以捕获响应体
 type bodyLogWriter struct {
 	gin.ResponseWriter
 	body *bytes.Buffer
 }
 
-func (w bodyLogWriter) Write(b []byte) (int, error) {
+func (w *bodyLogWriter) Write(b []byte) (int, error) {
 	w.body.Write(b)
 	return w.ResponseWriter.Write(b)
 }
+
+// 辅助函数
+
+func generateTraceID() string {
+	// 使用时间戳 + 4位随机字符串
+	return fmt.Sprintf("%d-%s", time.Now().Unix(), random.RandString(4))
+}
+
+func getImportantHeaders(ctx *gin.Context) map[string]string {
+	headers := make(map[string]string)
+	for _, key := range importantHeaders {
+		if value := ctx.GetHeader(key); value != "" {
+			headers[key] = value
+		}
+	}
+	return headers
+}
+
+func shouldLogRequestBody(ctx *gin.Context) bool {
+	method := ctx.Request.Method
+	return method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE"
+}
+
+func shouldLogResponseBody(ctx *gin.Context) bool {
+	// 检查内容类型
+	contentType := ctx.Writer.Header().Get("Content-Type")
+	return strings.Contains(contentType, "json") ||
+		strings.Contains(contentType, "xml") ||
+		strings.Contains(contentType, "text")
+}
+
+func getRequestBody(ctx *gin.Context) string {
+	if ctx.Request.Body == nil {
+		return ""
+	}
+
+	bodyBytes, err := ctx.GetRawData()
+	if err != nil {
+		return ""
+	}
+
+	// 重置请求体
+	ctx.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
+
+	// 检查大小
+	if len(bodyBytes) == 0 {
+		return ""
+	}
+	if len(bodyBytes) > MaxBodySize {
+		return fmt.Sprintf("[TRUNCATED: %d bytes]", len(bodyBytes))
+	}
+
+	// 脱敏处理
+	return maskSensitiveData(string(bodyBytes))
+}
+
+func maskSensitiveData(data string) string {
+	result := data
+	for _, field := range sensitiveFields {
+		// 简单的JSON字段脱敏
+		result = maskJSONField(result, field)
+		// URL参数脱敏
+		result = maskURLParam(result, field)
+	}
+	return result
+}
+
+func maskJSONField(data, field string) string {
+	lowerData := strings.ToLower(data)
+	lowerField := strings.ToLower(field)
+
+	// 查找字段位置(不区分大小写)
+	idx := strings.Index(lowerData, `"`+lowerField+`"`)
+	if idx == -1 {
+		idx = strings.Index(lowerData, `'`+lowerField+`'`)
+		if idx == -1 {
+			return data
+		}
+	}
+
+	// 找到冒号位置
+	colonIdx := strings.Index(data[idx:], ":")
+	if colonIdx == -1 {
+		return data
+	}
+	colonIdx += idx
+
+	// 找到值的开始和结束位置
+	valueStart := colonIdx + 1
+	for valueStart < len(data) && (data[valueStart] == ' ' || data[valueStart] == '\t') {
+		valueStart++
+	}
+
+	if valueStart >= len(data) {
+		return data
+	}
+
+	// 判断值的类型
+	var valueEnd int
+	if data[valueStart] == '"' || data[valueStart] == '\'' {
+		// 字符串值
+		quote := data[valueStart]
+		valueEnd = valueStart + 1
+		for valueEnd < len(data) && data[valueEnd] != quote {
+			if data[valueEnd] == '\\' {
+				valueEnd++ // 跳过转义字符
+			}
+			valueEnd++
+		}
+		if valueEnd < len(data) {
+			valueEnd++ // 包含结束引号
+		}
+	} else {
+		// 非字符串值(数字、布尔值等)
+		valueEnd = valueStart
+		for valueEnd < len(data) && data[valueEnd] != ',' && data[valueEnd] != '}' && data[valueEnd] != ']' && data[valueEnd] != '\n' && data[valueEnd] != '\r' {
+			valueEnd++
+		}
+	}
+
+	// 替换为脱敏值
+	return data[:valueStart] + `"***"` + data[valueEnd:]
+}
+
+func maskURLParam(data, param string) string {
+	lowerData := strings.ToLower(data)
+	lowerParam := strings.ToLower(param)
+
+	// 查找参数位置
+	idx := strings.Index(lowerData, lowerParam+"=")
+	if idx == -1 {
+		return data
+	}
+
+	// 确保是参数开始位置(前面是?或&)
+	if idx > 0 && data[idx-1] != '?' && data[idx-1] != '&' && data[idx-1] != ' ' && data[idx-1] != '\n' {
+		return data
+	}
+
+	// 找到参数值结束位置
+	valueStart := idx + len(param) + 1
+	valueEnd := valueStart
+	for valueEnd < len(data) && data[valueEnd] != '&' && data[valueEnd] != ' ' && data[valueEnd] != '\n' && data[valueEnd] != '\r' {
+		valueEnd++
+	}
+
+	// 替换为脱敏值
+	return data[:valueStart] + "***" + data[valueEnd:]
+}

+ 35 - 5
internal/repository/repository.go

@@ -11,6 +11,7 @@ import (
 	"gorm.io/driver/mysql"
 	"gorm.io/driver/postgres"
 	"gorm.io/gorm"
+	gormlogger "gorm.io/gorm/logger"
 	"time"
 )
 
@@ -67,10 +68,33 @@ func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
 		err error
 	)
 
-	logger := zapgorm2.New(l.Logger)
 	driver := conf.GetString("data.db.user.driver")
 	dsn := conf.GetString("data.db.user.dsn")
 
+	// 读取日志级别配置
+	logLevelStr := conf.GetString("data.db.user.logLevel")
+	var logLevel gormlogger.LogLevel
+
+	switch logLevelStr {
+	case "silent":
+		logLevel = gormlogger.Silent
+	case "error":
+		logLevel = gormlogger.Error
+	case "warn":
+		logLevel = gormlogger.Warn
+	case "info":
+		logLevel = gormlogger.Info
+	default:
+		// MySQL 默认只记录警告和错误
+		if driver == "mysql" {
+			logLevel = gormlogger.Warn
+		} else {
+			logLevel = gormlogger.Info
+		}
+	}
+
+	logger := zapgorm2.New(l.Logger).LogMode(logLevel)
+
 	// GORM doc: https://gorm.io/docs/connecting_to_the_database.html
 	switch driver {
 	case "mysql":
@@ -80,17 +104,21 @@ func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
 	case "postgres":
 		db, err = gorm.Open(postgres.New(postgres.Config{
 			DSN:                  dsn,
-			PreferSimpleProtocol: true, // disables implicit prepared statement usage
-		}), &gorm.Config{})
+			PreferSimpleProtocol: true,
+		}), &gorm.Config{
+			Logger: logger,
+		})
 	case "sqlite":
-		db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
+		db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{
+			Logger: logger,
+		})
 	default:
 		panic("unknown db driver")
 	}
+
 	if err != nil {
 		panic(err)
 	}
-	db = db.Debug()
 
 	// Connection Pool config
 	sqlDB, err := db.DB()
@@ -100,8 +128,10 @@ func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
 	sqlDB.SetMaxIdleConns(10)
 	sqlDB.SetMaxOpenConns(100)
 	sqlDB.SetConnMaxLifetime(time.Hour)
+
 	return db
 }
+
 func NewRedis(conf *viper.Viper) *redis.Client {
 	rdb := redis.NewClient(&redis.Options{
 		Addr:     conf.GetString("data.redis.addr"),