|
@@ -91,34 +91,268 @@ type RenewalRequest struct {
|
|
|
// =================== 核心辅助函数 (Core Helpers) =================
|
|
|
// =================================================================
|
|
|
|
|
|
-// (wrapTaskError, getCdnWebIdsByHostIds, setCdnWebsitesState, executeRenewalActions 保持不变)
|
|
|
-// ...
|
|
|
-func (t *wafTask) wrapTaskError(taskName, step string, err error) error { /* ... */ return nil }
|
|
|
-func (t *wafTask) getCdnWebIdsByHostIds(ctx context.Context, hostIds []int) ([]int, error) { /* ... */ return nil, nil }
|
|
|
-func (t *wafTask) setCdnWebsitesState(ctx context.Context, ids []int, enable bool) error { /* ... */ return nil }
|
|
|
-func (t *wafTask) executeRenewalActions(ctx context.Context, reqs []RenewalRequest) error { /* ... */ return nil }
|
|
|
+// wrapTaskError 统一封装任务错误信息,方便日志和调试
|
|
|
+func (t *wafTask) wrapTaskError(taskName, step string, err error) error {
|
|
|
+ if err == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return fmt.Errorf("执行[%s]-%s失败: %w", taskName, step, err)
|
|
|
+}
|
|
|
+
|
|
|
+// getCdnWebIdsByHostIds (原GetCdnWebId) 根据hostId列表获取所有关联的转发规则ID
|
|
|
+func (t *wafTask) getCdnWebIdsByHostIds(ctx context.Context, hostIds []int) ([]int, error) {
|
|
|
+ if len(hostIds) == 0 {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+ var ids []int
|
|
|
+ var result *multierror.Error
|
|
|
+
|
|
|
+ tcpIds, err := t.tcpforwardingRep.GetTcpAll(ctx, hostIds)
|
|
|
+ if err != nil {
|
|
|
+ result = multierror.Append(result, err)
|
|
|
+ }
|
|
|
+ ids = append(ids, tcpIds...)
|
|
|
+
|
|
|
+ udpIds, err := t.udpForWardingRep.GetUdpAll(ctx, hostIds)
|
|
|
+ if err != nil {
|
|
|
+ result = multierror.Append(result, err)
|
|
|
+ }
|
|
|
+ ids = append(ids, udpIds...)
|
|
|
+
|
|
|
+ webIds, err := t.webForWardingRep.GetWebAll(ctx, hostIds)
|
|
|
+ if err != nil {
|
|
|
+ result = multierror.Append(result, err)
|
|
|
+ }
|
|
|
+ ids = append(ids, webIds...)
|
|
|
+
|
|
|
+ return ids, result.ErrorOrNil()
|
|
|
+}
|
|
|
+
|
|
|
+// setCdnWebsitesState (原BanServer) 启用或禁用一组CDN网站 (并发执行)
|
|
|
+func (t *wafTask) setCdnWebsitesState(ctx context.Context, ids []int, enable bool) error {
|
|
|
+ if len(ids) == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ errChan := make(chan error, len(ids))
|
|
|
+ wg.Add(len(ids))
|
|
|
+ for _, id := range ids {
|
|
|
+ go func(id int) {
|
|
|
+ defer wg.Done()
|
|
|
+ // cdn.EditWebIsOn 的第二个参数 isBan, false=启用, true=禁用
|
|
|
+ // 所以 enable=true 对应 isBan=false
|
|
|
+ if err := t.cdn.EditWebIsOn(ctx, int64(id), !enable); err != nil {
|
|
|
+ errChan <- err
|
|
|
+ }
|
|
|
+ }(id)
|
|
|
+ }
|
|
|
+ wg.Wait()
|
|
|
+ close(errChan)
|
|
|
+ var result *multierror.Error
|
|
|
+ for err := range errChan {
|
|
|
+ result = multierror.Append(result, err)
|
|
|
+ }
|
|
|
+ return result.ErrorOrNil()
|
|
|
+}
|
|
|
+
|
|
|
+// executeRenewalActions (原EditExpired) 执行续费操作,包括更新DB和调用CDN API
|
|
|
+func (t *wafTask) executeRenewalActions(ctx context.Context, reqs []RenewalRequest) error {
|
|
|
+ if len(reqs) == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ var allErrors *multierror.Error
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ wg.Add(len(reqs))
|
|
|
|
|
|
+ for _, req := range reqs {
|
|
|
+ go func(r RenewalRequest) {
|
|
|
+ defer wg.Done()
|
|
|
+ // 更新数据库状态
|
|
|
+ err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{HostId: r.HostId, ExpiredAt: r.ExpiredAt, State: true})
|
|
|
+ if err != nil {
|
|
|
+ allErrors = multierror.Append(allErrors, err)
|
|
|
+ return // 如果DB更新失败,不继续调用CDN API
|
|
|
+ }
|
|
|
+ // 调用CDN API续费
|
|
|
+ cdnErr := t.cdn.RenewPlan(ctx, v1.RenewalPlan{
|
|
|
+ UserPlanId: int64(r.PlanId),
|
|
|
+ IsFree: true,
|
|
|
+ DayTo: time.Unix(r.ExpiredAt, 0).Format("2006-01-02"),
|
|
|
+ Period: "monthly",
|
|
|
+ CountPeriod: 1,
|
|
|
+ PeriodDayTo: time.Unix(r.ExpiredAt, 0).Format("2006-01-02"),
|
|
|
+ })
|
|
|
+ if cdnErr != nil {
|
|
|
+ allErrors = multierror.Append(allErrors, cdnErr)
|
|
|
+ }
|
|
|
+ }(req)
|
|
|
+ }
|
|
|
+
|
|
|
+ wg.Wait()
|
|
|
+ return allErrors.ErrorOrNil()
|
|
|
+}
|
|
|
|
|
|
// =================================================================
|
|
|
// =================== 1. 数据查找与决策层 ==========================
|
|
|
// =================================================================
|
|
|
|
|
|
-// (findPlansNeedingSync, findAllCurrentlyExpiredWAFPlans, findRecentlyExpiredWAFPlans, findStaleWAFPlans 保持不变)
|
|
|
-// ...
|
|
|
-func (t *wafTask) findPlansNeedingSync(ctx context.Context, wafLimits []model.GlobalLimit) ([]RenewalRequest, error) { /* ... */ return nil, nil }
|
|
|
-func (t *wafTask) findAllCurrentlyExpiredWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { /* ... */ return nil, nil }
|
|
|
-func (t *wafTask) findRecentlyExpiredWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { /* ... */ return nil, nil }
|
|
|
-func (t *wafTask) findStaleWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { /* ... */ return nil, nil }
|
|
|
+// findPlansNeedingSync (原findMismatchedExpirations) 检查WAF和Host的到期时间差异,返回需要同步的请求
|
|
|
+func (t *wafTask) findPlansNeedingSync(ctx context.Context, wafLimits []model.GlobalLimit) ([]RenewalRequest, error) {
|
|
|
+ if len(wafLimits) == 0 {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+ wafExpiredMap := make(map[int]int64, len(wafLimits))
|
|
|
+ wafPlanMap := make(map[int]int, len(wafLimits))
|
|
|
+ var hostIds []int
|
|
|
+ for _, limit := range wafLimits {
|
|
|
+ hostIds = append(hostIds, limit.HostId)
|
|
|
+ wafExpiredMap[limit.HostId] = limit.ExpiredAt
|
|
|
+ wafPlanMap[limit.HostId] = limit.RuleId
|
|
|
+ }
|
|
|
+
|
|
|
+ hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, hostIds)
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("获取主机到期时间失败: %w", err)
|
|
|
+ }
|
|
|
+ hostExpiredMap := make(map[int]int64, len(hostExpirations))
|
|
|
+ for _, h := range hostExpirations {
|
|
|
+ hostExpiredMap[h.HostId] = h.ExpiredAt
|
|
|
+ }
|
|
|
+
|
|
|
+ var renewalRequests []RenewalRequest
|
|
|
+ for hostId, wafExpiredTime := range wafExpiredMap {
|
|
|
+ hostTime, ok := hostExpiredMap[hostId]
|
|
|
+ if !ok || hostTime != wafExpiredTime {
|
|
|
+ planId, planOk := wafPlanMap[hostId]
|
|
|
+ if !planOk {
|
|
|
+ t.logger.Warn("数据不一致:在waf_limits中找不到hostId对应的套餐ID", zap.Int("hostId", hostId))
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ renewalRequests = append(renewalRequests, RenewalRequest{HostId: hostId, ExpiredAt: hostTime, PlanId: planId})
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return renewalRequests, nil
|
|
|
+}
|
|
|
+
|
|
|
+// findAllCurrentlyExpiredWAFPlans (原findAllCurrentlyExpiredPlans) 查找所有当前时间点已经到期的WAF记录
|
|
|
+func (t *wafTask) findAllCurrentlyExpiredWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) {
|
|
|
+ return t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, 0)
|
|
|
+}
|
|
|
+
|
|
|
+// findRecentlyExpiredWAFPlans (原findRecentlyExpiredPlans) 查找在过去7天内到期的WAF记录
|
|
|
+func (t *wafTask) findRecentlyExpiredWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) {
|
|
|
+ sevenDaysAgo := time.Now().Add(-7 * 24 * time.Hour).Unix()
|
|
|
+ now := time.Now().Unix()
|
|
|
+ return t.globalLimitRep.GetGlobalLimitsByExpirationRange(ctx, sevenDaysAgo, now)
|
|
|
+}
|
|
|
+
|
|
|
+// findStaleWAFPlans (原findStaleExpiredPlans) 查找7天前或更早就已到期的WAF记录
|
|
|
+func (t *wafTask) findStaleWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) {
|
|
|
+ sevenDaysAgoOffset := int64(-1 * SevenDaysInSeconds)
|
|
|
+ return t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, sevenDaysAgoOffset)
|
|
|
+}
|
|
|
|
|
|
// =================================================================
|
|
|
// ============== 2. 业务执行与公共API层 ===========================
|
|
|
// =================================================================
|
|
|
|
|
|
-// (SynchronizationTime, StopPlan 保持不变)
|
|
|
-// ...
|
|
|
-func (t *wafTask) SynchronizationTime(ctx context.Context) error { /* ... */ return nil }
|
|
|
-func (t *wafTask) StopPlan(ctx context.Context) error { /* ... */ return nil }
|
|
|
+// SynchronizationTime 同步即将到期(1天内)的套餐时间
|
|
|
+func (t *wafTask) SynchronizationTime(ctx context.Context) error {
|
|
|
+ taskName := "同步到期时间"
|
|
|
+ wafLimits, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, OneDaysInSeconds)
|
|
|
+ if err != nil {
|
|
|
+ return t.wrapTaskError(taskName, "查找失败", err)
|
|
|
+ }
|
|
|
+ if len(wafLimits) == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ renewalRequests, err := t.findPlansNeedingSync(ctx, wafLimits)
|
|
|
+ if err != nil {
|
|
|
+ return t.wrapTaskError(taskName, "决策失败", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(renewalRequests) > 0 {
|
|
|
+ t.logger.Info("发现记录需要同步到期时间", zap.String("task", taskName), zap.Int("数量", len(renewalRequests)))
|
|
|
+ return t.wrapTaskError(taskName, "执行同步", t.executeRenewalActions(ctx, renewalRequests))
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// StopPlan 停止所有已到期的套餐
|
|
|
+func (t *wafTask) StopPlan(ctx context.Context) error {
|
|
|
+ taskName := "停止到期套餐"
|
|
|
+ // 1. 查找所有理论上已到期的记录
|
|
|
+ expiredLimits, err := t.findAllCurrentlyExpiredWAFPlans(ctx)
|
|
|
+ if err != nil {
|
|
|
+ return t.wrapTaskError(taskName, "查找失败", err)
|
|
|
+ }
|
|
|
+ if len(expiredLimits) == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // 2. 决策 - 第1步:检查这些记录中是否已有续费但未同步的
|
|
|
+ renewalRequests, err := t.findPlansNeedingSync(ctx, expiredLimits)
|
|
|
+ if err != nil {
|
|
|
+ return t.wrapTaskError(taskName, "决策检查续费", err)
|
|
|
+ }
|
|
|
+ renewedHostIds := make(map[int]struct{}, len(renewalRequests))
|
|
|
+ for _, req := range renewalRequests {
|
|
|
+ if req.ExpiredAt > time.Now().Unix() {
|
|
|
+ renewedHostIds[req.HostId] = struct{}{}
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 2. 决策 - 第2步:筛选出真正需要停止的记录
|
|
|
+ var plansToClose []model.GlobalLimit
|
|
|
+ for _, limit := range expiredLimits {
|
|
|
+ if _, found := renewedHostIds[limit.HostId]; found {
|
|
|
+ t.logger.Info("发现已到期但刚续费的套餐,跳过停止操作", zap.String("task", taskName), zap.Int("hostId", limit.HostId))
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ isClosed, err := t.expiredRep.IsPlanInList(ctx, repository.ClosedPlansList, int64(limit.HostId))
|
|
|
+ if err != nil {
|
|
|
+ t.logger.Error("决策[停止]:检查Redis套餐状态失败", zap.String("task", taskName), zap.Int("hostId", limit.HostId), zap.Error(err))
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if !isClosed {
|
|
|
+ plansToClose = append(plansToClose, limit)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(plansToClose) == 0 {
|
|
|
+ t.logger.Info("没有需要停止的套餐(可能均已续费或已在停止列表)", zap.String("task", taskName))
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // 3. 执行停止操作
|
|
|
+ t.logger.Info("开始关闭到期的WAF服务", zap.String("task", taskName), zap.Int("数量", len(plansToClose)))
|
|
|
+ var hostIds []int
|
|
|
+ for _, limit := range plansToClose {
|
|
|
+ hostIds = append(hostIds, limit.HostId)
|
|
|
+ }
|
|
|
+
|
|
|
+ var allErrors *multierror.Error
|
|
|
+
|
|
|
+ webIds, err := t.getCdnWebIdsByHostIds(ctx, hostIds)
|
|
|
+ if err != nil {
|
|
|
+ allErrors = multierror.Append(allErrors, fmt.Errorf("获取cdn_web_id失败: %w", err))
|
|
|
+ } else {
|
|
|
+ if err := t.setCdnWebsitesState(ctx, webIds, false); err != nil { // enable=false
|
|
|
+ allErrors = multierror.Append(allErrors, fmt.Errorf("禁用服务失败: %w", err))
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ closedPlanIds := make([]int64, len(hostIds))
|
|
|
+ for i, id := range hostIds {
|
|
|
+ closedPlanIds[i] = int64(id)
|
|
|
+ }
|
|
|
+ if err := t.expiredRep.AddPlans(ctx, repository.ClosedPlansList, closedPlanIds...); err != nil {
|
|
|
+ allErrors = multierror.Append(allErrors, fmt.Errorf("标记为已关闭失败: %w", err))
|
|
|
+ }
|
|
|
|
|
|
+ return t.wrapTaskError(taskName, "执行停止", allErrors.ErrorOrNil())
|
|
|
+}
|
|
|
|
|
|
// _recoverPlans 是一个统一的、可重用的套餐恢复流程
|
|
|
func (t *wafTask) _recoverPlans(ctx context.Context, limitsToCheck []model.GlobalLimit, taskName string, redisListKey repository.PlanListType) error {
|
|
@@ -155,7 +389,7 @@ func (t *wafTask) _recoverPlans(ctx context.Context, limitsToCheck []model.Globa
|
|
|
if err != nil {
|
|
|
allErrors = multierror.Append(allErrors, fmt.Errorf("获取webId失败: %w", err))
|
|
|
} else {
|
|
|
- if err := t.setCdnWebsitesState(ctx, webIds, true); err != nil {
|
|
|
+ if err := t.setCdnWebsitesState(ctx, webIds, true); err != nil { // enable=true
|
|
|
allErrors = multierror.Append(allErrors, fmt.Errorf("启用web服务失败: %w", err))
|
|
|
}
|
|
|
}
|
|
@@ -236,6 +470,7 @@ func (t *wafTask) CleanUpStaleRecords(ctx context.Context) error {
|
|
|
now := time.Now().Unix()
|
|
|
for _, limit := range uncleanedLimits {
|
|
|
hostExpiredTime, ok := hostExpiredMap[limit.HostId]
|
|
|
+ // 清理条件:主机记录不存在,或者主机记录的到期时间是过去时
|
|
|
if !ok || hostExpiredTime <= now {
|
|
|
plansToClean = append(plansToClean, limit)
|
|
|
}
|
|
@@ -281,7 +516,6 @@ func (t *wafTask) executeSinglePlanCleanup(ctx context.Context, limit model.Glob
|
|
|
allErrors = multierror.Append(allErrors, err)
|
|
|
}
|
|
|
|
|
|
-
|
|
|
// 删除关联的转发规则...
|
|
|
tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, limit.HostId)
|
|
|
if err != nil {
|
|
@@ -291,11 +525,27 @@ func (t *wafTask) executeSinglePlanCleanup(ctx context.Context, limit model.Glob
|
|
|
allErrors = multierror.Append(allErrors, err)
|
|
|
}
|
|
|
}
|
|
|
- // ... 删除 UDP 和 Web 规则的逻辑保持不变
|
|
|
+
|
|
|
+ udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, limit.HostId)
|
|
|
+ if err != nil {
|
|
|
+ allErrors = multierror.Append(allErrors, err)
|
|
|
+ } else if len(udpIds) > 0 {
|
|
|
+ if err := t.udp.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{Ids: udpIds, HostId: limit.HostId}); err != nil {
|
|
|
+ allErrors = multierror.Append(allErrors, err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, limit.HostId)
|
|
|
+ if err != nil {
|
|
|
+ allErrors = multierror.Append(allErrors, err)
|
|
|
+ } else if len(webIds) > 0 {
|
|
|
+ if err := t.web.DeleteWebForwarding(ctx, v1.DeleteWebForwardingRequest{Ids: webIds, HostId: limit.HostId}); err != nil {
|
|
|
+ allErrors = multierror.Append(allErrors, err)
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
// 只有在上述所有步骤都没有出错的情况下,才执行最终的数据库更新和Redis标记
|
|
|
if allErrors.ErrorOrNil() == nil {
|
|
|
- // 执行您指定的数据库“重置”操作
|
|
|
err := t.gatewayIpRep.CleanIPByHostId(ctx, []int64{hostId})
|
|
|
if err != nil {
|
|
|
allErrors = multierror.Append(allErrors, err)
|