Browse Source

refactor(internal): 重构 WAF 任务处理逻辑

- 新增 GetTcpAll、GetUdpAll 和 GetWebAll 方法,用于获取所有数据
- 修改 GetCdnWebId 方法,支持批量获取 web ID
- 优化 WAF任务处理流程,减少循环中的数据库查询
- 调整 OneDaysInSeconds 常量,改为 1 天后秒数
fusu 3 weeks ago
parent
commit
02c6443598

+ 12 - 1
internal/repository/tcpforwarding.go

@@ -28,6 +28,8 @@ type TcpforwardingRepository interface {
 	GetIpCountByIp(ctx context.Context,ips []string) ([]v1.IpCountResult, error)
 	// 获取端口数量
 	GetPortCount(ctx context.Context,hostId int64, port string) (int64, error)
+	// 获取所有数据
+	GetTcpAll(ctx context.Context, hostIds []int) ([]int, error)
 }
 
 func NewTcpforwardingRepository(
@@ -233,4 +235,13 @@ func (r *tcpforwardingRepository) GetPortCount(ctx context.Context,hostId int64,
 		return 0, err
 	}
 	return count, nil
-}
+}
+
+
+func (r *tcpforwardingRepository) GetTcpAll(ctx context.Context, hostIds []int) ([]int, error) {
+	var res []int
+	if err := r.db.WithContext(ctx).Model(&model.Tcpforwarding{}).Where("host_id IN ?", hostIds).Select("cdn_web_id").Scan(&res).Error; err != nil {
+		return nil, err
+	}
+	return res, nil
+}

+ 10 - 1
internal/repository/udpforwarding.go

@@ -28,6 +28,7 @@ type UdpForWardingRepository interface {
 	GetIpCountByIp(ctx context.Context, ips []string) ([]v1.IpCountResult, error)
 	// 获取端口数量
 	GetPortCount(ctx context.Context, hostId int64, port string) (int64, error)
+	GetUdpAll(ctx context.Context, hostIds []int) ([]int, error)
 }
 
 func NewUdpForWardingRepository(
@@ -228,4 +229,12 @@ func (r *udpForWardingRepository) GetPortCount(ctx context.Context, hostId int64
 		return 0, err
 	}
 	return count, nil
-}
+}
+
+func (r *udpForWardingRepository) GetUdpAll(ctx context.Context, hostIds []int) ([]int, error) {
+	var res []int
+	if err:= r.db.WithContext(ctx).Model(&model.UdpForWarding{}).Where("host_id IN ?", hostIds).Select("cdn_web_id").Scan(&res).Error; err != nil {
+		return nil, err
+	}
+	return res, nil
+}

+ 10 - 0
internal/repository/webforwarding.go

@@ -38,6 +38,7 @@ type WebForwardingRepository interface {
 	GetDomainByHostIdPort(ctx context.Context, hostId int64, port string) ([]v1.Domain, error)
 	// 获取CDN的web配置的id
 	GetWebId(ctx context.Context, serverId int64) (int64, error)
+	GetWebAll(ctx context.Context, hostIds []int) ([]int, error)
 }
 
 func NewWebForwardingRepository(
@@ -337,4 +338,13 @@ func (r *webForwardingRepository) GetWebId(ctx context.Context, serverId int64)
 	}
 	return webId, nil
 
+}
+
+func (r *webForwardingRepository) GetWebAll(ctx context.Context, hostIds []int) ([]int, error) {
+	var res []int
+	if err := r.db.Model(&model.WebForwarding{}).WithContext(ctx).Where("host_id IN ?", hostIds).Select("cdn_web_id").Scan(&res).Error; err != nil {
+		return nil, err
+	}
+
+	return res, nil
 }

+ 25 - 26
internal/task/waf.go

@@ -49,22 +49,22 @@ type wafTask struct {
 
 
 const (
-	// 3天后秒数
-	OneDaysInSeconds = 3 * 24 * 60 * 60
+	// 1天后秒数
+	OneDaysInSeconds = 1 * 24 * 60 * 60
 	// 7天前秒数
 	SevenDaysInSeconds = 7 * 24 * 60 * 60 * -1
 )
 // 获取cdn web id
-func (t wafTask) GetCdnWebId(ctx context.Context,hostId int) ([]int, error) {
-	tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, hostId)
+func (t wafTask) GetCdnWebId(ctx context.Context,hostId []int) ([]int, error) {
+	tcpIds, err := t.tcpforwardingRep.GetTcpAll(ctx, hostId)
 	if err != nil {
 		return nil, err
 	}
-	udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, hostId)
+	udpIds, err := t.udpForWardingRep.GetUdpAll(ctx, hostId)
 	if err != nil {
 		return nil, err
 	}
-	webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, hostId)
+	webIds, err := t.webForWardingRep.GetWebAll(ctx, hostId)
 	if err != nil {
 		return nil, err
 	}
@@ -347,20 +347,18 @@ func (t *wafTask) StopPlan(ctx context.Context) error {
 	t.logger.Info("开始关闭已到期的WAF服务", zap.Int("数量", len(wafLimits)))
 	var allErrors *multierror.Error
 
+	var webIds []int
 	for _, limit := range wafLimits {
+		webIds = append(webIds, limit.HostId)
+	}
 
-		webIds, err := t.GetCdnWebId(ctx, limit.HostId)
-		if err != nil {
-			allErrors = multierror.Append(allErrors, fmt.Errorf("获取hostId %d 的webId失败: %w", limit.HostId, err))
-			continue // 继续处理下一个
-		}
-
-		if err := t.BanServer(ctx, webIds, false); err != nil {
-			allErrors = multierror.Append(allErrors, fmt.Errorf("关闭hostId %d 的服务失败: %w", limit.HostId, err))
-		}
 
+	if err := t.BanServer(ctx, webIds, false); err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("关闭hostId %d 的服务失败: %w", webIds, err))
 	}
 
+
+
 	return allErrors.ErrorOrNil()
 }
 //对于到期7天内续费的产品需要进行恢复操作
@@ -394,25 +392,26 @@ func (t *wafTask) RecoverStopPlan(ctx context.Context) error {
 	t.logger.Info("发现已续费、需要恢复的WAF服务", zap.Int("数量", len(renewalRequests)))
 	var allErrors *multierror.Error
 
+
+	var webIds []int
 	for _, req := range renewalRequests {
-		// 启用CDN服务
-		webIds, err := t.GetCdnWebId(ctx, req.HostId)
-		if err != nil {
-			allErrors = multierror.Append(allErrors, fmt.Errorf("恢复hostId %d: 获取webId失败: %w", req.HostId, err))
-			continue
-		}
+		webIds = append(webIds, req.HostId)
+	}
 
-		if err := t.BanServer(ctx, webIds, true); err != nil {
-			allErrors = multierror.Append(allErrors, fmt.Errorf("恢复hostId %d: 启用服务失败: %w", req.HostId, err))
-			continue // 服务启用失败,暂时不更新数据库状态
-		}
 
-		// 更新数据库状态(到期时间),state 保持为 true
+	if err := t.BanServer(ctx, webIds, true); err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("恢复hostId %d: 启用服务失败: %w", webIds, err))
+	}
+
+
+
+	for _, req := range renewalRequests {
 		if err := t.EditExpired(ctx, []RenewalRequest{req}); err != nil {
 			allErrors = multierror.Append(allErrors, fmt.Errorf("恢复hostId %d: 更新数据库状态失败: %w", req.HostId, err))
 		}
 	}
 
+
 	return allErrors.ErrorOrNil()
 }