|
- package task
- import (
- "context"
- "fmt"
- v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
- "github.com/go-nunu/nunu-layout-advanced/internal/model"
- "github.com/go-nunu/nunu-layout-advanced/internal/repository"
- "github.com/go-nunu/nunu-layout-advanced/internal/service"
- "github.com/hashicorp/go-multierror"
- "go.uber.org/zap"
- "sync"
- "time"
- )
- // WafTask 定义了WAF相关的五个独立定时任务接口
- type WafTask interface {
- // 1. 同步即将到期(1天内)的套餐时间
- SynchronizationTime(ctx context.Context) error
- // 2. 停止所有已到期的套餐
- StopPlan(ctx context.Context) error
- // 3. 恢复7天内续费的套餐
- RecoverRecentPlan(ctx context.Context) error
- // 4. 清理过期超过7天且仍未续费的记录
- CleanUpStaleRecords(ctx context.Context) error
- // 5. 恢复超过7天后才续费的套餐
- RecoverStalePlan(ctx context.Context) error
- }
- func NewWafTask(
- webForWardingRep repository.WebForwardingRepository,
- tcpforwardingRep repository.TcpforwardingRepository,
- udpForWardingRep repository.UdpForWardingRepository,
- cdn service.CdnService,
- hostRep repository.HostRepository,
- globalLimitRep repository.GlobalLimitRepository,
- expiredRep repository.ExpiredRepository,
- task *Task,
- gatewayGroupIpRep repository.GateWayGroupIpRepository,
- tcp service.TcpforwardingService,
- udp service.UdpForWardingService,
- web service.WebForwardingService,
- ) WafTask {
- return &wafTask{
- Task: task,
- webForWardingRep: webForWardingRep,
- tcpforwardingRep: tcpforwardingRep,
- udpForWardingRep: udpForWardingRep,
- cdn: cdn,
- hostRep: hostRep,
- globalLimitRep: globalLimitRep,
- expiredRep: expiredRep,
- gatewayGroupIpRep: gatewayGroupIpRep,
- tcp: tcp,
- udp: udp,
- web: web,
- }
- }
- type wafTask struct {
- *Task
- webForWardingRep repository.WebForwardingRepository
- tcpforwardingRep repository.TcpforwardingRepository
- udpForWardingRep repository.UdpForWardingRepository
- cdn service.CdnService
- hostRep repository.HostRepository
- globalLimitRep repository.GlobalLimitRepository
- expiredRep repository.ExpiredRepository
- gatewayGroupIpRep repository.GateWayGroupIpRepository
- tcp service.TcpforwardingService
- udp service.UdpForWardingService
- web service.WebForwardingService
- }
- const (
- // 1天对应的秒数
- OneDaysInSeconds = 1 * 24 * 60 * 60
- // 7天对应的秒数
- SevenDaysInSeconds = 7 * 24 * 60 * 60
- )
- // RenewalRequest 续费操作请求结构体
- type RenewalRequest struct {
- HostId int
- PlanId int
- ExpiredAt int64
- }
- // =================================================================
- // =================== 原始辅助函数 (Helpers) =====================
- // =================================================================
- // 获取cdn web id
- func (t wafTask) GetCdnWebId(ctx context.Context,hostId int) ([]int, error) {
- tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, hostId)
- if err != nil {
- return nil, err
- }
- udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, hostId)
- if err != nil {
- return nil, err
- }
- webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, hostId)
- if err != nil {
- return nil, err
- }
- var ids []int
- ids = append(ids, tcpIds...)
- ids = append(ids, udpIds...)
- ids = append(ids, webIds...)
- return ids, nil
- }
- // BanServer 启用/禁用 网站 (并发执行)
- func (t wafTask) BanServer(ctx context.Context, ids []int, isBan bool) error {
- if len(ids) == 0 { return nil }
- var wg sync.WaitGroup
- errChan := make(chan error, len(ids))
- wg.Add(len(ids))
- for _, id := range ids {
- go func(id int) {
- defer wg.Done()
- if err := t.cdn.EditWebIsOn(ctx, int64(id), isBan); err != nil {
- errChan <- err
- }
- }(id)
- }
- wg.Wait()
- close(errChan)
- var result error
- for err := range errChan {
- result = multierror.Append(result, err)
- }
- return result
- }
- // EditExpired 统一的续费操作入口
- func (t wafTask) EditExpired(ctx context.Context, reqs []RenewalRequest) error {
- if len(reqs) == 0 { return nil }
- var globalLimitUpdates []struct { hostId int; expiredAt int64 }
- var planRenewals []struct { planId int; expiredAt int64 }
- for _, req := range reqs {
- globalLimitUpdates = append(globalLimitUpdates, struct{ hostId int; expiredAt int64 }{req.HostId, req.ExpiredAt})
- planRenewals = append(planRenewals, struct{ planId int; expiredAt int64 }{req.PlanId, req.ExpiredAt})
- }
- var result *multierror.Error
- if err := t.editGlobalLimitState(ctx, globalLimitUpdates, true); err != nil {
- result = multierror.Append(result, err)
- }
- if err := t.renewCdnPlan(ctx, planRenewals); err != nil {
- result = multierror.Append(result, err)
- }
- return result.ErrorOrNil()
- }
- // editGlobalLimitState 内部函数,用于更新数据库中的状态和时间
- func (t wafTask) editGlobalLimitState(ctx context.Context, req []struct { hostId int; expiredAt int64 }, state bool) error {
- var result *multierror.Error
- for _, v := range req {
- err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{HostId: v.hostId, ExpiredAt: v.expiredAt, State: state})
- if err != nil { result = multierror.Append(result, err) }
- }
- return result.ErrorOrNil()
- }
- // renewCdnPlan 内部函数,用于调用CDN服务进行续费
- func (t wafTask) renewCdnPlan(ctx context.Context, req []struct { planId int; expiredAt int64 }) error {
- var result *multierror.Error
- for _, v := range req {
- err := t.cdn.RenewPlan(ctx, v1.RenewalPlan{
- UserPlanId: int64(v.planId), IsFree: true, DayTo: time.Unix(v.expiredAt, 0).Format("2006-01-02"),
- Period: "monthly", CountPeriod: 1, PeriodDayTo: time.Unix(v.expiredAt, 0).Format("2006-01-02"),
- })
- if err != nil { result = multierror.Append(result, err) }
- }
- return result.ErrorOrNil()
- }
- // =================================================================
- // =================== 1. 数据查找层 (Finders) =====================
- // =================================================================
- // findMismatchedExpirations 检查 WAF 和 Host 的到期时间差异。这是决策的核心。
- func (t *wafTask) findMismatchedExpirations(ctx context.Context, wafLimits []model.GlobalLimit) ([]RenewalRequest, error) {
- if len(wafLimits) == 0 { return nil, nil }
- wafExpiredMap := make(map[int]int64, len(wafLimits))
- wafPlanMap := make(map[int]int, len(wafLimits))
- var hostIds []int
- for _, limit := range wafLimits {
- hostIds = append(hostIds, limit.HostId)
- wafExpiredMap[limit.HostId] = limit.ExpiredAt
- wafPlanMap[limit.HostId] = limit.RuleId
- }
- hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, hostIds)
- if err != nil { return nil, fmt.Errorf("获取主机到期时间失败: %w", err) }
- hostExpiredMap := make(map[int]int64, len(hostExpirations))
- for _, h := range hostExpirations { hostExpiredMap[h.HostId] = h.ExpiredAt }
- var renewalRequests []RenewalRequest
- for hostId, wafExpiredTime := range wafExpiredMap {
- hostTime, ok := hostExpiredMap[hostId]
- if !ok || hostTime != wafExpiredTime {
- planId, planOk := wafPlanMap[hostId]
- if !planOk {
- t.logger.Warn("数据不一致:在waf_limits中找不到hostId对应的套餐ID", zap.Int("hostId", hostId))
- continue
- }
- renewalRequests = append(renewalRequests, RenewalRequest{HostId: hostId, ExpiredAt: hostTime, PlanId: planId})
- }
- }
- return renewalRequests, nil
- }
- // findAllCurrentlyExpiredPlans 查找所有当前时间点已经到期的WAF记录。
- func (t *wafTask) findAllCurrentlyExpiredPlans(ctx context.Context) ([]model.GlobalLimit, error) {
- return t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, 0)
- }
- // findRecentlyExpiredPlans (精确查找) 查找在过去7天内到期的WAF记录。
- func (t *wafTask) findRecentlyExpiredPlans(ctx context.Context) ([]model.GlobalLimit, error) {
- sevenDaysAgo := time.Now().Add(-7 * 24 * time.Hour).Unix()
- now := time.Now().Unix()
- return t.globalLimitRep.GetGlobalLimitsByExpirationRange(ctx, sevenDaysAgo, now)
- }
- // findStaleExpiredPlans (精确查找) 查找7天前或更早就已到期的WAF记录。
- func (t *wafTask) findStaleExpiredPlans(ctx context.Context) ([]model.GlobalLimit, error) {
- sevenDaysAgoOffset := int64(-1 * SevenDaysInSeconds)
- return t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, sevenDaysAgoOffset)
- }
- // =================================================================
- // =================== 2. 业务决策层 (Filters) =====================
- // =================================================================
- // filterCleanablePlans (精确决策) 从长期过期的列表中,筛选出确认未续费且需要被清理的记录。
- func (t *wafTask) filterCleanablePlans(ctx context.Context, staleLimits []model.GlobalLimit) ([]model.GlobalLimit, error) {
- renewedStalePlans, err := t.findMismatchedExpirations(ctx, staleLimits)
- if err != nil {
- return nil, fmt.Errorf("决策[清理]: 检查续费状态失败: %w", err)
- }
- renewedHostIds := make(map[int]struct{}, len(renewedStalePlans))
- for _, req := range renewedStalePlans {
- renewedHostIds[req.HostId] = struct{}{}
- }
- var plansToClean []model.GlobalLimit
- for _, limit := range staleLimits {
- if _, found := renewedHostIds[limit.HostId]; !found {
- plansToClean = append(plansToClean, limit)
- }
- }
- return plansToClean, nil
- }
- // =================================================================
- // ============== 3. 业务执行层 (Executors & Public API) =============
- // =================================================================
- // executePlanRecovery (可重用) 负责恢复套餐的所有步骤。
- func (t *wafTask) executePlanRecovery(ctx context.Context, renewalRequests []RenewalRequest, taskName string,key repository.PlanListType) error {
- t.logger.Info(fmt.Sprintf("开始执行[%s]套餐恢复流程", taskName), zap.Int("数量", len(renewalRequests)))
- var hostIds []int
- for _, req := range renewalRequests {
- hostIds = append(hostIds, req.HostId)
- }
- var allErrors *multierror.Error
- for _, v := range renewalRequests {
- webIds, err := t.GetCdnWebId(ctx, v.HostId)
- if err != nil {
- allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-获取webId失败: %w", taskName, err))
- }
- if err := t.BanServer(ctx, webIds, true); err != nil {
- allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-封禁webId失败: %w", taskName, err))
- }
- }
- if err := t.EditExpired(ctx, renewalRequests); err != nil {
- allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-同步续费信息失败: %w", taskName, err))
- }
- planIdsToRecover := make([]int64, len(hostIds))
- for i, id := range hostIds { planIdsToRecover[i] = int64(id) }
- if err := t.expiredRep.RemovePlans(ctx, key, planIdsToRecover...); err != nil {
- allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-移除Redis关闭标记失败: %w", taskName, err))
- }
- return allErrors.ErrorOrNil()
- }
- // 1. SynchronizationTime 同步即将到期(1天内)的套餐时间
- func (t *wafTask) SynchronizationTime(ctx context.Context) error {
- wafLimits, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, OneDaysInSeconds)
- if err != nil { return fmt.Errorf("执行[同步]-查找失败: %w", err) }
- renewalRequests, err := t.findMismatchedExpirations(ctx, wafLimits)
- if err != nil { return fmt.Errorf("执行[同步]-决策失败: %w", err) }
- if len(renewalRequests) > 0 {
- t.logger.Info("发现记录需要同步到期时间。", zap.Int("数量", len(renewalRequests)))
- return t.EditExpired(ctx, renewalRequests)
- }
- return nil
- }
- // 2. StopPlan (已优化) 停止所有已到期的套餐
- func (t *wafTask) StopPlan(ctx context.Context) error {
- // 1. 查找所有理论上已到期的记录
- expiredLimits, err := t.findAllCurrentlyExpiredPlans(ctx)
- if err != nil { return fmt.Errorf("执行[停止]-查找失败: %w", err) }
- if len(expiredLimits) == 0 { return nil }
- // 2. 决策 - 第1步:检查这些记录中是否已有续费但未同步的
- renewalRequests, err := t.findMismatchedExpirations(ctx, expiredLimits)
- if err != nil { return fmt.Errorf("执行[停止]-决策检查续费失败: %w", err) }
- renewedHostIds := make(map[int]struct{}, len(renewalRequests))
- for _, req := range renewalRequests {
- renewedHostIds[req.HostId] = struct{}{}
- }
- // 2. 决策 - 第2步:筛选出真正需要停止的记录
- var plansToClose []model.GlobalLimit
- for _, limit := range expiredLimits {
- if _, found := renewedHostIds[limit.HostId]; found {
- t.logger.Info("发现已到期但刚续费的套餐,跳过停止操作", zap.Int("hostId", limit.HostId))
- continue
- }
- isClosed, err := t.expiredRep.IsPlanInList(ctx, repository.ClosedPlansList, int64(limit.HostId))
- if err != nil {
- t.logger.Error("决策[停止]:检查套餐是否已关闭失败", zap.Int("hostId", limit.HostId), zap.Error(err))
- continue
- }
- if !isClosed {
- plansToClose = append(plansToClose, limit)
- }
- }
- if len(plansToClose) == 0 {
- t.logger.Info("没有需要停止的套餐(可能均已续费或已关闭)")
- return nil
- }
- // 3. 执行停止操作
- t.logger.Info("开始关闭到期的WAF服务", zap.Int("数量", len(plansToClose)))
- var hostIds []int
- for _, limit := range plansToClose {
- hostIds = append(hostIds, limit.HostId)
- }
- for _, hostId := range hostIds {
- webIds, err := t.GetCdnWebId(ctx, hostId)
- if err != nil { return fmt.Errorf("执行[停止]-获取cdn_web_id失败: %w", err) }
- if err := t.BanServer(ctx, webIds, false); err != nil {
- return fmt.Errorf("执行[停止]-禁用服务失败: %w", err)
- }
- }
- closedPlanIds := make([]int64, len(hostIds))
- for i, id := range hostIds { closedPlanIds[i] = int64(id) }
- if err := t.expiredRep.AddPlans(ctx, repository.ClosedPlansList, closedPlanIds...); err != nil {
- return fmt.Errorf("执行[停止]-标记为已关闭失败: %w", err)
- }
- return nil
- }
- // 3. RecoverRecentPlan 恢复7天内续费的套餐
- func (t *wafTask) RecoverRecentPlan(ctx context.Context) error {
- recentlyExpiredLimits, err := t.findRecentlyExpiredPlans(ctx)
- if err != nil { return fmt.Errorf("执行[近期恢复]-查找失败: %w", err) }
- if len(recentlyExpiredLimits) == 0 { return nil }
- renewalRequests, err := t.findMismatchedExpirations(ctx, recentlyExpiredLimits)
- if err != nil { return fmt.Errorf("执行[近期恢复]-决策失败: %w", err) }
- if len(renewalRequests) == 0 {
- t.logger.Info("在近期过期的套餐中,没有发现已续费的")
- return nil
- }
- return t.executePlanRecovery(ctx, renewalRequests, "近期恢复",repository.ClosedPlansList)
- }
- // 4. CleanUpStaleRecords 清理过期超过7天且仍未续费的记录
- func (t *wafTask) CleanUpStaleRecords(ctx context.Context) error {
- staleLimits, err := t.findStaleExpiredPlans(ctx)
- if err != nil { return fmt.Errorf("执行[清理]-查找失败: %w", err) }
- if len(staleLimits) == 0 { return nil }
- plansToClean, err := t.filterCleanablePlans(ctx, staleLimits)
- if err != nil { return fmt.Errorf("执行[清理]-决策失败: %w", err) }
- if len(plansToClean) == 0 {
- t.logger.Info("没有长期未续费的记录需要清理")
- return nil
- }
- t.logger.Info("开始清理长期未续费的WAF记录", zap.Int("数量", len(plansToClean)))
- var planIdsToClean []int64
- for _, limit := range plansToClean {
- planIdsToClean = append(planIdsToClean, int64(limit.HostId))
- }
- if err := t.expiredRep.RemovePlans(ctx, repository.ClosedPlansList, planIdsToClean...); err != nil {
- return fmt.Errorf("执行[清理]-从Redis移除关闭标记失败: %w", err)
- }
- if err := t.expiredRep.AddPlans(ctx, repository.ExpiringSoonPlansList, planIdsToClean...); err != nil {
- return fmt.Errorf("执行[清理]-从Redis移除过期标记失败: %w", err)
- }
- // 在这里可以添加从数据库删除或调用CDN API彻底删除的逻辑
- for _, limit := range plansToClean {
- err = t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
- HostId: limit.HostId,
- GatewayGroupId: limit.GatewayGroupId,
- State: true,
- })
- if err != nil {
- return fmt.Errorf("执行[清理]-更新套餐状态失败: %w", err)
- }
- tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, limit.HostId)
- if err != nil {
- return err
- }
- udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, limit.HostId)
- if err != nil {
- return err
- }
- webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, limit.HostId)
- if err != nil {
- return err
- }
- err = t.tcp.DeleteTcpForwarding(ctx, v1.DeleteTcpForwardingRequest{
- Ids: tcpIds,
- Uid: 0,
- HostId: limit.HostId,
- })
- if err != nil {
- return err
- }
- err = t.udp.DeleteUdpForwarding(ctx, udpIds)
- if err != nil {
- return err
- }
- err = t.web.DeleteWebForwarding(ctx, webIds)
- if err != nil {
- return err
- }
- }
- return nil
- }
- // 5. RecoverStalePlan 恢复超过7天后才续费的套餐
- func (t *wafTask) RecoverStalePlan(ctx context.Context) error {
- staleLimits, err := t.findStaleExpiredPlans(ctx)
- if err != nil { return fmt.Errorf("执行[长期恢复]-查找失败: %w", err) }
- if len(staleLimits) == 0 { return nil }
- renewalRequests, err := t.findMismatchedExpirations(ctx, staleLimits)
- if err != nil { return fmt.Errorf("执行[长期恢复]-决策失败: %w", err) }
- if len(renewalRequests) == 0 {
- t.logger.Info("在长期过期的套餐中,没有发现已续费的")
- return nil
- }
- return t.executePlanRecovery(ctx, renewalRequests, "长期恢复",repository.ExpiringSoonPlansList)
- }
|