package task import ( "context" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" "github.com/go-nunu/nunu-layout-advanced/internal/service" "github.com/hashicorp/go-multierror" "go.uber.org/zap" "sync" "time" ) type WafTask interface { //获取到期时间小于3天的同步时间 SynchronizationTime(ctx context.Context) error } func NewWafTask ( webForWardingRep repository.WebForwardingRepository, tcpforwardingRep repository.TcpforwardingRepository, udpForWardingRep repository.UdpForWardingRepository, cdn service.CdnService, hostRep repository.HostRepository, globalLimitRep repository.GlobalLimitRepository, task *Task, ) WafTask{ return &wafTask{ Task: task, webForWardingRep: webForWardingRep, tcpforwardingRep: tcpforwardingRep, udpForWardingRep: udpForWardingRep, cdn: cdn, hostRep: hostRep, globalLimitRep: globalLimitRep, } } type wafTask struct { *Task webForWardingRep repository.WebForwardingRepository tcpforwardingRep repository.TcpforwardingRepository udpForWardingRep repository.UdpForWardingRepository cdn service.CdnService hostRep repository.HostRepository globalLimitRep repository.GlobalLimitRepository } const ( // 3天后秒数 OneDaysInSeconds = 3 * 24 * 60 * 60 // 7天前秒数 SevenDaysInSeconds = 7 * 24 * 60 * 60 * -1 ) // 获取cdn web id func (t wafTask) GetCdnWebId(ctx context.Context,hostId int) ([]int, error) { tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, hostId) if err != nil { return nil, err } udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, hostId) if err != nil { return nil, err } webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, hostId) if err != nil { return nil, err } var ids []int ids = append(ids, tcpIds...) ids = append(ids, udpIds...) ids = append(ids, webIds...) return ids, nil } // 启用/禁用 网站 func (t wafTask) BanServer(ctx context.Context, ids []int, isBan bool) error { var wg sync.WaitGroup errChan := make(chan error, len(ids)) // 修正1:为每个 goroutine 增加 WaitGroup 的计数 wg.Add(len(ids)) for _, id := range ids { go func(id int) { // 修正2:确保每个 goroutine 在退出时都调用 Done() defer wg.Done() err := t.cdn.EditWebIsOn(ctx, int64(id), isBan) if err != nil { errChan <- err // 这里不需要 return,因为 defer wg.Done() 会在函数退出时执行 } }(id) } // 现在 wg.Wait() 会正确地阻塞,直到所有 goroutine 都调用了 Done() wg.Wait() // 在所有 goroutine 都结束后,安全地关闭 channel close(errChan) var result error for err := range errChan { result = multierror.Append(result, err) // 将多个 error 对象合并成一个单一的 error 对象 } // 修正3:返回收集到的错误,而不是 nil return result } // 获取指定到期时间 func (t wafTask) GetAlmostExpiring(ctx context.Context,hostIds []int,addTime int64) ([]v1.GetAlmostExpireHostResponse,error) { // 3 天 res, err := t.hostRep.GetAlmostExpired(ctx, hostIds, addTime) if err != nil { return nil,err } return res, nil } // 获取waf全局到期时间 func (t wafTask) GetGlobalAlmostExpiring(ctx context.Context,addTime int64) ([]model.GlobalLimit,error) { res, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, addTime) if err != nil { return nil, err } return res, nil } // 修改全局续费 func (t wafTask) EditGlobalExpired(ctx context.Context, req []struct{ hostId int expiredAt int64 }, state bool) error { var result *multierror.Error // 使用 multierror for _, v := range req { err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{ HostId: v.hostId, ExpiredAt: v.expiredAt, State: state, }) if err != nil { // 收集错误,而不是直接返回 result = multierror.Append(result, err) } } // 返回所有收集到的错误 return result.ErrorOrNil() } // 续费套餐 func (t wafTask) EnablePlan(ctx context.Context, req []struct{ planId int expiredAt int64 }) error { var result *multierror.Error for _, v := range req { err := t.cdn.RenewPlan(ctx, v1.RenewalPlan{ UserPlanId: int64(v.planId), IsFree: true, DayTo: time.Unix(v.expiredAt, 0).Format("2006-01-02"), Period: "monthly", CountPeriod: 1, PeriodDayTo: time.Unix(v.expiredAt, 0).Format("2006-01-02"), }) if err != nil { result = multierror.Append(result, err) } } return result.ErrorOrNil() } // 续费操作 type RenewalRequest struct { HostId int PlanId int ExpiredAt int64 } // 续费操作 func (t wafTask) EditExpired(ctx context.Context, reqs []RenewalRequest) error { // 如果请求为空,直接返回 if len(reqs) == 0 { return nil } // 1. 准备用于更新 GlobalLimit 的数据 var globalLimitUpdates []struct { hostId int expiredAt int64 } for _, req := range reqs { globalLimitUpdates = append(globalLimitUpdates, struct { hostId int expiredAt int64 }{req.HostId, req.ExpiredAt}) } // 2. 准备用于续费套餐的数据 var planRenewals []struct { planId int expiredAt int64 } for _, req := range reqs { planRenewals = append(planRenewals, struct { planId int expiredAt int64 }{req.PlanId, req.ExpiredAt}) } var result *multierror.Error // 3. 执行更新,并收集错误 if err := t.EditGlobalExpired(ctx, globalLimitUpdates, true); err != nil { result = multierror.Append(result, err) } if err := t.EnablePlan(ctx, planRenewals); err != nil { result = multierror.Append(result, err) } return result.ErrorOrNil() } // findMismatchedExpirations 检查 WAF 和 Host 的到期时间差异,并返回需要同步的请求。 func (t *wafTask) findMismatchedExpirations(ctx context.Context, wafLimits []model.GlobalLimit) ([]RenewalRequest, error) { if len(wafLimits) == 0 { return nil, nil } // 2. 将 WAF 数据组织成 Map wafExpiredMap := make(map[int]int64, len(wafLimits)) wafPlanMap := make(map[int]int, len(wafLimits)) var hostIds []int for _, limit := range wafLimits { hostIds = append(hostIds, limit.HostId) wafExpiredMap[limit.HostId] = limit.ExpiredAt wafPlanMap[limit.HostId] = limit.RuleId } // 3. 获取对应 Host 的到期时间 hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, hostIds) if err != nil { return nil, fmt.Errorf("获取主机到期时间失败: %w", err) } hostExpiredMap := make(map[int]int64, len(hostExpirations)) for _, h := range hostExpirations { hostExpiredMap[h.HostId] = h.ExpiredAt } // 4. 找出时间不一致的记录 var renewalRequests []RenewalRequest for hostId, wafExpiredTime := range wafExpiredMap { hostTime, ok := hostExpiredMap[hostId] // 如果 Host 时间与 WAF 时间不一致,则需要同步 if !ok || hostTime != wafExpiredTime { planId, planOk := wafPlanMap[hostId] if !planOk { t.logger.Warn("数据不一致:在waf_limits中找不到hostId对应的套餐ID", zap.Int("hostId", hostId)) continue } renewalRequests = append(renewalRequests, RenewalRequest{ HostId: hostId, ExpiredAt: hostTime, // 以 WAF 表的时间为准 PlanId: planId, }) } } return renewalRequests, nil } //获取到期时间小于3天的同步时间 func (t *wafTask) SynchronizationTime(ctx context.Context) error { // 1. 获取 WAF 全局配置中即将到期(小于3天)的数据 wafLimits, err := t.GetGlobalAlmostExpiring(ctx, OneDaysInSeconds) if err != nil { return fmt.Errorf("获取全局到期配置失败: %w", err) } // 2. 找出需要同步的数据 renewalRequests, err := t.findMismatchedExpirations(ctx, wafLimits) if err != nil { return err // 错误已在辅助函数中包装 } // 3. 如果有需要同步的数据,执行续费操作 if len(renewalRequests) > 0 { t.logger.Info("发现记录需要同步到期时间。", zap.Int("数量", len(renewalRequests))) return t.EditExpired(ctx, renewalRequests) } return nil } //获取到期的进行关闭套餐操作 // 获取到期的进行关闭套餐操作 func (t *wafTask) StopPlan(ctx context.Context) error { // 1. 获取 WAF 全局配置中已经到期的数据 // 使用 time.Now().Unix() 表示获取所有 expired_at <= 当前时间的记录 wafLimits, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, time.Now().Unix()) if err != nil { return fmt.Errorf("获取全局到期配置失败: %w", err) } if len(wafLimits) == 0 { return nil // 没有到期的,任务完成 } // 2. (可选,但推荐)先同步任何时间不一致的数据,确保状态准确 renewalRequests, err := t.findMismatchedExpirations(ctx, wafLimits) if err != nil { t.logger.Error("在关闭套餐前,同步时间失败", zap.Error(err)) // 根据业务决定是否要继续,这里我们选择继续,但记录错误 } if len(renewalRequests) > 0 { t.logger.Info("关闭套餐前,发现并同步不一致的时间记录", zap.Int("数量", len(renewalRequests))) if err := t.EditExpired(ctx, renewalRequests); err != nil { t.logger.Error("同步不一致的时间记录失败", zap.Error(err)) } } // 3. 关闭所有已经到期的套餐 t.logger.Info("开始关闭已到期的WAF服务", zap.Int("数量", len(wafLimits))) var allErrors *multierror.Error for _, limit := range wafLimits { webIds, err := t.GetCdnWebId(ctx, limit.HostId) if err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("获取hostId %d 的webId失败: %w", limit.HostId, err)) continue // 继续处理下一个 } if err := t.BanServer(ctx, webIds, false); err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("关闭hostId %d 的服务失败: %w", limit.HostId, err)) } } return allErrors.ErrorOrNil() } //对于到期7天内续费的产品需要进行恢复操作 // RecoverStopPlan 对于到期7天内续费的产品进行恢复操作 func (t *wafTask) RecoverStopPlan(ctx context.Context) error { // 1. 查找在过去7天内到期,并且当前状态为“已关闭”的 WAF 记录 // 这可能需要一个新的 repository 方法,例如: GetRecentlyClosedLimits // 我们先假设有这样一个方法,它返回 state=false 且 expired_at 在 (now-7天, now] 之间的记录 since := time.Now().Add(-7 * 24 * time.Hour).Unix() // 假设你有一个方法 `GetClosedLimitsSince(ctx, sinceTime)` // closedLimits, err := t.globalLimitRep.GetClosedLimitsSince(ctx, since) // 为简化,我们先获取所有7天内到期的,再在逻辑里判断 // 简单的实现:获取7天内到期的所有记录 wafLimits, err := t.globalLimitRep.GetLimitsExpiredSince(ctx, since) // 假设有这个方法 if err != nil { return fmt.Errorf("获取近期到期配置失败: %w", err) } if len(wafLimits) == 0 { return nil } // 提取 hostIds 并过滤出已关闭的记录 var hostIds []int closedLimitsMap := make(map[int]model.GlobalLimit) for _, limit := range wafLimits { if !limit.State { // 只处理状态为“已关闭”的 hostIds = append(hostIds, limit.HostId) closedLimitsMap[limit.HostId] = limit } } if len(hostIds) == 0 { return nil // 没有已关闭的记录需要检查 } // 2. 获取这些 host 的当前到期时间 hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, hostIds) if err != nil { return fmt.Errorf("获取主机当前到期时间失败: %w", err) } hostExpiredMap := make(map[int]int64) for _, h := range hostExpirations { hostExpiredMap[h.HostId] = h.ExpiredAt } var allErrors *multierror.Error // 3. 比较时间,找出已续费的 host,并恢复服务 for hostId, closedLimit := range closedLimitsMap { currentHostExpiry, ok := hostExpiredMap[hostId] if !ok { continue // host 不存在了,跳过 } // 如果 host 表的到期时间 > global_limit 表的到期时间,说明已续费 if currentHostExpiry > closedLimit.ExpiredAt { t.logger.Info("发现已续费并关闭的WAF服务,准备恢复", zap.Int("hostId", hostId)) // 3a. 恢复网站服务 webIds, err := t.GetCdnWebId(ctx, hostId) if err != nil { allErrors = multierror.Append(allErrors, fmt.Errorf("恢复hostId %d 时获取webId失败: %w", hostId, err)) continue } if err := t.BanServer(ctx, webIds, true); err != nil { // true 表示启用 allErrors = multierror.Append(allErrors, fmt.Errorf("恢复hostId %d 服务失败: %w", hostId, err)) continue } // 3b. 更新 global_limit 表的时间和状态 var singleUpdate []struct{hostId int; expiredAt int64} singleUpdate = append(singleUpdate, struct{hostId int; expiredAt int64}{hostId: hostId, expiredAt: currentHostExpiry}) if err := t.EditGlobalExpired(ctx, singleUpdate, true); err != nil { // true 表示启用 allErrors = multierror.Append(allErrors, fmt.Errorf("更新hostId %d 状态为已恢复失败: %w", hostId, err)) } } } return allErrors.ErrorOrNil() } //对于大于7天的药进行数据情侣操作