Parcourir la source

refactor(internal/task): 重构 WAF 任务处理逻辑

- 修改 GetCdnWebId 函数参数类型,支持批量处理
- 更新相关函数调用,以适应新的参数类型
- 优化错误处理,减少重复代码
-
fusu il y a 3 semaines
Parent
commit
c8e03938b8
1 fichiers modifiés avec 15 ajouts et 18 suppressions
  1. 15 18
      internal/task/waf.go

+ 15 - 18
internal/task/waf.go

@@ -91,16 +91,16 @@ type RenewalRequest struct {
 // =================================================================
 
 // 获取cdn web id
-func (t wafTask) GetCdnWebId(ctx context.Context,hostId int) ([]int, error) {
-	tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, hostId)
+func (t wafTask) GetCdnWebId(ctx context.Context,hostId []int) ([]int, error) {
+	tcpIds, err := t.tcpforwardingRep.GetTcpAll(ctx, hostId)
 	if err != nil {
 		return nil, err
 	}
-	udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, hostId)
+	udpIds, err := t.udpForWardingRep.GetUdpAll(ctx, hostId)
 	if err != nil {
 		return nil, err
 	}
-	webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, hostId)
+	webIds, err := t.webForWardingRep.GetWebAll(ctx, hostId)
 	if err != nil {
 		return nil, err
 	}
@@ -267,14 +267,13 @@ func (t *wafTask) executePlanRecovery(ctx context.Context, renewalRequests []Ren
 
 	var allErrors *multierror.Error
 
-	for _, v := range renewalRequests {
-		webIds, err := t.GetCdnWebId(ctx, v.HostId)
-		if err != nil {
-			allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-获取webId失败: %w", taskName, err))
-		}
-		if err := t.BanServer(ctx, webIds, true); err != nil {
-			allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-封禁webId失败: %w", taskName, err))
-		}
+	webIds, err := t.GetCdnWebId(ctx, hostIds)
+	if err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-获取webId失败: %w", taskName, err))
+	}
+
+	if err := t.BanServer(ctx, webIds, true); err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-封禁webId失败: %w", taskName, err))
 	}
 
 
@@ -349,12 +348,10 @@ func (t *wafTask) StopPlan(ctx context.Context) error {
 		hostIds = append(hostIds, limit.HostId)
 	}
 
-	for _, hostId := range hostIds {
-		webIds, err := t.GetCdnWebId(ctx, hostId)
-		if err != nil { return fmt.Errorf("执行[停止]-获取cdn_web_id失败: %w", err) }
-		if err := t.BanServer(ctx, webIds, false); err != nil {
-			return fmt.Errorf("执行[停止]-禁用服务失败: %w", err)
-		}
+	webIds, err := t.GetCdnWebId(ctx, hostIds)
+	if err != nil { return fmt.Errorf("执行[停止]-获取cdn_web_id失败: %w", err) }
+	if err := t.BanServer(ctx, webIds, false); err != nil {
+		return fmt.Errorf("执行[停止]-禁用服务失败: %w", err)
 	}