package task import ( "context" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" "github.com/go-nunu/nunu-layout-advanced/internal/service" "github.com/hashicorp/go-multierror" "go.uber.org/zap" "sync" "time" ) // WafTask 定义了WAF相关的五个独立定时任务接口 type WafTask interface { // 1. 同步即将到期(1天内)的套餐时间 SynchronizationTime(ctx context.Context) error // 2. 停止所有已到期的套餐 StopPlan(ctx context.Context) error // 3. 恢复7天内续费的套餐 RecoverRecentPlan(ctx context.Context) error // 4. 清理过期超过7天且仍未续费的记录 CleanUpStaleRecords(ctx context.Context) error // 5. 恢复超过7天后才续费的套餐 RecoverStalePlan(ctx context.Context) error } func NewWafTask( webForWardingRep repository.WebForwardingRepository, tcpforwardingRep repository.TcpforwardingRepository, udpForWardingRep repository.UdpForWardingRepository, cdn service.CdnService, hostRep repository.HostRepository, globalLimitRep repository.GlobalLimitRepository, expiredRep repository.ExpiredRepository, task *Task, gatewayGroupIpRep repository.GateWayGroupIpRepository, tcp service.TcpforwardingService, udp service.UdpForWardingService, web service.WebForwardingService, ) WafTask { return &wafTask{ Task: task, webForWardingRep: webForWardingRep, tcpforwardingRep: tcpforwardingRep, udpForWardingRep: udpForWardingRep, cdn: cdn, hostRep: hostRep, globalLimitRep: globalLimitRep, expiredRep: expiredRep, gatewayGroupIpRep: gatewayGroupIpRep, tcp: tcp, udp: udp, web: web, } } type wafTask struct { *Task webForWardingRep repository.WebForwardingRepository tcpforwardingRep repository.TcpforwardingRepository udpForWardingRep repository.UdpForWardingRepository cdn service.CdnService hostRep repository.HostRepository globalLimitRep repository.GlobalLimitRepository expiredRep repository.ExpiredRepository gatewayGroupIpRep repository.GateWayGroupIpRepository tcp service.TcpforwardingService udp service.UdpForWardingService web service.WebForwardingService } const ( // 1天对应的秒数 OneDaysInSeconds = 1 * 24 * 60 * 60 // 7天对应的秒数 SevenDaysInSeconds = 7 * 24 * 60 * 60 ) // RenewalRequest 续费操作请求结构体 type RenewalRequest struct { HostId int PlanId int ExpiredAt int64 } // ================================================================= // =================== 原始辅助函数 (Helpers) ===================== // ================================================================= // 获取cdn web id func (t wafTask) GetCdnWebId(ctx context.Context,hostId int) ([]int, error) { tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, hostId) if err != nil { return nil, err } udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, hostId) if err != nil { return nil, err } webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, hostId) if err != nil { return nil, err } var ids []int ids = append(ids, tcpIds...) ids = append(ids, udpIds...) ids = append(ids, webIds...) return ids, nil } // BanServer 启用/禁用 网站 (并发执行) func (t wafTask) BanServer(ctx context.Context, ids []int, isBan bool) error { if len(ids) == 0 { return nil } var wg sync.WaitGroup errChan := make(chan error, len(ids)) wg.Add(len(ids)) for _, id := range ids { go func(id int) { defer wg.Done() if err := t.cdn.EditWebIsOn(ctx, int64(id), isBan); err != nil { errChan <- err } }(id) } wg.Wait() close(errChan) var result error for err := range errChan { result = multierror.Append(result, err) } return result } // EditExpired 统一的续费操作入口 func (t wafTask) EditExpired(ctx context.Context, reqs []RenewalRequest) error { if len(reqs) == 0 { return nil } var globalLimitUpdates []struct { hostId int; expiredAt int64 } var planRenewals []struct { planId int; expiredAt int64 } for _, req := range reqs { globalLimitUpdates = append(globalLimitUpdates, struct{ hostId int; expiredAt int64 }{req.HostId, req.ExpiredAt}) planRenewals = append(planRenewals, struct{ planId int; expiredAt int64 }{req.PlanId, req.ExpiredAt}) } var result *multierror.Error if err := t.editGlobalLimitState(ctx, globalLimitUpdates, true); err != nil { result = multierror.Append(result, err) } if err := t.renewCdnPlan(ctx, planRenewals); err != nil { result = multierror.Append(result, err) } return result.ErrorOrNil() } // editGlobalLimitState 内部函数,用于更新数据库中的状态和时间 func (t wafTask) editGlobalLimitState(ctx context.Context, req []struct { hostId int; expiredAt int64 }, state bool) error { var result *multierror.Error for _, v := range req { err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{HostId: v.hostId, ExpiredAt: v.expiredAt, State: state}) if err != nil { result = multierror.Append(result, err) } } return result.ErrorOrNil() } // renewCdnPlan 内部函数,用于调用CDN服务进行续费 func (t wafTask) renewCdnPlan(ctx context.Context, req []struct { planId int; expiredAt int64 }) error { var result *multierror.Error for _, v := range req { err := t.cdn.RenewPlan(ctx, v1.RenewalPlan{ UserPlanId: int64(v.planId), IsFree: true, DayTo: time.Unix(v.expiredAt, 0).Format("2006-01-02"), Period: "monthly", CountPeriod: 1, PeriodDayTo: time.Unix(v.expiredAt, 0).Format("2006-01-02"), }) if err != nil { result = multierror.Append(result, err) } } return result.ErrorOrNil() } // ================================================================= // =================== 1. 数据查找层 (Finders) ===================== // ================================================================= // findMismatchedExpirations 检查 WAF 和 Host 的到期时间差异。这是决策的核心。 func (t *wafTask) findMismatchedExpirations(ctx context.Context, wafLimits []model.GlobalLimit) ([]RenewalRequest, error) { if len(wafLimits) == 0 { return nil, nil } wafExpiredMap := make(map[int]int64, len(wafLimits)) wafPlanMap := make(map[int]int, len(wafLimits)) var hostIds []int for _, limit := range wafLimits { hostIds = append(hostIds, limit.HostId) wafExpiredMap[limit.HostId] = limit.ExpiredAt wafPlanMap[limit.HostId] = limit.RuleId } hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, hostIds) if err != nil { return nil, fmt.Errorf("获取主机到期时间失败: %w", err) } hostExpiredMap := make(map[int]int64, len(hostExpirations)) for _, h := range hostExpirations { hostExpiredMap[h.HostId] = h.ExpiredAt } var renewalRequests []RenewalRequest for hostId, wafExpiredTime := range wafExpiredMap { hostTime, ok := hostExpiredMap[hostId] if !ok || hostTime != wafExpiredTime { planId, planOk := wafPlanMap[hostId] if !planOk { t.logger.Warn("数据不一致:在waf_limits中找不到hostId对应的套餐ID", zap.Int("hostId", hostId)) continue } renewalRequests = append(renewalRequests, RenewalRequest{HostId: hostId, ExpiredAt: hostTime, PlanId: planId}) } } return renewalRequests, nil } // findAllCurrentlyExpiredPlans 查找所有当前时间点已经到期的WAF记录。 func (t *wafTask) findAllCurrentlyExpiredPlans(ctx context.Context) ([]model.GlobalLimit, error) { return t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, 0) } // findRecentlyExpiredPlans (精确查找) 查找在过去7天内到期的WAF记录。 func (t *wafTask) findRecentlyExpiredPlans(ctx context.Context) ([]model.GlobalLimit, error) { sevenDaysAgo := time.Now().Add(-7 * 24 * time.Hour).Unix() now := time.Now().Unix() return t.globalLimitRep.GetGlobalLimitsByExpirationRange(ctx, sevenDaysAgo, now) } // findStaleExpiredPlans (精确查找) 查找7天前或更早就已到期的WAF记录。 func (t *wafTask) findStaleExpiredPlans(ctx context.Context) ([]model.GlobalLimit, error) { sevenDaysAgoOffset := int64(-1 * SevenDaysInSeconds) return t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, sevenDaysAgoOffset) } // ================================================================= // =================== 2. 业务决策层 (Filters) ===================== // ================================================================= // filterCleanablePlans (精确决策) 从长期过期的列表中,筛选出确认未续费且需要被清理的记录。 func (t *wafTask) filterCleanablePlans(ctx context.Context, staleLimits []model.GlobalLimit) ([]model.GlobalLimit, error) { renewedStalePlans, err := t.findMismatchedExpirations(ctx, staleLimits) if err != nil { return nil, fmt.Errorf("决策[清理]: 检查续费状态失败: %w", err) } renewedHostIds := make(map[int]struct{}, len(renewedStalePlans)) for _, req := range renewedStalePlans { renewedHostIds[req.HostId] = struct{}{} } var plansToClean []model.GlobalLimit for _, limit := range staleLimits { if _, found := renewedHostIds[limit.HostId]; !found { plansToClean = append(plansToClean, limit) } } return plansToClean, nil } // ================================================================= // ============== 3. 业务执行层 (Executors & Public API) ============= // ================================================================= // executePlanRecovery (可重用) 负责恢复套餐的所有步骤。 func (t *wafTask) executePlanRecovery(ctx context.Context, renewalRequests []RenewalRequest, taskName string,key repository.PlanListType) error { t.logger.Info(fmt.Sprintf("开始执行[%s]套餐恢复流程", taskName), zap.Int("数量", len(renewalRequests))) var hostIds []int for _, req := range renewalRequests { hostIds = append(hostIds, req.HostId) } var allErrors *multierror.Error for _, v := range renewalRequests { webIds, err := t.GetCdnWebId(ctx, v.HostId) if err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-获取webId失败: %w", taskName, err)) } if err := t.BanServer(ctx, webIds, true); err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-封禁webId失败: %w", taskName, err)) } } if err := t.EditExpired(ctx, renewalRequests); err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-同步续费信息失败: %w", taskName, err)) } planIdsToRecover := make([]int64, len(hostIds)) for i, id := range hostIds { planIdsToRecover[i] = int64(id) } if err := t.expiredRep.RemovePlans(ctx, key, planIdsToRecover...); err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-移除Redis关闭标记失败: %w", taskName, err)) } return allErrors.ErrorOrNil() } // 1. SynchronizationTime 同步即将到期(1天内)的套餐时间 func (t *wafTask) SynchronizationTime(ctx context.Context) error { wafLimits, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, OneDaysInSeconds) if err != nil { return fmt.Errorf("执行[同步]-查找失败: %w", err) } renewalRequests, err := t.findMismatchedExpirations(ctx, wafLimits) if err != nil { return fmt.Errorf("执行[同步]-决策失败: %w", err) } if len(renewalRequests) > 0 { t.logger.Info("发现记录需要同步到期时间。", zap.Int("数量", len(renewalRequests))) return t.EditExpired(ctx, renewalRequests) } return nil } // 2. StopPlan (已优化) 停止所有已到期的套餐 func (t *wafTask) StopPlan(ctx context.Context) error { // 1. 查找所有理论上已到期的记录 expiredLimits, err := t.findAllCurrentlyExpiredPlans(ctx) if err != nil { return fmt.Errorf("执行[停止]-查找失败: %w", err) } if len(expiredLimits) == 0 { return nil } // 2. 决策 - 第1步:检查这些记录中是否已有续费但未同步的 renewalRequests, err := t.findMismatchedExpirations(ctx, expiredLimits) if err != nil { return fmt.Errorf("执行[停止]-决策检查续费失败: %w", err) } renewedHostIds := make(map[int]struct{}, len(renewalRequests)) for _, req := range renewalRequests { renewedHostIds[req.HostId] = struct{}{} } // 2. 决策 - 第2步:筛选出真正需要停止的记录 var plansToClose []model.GlobalLimit for _, limit := range expiredLimits { if _, found := renewedHostIds[limit.HostId]; found { t.logger.Info("发现已到期但刚续费的套餐,跳过停止操作", zap.Int("hostId", limit.HostId)) continue } isClosed, err := t.expiredRep.IsPlanInList(ctx, repository.ClosedPlansList, int64(limit.HostId)) if err != nil { t.logger.Error("决策[停止]:检查套餐是否已关闭失败", zap.Int("hostId", limit.HostId), zap.Error(err)) continue } if !isClosed { plansToClose = append(plansToClose, limit) } } if len(plansToClose) == 0 { t.logger.Info("没有需要停止的套餐(可能均已续费或已关闭)") return nil } // 3. 执行停止操作 t.logger.Info("开始关闭到期的WAF服务", zap.Int("数量", len(plansToClose))) var hostIds []int for _, limit := range plansToClose { hostIds = append(hostIds, limit.HostId) } for _, hostId := range hostIds { webIds, err := t.GetCdnWebId(ctx, hostId) if err != nil { return fmt.Errorf("执行[停止]-获取cdn_web_id失败: %w", err) } if err := t.BanServer(ctx, webIds, false); err != nil { return fmt.Errorf("执行[停止]-禁用服务失败: %w", err) } } closedPlanIds := make([]int64, len(hostIds)) for i, id := range hostIds { closedPlanIds[i] = int64(id) } if err := t.expiredRep.AddPlans(ctx, repository.ClosedPlansList, closedPlanIds...); err != nil { return fmt.Errorf("执行[停止]-标记为已关闭失败: %w", err) } return nil } // 3. RecoverRecentPlan 恢复7天内续费的套餐 func (t *wafTask) RecoverRecentPlan(ctx context.Context) error { recentlyExpiredLimits, err := t.findRecentlyExpiredPlans(ctx) if err != nil { return fmt.Errorf("执行[近期恢复]-查找失败: %w", err) } if len(recentlyExpiredLimits) == 0 { return nil } renewalRequests, err := t.findMismatchedExpirations(ctx, recentlyExpiredLimits) if err != nil { return fmt.Errorf("执行[近期恢复]-决策失败: %w", err) } if len(renewalRequests) == 0 { t.logger.Info("在近期过期的套餐中,没有发现已续费的") return nil } return t.executePlanRecovery(ctx, renewalRequests, "近期恢复",repository.ClosedPlansList) } // 4. CleanUpStaleRecords 清理过期超过7天且仍未续费的记录 func (t *wafTask) CleanUpStaleRecords(ctx context.Context) error { staleLimits, err := t.findStaleExpiredPlans(ctx) if err != nil { return fmt.Errorf("执行[清理]-查找失败: %w", err) } if len(staleLimits) == 0 { return nil } plansToClean, err := t.filterCleanablePlans(ctx, staleLimits) if err != nil { return fmt.Errorf("执行[清理]-决策失败: %w", err) } if len(plansToClean) == 0 { t.logger.Info("没有长期未续费的记录需要清理") return nil } t.logger.Info("开始清理长期未续费的WAF记录", zap.Int("数量", len(plansToClean))) var planIdsToClean []int64 for _, limit := range plansToClean { planIdsToClean = append(planIdsToClean, int64(limit.HostId)) } if err := t.expiredRep.RemovePlans(ctx, repository.ClosedPlansList, planIdsToClean...); err != nil { return fmt.Errorf("执行[清理]-从Redis移除关闭标记失败: %w", err) } if err := t.expiredRep.AddPlans(ctx, repository.ExpiringSoonPlansList, planIdsToClean...); err != nil { return fmt.Errorf("执行[清理]-从Redis移除过期标记失败: %w", err) } // 在这里可以添加从数据库删除或调用CDN API彻底删除的逻辑 for _, limit := range plansToClean { err = t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{ HostId: limit.HostId, GatewayGroupId: limit.GatewayGroupId, State: true, }) if err != nil { return fmt.Errorf("执行[清理]-更新套餐状态失败: %w", err) } tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, limit.HostId) if err != nil { return err } udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, limit.HostId) if err != nil { return err } webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, limit.HostId) if err != nil { return err } err = t.tcp.DeleteTcpForwarding(ctx, v1.DeleteTcpForwardingRequest{ Ids: tcpIds, Uid: 0, HostId: limit.HostId, }) if err != nil { return err } err = t.udp.DeleteUdpForwarding(ctx, udpIds) if err != nil { return err } err = t.web.DeleteWebForwarding(ctx, webIds) if err != nil { return err } } return nil } // 5. RecoverStalePlan 恢复超过7天后才续费的套餐 func (t *wafTask) RecoverStalePlan(ctx context.Context) error { staleLimits, err := t.findStaleExpiredPlans(ctx) if err != nil { return fmt.Errorf("执行[长期恢复]-查找失败: %w", err) } if len(staleLimits) == 0 { return nil } renewalRequests, err := t.findMismatchedExpirations(ctx, staleLimits) if err != nil { return fmt.Errorf("执行[长期恢复]-决策失败: %w", err) } if len(renewalRequests) == 0 { t.logger.Info("在长期过期的套餐中,没有发现已续费的") return nil } return t.executePlanRecovery(ctx, renewalRequests, "长期恢复",repository.ExpiringSoonPlansList) }