Procházet zdrojové kódy

refactor(admin): 重构 WafLog 导出相关功能

- 新增 buildExportQuery 辅助函数,用于构建导出日志的公共查询条件
- 优化 ExportWafLogWithPagination 函数,使用新的查询构建逻辑
- 重写 GetWafLogExportCount 函数,复用 buildExportQuery 以减少代码重复
- 通过重构,提高了代码的可维护性和可读性,同时减少了冗余代码
fusu před 3 dny
rodič
revize
bb31be1928
1 změnil soubory, kde provedl 62 přidání a 117 odebrání
  1. 62 117
      internal/repository/admin/waflog.go

+ 62 - 117
internal/repository/admin/waflog.go

@@ -16,11 +16,8 @@ type WafLogRepository interface {
 	GetWafLogList(ctx context.Context, req adminApi.SearchWafLogParams) (*v1.PaginatedResponse[model.WafLog], error)
 	AddWafLog(ctx context.Context, log *model.WafLog) error
 	BatchAddWafLog(ctx context.Context, logs []*model.WafLog) error
-	// 导出日志(已更新)
 	ExportWafLog(ctx context.Context, req adminApi.ExportWafLog) ([]model.WafLogWithGatewayIP, error)
-	// 支持分页的导出方法(已更新)
 	ExportWafLogWithPagination(ctx context.Context, req adminApi.ExportWafLog, page, pageSize int) ([]model.WafLogWithGatewayIP, error)
-	// 获取导出数据总数
 	GetWafLogExportCount(ctx context.Context, req adminApi.ExportWafLog) (int, error)
 }
 
@@ -36,6 +33,49 @@ type wafLogRepository struct {
 	*repository.Repository
 }
 
