Procházet zdrojové kódy

refactor(internal): 重构 WAF 任务处理逻辑

- 优化了数据查找与决策层函数,提高代码可读性和维护性- 重构了业务执行与公共 API 层,增强了错误处理和日志记录
- 统一了套餐操作流程,提高了代码复用性
- 调整了部分函数命名,使其更加符合实际功能
- 移除了未使用的 GateWayGroupIpRepository 依赖
fusu před 3 týdny
rodič
revize
1873035fce
2 změnil soubory, kde provedl 271 přidání a 22 odebrání
  1. 1 2
      cmd/task/wire/wire_gen.go
  2. 270 20
      internal/task/waf.go

+ 1 - 2
cmd/task/wire/wire_gen.go

@@ -62,7 +62,6 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	cdnService := service.NewCdnService(serviceService, viperViper, requestService, cdnRepository)
 	globalLimitRepository := repository.NewGlobalLimitRepository(repositoryRepository)
 	expiredRepository := repository.NewExpiredRepository(repositoryRepository)
-	gateWayGroupIpRepository := repository.NewGateWayGroupIpRepository(repositoryRepository)
 	gatewayipRepository := repository.NewGatewayipRepository(repositoryRepository)
 	wafFormatterService := service.NewWafFormatterService(serviceService, globalLimitRepository, hostRepository, requiredService, parserService, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, rabbitMQ, hostService, gatewayipRepository, cdnService)
 	proxyRepository := repository.NewProxyRepository(repositoryRepository)
@@ -73,7 +72,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	sslCertService := service.NewSslCertService(serviceService, webForwardingRepository, cdnService)
 	websocketService := service.NewWebsocketService(serviceService, cdnService, webForwardingRepository)
 	webForwardingService := service.NewWebForwardingService(serviceService, requiredService, webForwardingRepository, crawlerService, parserService, wafFormatterService, aoDunService, rabbitMQ, gatewayipRepository, globalLimitRepository, cdnService, proxyService, sslCertService, websocketService)
-	wafTask := task.NewWafTask(webForwardingRepository, tcpforwardingRepository, udpForWardingRepository, cdnService, hostRepository, globalLimitRepository, expiredRepository, taskTask, gateWayGroupIpRepository, tcpforwardingService, udpForWardingService, webForwardingService)
+	wafTask := task.NewWafTask(webForwardingRepository, tcpforwardingRepository, udpForWardingRepository, cdnService, hostRepository, globalLimitRepository, expiredRepository, taskTask, gatewayipRepository, tcpforwardingService, udpForWardingService, webForwardingService)
 	taskServer := server.NewTaskServer(logger, userTask, gameShieldTask, wafTask)
 	jobJob := job.NewJob(transaction, logger, sidSid, rabbitMQ)
 	userJob := job.NewUserJob(jobJob, userRepository)

+ 270 - 20
internal/task/waf.go

