package admin import ( "context" "fmt" "math" "strings" "time" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" adminApi "github.com/go-nunu/nunu-layout-advanced/api/v1/admin" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" "github.com/go-nunu/nunu-layout-advanced/pkg/sharding" "gorm.io/gorm" ) type WafLogRepository interface { GetWafLog(ctx context.Context, id int64) (*model.WafLog, error) GetWafLogList(ctx context.Context, req adminApi.SearchWafLogParams) (*v1.PaginatedResponse[model.WafLog], error) AddWafLog(ctx context.Context, log *model.WafLog) error BatchAddWafLog(ctx context.Context, logs []*model.WafLog) error ExportWafLog(ctx context.Context, req adminApi.ExportWafLog) ([]model.WafLogWithGatewayIP, error) ExportWafLogWithPagination(ctx context.Context, req adminApi.ExportWafLog, page, pageSize int) ([]model.WafLogWithGatewayIP, error) GetWafLogExportCount(ctx context.Context, req adminApi.ExportWafLog) (int, error) } func NewWafLogRepository( repository *repository.Repository, ) WafLogRepository { return &wafLogRepository{ Repository: repository, } } type wafLogRepository struct { *repository.Repository } // buildExportQuery 是一个辅助函数,用于构建导出日志的公共查询条件 func (r *wafLogRepository) buildExportQuery(ctx context.Context, req adminApi.ExportWafLog) *gorm.DB { // 使用 Table("waf_log as wl") 是为了给主表起一个别名,方便子查询中引用 query := r.DBWithName(ctx, "admin").Model(&model.WafLog{}).Table("waf_log as wl") if req.RequestIp != "" { query = query.Where("wl.request_ip = ?", strings.TrimSpace(req.RequestIp)) } if req.Uid != 0 { query = query.Where("wl.uid = ?", req.Uid) } if req.Api != "" { query = query.Where("wl.api = ?", strings.TrimSpace(req.Api)) } if req.Name != "" { query = query.Where("wl.name = ?", strings.TrimSpace(req.Name)) } if req.RuleId != 0 { query = query.Where("wl.rule_id = ?", req.RuleId) } if len(req.HostIds) > 0 { query = query.Where("wl.host_id IN ?", req.HostIds) } if req.UserAgent != "" { query = query.Where("wl.user_agent = ?", strings.TrimSpace(req.UserAgent)) } if len(req.ApiNames) > 0 { query = query.Where("wl.api_name IN ?", req.ApiNames) } if len(req.ApiTypes) > 0 { query = query.Where("wl.api_type IN ?", req.ApiTypes) } if req.StartTime != "" { query = query.Where("wl.created_at > ?", strings.TrimSpace(req.StartTime)) } if req.EndTime != "" { query = query.Where("wl.created_at < ?", strings.TrimSpace(req.EndTime)) } return query } func (r *wafLogRepository) GetWafLog(ctx context.Context, id int64) (*model.WafLog, error) { var res model.WafLog // 获取分表管理器 shardingMgr := r.getShardingManager() // 获取存在的分表 existingTables := shardingMgr.GetTableNamesWithExistenceCheck(r.DBWithName(ctx, "admin"), "waf_log", nil, nil) // 在各个分表中查找 for _, tableName := range existingTables { err := r.DBWithName(ctx, "admin").Table(tableName).Where("id = ?", id).First(&res).Error if err == nil { res.SetTableName(tableName) return &res, nil } } return nil, fmt.Errorf("未找到ID为 %d 的WAF日志记录", id) } func (r *wafLogRepository) GetWafLogList(ctx context.Context, req adminApi.SearchWafLogParams) (*v1.PaginatedResponse[model.WafLog], error) { // 获取分表管理器 shardingMgr := r.getShardingManager() // 解析时间范围(如果有的话) var startTime, endTime *time.Time // TODO: 这里可以根据req中的时间字段来确定查询范围 // 暂时查询最近3个月的数据 // 获取需要查询的表 existingTables := shardingMgr.GetTableNamesWithExistenceCheck(r.DBWithName(ctx, "admin"), "waf_log", startTime, endTime) if len(existingTables) == 0 { // 没有分表,返回空结果 return &v1.PaginatedResponse[model.WafLog]{ Records: []model.WafLog{}, Page: 1, PageSize: 10, Total: 0, TotalPages: 0, }, nil } if len(existingTables) == 1 { // 只有一个表,直接查询 return r.queryWafLogFromSingleTable(ctx, req, existingTables[0]) } // 跨表分页查询 return r.queryWafLogFromMultipleTables(ctx, req, existingTables) } // queryWafLogFromSingleTable 单表查询 func (r *wafLogRepository) queryWafLogFromSingleTable(ctx context.Context, req adminApi.SearchWafLogParams, tableName string) (*v1.PaginatedResponse[model.WafLog], error) { var res []model.WafLog var total int64 query := r.DBWithName(ctx, "admin").Table(tableName) query = r.applyWafLogFilters(query, req) if err := query.Count(&total).Error; err != nil { return nil, err } page := req.Current pageSize := req.PageSize if page <= 0 { page = 1 } if pageSize <= 0 { pageSize = 10 } else if pageSize > 100 { pageSize = 100 } offset := (page - 1) * pageSize if req.Column != "" { query = query.Order(req.Column + " " + req.Order) } result := query.Offset(offset).Limit(pageSize).Find(&res) if result.Error != nil { return nil, result.Error } return &v1.PaginatedResponse[model.WafLog]{ Records: res, Page: page, PageSize: pageSize, Total: total, TotalPages: int(math.Ceil(float64(total) / float64(pageSize))), }, nil } // queryWafLogFromMultipleTables 多表联合查询 func (r *wafLogRepository) queryWafLogFromMultipleTables(ctx context.Context, req adminApi.SearchWafLogParams, tableNames []string) (*v1.PaginatedResponse[model.WafLog], error) { var allResults []model.WafLog var totalCount int64 // 先计算总数 for _, tableName := range tableNames { var count int64 query := r.DBWithName(ctx, "admin").Table(tableName) query = r.applyWafLogFilters(query, req) if err := query.Count(&count).Error; err != nil { return nil, err } totalCount += count } page := req.Current pageSize := req.PageSize if page <= 0 { page = 1 } if pageSize <= 0 { pageSize = 10 } else if pageSize > 100 { pageSize = 100 } // 计算需要跳过的记录数 offset := (page - 1) * pageSize limit := pageSize currentOffset := 0 // 逐表查询直到获取足够的记录 for _, tableName := range tableNames { if limit <= 0 { break } var tableCount int64 countQuery := r.DBWithName(ctx, "admin").Table(tableName) countQuery = r.applyWafLogFilters(countQuery, req) if err := countQuery.Count(&tableCount).Error; err != nil { return nil, err } // 如果当前表的记录数不足以满足offset要求,跳过这个表 if currentOffset+int(tableCount) <= offset { currentOffset += int(tableCount) continue } // 计算在当前表中的offset tableOffset := offset - currentOffset if tableOffset < 0 { tableOffset = 0 } var tableResults []model.WafLog query := r.DBWithName(ctx, "admin").Table(tableName) query = r.applyWafLogFilters(query, req) if req.Column != "" { query = query.Order(req.Column + " " + req.Order) } err := query.Offset(tableOffset).Limit(limit).Find(&tableResults).Error if err != nil { return nil, err } // 设置表名 for i := range tableResults { tableResults[i].SetTableName(tableName) } allResults = append(allResults, tableResults...) limit -= len(tableResults) currentOffset += int(tableCount) } return &v1.PaginatedResponse[model.WafLog]{ Records: allResults, Page: page, PageSize: pageSize, Total: totalCount, TotalPages: int(math.Ceil(float64(totalCount) / float64(pageSize))), }, nil } func (r *wafLogRepository) AddWafLog(ctx context.Context, log *model.WafLog) error { // 设置创建时间 if log.CreatedAt.IsZero() { log.CreatedAt = time.Now() } // 获取分表管理器 shardingMgr := r.getShardingManagerWithThreshold() // 获取最优的写入表(考虑数据量阈值) tableName, err := shardingMgr.GetOptimalWriteTable(ctx, r.DBWithName(ctx, "admin"), log, r.getMaxRowsForTable("waf_log")) if err != nil { return fmt.Errorf("获取写入表失败: %v", err) } log.SetTableName(tableName) // 确保表存在 err = shardingMgr.EnsureTableExists(ctx, r.DBWithName(ctx, "admin"), tableName, &model.WafLog{}) if err != nil { return err } // 写入数据 return r.DBWithName(ctx, "admin").Table(tableName).Create(log).Error } func (r *wafLogRepository) BatchAddWafLog(ctx context.Context, logs []*model.WafLog) error { if len(logs) == 0 { return nil } // 获取带阈值的分表管理器 shardingMgr := r.getShardingManagerWithThreshold() maxRows := r.getMaxRowsForTable("waf_log") // 按表名分组 tableGroups := make(map[string][]*model.WafLog) for _, log := range logs { // 设置创建时间 if log.CreatedAt.IsZero() { log.CreatedAt = time.Now() } // 获取最优的写入表(考虑数据量阈值) tableName, err := shardingMgr.GetOptimalWriteTable(ctx, r.DBWithName(ctx, "admin"), log, maxRows) if err != nil { return fmt.Errorf("获取写入表失败: %v", err) } log.SetTableName(tableName) // 按表名分组 tableGroups[tableName] = append(tableGroups[tableName], log) } // 为每个表批量插入 for tableName, tableLogs := range tableGroups { // 确保表存在 err := shardingMgr.EnsureTableExists(ctx, r.DBWithName(ctx, "admin"), tableName, &model.WafLog{}) if err != nil { return err } // 批量插入 err = r.DBWithName(ctx, "admin").Table(tableName).CreateInBatches(tableLogs, len(tableLogs)).Error if err != nil { return err } } return nil } func (r *wafLogRepository) ExportWafLog(ctx context.Context, req adminApi.ExportWafLog) ([]model.WafLogWithGatewayIP, error) { return r.ExportWafLogWithPagination(ctx, req, 0, 0) } // ExportWafLogWithPagination 使用子查询获取每条日志在当时时间点的正确网关组IP func (r *wafLogRepository) ExportWafLogWithPagination(ctx context.Context, req adminApi.ExportWafLog, page, pageSize int) ([]model.WafLogWithGatewayIP, error) { var res []model.WafLogWithGatewayIP // 1. 使用辅助函数构建基础查询 query := r.buildExportQuery(ctx, req) // 2. 构建子查询 subQuery := r.DBWithName(ctx, "admin").Model(&model.WafLog{}). Select("extra_data"). Where("api_name = ?", "分配网关组"). Where("host_id = wl.host_id"). Where("uid = wl.uid"). Where("created_at <= wl.created_at"). Order("created_at DESC"). Limit(1) // 3. 添加 Select 和分页 query = query.Select("wl.*, (?) as gateway_ip_data", subQuery) if page > 0 && pageSize > 0 { offset := (page - 1) * pageSize query = query.Offset(offset).Limit(pageSize) } // 4. 执行查询 if err := query.Find(&res).Error; err != nil { if err == gorm.ErrRecordNotFound { return []model.WafLogWithGatewayIP{}, nil } return nil, err } return res, nil } // GetWafLogExportCount 获取导出数据总数(已优化) func (r *wafLogRepository) GetWafLogExportCount(ctx context.Context, req adminApi.ExportWafLog) (int, error) { var count int64 // 直接复用 buildExportQuery 来构建查询 query := r.buildExportQuery(ctx, req) if err := query.Count(&count).Error; err != nil { return 0, err } return int(count), nil } // getShardingManager 获取分表管理器 func (r *wafLogRepository) getShardingManager() *sharding.ShardingManager { // 使用月度分表策略 strategy := sharding.NewMonthlyShardingStrategy() return sharding.NewShardingManager(strategy, r.Logger) } // getShardingManagerWithThreshold 获取带阈值配置的分表管理器 func (r *wafLogRepository) getShardingManagerWithThreshold() *sharding.ShardingManager { strategy := sharding.NewMonthlyShardingStrategy() // 阈值配置(这里可以从配置文件读取,暂时硬编码) thresholdConfig := &sharding.ThresholdConfig{ Enabled: true, MaxRows: 5000000, // waf_log表默认500万条 CheckInterval: time.Hour, } return sharding.NewShardingManagerWithThreshold(strategy, r.Logger, thresholdConfig) } // getMaxRowsForTable 获取指定表的最大行数配置 func (r *wafLogRepository) getMaxRowsForTable(tableName string) int64 { switch tableName { case "log": return 3000000 // 300万条 case "waf_log": return 5000000 // 500万条 default: return 3000000 // 默认300万条 } } // applyWafLogFilters 应用WafLog查询过滤条件 func (r *wafLogRepository) applyWafLogFilters(query *gorm.DB, req adminApi.SearchWafLogParams) *gorm.DB { if req.RequestIp != "" { trimmedName := strings.TrimSpace(req.RequestIp) query = query.Where("request_ip LIKE CONCAT('%', ?, '%')", trimmedName) } if req.Uid != 0 { query = query.Where("uid = ?", req.Uid) } if req.Api != "" { trimmedName := strings.TrimSpace(req.Api) query = query.Where("api LIKE CONCAT('%', ?, '%')", trimmedName) } if req.Name != "" { trimmedName := strings.TrimSpace(req.Name) query = query.Where("name LIKE CONCAT('%', ?, '%')", trimmedName) } if req.RuleId != 0 { query = query.Where("rule_id = ?", req.RuleId) } if req.HostId != 0 { query = query.Where("host_id = ?", req.HostId) } if req.UserAgent != "" { trimmedName := strings.TrimSpace(req.UserAgent) query = query.Where("user_agent LIKE CONCAT('%', ?, '%')", trimmedName) } if req.ApiName != "" { trimmedName := strings.TrimSpace(req.ApiName) query = query.Where("api_name LIKE CONCAT('%', ?, '%')", trimmedName) } if req.ApiType != "" { query = query.Where("api_type = ?", req.ApiType) } return query }