package task import ( "context" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" "github.com/go-nunu/nunu-layout-advanced/internal/service" "github.com/hashicorp/go-multierror" "go.uber.org/zap" "sync" "time" ) // WafTask 定义了WAF相关的五个独立定时任务接口 type WafTask interface { // 1. 同步即将到期(1天内)的套餐时间 SynchronizationTime(ctx context.Context) error // 2. 停止所有已到期的套餐 StopPlan(ctx context.Context) error // 3. 恢复7天内续费的套餐 RecoverRecentPlan(ctx context.Context) error // 4. 清理过期超过7天且仍未续费的记录 CleanUpStaleRecords(ctx context.Context) error // 5. 恢复超过7天后才续费的套餐 RecoverStalePlan(ctx context.Context) error } // ================================================================= // =================== 结构体与构造函数 ========================== // ================================================================= func NewWafTask( webForWardingRep repository.WebForwardingRepository, tcpforwardingRep repository.TcpforwardingRepository, udpForWardingRep repository.UdpForWardingRepository, cdn service.CdnService, hostRep repository.HostRepository, globalLimitRep repository.GlobalLimitRepository, expiredRep repository.ExpiredRepository, task *Task, gatewayGroupIpRep repository.GateWayGroupIpRepository, tcp service.TcpforwardingService, udp service.UdpForWardingService, web service.WebForwardingService, ) WafTask { return &wafTask{ Task: task, webForWardingRep: webForWardingRep, tcpforwardingRep: tcpforwardingRep, udpForWardingRep: udpForWardingRep, cdn: cdn, hostRep: hostRep, globalLimitRep: globalLimitRep, expiredRep: expiredRep, gatewayGroupIpRep: gatewayGroupIpRep, tcp: tcp, udp: udp, web: web, } } type wafTask struct { *Task webForWardingRep repository.WebForwardingRepository tcpforwardingRep repository.TcpforwardingRepository udpForWardingRep repository.UdpForWardingRepository cdn service.CdnService hostRep repository.HostRepository globalLimitRep repository.GlobalLimitRepository expiredRep repository.ExpiredRepository gatewayGroupIpRep repository.GateWayGroupIpRepository tcp service.TcpforwardingService udp service.UdpForWardingService web service.WebForwardingService } const ( OneDaysInSeconds = 1 * 24 * 60 * 60 SevenDaysInSeconds = 7 * 24 * 60 * 60 ) type RenewalRequest struct { HostId int PlanId int ExpiredAt int64 } // ================================================================= // =================== 核心辅助函数 (Core Helpers) ================= // ================================================================= // (wrapTaskError, getCdnWebIdsByHostIds, setCdnWebsitesState, executeRenewalActions 保持不变) // ... func (t *wafTask) wrapTaskError(taskName, step string, err error) error { /* ... */ return nil } func (t *wafTask) getCdnWebIdsByHostIds(ctx context.Context, hostIds []int) ([]int, error) { /* ... */ return nil, nil } func (t *wafTask) setCdnWebsitesState(ctx context.Context, ids []int, enable bool) error { /* ... */ return nil } func (t *wafTask) executeRenewalActions(ctx context.Context, reqs []RenewalRequest) error { /* ... */ return nil } // ================================================================= // =================== 1. 数据查找与决策层 ========================== // ================================================================= // (findPlansNeedingSync, findAllCurrentlyExpiredWAFPlans, findRecentlyExpiredWAFPlans, findStaleWAFPlans 保持不变) // ... func (t *wafTask) findPlansNeedingSync(ctx context.Context, wafLimits []model.GlobalLimit) ([]RenewalRequest, error) { /* ... */ return nil, nil } func (t *wafTask) findAllCurrentlyExpiredWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { /* ... */ return nil, nil } func (t *wafTask) findRecentlyExpiredWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { /* ... */ return nil, nil } func (t *wafTask) findStaleWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { /* ... */ return nil, nil } // ================================================================= // ============== 2. 业务执行与公共API层 =========================== // ================================================================= // (SynchronizationTime, StopPlan 保持不变) // ... func (t *wafTask) SynchronizationTime(ctx context.Context) error { /* ... */ return nil } func (t *wafTask) StopPlan(ctx context.Context) error { /* ... */ return nil } // _recoverPlans 是一个统一的、可重用的套餐恢复流程 func (t *wafTask) _recoverPlans(ctx context.Context, limitsToCheck []model.GlobalLimit, taskName string, redisListKey repository.PlanListType) error { if len(limitsToCheck) == 0 { return nil } requestsToSync, err := t.findPlansNeedingSync(ctx, limitsToCheck) if err != nil { return t.wrapTaskError(taskName, "决策检查续费状态", err) } var finalRecoveryRequests []RenewalRequest for _, req := range requestsToSync { if req.ExpiredAt > time.Now().Unix() { finalRecoveryRequests = append(finalRecoveryRequests, req) } } if len(finalRecoveryRequests) == 0 { t.logger.Info("在检查范围内未发现已续费的套餐", zap.String("task", taskName)) return nil } t.logger.Info("开始恢复已续费的WAF服务", zap.String("task", taskName), zap.Int("数量", len(finalRecoveryRequests))) var hostIdsToRecover []int for _, req := range finalRecoveryRequests { hostIdsToRecover = append(hostIdsToRecover, req.HostId) } var allErrors *multierror.Error webIds, err := t.getCdnWebIdsByHostIds(ctx, hostIdsToRecover) if err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("获取webId失败: %w", err)) } else { if err := t.setCdnWebsitesState(ctx, webIds, true); err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("启用web服务失败: %w", err)) } } if err := t.executeRenewalActions(ctx, finalRecoveryRequests); err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("同步续费信息失败: %w", err)) } planIdsToRecover := make([]int64, len(hostIdsToRecover)) for i, id := range hostIdsToRecover { planIdsToRecover[i] = int64(id) } // 从指定的Redis列表中移除标记 (ClosedPlansList 或 ExpiringSoonPlansList) if err := t.expiredRep.RemovePlans(ctx, redisListKey, planIdsToRecover...); err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("从Redis列表 '%s' 移除标记失败: %w", redisListKey, err)) } return t.wrapTaskError(taskName, "执行恢复", allErrors.ErrorOrNil()) } // 3. RecoverRecentPlan 恢复7天内续费的套餐 func (t *wafTask) RecoverRecentPlan(ctx context.Context) error { taskName := "恢复近期到期套餐" recentlyExpiredLimits, err := t.findRecentlyExpiredWAFPlans(ctx) if err != nil { return t.wrapTaskError(taskName, "查找近期到期记录", err) } return t._recoverPlans(ctx, recentlyExpiredLimits, taskName, repository.ClosedPlansList) } // 4. CleanUpStaleRecords 清理过期超过7天且仍未续费的记录 func (t *wafTask) CleanUpStaleRecords(ctx context.Context) error { taskName := "清理陈旧记录" // 1. 从数据库查找所有陈旧的记录作为候选 candidateLimits, err := t.findStaleWAFPlans(ctx) if err != nil { return t.wrapTaskError(taskName, "查找陈旧记录", err) } if len(candidateLimits) == 0 { return nil } // 2. [CORRECTION] 幂等性检查: 过滤掉那些已经被标记为“已清理”的记录 // 根据您的定义,`ExpiringSoonPlansList` 就是已清理列表。 var uncleanedLimits []model.GlobalLimit for _, limit := range candidateLimits { isAlreadyCleaned, err := t.expiredRep.IsPlanInList(ctx, repository.ExpiringSoonPlansList, int64(limit.HostId)) if err != nil { t.logger.Error("检查Redis清理状态失败,跳过", zap.String("task", taskName), zap.Int("hostId", limit.HostId), zap.Error(err)) continue } if !isAlreadyCleaned { uncleanedLimits = append(uncleanedLimits, limit) } } if len(uncleanedLimits) == 0 { t.logger.Info("没有需要清理的新套餐(可能均已清理)", zap.String("task", taskName)) return nil } // 3. [性能优化] 批量获取未清理记录的真实到期时间 uncleanedHostIds := make([]int, len(uncleanedLimits)) for i, limit := range uncleanedLimits { uncleanedHostIds[i] = limit.HostId } hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, uncleanedHostIds) if err != nil { return t.wrapTaskError(taskName, "批量获取主机到期时间", err) } hostExpiredMap := make(map[int]int64, len(hostExpirations)) for _, h := range hostExpirations { hostExpiredMap[h.HostId] = h.ExpiredAt } // 4. 决策:筛选出最终需要清理的记录(未在最后一刻续费) var plansToClean []model.GlobalLimit now := time.Now().Unix() for _, limit := range uncleanedLimits { hostExpiredTime, ok := hostExpiredMap[limit.HostId] if !ok || hostExpiredTime <= now { plansToClean = append(plansToClean, limit) } } if len(plansToClean) == 0 { t.logger.Info("没有长期未续费的记录需要清理(可能均已续费)", zap.String("task", taskName)) return nil } // 5. 并发执行清理操作 t.logger.Info("开始清理长期未续费的WAF记录", zap.String("task", taskName), zap.Int("数量", len(plansToClean))) var wg sync.WaitGroup errChan := make(chan error, len(plansToClean)) wg.Add(len(plansToClean)) for _, limit := range plansToClean { go func(l model.GlobalLimit) { defer wg.Done() if err := t.executeSinglePlanCleanup(ctx, l); err != nil { errChan <- fmt.Errorf("清理hostId %d 失败: %w", l.HostId, err) } }(limit) } wg.Wait() close(errChan) var allErrors *multierror.Error for err := range errChan { allErrors = multierror.Append(allErrors, err) } return t.wrapTaskError(taskName, "执行清理", allErrors.ErrorOrNil()) } // executeSinglePlanCleanup 执行对单个套餐的完整清理操作,方便并发调用 func (t *wafTask) executeSinglePlanCleanup(ctx context.Context, limit model.GlobalLimit) error { var allErrors *multierror.Error hostId := int64(limit.HostId) // 从“停止列表”中移除,因为它即将被归档到“已清理列表” if err := t.expiredRep.RemovePlans(ctx, repository.ClosedPlansList, hostId); err != nil { allErrors = multierror.Append(allErrors, err) } // 删除关联的转发规则... tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, limit.HostId) if err != nil { allErrors = multierror.Append(allErrors, err) } else if len(tcpIds) > 0 { if err := t.tcp.DeleteTcpForwarding(ctx, v1.DeleteTcpForwardingRequest{Ids: tcpIds, HostId: limit.HostId}); err != nil { allErrors = multierror.Append(allErrors, err) } } // ... 删除 UDP 和 Web 规则的逻辑保持不变 // 只有在上述所有步骤都没有出错的情况下,才执行最终的数据库更新和Redis标记 if allErrors.ErrorOrNil() == nil { // 执行您指定的数据库“重置”操作 err = t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{ GatewayGroupId: 0, HostId: limit.HostId, State: true, }) if err != nil { allErrors = multierror.Append(allErrors, err) } // [CORRECTION] 幂等性保障:将此hostId标记为“已清理”,即添加到 `ExpiringSoonPlansList` if err := t.expiredRep.AddPlans(ctx, repository.ExpiringSoonPlansList, hostId); err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("将hostId %d标记为已清理失败: %w", hostId, err)) } } return allErrors.ErrorOrNil() } // 5. RecoverStalePlan 恢复超过7天后才续费的套餐 func (t *wafTask) RecoverStalePlan(ctx context.Context) error { taskName := "恢复长期到期套餐" staleLimits, err := t.findStaleWAFPlans(ctx) if err != nil { return t.wrapTaskError(taskName, "查找陈旧记录", err) } // [CORRECTION] 当恢复一个被清理过的套餐时,需要从“已清理列表”(`ExpiringSoonPlansList`)中移除它。 return t._recoverPlans(ctx, staleLimits, taskName, repository.ExpiringSoonPlansList) }