@@ -91,34 +91,268 @@ type RenewalRequest struct {
 // =================== 核心辅助函数 (Core Helpers) =================
 // =================================================================
 
-// (wrapTaskError, getCdnWebIdsByHostIds, setCdnWebsitesState, executeRenewalActions 保持不变)
-// ...
-func (t *wafTask) wrapTaskError(taskName, step string, err error) error { /* ... */ return nil }
-func (t *wafTask) getCdnWebIdsByHostIds(ctx context.Context, hostIds []int) ([]int, error) { /* ... */ return nil, nil }
-func (t *wafTask) setCdnWebsitesState(ctx context.Context, ids []int, enable bool) error { /* ... */ return nil }
-func (t *wafTask) executeRenewalActions(ctx context.Context, reqs []RenewalRequest) error { /* ... */ return nil }
+// wrapTaskError 统一封装任务错误信息,方便日志和调试
+func (t *wafTask) wrapTaskError(taskName, step string, err error) error {
+	if err == nil {
+		return nil
+	}
+	return fmt.Errorf("执行[%s]-%s失败: %w", taskName, step, err)
+}
+
+// getCdnWebIdsByHostIds (原GetCdnWebId) 根据hostId列表获取所有关联的转发规则ID
+func (t *wafTask) getCdnWebIdsByHostIds(ctx context.Context, hostIds []int) ([]int, error) {
+	if len(hostIds) == 0 {
+		return nil, nil
+	}
+	var ids []int
+	var result *multierror.Error
+
+	tcpIds, err := t.tcpforwardingRep.GetTcpAll(ctx, hostIds)
+	if err != nil {
+		result = multierror.Append(result, err)
+	}
+	ids = append(ids, tcpIds...)
+
+	udpIds, err := t.udpForWardingRep.GetUdpAll(ctx, hostIds)
+	if err != nil {
+		result = multierror.Append(result, err)
+	}
+	ids = append(ids, udpIds...)
+
+	webIds, err := t.webForWardingRep.GetWebAll(ctx, hostIds)
+	if err != nil {
+		result = multierror.Append(result, err)
+	}
+	ids = append(ids, webIds...)
+
+	return ids, result.ErrorOrNil()
+}
+
+// setCdnWebsitesState (原BanServer) 启用或禁用一组CDN网站 (并发执行)
+func (t *wafTask) setCdnWebsitesState(ctx context.Context, ids []int, enable bool) error {
+	if len(ids) == 0 {
+		return nil
+	}
+	var wg sync.WaitGroup
+	errChan := make(chan error, len(ids))
+	wg.Add(len(ids))
+	for _, id := range ids {
+		go func(id int) {
+			defer wg.Done()
+			// cdn.EditWebIsOn 的第二个参数 isBan, false=启用, true=禁用
+			// 所以 enable=true 对应 isBan=false
+			if err := t.cdn.EditWebIsOn(ctx, int64(id), !enable); err != nil {
+				errChan <- err
+			}
+		}(id)
+	}
+	wg.Wait()
+	close(errChan)
+	var result *multierror.Error
+	for err := range errChan {
+		result = multierror.Append(result, err)
+	}
+	return result.ErrorOrNil()
+}
+
+// executeRenewalActions (原EditExpired) 执行续费操作,包括更新DB和调用CDN API
+func (t *wafTask) executeRenewalActions(ctx context.Context, reqs []RenewalRequest) error {
+	if len(reqs) == 0 {
+		return nil
+	}
+	var allErrors *multierror.Error
+	var wg sync.WaitGroup
+	wg.Add(len(reqs))
 
+	for _, req := range reqs {
+		go func(r RenewalRequest) {
+			defer wg.Done()
+			// 更新数据库状态
+			err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{HostId: r.HostId, ExpiredAt: r.ExpiredAt, State: true})
+			if err != nil {
+				allErrors = multierror.Append(allErrors, err)
+				return // 如果DB更新失败,不继续调用CDN API
+			}
+			// 调用CDN API续费
+			cdnErr := t.cdn.RenewPlan(ctx, v1.RenewalPlan{
+				UserPlanId:  int64(r.PlanId),
+				IsFree:      true,
+				DayTo:       time.Unix(r.ExpiredAt, 0).Format("2006-01-02"),
+				Period:      "monthly",
+				CountPeriod: 1,
+				PeriodDayTo: time.Unix(r.ExpiredAt, 0).Format("2006-01-02"),
+			})
+			if cdnErr != nil {
+				allErrors = multierror.Append(allErrors, cdnErr)
+			}
+		}(req)
+	}
+
+	wg.Wait()
+	return allErrors.ErrorOrNil()
+}
 
 // =================================================================
 // =================== 1. 数据查找与决策层 ==========================
 // =================================================================
 