+// buildExportQuery 是一个辅助函数,用于构建导出日志的公共查询条件
+func (r *wafLogRepository) buildExportQuery(ctx context.Context, req adminApi.ExportWafLog) *gorm.DB {
+	// 使用 Table("waf_log as wl") 是为了给主表起一个别名,方便子查询中引用
+	query := r.DBWithName(ctx, "admin").Model(&model.WafLog{}).Table("waf_log as wl")
+
+	if req.RequestIp != "" {
+		query = query.Where("wl.request_ip = ?", strings.TrimSpace(req.RequestIp))
+	}
+	if req.Uid != 0 {
+		query = query.Where("wl.uid = ?", req.Uid)
+	}
+	if req.Api != "" {
+		query = query.Where("wl.api = ?", strings.TrimSpace(req.Api))
+	}
+	if req.Name != "" {
+		query = query.Where("wl.name = ?", strings.TrimSpace(req.Name))
+	}
+	if req.RuleId != 0 {
+		query = query.Where("wl.rule_id = ?", req.RuleId)
+	}
+	if len(req.HostIds) > 0 {
+		query = query.Where("wl.host_id IN ?", req.HostIds)
+	}
+	if req.UserAgent != "" {
+		query = query.Where("wl.user_agent = ?", strings.TrimSpace(req.UserAgent))
+	}
+	if len(req.ApiNames) > 0 {
+		query = query.Where("wl.api_name IN ?", req.ApiNames)
+	}
+	if len(req.ApiTypes) > 0 {
+		query = query.Where("wl.api_type IN ?", req.ApiTypes)
+	}
+	if req.StartTime != "" {
+		query = query.Where("wl.created_at > ?", strings.TrimSpace(req.StartTime))
+	}
+	if req.EndTime != "" {
+		query = query.Where("wl.created_at < ?", strings.TrimSpace(req.EndTime))
+	}
+
+	return query
+}
+
+
 func (r *wafLogRepository) GetWafLog(ctx context.Context, id int64) (*model.WafLog, error) {
 	var res model.WafLog
 	return &res, r.DBWithName(ctx,"admin").Where("id = ?", id).First(&res).Error
@@ -151,70 +191,28 @@ func (r *wafLogRepository) ExportWafLog(ctx context.Context, req adminApi.Export
 func (r *wafLogRepository) ExportWafLogWithPagination(ctx context.Context, req adminApi.ExportWafLog, page, pageSize int) ([]model.WafLogWithGatewayIP, error) {
 	var res []model.WafLogWithGatewayIP
 	
-	// 主查询,我们将其命名为 "wl" 以便在子查询中引用
-	query := r.DBWithName(ctx, "admin").Model(&model.WafLog{}).Table("waf_log as wl")
+	// 1. 使用辅助函数构建基础查询
+	query := r.buildExportQuery(ctx, req)
 
-	// --- 构建子查询 ---
-	// 这个子查询的目标是:对于 "wl" 表中的每一行,找到在它创建时间点之前(或同时)的、
-	// host_id 和 uid 都匹配的、最新的那条 "分配网关组" 的日志,并返回其 extra_data。
+	// 2. 构建子查询
 	subQuery := r.DBWithName(ctx, "admin").Model(&model.WafLog{}).
 		Select("extra_data").
 		Where("api_name = ?", "分配网关组").
-		Where("host_id = wl.host_id"). // 关联主查询的 host_id
-		Where("uid = wl.uid").         // 关联主查询的 uid
-		Where("created_at <= wl.created_at"). // 时间条件:必须是历史或当前记录
-		Order("created_at DESC"). // 按时间降序,保证第一条是最新
-		Limit(1) // 只取最新的一条
-
-	// --- 构建主查询的选择列表 ---
-	// "wl.*" 选择 waf_log 表的所有字段
-	// 第二个参数是使用子查询作为 "gateway_ip_data" 字段的值
-	query = query.Select("wl.*, (?) as gateway_ip_data", subQuery)
+		Where("host_id = wl.host_id").
+		Where("uid = wl.uid").
+		Where("created_at <= wl.created_at").
+		Order("created_at DESC").
+		Limit(1)
 
-	// --- 应用过滤条件 (与原函数保持一致) ---
-	if req.RequestIp != "" {
-		query = query.Where("wl.request_ip = ?", strings.TrimSpace(req.RequestIp))
-	}
-	if req.Uid != 0 {
-		query = query.Where("wl.uid = ?", req.Uid)
-	}
-	if req.Api != "" {
-		query = query.Where("wl.api = ?", strings.TrimSpace(req.Api))
-	}
-	if req.Name != "" {
-		query = query.Where("wl.name = ?", strings.TrimSpace(req.Name))
-	}
-	if req.RuleId != 0 {
-		query = query.Where("wl.rule_id = ?", req.RuleId)
-	}
-	if len(req.HostIds) > 0 {
-		query = query.Where("wl.host_id IN ?", req.HostIds)
-	}
-	if req.UserAgent != "" {
-		query = query.Where("wl.user_agent = ?", strings.TrimSpace(req.UserAgent))
-	}
-	if len(req.ApiNames) > 0 {
-		query = query.Where("wl.api_name IN ?", req.ApiNames)
-	}
-	if len(req.ApiTypes) > 0 {
-		query = query.Where("wl.api_type IN ?", req.ApiTypes)
-	}
-	if req.StartTime != "" {
-		query = query.Where("wl.created_at > ?", strings.TrimSpace(req.StartTime))
-	}
-	if req.EndTime != "" {
-		query = query.Where("wl.created_at < ?", strings.TrimSpace(req.EndTime))
-	}
-
-	// --- 应用分页 ---
+	// 3. 添加 Select 和分页
+	query = query.Select("wl.*, (?) as gateway_ip_data", subQuery)
 	if page > 0 && pageSize > 0 {
 		offset := (page - 1) * pageSize
 		query = query.Offset(offset).Limit(pageSize)
 	}
 
-	// 执行查询
+	// 4. 执行查询
 	if err := query.Find(&res).Error; err != nil {
-		// 如果是记录未找到的错误,我们希望返回一个空切片而不是错误
 		if err == gorm.ErrRecordNotFound {
 			return []model.WafLogWithGatewayIP{}, nil
 		}
@@ -225,69 +223,16 @@ func (r *wafLogRepository) ExportWafLogWithPagination(ctx context.Context, req a
 }
 
 
-// GetWafLogExportCount 获取导出数据总数
+// GetWafLogExportCount 获取导出数据总数(已优化)
 func (r *wafLogRepository) GetWafLogExportCount(ctx context.Context, req adminApi.ExportWafLog) (int, error) {
 	var count int64
-	query := r.DBWithName(ctx,"admin").Model(&model.WafLog{})
 	
-	// 复用ExportWafLog的查询条件
-	if req.RequestIp != "" {
-		trimmedName := strings.TrimSpace(req.RequestIp)
-		query = query.Where("request_ip = ?", trimmedName)
-	}
-
-	if req.Uid != 0 {
-		query = query.Where("uid = ?", req.Uid)
-	}
-
-	if req.Api != "" {
-		trimmedName := strings.TrimSpace(req.Api)
-		query = query.Where("api = ?", trimmedName)
-	}
-
-	if req.Name != "" {
-		trimmedName := strings.TrimSpace(req.Name)
-		query = query.Where("name = ?", trimmedName)
-	}
-
-	if req.RuleId != 0 {
-		query = query.Where("rule_id = ?", req.RuleId)
-	}
-
-	if len(req.HostIds) > 0 {
-		query = query.Where("host_id IN ?", req.HostIds)
-	}
-
-	if req.UserAgent != "" {
-		trimmedName := strings.TrimSpace(req.UserAgent)
-		query = query.Where("user_agent = ?", trimmedName)
-	}
-
-	if len(req.ApiNames) > 0 {
-		trimmedNames := make([]string, len(req.ApiNames))
-		for i, apiName := range req.ApiNames {
-			trimmedNames[i] = strings.TrimSpace(apiName)
-		}
-		query = query.Where("api_name IN ?", trimmedNames)
-	}
-
-	if len(req.ApiTypes) > 0 {
-		query = query.Where("api_type IN ?", req.ApiTypes)
-	}
-
-	if req.StartTime != "" {
-		trimmedName := strings.TrimSpace(req.StartTime)
-		query = query.Where("created_at > ?", trimmedName)
-	}
-
-	if req.EndTime != "" {
-		trimmedName := strings.TrimSpace(req.EndTime)
-		query = query.Where("created_at < ?", trimmedName)
-	}
-
-	result := query.Count(&count)
-	if result.Error != nil {
-		return 0, result.Error
+	// 直接复用 buildExportQuery 来构建查询
+	query := r.buildExportQuery(ctx, req)
+	
+	if err := query.Count(&count).Error; err != nil {
+		return 0, err
 	}
+	
 	return int(count), nil
-}
+}