package task import ( "context" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" waf2 "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf" "github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn" "github.com/go-nunu/nunu-layout-advanced/internal/service/api/waf" "github.com/hashicorp/go-multierror" "go.uber.org/zap" "sync" "time" ) // WafTask 定义了WAF相关的五个独立定时任务接口 type WafTask interface { // 1. 同步即将到期(1天内)的套餐时间 SynchronizationTime(ctx context.Context) error // 2. 停止所有已到期的套餐 StopPlan(ctx context.Context) error // 3. 恢复7天内续费的套餐 RecoverRecentPlan(ctx context.Context) error // 4. 清理过期超过7天且仍未续费的记录 CleanUpStaleRecords(ctx context.Context) error // 5. 恢复超过7天后才续费的套餐 RecoverStalePlan(ctx context.Context) error } // ================================================================= // =================== 结构体与构造函数 ========================== // ================================================================= func NewWafTask( webForWardingRep waf2.WebForwardingRepository, tcpforwardingRep waf2.TcpforwardingRepository, udpForWardingRep waf2.UdpForWardingRepository, cdn flexCdn.CdnService, hostRep repository.HostRepository, globalLimitRep waf2.GlobalLimitRepository, expiredRep repository.ExpiredRepository, task *Task, gatewayIpRep waf2.GatewayipRepository, tcp waf.TcpforwardingService, udp waf.UdpForWardingService, web waf.WebForwardingService, buildAoDun waf.BuildAudunService, zzyBgp waf.ZzybgpService, ) WafTask { return &wafTask{ Task: task, webForWardingRep: webForWardingRep, tcpforwardingRep: tcpforwardingRep, udpForWardingRep: udpForWardingRep, cdn: cdn, hostRep: hostRep, globalLimitRep: globalLimitRep, expiredRep: expiredRep, gatewayIpRep: gatewayIpRep, tcp: tcp, udp: udp, web: web, buildAoDun: buildAoDun, zzyBgp : zzyBgp, } } type wafTask struct { *Task webForWardingRep waf2.WebForwardingRepository tcpforwardingRep waf2.TcpforwardingRepository udpForWardingRep waf2.UdpForWardingRepository cdn flexCdn.CdnService hostRep repository.HostRepository globalLimitRep waf2.GlobalLimitRepository expiredRep repository.ExpiredRepository gatewayIpRep waf2.GatewayipRepository tcp waf.TcpforwardingService udp waf.UdpForWardingService web waf.WebForwardingService buildAoDun waf.BuildAudunService zzyBgp waf.ZzybgpService } const ( SynchronousInSeconds = 7 * 24 * 60 * 60 SevenDaysInSeconds = 7 * 24 * 60 * 60 ) type RenewalRequest struct { HostId int ExpiredAt int64 } // ================================================================= // =================== 核心辅助函数 (Core Helpers) ================= // ================================================================= // wrapTaskError 统一封装任务错误信息,方便日志和调试 func (t *wafTask) wrapTaskError(taskName, step string, err error) error { if err == nil { return nil } return fmt.Errorf("执行[%s]-%s失败: %w", taskName, step, err) } // getCdnWebIdsByHostIds (原GetCdnWebId) 根据hostId列表获取所有关联的转发规则ID func (t *wafTask) getCdnWebIdsByHostIds(ctx context.Context, hostIds []int) ([]int, error) { if len(hostIds) == 0 { return nil, nil } var ids []int var result *multierror.Error tcpIds, err := t.tcpforwardingRep.GetTcpAll(ctx, hostIds) if err != nil { result = multierror.Append(result, err) } ids = append(ids, tcpIds...) udpIds, err := t.udpForWardingRep.GetUdpAll(ctx, hostIds) if err != nil { result = multierror.Append(result, err) } ids = append(ids, udpIds...) webIds, err := t.webForWardingRep.GetWebAll(ctx, hostIds) if err != nil { result = multierror.Append(result, err) } ids = append(ids, webIds...) return ids, result.ErrorOrNil() } // setCdnWebsitesState (原BanServer) 启用或禁用一组CDN网站 (并发执行) func (t *wafTask) setCdnWebsitesState(ctx context.Context, ids []int, enable bool) error { if len(ids) == 0 { return nil } var wg sync.WaitGroup errChan := make(chan error, len(ids)) wg.Add(len(ids)) for _, id := range ids { go func(id int) { defer wg.Done() // cdn.EditWebIsOn 的第二个参数 isBan, false=启用, true=禁用 // 所以 enable=true 对应 isBan=false if err := t.cdn.EditWebIsOn(ctx, int64(id), enable); err != nil { errChan <- err } }(id) } wg.Wait() close(errChan) var result *multierror.Error for err := range errChan { result = multierror.Append(result, err) } return result.ErrorOrNil() } // executeRenewalActions (原EditExpired) 执行续费操作,包括更新DB和调用CDN API func (t *wafTask) executeRenewalActions(ctx context.Context, reqs []RenewalRequest) error { if len(reqs) == 0 { return nil } var allErrors *multierror.Error var wg sync.WaitGroup wg.Add(len(reqs)) var mu sync.Mutex for _, req := range reqs { go func(r RenewalRequest) { defer wg.Done() // 更新数据库状态 err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{HostId: r.HostId, ExpiredAt: r.ExpiredAt, State: true}) if err != nil { mu.Lock() // 在修改前加锁 allErrors = multierror.Append(allErrors, err) mu.Unlock() // 修改后解锁 return // 如果DB更新失败,不继续调用CDN API } }(req) } wg.Wait() return allErrors.ErrorOrNil() } // ================================================================= // =================== 1. 数据查找与决策层 ========================== // ================================================================= // findPlansNeedingSync (原findMismatchedExpirations) 检查WAF和Host的到期时间差异,返回需要同步的请求 func (t *wafTask) findPlansNeedingSync(ctx context.Context, wafLimits []model.GlobalLimit) ([]RenewalRequest, error) { if len(wafLimits) == 0 { return nil, nil } wafExpiredMap := make(map[int]int64, len(wafLimits)) var hostIds []int for _, limit := range wafLimits { hostIds = append(hostIds, limit.HostId) wafExpiredMap[limit.HostId] = limit.ExpiredAt } hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, hostIds) if err != nil { return nil, fmt.Errorf("获取主机到期时间失败: %w", err) } hostExpiredMap := make(map[int]int64, len(hostExpirations)) for _, h := range hostExpirations { hostExpiredMap[h.HostId] = h.ExpiredAt } var renewalRequests []RenewalRequest for hostId, wafExpiredTime := range wafExpiredMap { hostTime, ok := hostExpiredMap[hostId] if !ok || hostTime != wafExpiredTime { renewalRequests = append(renewalRequests, RenewalRequest{HostId: hostId, ExpiredAt: hostTime}) } } return renewalRequests, nil } // findAllCurrentlyExpiredWAFPlans (原findAllCurrentlyExpiredPlans) 查找所有当前时间点已经到期的WAF记录 func (t *wafTask) findAllCurrentlyExpiredWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { return t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, 0) } // findRecentlyExpiredWAFPlans (原findRecentlyExpiredPlans) 查找在过去7天内到期的WAF记录 func (t *wafTask) findRecentlyExpiredWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { sevenDaysAgo := time.Now().Add(-7 * 24 * time.Hour).Unix() now := time.Now().Unix() return t.globalLimitRep.GetGlobalLimitsByExpirationRange(ctx, sevenDaysAgo, now) } // findStaleWAFPlans (原findStaleExpiredPlans) 查找7天前或更早就已到期的WAF记录 func (t *wafTask) findStaleWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { sevenDaysAgoOffset := int64(-1 * SevenDaysInSeconds) return t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, sevenDaysAgoOffset) } // ================================================================= // ============== 2. 业务执行与公共API层 =========================== // ================================================================= // SynchronizationTime 同步即将到期(1天内)的套餐时间 func (t *wafTask) SynchronizationTime(ctx context.Context) error { taskName := "同步到期时间" wafLimits, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, SynchronousInSeconds) if err != nil { return t.wrapTaskError(taskName, "查找失败", err) } if len(wafLimits) == 0 { return nil } renewalRequests, err := t.findPlansNeedingSync(ctx, wafLimits) if err != nil { return t.wrapTaskError(taskName, "决策失败", err) } if len(renewalRequests) > 0 { t.logger.Info("发现记录需要同步到期时间", zap.String("task", taskName), zap.Int("数量", len(renewalRequests)), zap.Any("套餐内容", renewalRequests)) return t.wrapTaskError(taskName, "执行同步", t.executeRenewalActions(ctx, renewalRequests)) } return nil } // StopPlan 停止所有已到期的套餐 func (t *wafTask) StopPlan(ctx context.Context) error { taskName := "停止到期套餐" // 1. 查找所有理论上已到期的记录 expiredLimits, err := t.findAllCurrentlyExpiredWAFPlans(ctx) if err != nil { return t.wrapTaskError(taskName, "查找失败", err) } if len(expiredLimits) == 0 { return nil } // 2. 决策 - 第1步:检查这些记录中是否已有续费但未同步的 renewalRequests, err := t.findPlansNeedingSync(ctx, expiredLimits) if err != nil { return t.wrapTaskError(taskName, "决策检查续费", err) } renewedHostIds := make(map[int]struct{}, len(renewalRequests)) for _, req := range renewalRequests { if req.ExpiredAt > time.Now().Unix() { renewedHostIds[req.HostId] = struct{}{} } } // 2. 决策 - 第2步:筛选出真正需要停止的记录 var plansToClose []model.GlobalLimit for _, limit := range expiredLimits { if _, found := renewedHostIds[limit.HostId]; found { t.logger.Info("发现已到期但刚续费的套餐,跳过停止操作", zap.String("task", taskName), zap.Int("hostId", limit.HostId)) continue } isClosed, err := t.expiredRep.IsPlanInList(ctx, repository.ClosedPlansList, int64(limit.HostId)) if err != nil { t.logger.Error("决策[停止]:检查Redis套餐状态失败", zap.String("task", taskName), zap.Int("hostId", limit.HostId), zap.Error(err)) continue } if !isClosed { plansToClose = append(plansToClose, limit) } } if len(plansToClose) == 0 { t.logger.Info("没有需要停止的套餐(可能均已续费或已在停止列表)", zap.String("task", taskName)) return nil } // 3. 执行停止操作 t.logger.Info("开始关闭到期的WAF服务", zap.String("task", taskName), zap.Int("数量", len(plansToClose)), zap.Any("套餐内容", renewalRequests)) var hostIds []int for _, limit := range plansToClose { hostIds = append(hostIds, limit.HostId) } var allErrors *multierror.Error webIds, err := t.getCdnWebIdsByHostIds(ctx, hostIds) if err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("获取cdn_web_id失败: %w", err)) } else { if err := t.setCdnWebsitesState(ctx, webIds, false); err != nil { // enable=false allErrors = multierror.Append(allErrors, fmt.Errorf("禁用服务失败: %w", err)) } } closedPlanIds := make([]int64, len(hostIds)) for i, id := range hostIds { closedPlanIds[i] = int64(id) } if err := t.expiredRep.AddPlans(ctx, repository.ClosedPlansList, closedPlanIds...); err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("标记为已关闭失败: %w", err)) } return t.wrapTaskError(taskName, "执行停止", allErrors.ErrorOrNil()) } // _recoverPlans 是一个统一的、可重用的套餐恢复流程 func (t *wafTask) _recoverPlans(ctx context.Context, limitsToCheck []model.GlobalLimit, taskName string, redisListKey repository.PlanListType) error { if len(limitsToCheck) == 0 { return nil } requestsToSync, err := t.findPlansNeedingSync(ctx, limitsToCheck) if err != nil { return t.wrapTaskError(taskName, "决策检查续费状态", err) } var finalRecoveryRequests []RenewalRequest for _, req := range requestsToSync { if req.ExpiredAt > time.Now().Unix() { finalRecoveryRequests = append(finalRecoveryRequests, req) } } if len(finalRecoveryRequests) == 0 { t.logger.Info("在检查范围内未发现已续费的套餐", zap.String("task", taskName)) return nil } t.logger.Info("开始恢复已续费的WAF服务", zap.String("task", taskName), zap.Int("数量", len(finalRecoveryRequests)), zap.Any("套餐内容", finalRecoveryRequests)) var hostIdsToRecover []int for _, req := range finalRecoveryRequests { hostIdsToRecover = append(hostIdsToRecover, req.HostId) } var allErrors *multierror.Error webIds, err := t.getCdnWebIdsByHostIds(ctx, hostIdsToRecover) if err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("获取webId失败: %w", err)) } else { if err := t.setCdnWebsitesState(ctx, webIds, true); err != nil { // enable=true allErrors = multierror.Append(allErrors, fmt.Errorf("启用web服务失败: %w", err)) } } if err := t.executeRenewalActions(ctx, finalRecoveryRequests); err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("同步续费信息失败: %w", err)) } planIdsToRecover := make([]int64, len(hostIdsToRecover)) for i, id := range hostIdsToRecover { planIdsToRecover[i] = int64(id) } // 从指定的Redis列表中移除标记 (ClosedPlansList 或 ExpiringSoonPlansList) if err := t.expiredRep.RemovePlans(ctx, redisListKey, planIdsToRecover...); err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("从Redis列表 '%s' 移除标记失败: %w", redisListKey, err)) } return t.wrapTaskError(taskName, "执行恢复", allErrors.ErrorOrNil()) } // 3. RecoverRecentPlan 恢复7天内续费的套餐 func (t *wafTask) RecoverRecentPlan(ctx context.Context) error { taskName := "恢复近期到期套餐" recentlyExpiredLimits, err := t.findRecentlyExpiredWAFPlans(ctx) if err != nil { return t.wrapTaskError(taskName, "查找近期到期记录", err) } return t._recoverPlans(ctx, recentlyExpiredLimits, taskName, repository.ClosedPlansList) } // 4. CleanUpStaleRecords 清理过期超过7天且仍未续费的记录 func (t *wafTask) CleanUpStaleRecords(ctx context.Context) error { taskName := "清理陈旧记录" // 1. 从数据库查找所有陈旧的记录作为候选 candidateLimits, err := t.findStaleWAFPlans(ctx) if err != nil { return t.wrapTaskError(taskName, "查找陈旧记录", err) } if len(candidateLimits) == 0 { return nil } // 2. [CORRECTION] 幂等性检查: 过滤掉那些已经被标记为“已清理”的记录 // 根据您的定义,`ExpiringSoonPlansList` 就是已清理列表。 var uncleanedLimits []model.GlobalLimit for _, limit := range candidateLimits { isAlreadyCleaned, err := t.expiredRep.IsPlanInList(ctx, repository.ExpiringSoonPlansList, int64(limit.HostId)) if err != nil { t.logger.Error("检查Redis清理状态失败,跳过", zap.String("task", taskName), zap.Int("hostId", limit.HostId), zap.Error(err)) continue } if !isAlreadyCleaned { uncleanedLimits = append(uncleanedLimits, limit) } } if len(uncleanedLimits) == 0 { t.logger.Info("没有需要清理的新套餐(可能均已清理)", zap.String("task", taskName)) return nil } // 3. [性能优化] 批量获取未清理记录的真实到期时间 uncleanedHostIds := make([]int, len(uncleanedLimits)) for i, limit := range uncleanedLimits { uncleanedHostIds[i] = limit.HostId } hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, uncleanedHostIds) if err != nil { return t.wrapTaskError(taskName, "批量获取主机到期时间", err) } hostExpiredMap := make(map[int]int64, len(hostExpirations)) for _, h := range hostExpirations { hostExpiredMap[h.HostId] = h.ExpiredAt } // 4. 决策:筛选出最终需要清理的记录(未在最后一刻续费) var plansToClean []model.GlobalLimit now := time.Now().Unix() for _, limit := range uncleanedLimits { hostExpiredTime, ok := hostExpiredMap[limit.HostId] // 清理条件:主机记录不存在,或者主机记录的到期时间是过去时 if !ok || hostExpiredTime <= now { plansToClean = append(plansToClean, limit) } } if len(plansToClean) == 0 { t.logger.Info("没有长期未续费的记录需要清理(可能均已续费)", zap.String("task", taskName)) return nil } // 5. 并发执行清理操作 t.logger.Info("开始清理长期未续费的WAF记录", zap.String("task", taskName), zap.Int("数量", len(plansToClean)), zap.Any("套餐内容", plansToClean)) var wg sync.WaitGroup errChan := make(chan error, len(plansToClean)) wg.Add(len(plansToClean)) for _, limit := range plansToClean { go func(l model.GlobalLimit) { defer wg.Done() if err := t.executeSinglePlanCleanup(ctx, l); err != nil { errChan <- fmt.Errorf("清理hostId %d 失败: %w", l.HostId, err) } }(limit) } wg.Wait() close(errChan) var allErrors *multierror.Error for err := range errChan { allErrors = multierror.Append(allErrors, err) } return t.wrapTaskError(taskName, "执行清理", allErrors.ErrorOrNil()) } // executeSinglePlanCleanup 执行对单个套餐的完整清理操作,方便并发调用 func (t *wafTask) executeSinglePlanCleanup(ctx context.Context, limit model.GlobalLimit) error { var allErrors *multierror.Error hostId := int64(limit.HostId) // 从“停止列表”中移除,因为它即将被归档到“已清理列表” if err := t.expiredRep.RemovePlans(ctx, repository.ClosedPlansList, hostId); err != nil { allErrors = multierror.Append(allErrors, err) } // 删除关联的转发规则... tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, limit.HostId) if err != nil { allErrors = multierror.Append(allErrors, err) } else if len(tcpIds) > 0 { if err := t.tcp.DeleteTcpForwarding(ctx, v1.DeleteTcpForwardingRequest{Ids: tcpIds, HostId: limit.HostId,Uid: limit.Uid}); err != nil { allErrors = multierror.Append(allErrors, err) } } udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, limit.HostId) if err != nil { allErrors = multierror.Append(allErrors, err) } else if len(udpIds) > 0 { if err := t.udp.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{Ids: udpIds, HostId: limit.HostId,Uid: limit.Uid}); err != nil { allErrors = multierror.Append(allErrors, err) } } webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, limit.HostId) if err != nil { allErrors = multierror.Append(allErrors, err) } else if len(webIds) > 0 { if err := t.web.DeleteWebForwarding(ctx, v1.DeleteWebForwardingRequest{Ids: webIds, HostId: limit.HostId,Uid: limit.Uid}); err != nil { allErrors = multierror.Append(allErrors, err) } } // 重置防护 err = t.zzyBgp.SetDefense(ctx, hostId, 10) if err != nil { return err } // 清除小防火墙带宽限制 if err := t.buildAoDun.Bandwidth(ctx, hostId, "del"); err != nil { allErrors = multierror.Append(allErrors, err) } // 只有在上述所有步骤都没有出错的情况下,才执行最终的数据库更新和Redis标记 if allErrors.ErrorOrNil() == nil { err := t.gatewayIpRep.CleanIPByHostId(ctx, []int64{hostId}) if err != nil { allErrors = multierror.Append(allErrors, err) } // [CORRECTION] 幂等性保障:将此hostId标记为“已清理”,即添加到 `ExpiringSoonPlansList` if err := t.expiredRep.AddPlans(ctx, repository.ExpiringSoonPlansList, hostId); err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("将hostId %d标记为已清理失败: %w", hostId, err)) } } return allErrors.ErrorOrNil() } // 5. RecoverStalePlan 恢复超过7天后才续费的套餐 func (t *wafTask) RecoverStalePlan(ctx context.Context) error { taskName := "恢复长期到期套餐" staleLimits, err := t.findStaleWAFPlans(ctx) if err != nil { return t.wrapTaskError(taskName, "查找陈旧记录", err) } // [CORRECTION] 当恢复一个被清理过的套餐时,需要从“已清理列表”(`ExpiringSoonPlansList`)中移除它。 return t._recoverPlans(ctx, staleLimits, taskName, repository.ExpiringSoonPlansList) }