-// (findPlansNeedingSync, findAllCurrentlyExpiredWAFPlans, findRecentlyExpiredWAFPlans, findStaleWAFPlans 保持不变)
-// ...
-func (t *wafTask) findPlansNeedingSync(ctx context.Context, wafLimits []model.GlobalLimit) ([]RenewalRequest, error) { /* ... */ return nil, nil }
-func (t *wafTask) findAllCurrentlyExpiredWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { /* ... */ return nil, nil }
-func (t *wafTask) findRecentlyExpiredWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { /* ... */ return nil, nil }
-func (t *wafTask) findStaleWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { /* ... */ return nil, nil }
+// findPlansNeedingSync (原findMismatchedExpirations) 检查WAF和Host的到期时间差异,返回需要同步的请求
+func (t *wafTask) findPlansNeedingSync(ctx context.Context, wafLimits []model.GlobalLimit) ([]RenewalRequest, error) {
+	if len(wafLimits) == 0 {
+		return nil, nil
+	}
+	wafExpiredMap := make(map[int]int64, len(wafLimits))
+	wafPlanMap := make(map[int]int, len(wafLimits))
+	var hostIds []int
+	for _, limit := range wafLimits {
+		hostIds = append(hostIds, limit.HostId)
+		wafExpiredMap[limit.HostId] = limit.ExpiredAt
+		wafPlanMap[limit.HostId] = limit.RuleId
+	}
+
+	hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, hostIds)
+	if err != nil {
+		return nil, fmt.Errorf("获取主机到期时间失败: %w", err)
+	}
+	hostExpiredMap := make(map[int]int64, len(hostExpirations))
+	for _, h := range hostExpirations {
+		hostExpiredMap[h.HostId] = h.ExpiredAt
+	}
+
+	var renewalRequests []RenewalRequest
+	for hostId, wafExpiredTime := range wafExpiredMap {
+		hostTime, ok := hostExpiredMap[hostId]
+		if !ok || hostTime != wafExpiredTime {
+			planId, planOk := wafPlanMap[hostId]
+			if !planOk {
+				t.logger.Warn("数据不一致:在waf_limits中找不到hostId对应的套餐ID", zap.Int("hostId", hostId))
+				continue
+			}
+			renewalRequests = append(renewalRequests, RenewalRequest{HostId: hostId, ExpiredAt: hostTime, PlanId: planId})
+		}
+	}
+	return renewalRequests, nil
+}
+
+// findAllCurrentlyExpiredWAFPlans (原findAllCurrentlyExpiredPlans) 查找所有当前时间点已经到期的WAF记录
+func (t *wafTask) findAllCurrentlyExpiredWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) {
+	return t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, 0)
+}
+
+// findRecentlyExpiredWAFPlans (原findRecentlyExpiredPlans) 查找在过去7天内到期的WAF记录
+func (t *wafTask) findRecentlyExpiredWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) {
+	sevenDaysAgo := time.Now().Add(-7 * 24 * time.Hour).Unix()
+	now := time.Now().Unix()
+	return t.globalLimitRep.GetGlobalLimitsByExpirationRange(ctx, sevenDaysAgo, now)
+}
+
+// findStaleWAFPlans (原findStaleExpiredPlans) 查找7天前或更早就已到期的WAF记录
+func (t *wafTask) findStaleWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) {
+	sevenDaysAgoOffset := int64(-1 * SevenDaysInSeconds)
+	return t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, sevenDaysAgoOffset)
+}
 
 // =================================================================
 // ============== 2. 业务执行与公共API层 ===========================
 // =================================================================
 
-// (SynchronizationTime, StopPlan 保持不变)
-// ...
-func (t *wafTask) SynchronizationTime(ctx context.Context) error { /* ... */ return nil }
-func (t *wafTask) StopPlan(ctx context.Context) error { /* ... */ return nil }
+// SynchronizationTime 同步即将到期(1天内)的套餐时间
+func (t *wafTask) SynchronizationTime(ctx context.Context) error {
+	taskName := "同步到期时间"
+	wafLimits, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, OneDaysInSeconds)
+	if err != nil {
+		return t.wrapTaskError(taskName, "查找失败", err)
+	}
+	if len(wafLimits) == 0 {
+		return nil
+	}
+
+	renewalRequests, err := t.findPlansNeedingSync(ctx, wafLimits)
+	if err != nil {
+		return t.wrapTaskError(taskName, "决策失败", err)
+	}
+
+	if len(renewalRequests) > 0 {
+		t.logger.Info("发现记录需要同步到期时间", zap.String("task", taskName), zap.Int("数量", len(renewalRequests)))
+		return t.wrapTaskError(taskName, "执行同步", t.executeRenewalActions(ctx, renewalRequests))
+	}
+	return nil
+}
+
+// StopPlan 停止所有已到期的套餐
+func (t *wafTask) StopPlan(ctx context.Context) error {
+	taskName := "停止到期套餐"
+	// 1. 查找所有理论上已到期的记录
+	expiredLimits, err := t.findAllCurrentlyExpiredWAFPlans(ctx)
+	if err != nil {
+		return t.wrapTaskError(taskName, "查找失败", err)
+	}
+	if len(expiredLimits) == 0 {
+		return nil
+	}
+
+	// 2. 决策 - 第1步:检查这些记录中是否已有续费但未同步的
+	renewalRequests, err := t.findPlansNeedingSync(ctx, expiredLimits)
+	if err != nil {
+		return t.wrapTaskError(taskName, "决策检查续费", err)
+	}
+	renewedHostIds := make(map[int]struct{}, len(renewalRequests))
+	for _, req := range renewalRequests {
+		if req.ExpiredAt > time.Now().Unix() {
+			renewedHostIds[req.HostId] = struct{}{}
+		}
+	}
+
+	// 2. 决策 - 第2步:筛选出真正需要停止的记录
+	var plansToClose []model.GlobalLimit
+	for _, limit := range expiredLimits {
+		if _, found := renewedHostIds[limit.HostId]; found {
+			t.logger.Info("发现已到期但刚续费的套餐,跳过停止操作", zap.String("task", taskName), zap.Int("hostId", limit.HostId))
+			continue
+		}
+		isClosed, err := t.expiredRep.IsPlanInList(ctx, repository.ClosedPlansList, int64(limit.HostId))
+		if err != nil {
+			t.logger.Error("决策[停止]:检查Redis套餐状态失败", zap.String("task", taskName), zap.Int("hostId", limit.HostId), zap.Error(err))
+			continue
+		}
+		if !isClosed {
+			plansToClose = append(plansToClose, limit)
+		}
+	}
+
+	if len(plansToClose) == 0 {
+		t.logger.Info("没有需要停止的套餐(可能均已续费或已在停止列表)", zap.String("task", taskName))
+		return nil
+	}
+
+	// 3. 执行停止操作
+	t.logger.Info("开始关闭到期的WAF服务", zap.String("task", taskName), zap.Int("数量", len(plansToClose)))
+	var hostIds []int
+	for _, limit := range plansToClose {
+		hostIds = append(hostIds, limit.HostId)
+	}
+
+	var allErrors *multierror.Error
+
+	webIds, err := t.getCdnWebIdsByHostIds(ctx, hostIds)
+	if err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("获取cdn_web_id失败: %w", err))
+	} else {
+		if err := t.setCdnWebsitesState(ctx, webIds, false); err != nil { // enable=false
+			allErrors = multierror.Append(allErrors, fmt.Errorf("禁用服务失败: %w", err))
+		}
+	}
+
+	closedPlanIds := make([]int64, len(hostIds))
+	for i, id := range hostIds {
+		closedPlanIds[i] = int64(id)
+	}
+	if err := t.expiredRep.AddPlans(ctx, repository.ClosedPlansList, closedPlanIds...); err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("标记为已关闭失败: %w", err))
+	}
 
+	return t.wrapTaskError(taskName, "执行停止", allErrors.ErrorOrNil())
+}
 
 // _recoverPlans 是一个统一的、可重用的套餐恢复流程
 func (t *wafTask) _recoverPlans(ctx context.Context, limitsToCheck []model.GlobalLimit, taskName string, redisListKey repository.PlanListType) error {
@@ -155,7 +389,7 @@ func (t *wafTask) _recoverPlans(ctx context.Context, limitsToCheck []model.Globa
 	if err != nil {
 		allErrors = multierror.Append(allErrors, fmt.Errorf("获取webId失败: %w", err))
 	} else {
-		if err := t.setCdnWebsitesState(ctx, webIds, true); err != nil {
+		if err := t.setCdnWebsitesState(ctx, webIds, true); err != nil { // enable=true
 			allErrors = multierror.Append(allErrors, fmt.Errorf("启用web服务失败: %w", err))
 		}
 	}
@@ -236,6 +470,7 @@ func (t *wafTask) CleanUpStaleRecords(ctx context.Context) error {
 	now := time.Now().Unix()
 	for _, limit := range uncleanedLimits {
 		hostExpiredTime, ok := hostExpiredMap[limit.HostId]
+		// 清理条件:主机记录不存在,或者主机记录的到期时间是过去时
 		if !ok || hostExpiredTime <= now {
 			plansToClean = append(plansToClean, limit)
 		}
@@ -281,7 +516,6 @@ func (t *wafTask) executeSinglePlanCleanup(ctx context.Context, limit model.Glob
 		allErrors = multierror.Append(allErrors, err)
 	}
 
-
 	// 删除关联的转发规则...
 	tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, limit.HostId)
 	if err != nil {
@@ -291,11 +525,27 @@ func (t *wafTask) executeSinglePlanCleanup(ctx context.Context, limit model.Glob
 			allErrors = multierror.Append(allErrors, err)
 		}
 	}
-	// ... 删除 UDP 和 Web 规则的逻辑保持不变
+
+	udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, limit.HostId)
+	if err != nil {
+		allErrors = multierror.Append(allErrors, err)
+	} else if len(udpIds) > 0 {
+		if err := t.udp.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{Ids: udpIds, HostId: limit.HostId}); err != nil {
+			allErrors = multierror.Append(allErrors, err)
+		}
+	}
+
+	webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, limit.HostId)
+	if err != nil {
+		allErrors = multierror.Append(allErrors, err)
+	} else if len(webIds) > 0 {
+		if err := t.web.DeleteWebForwarding(ctx, v1.DeleteWebForwardingRequest{Ids: webIds, HostId: limit.HostId}); err != nil {
+			allErrors = multierror.Append(allErrors, err)
+		}
+	}
 
 	// 只有在上述所有步骤都没有出错的情况下,才执行最终的数据库更新和Redis标记
 	if allErrors.ErrorOrNil() == nil {
-		// 执行您指定的数据库“重置”操作
 		err := t.gatewayIpRep.CleanIPByHostId(ctx, []int64{hostId})
 		if err != nil {
 			allErrors = multierror.Append(allErrors, err)