Эх сурвалжийг харах

refactor(internal): 重构 WafTask 业务逻辑

- 新增 WafOperationsService 接口,将部分 WafTask 逻辑下沉到服务层
- 实现 WafOperationsService 的具体方法,包括:
  - GetForwardingRuleIds: 获取转发规则 ID
  - SetCdnWebsitesState: 设置 CDN 网站状态 - ExecuteRenewalActions:执行续费操作
  - CleanupPlan: 清理单个套餐资源
  - RecoverPlans:批量恢复套餐服务
- 更新 WafTask 构造函数,引入 WafOperationsService 依赖- 调整 wire.go 和 wire_gen.go,添加 WafOperationsService 的创建和注入
fusu 1 долоо хоног өмнө
parent
commit
05beddc8fa

+ 1 - 1
cmd/task/wire/wire.go

@@ -105,7 +105,7 @@ var serviceSet = wire.NewSet(
 	waf.NewZzybgpService,
 	admin2.NewWafLogService,
 	admin2.NewWafLogDataCleanService,
-
+	waf.NewWafOperationsService,
 )
 
 // build App

+ 3 - 2
cmd/task/wire/wire_gen.go

@@ -91,7 +91,8 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	webForwardingService := waf2.NewWebForwardingService(serviceService, requiredService, webForwardingRepository, crawlerService, parserService, wafFormatterService, aoDunService, rabbitMQ, gatewayipService, globalLimitRepository, cdnService, proxyService, sslCertService, websocketService, ccService, ccIpListService)
 	buildAudunService := waf2.NewBuildAudunService(serviceService, aoDunService, gatewayipRepository, hostService)
 	zzybgpService := waf2.NewZzybgpService(serviceService, gatewayipRepository, hostService, aoDunService)
-	wafTask := task.NewWafTask(webForwardingRepository, tcpforwardingRepository, udpForWardingRepository, cdnService, hostRepository, globalLimitRepository, expiredRepository, taskTask, gatewayipRepository, tcpforwardingService, udpForWardingService, webForwardingService, buildAudunService, zzybgpService)
+	wafOperationsService := waf2.NewWafOperationsService(serviceService, webForwardingRepository, tcpforwardingRepository, udpForWardingRepository, cdnService, hostRepository, globalLimitRepository, expiredRepository, gatewayipRepository, tcpforwardingService, udpForWardingService, webForwardingService, buildAudunService, zzybgpService)
+	wafTask := task.NewWafTask(webForwardingRepository, tcpforwardingRepository, udpForWardingRepository, cdnService, hostRepository, globalLimitRepository, expiredRepository, taskTask, gatewayipRepository, tcpforwardingService, udpForWardingService, webForwardingService, buildAudunService, zzybgpService, wafOperationsService)
 	taskServer := server.NewTaskServer(logger, userTask, gameShieldTask, wafTask)
 	jobJob := job.NewJob(transaction, logger, sidSid, rabbitMQ)
 	userJob := job.NewUserJob(jobJob, userRepository)
@@ -114,7 +115,7 @@ var jobSet = wire.NewSet(job.NewJob, job.NewUserJob, job.NewWhitelistJob, job.Ne
 
 var serverSet = wire.NewSet(server.NewTaskServer, server.NewJobServer)
 
-var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, gameShield.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewHostService, gameShield.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewGameShieldUserIpService, waf2.NewWafFormatterService, flexCdn2.NewCdnService, service.NewRequestService, waf2.NewTcpforwardingService, waf2.NewUdpForWardingService, waf2.NewWebForwardingService, flexCdn2.NewProxyService, flexCdn2.NewSslCertService, flexCdn2.NewWebsocketService, waf2.NewCcService, waf2.NewGatewayipService, service.NewLogService, waf2.NewCcIpListService, waf2.NewBuildAudunService, waf2.NewZzybgpService, admin2.NewWafLogService, admin2.NewWafLogDataCleanService)
+var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, gameShield.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewHostService, gameShield.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewGameShieldUserIpService, waf2.NewWafFormatterService, flexCdn2.NewCdnService, service.NewRequestService, waf2.NewTcpforwardingService, waf2.NewUdpForWardingService, waf2.NewWebForwardingService, flexCdn2.NewProxyService, flexCdn2.NewSslCertService, flexCdn2.NewWebsocketService, waf2.NewCcService, waf2.NewGatewayipService, service.NewLogService, waf2.NewCcIpListService, waf2.NewBuildAudunService, waf2.NewZzybgpService, admin2.NewWafLogService, admin2.NewWafLogDataCleanService, waf2.NewWafOperationsService)
 
 // build App
 func newApp(task2 *server.TaskServer,

+ 491 - 0
internal/service/api/waf/wafoperations.go

@@ -0,0 +1,491 @@
+package waf
+
+import (
+	"context"
+	"fmt"
+	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
+	"github.com/go-nunu/nunu-layout-advanced/internal/model"
+	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
+	waf2 "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
+	"github.com/go-nunu/nunu-layout-advanced/internal/service"
+	"github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn"
+	"github.com/hashicorp/go-multierror"
+	"go.uber.org/zap"
+
+	"sync"
+	"time"
+)
+
+// WafOperationsService WAF通用操作服务接口
+type WafOperationsService interface {
+	// 清理单个套餐的所有相关资源
+	CleanupPlan(ctx context.Context, limit model.GlobalLimit) error
+	// 批量恢复套餐服务
+	RecoverPlans(ctx context.Context, limits []model.GlobalLimit, redisListKey repository.PlanListType) error
+	// 获取主机关联的所有转发规则ID
+	GetForwardingRuleIds(ctx context.Context, hostIds []int) ([]int, error)
+	// 批量设置CDN网站状态
+	SetCdnWebsitesState(ctx context.Context, ids []int, enable bool) error
+	// 执行续费操作
+	ExecuteRenewalActions(ctx context.Context, reqs []RenewalRequest) error
+}
+
+type RenewalRequest struct {
+	HostId    int
+	ExpiredAt int64
+}
+
+func NewWafOperationsService(
+	service *service.Service,
+	webForWardingRep waf2.WebForwardingRepository,
+	tcpforwardingRep waf2.TcpforwardingRepository,
+	udpForWardingRep waf2.UdpForWardingRepository,
+	cdn flexCdn.CdnService,
+	hostRep repository.HostRepository,
+	globalLimitRep waf2.GlobalLimitRepository,
+	expiredRep repository.ExpiredRepository,
+	gatewayIpRep waf2.GatewayipRepository,
+	tcp TcpforwardingService,
+	udp UdpForWardingService,
+	web WebForwardingService,
+	buildAoDun BuildAudunService,
+	zzyBgp ZzybgpService,
+) WafOperationsService {
+	return &wafOperationsService{
+		Service:          service,
+		webForWardingRep: webForWardingRep,
+		tcpforwardingRep: tcpforwardingRep,
+		udpForWardingRep: udpForWardingRep,
+		cdn:              cdn,
+		hostRep:          hostRep,
+		globalLimitRep:   globalLimitRep,
+		expiredRep:       expiredRep,
+		gatewayIpRep:     gatewayIpRep,
+		tcp:              tcp,
+		udp:              udp,
+		web:              web,
+		buildAoDun:       buildAoDun,
+		zzyBgp:           zzyBgp,
+	}
+}
+
+type wafOperationsService struct {
+	*service.Service
+	webForWardingRep waf2.WebForwardingRepository
+	tcpforwardingRep waf2.TcpforwardingRepository
+	udpForWardingRep waf2.UdpForWardingRepository
+	cdn              flexCdn.CdnService
+	hostRep          repository.HostRepository
+	globalLimitRep   waf2.GlobalLimitRepository
+	expiredRep       repository.ExpiredRepository
+	gatewayIpRep     waf2.GatewayipRepository
+	tcp              TcpforwardingService
+	udp              UdpForWardingService
+	web              WebForwardingService
+	buildAoDun       BuildAudunService
+	zzyBgp           ZzybgpService
+}
+
+// GetForwardingRuleIds 获取主机关联的所有转发规则ID
+func (s *wafOperationsService) GetForwardingRuleIds(ctx context.Context, hostIds []int) ([]int, error) {
+	if len(hostIds) == 0 {
+		return nil, nil
+	}
+	
+	var ids []int
+	var result *multierror.Error
+
+	// 获取TCP转发规则ID
+	tcpIds, err := s.tcpforwardingRep.GetTcpAll(ctx, hostIds)
+	if err != nil {
+		result = multierror.Append(result, fmt.Errorf("获取TCP转发规则失败: %w", err))
+	}
+	ids = append(ids, tcpIds...)
+
+	// 获取UDP转发规则ID
+	udpIds, err := s.udpForWardingRep.GetUdpAll(ctx, hostIds)
+	if err != nil {
+		result = multierror.Append(result, fmt.Errorf("获取UDP转发规则失败: %w", err))
+	}
+	ids = append(ids, udpIds...)
+
+	// 获取Web转发规则ID
+	webIds, err := s.webForWardingRep.GetWebAll(ctx, hostIds)
+	if err != nil {
+		result = multierror.Append(result, fmt.Errorf("获取Web转发规则失败: %w", err))
+	}
+	ids = append(ids, webIds...)
+
+	return ids, result.ErrorOrNil()
+}
+
+// SetCdnWebsitesState 批量设置CDN网站状态(并发执行)
+func (s *wafOperationsService) SetCdnWebsitesState(ctx context.Context, ids []int, enable bool) error {
+	if len(ids) == 0 {
+		return nil
+	}
+	
+	var wg sync.WaitGroup
+	errChan := make(chan error, len(ids))
+	wg.Add(len(ids))
+	
+	for _, id := range ids {
+		go func(id int) {
+			defer wg.Done()
+			// cdn.EditWebIsOn 的第二个参数 enable: true=启用, false=禁用
+			if err := s.cdn.EditWebIsOn(ctx, int64(id), enable); err != nil {
+				errChan <- fmt.Errorf("设置CDN网站状态失败(ID:%d): %w", id, err)
+			}
+		}(id)
+	}
+	
+	wg.Wait()
+	close(errChan)
+	
+	var result *multierror.Error
+	for err := range errChan {
+		result = multierror.Append(result, err)
+	}
+	
+	return result.ErrorOrNil()
+}
+
+// ExecuteRenewalActions 执行续费操作,包括更新DB和调用CDN API
+// 该方法并发更新数据库中的套餐状态,将到期时间和状态同步到GlobalLimit表
+// 主要用于套餐续费后,需要将最新的到期时间从主机表同步到WAF套餐表
+//
+// 执行流程:
+// 1. 参数校验,如果没有续费请求则直接返回
+// 2. 创建goroutine池,每个续费请求对应一个goroutine
+// 3. 并发调用数据库更新操作,提高处理效率
+// 4. 使用互斥锁保护错误收集,避免并发写入冲突
+// 5. 等待所有更新操作完成,返回聚合的错误信息
+//
+// 并发安全:
+// - 使用sync.Mutex保护共享的错误收集器
+// - 每个goroutine独立处理一个续费请求,避免数据竞争
+// - 使用WaitGroup确保所有操作完成后才返回
+//
+// 参数:
+//   - ctx: 上下文对象,用于控制请求生命周期和传递trace信息
+//   - reqs: 续费请求列表,包含HostId和新的到期时间
+//
+// 返回:
+//   - error: 更新过程中的任何错误,如果部分失败会包含所有失败的详细信息
+func (s *wafOperationsService) ExecuteRenewalActions(ctx context.Context, reqs []RenewalRequest) error {
+	// 参数校验:如果没有续费请求,直接返回成功
+	if len(reqs) == 0 {
+		return nil
+	}
+	
+	// 并发控制和错误收集初始化
+	var allErrors *multierror.Error
+	var wg sync.WaitGroup
+	var mu sync.Mutex // 保护allErrors的并发写入
+	
+	wg.Add(len(reqs))
+	
+	// 为每个续费请求创建一个goroutine进行并发处理
+	for _, req := range reqs {
+		go func(r RenewalRequest) {
+			defer wg.Done()
+			// 更新数据库中的套餐状态
+			// 将State设置为true表示套餐处于激活状态
+			// ExpiredAt更新为最新的到期时间
+			err := s.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
+				HostId:    r.HostId,    // 主机ID,用于定位具体的套餐
+				ExpiredAt: r.ExpiredAt, // 新的到期时间戳
+				State:     true,       // 激活状态,表示套餐可用
+			})
+			if err != nil {
+				// 线程安全的错误收集
+				mu.Lock()
+				allErrors = multierror.Append(allErrors, fmt.Errorf("更新主机%d续费状态失败: %w", r.HostId, err))
+				mu.Unlock()
+			}
+		}(req)
+	}
+
+	// 等待所有更新操作完成
+	wg.Wait()
+	return allErrors.ErrorOrNil()
+}
+
+// CleanupPlan 清理单个套餐的所有相关资源
+// 该方法执行套餐过期后的完整清理流程,包括删除转发规则、重置防护设置、清理网络配置等
+// 这是一个复合操作,涉及多个子系统的协调,确保套餐相关的所有资源都被正确清理
+//
+// 清理步骤(按执行顺序):
+// 1. 从Redis "停止列表" 中移除该套餐(因为即将转移到 "已清理列表")
+// 2. 删除TCP转发规则 - 清理所有TCP端口转发配置
+// 3. 删除UDP转发规则 - 清理所有UDP端口转发配置  
+// 4. 删除Web转发规则 - 清理所有HTTP/HTTPS转发配置
+// 5. 重置BGP防护设置 - 将防护等级重置为默认值(10)
+// 6. 清除带宽限制 - 移除小防火墙的带宽限制配置
+// 7. 清理网关IP配置 - 删除该主机关联的所有网关IP
+// 8. 将套餐标记为"已清理" - 添加到Redis "已清理列表"
+//
+// 错误处理策略:
+// - 使用multierror收集所有步骤的错误,不会因单个步骤失败而中断整个流程
+// - 只有在前面所有步骤都成功的情况下,才执行最终的网关IP清理和Redis标记
+// - 记录详细的日志信息,便于问题排查和监控
+//
+// 幂等性保证:
+// - 该方法可以安全地重复调用,不会产生副作用
+// - 如果某个资源已经被清理,相关操作会优雅地处理
+//
+// 参数:
+//   - ctx: 上下文对象,用于控制请求生命周期和传递trace信息
+//   - limit: 需要清理的套餐信息,包含HostId、Uid等关键字段
+//
+// 返回:
+//   - error: 清理过程中的任何错误,使用multierror聚合多个错误
+func (s *wafOperationsService) CleanupPlan(ctx context.Context, limit model.GlobalLimit) error {
+	var allErrors *multierror.Error
+	hostId := int64(limit.HostId)
+
+	// 记录清理开始的日志,便于监控和调试
+	s.Logger.Info("开始清理套餐资源", 
+		zap.Int("hostId", limit.HostId),
+		zap.Int("uid", limit.Uid),
+		zap.String("operation", "cleanup_plan"))
+
+	// 步骤1: 从Redis "停止列表" 中移除该套餐
+	// 这是状态转换的第一步,表示套餐即将从 "已停止" 状态转换到 "已清理" 状态
+	if err := s.expiredRep.RemovePlans(ctx, repository.ClosedPlansList, hostId); err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("从停止列表移除失败: %w", err))
+		s.Logger.Warn("从停止列表移除失败", zap.Int64("hostId", hostId), zap.Error(err))
+	}
+
+	// 步骤2: 删除TCP转发规则
+	// TCP转发规则通常用于游戏服务器、数据库等需要TCP连接的服务
+	// 需要先获取所有关联的TCP规则ID,然后批量删除
+	tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, limit.HostId)
+	if err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("获取TCP转发规则失败: %w", err))
+		s.Logger.Warn("获取TCP转发规则失败", zap.Int("hostId", limit.HostId), zap.Error(err))
+	} else if len(tcpIds) > 0 {
+		s.Logger.Info("开始删除TCP转发规则", zap.Int("hostId", limit.HostId), zap.Int("count", len(tcpIds)))
+		if err := s.tcp.DeleteTcpForwarding(ctx, v1.DeleteTcpForwardingRequest{
+			Ids:    tcpIds,    // 需要删除的TCP规则ID列表
+			HostId: limit.HostId, // 主机ID,用于权限验证
+			Uid:    limit.Uid,    // 用户ID,用于权限验证
+		}); err != nil {
+			allErrors = multierror.Append(allErrors, fmt.Errorf("删除TCP转发规则失败: %w", err))
+			s.Logger.Error("删除TCP转发规则失败", zap.Int("hostId", limit.HostId), zap.Error(err))
+		} else {
+			s.Logger.Info("成功删除TCP转发规则", zap.Int("hostId", limit.HostId), zap.Int("count", len(tcpIds)))
+		}
+	} else {
+		s.Logger.Debug("该主机没有TCP转发规则需要删除", zap.Int("hostId", limit.HostId))
+	}
+
+	// 步骤3: 删除UDP转发规则  
+	// UDP转发规则通常用于游戏服务器、DNS服务等需要UDP连接的服务
+	// UDP协议的特点是无连接,但在防护场景下同样需要转发规则
+	udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, limit.HostId)
+	if err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("获取UDP转发规则失败: %w", err))
+		s.Logger.Warn("获取UDP转发规则失败", zap.Int("hostId", limit.HostId), zap.Error(err))
+	} else if len(udpIds) > 0 {
+		s.Logger.Info("开始删除UDP转发规则", zap.Int("hostId", limit.HostId), zap.Int("count", len(udpIds)))
+		if err := s.udp.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{
+			Ids:    udpIds,       // 需要删除的UDP规则ID列表
+			HostId: limit.HostId, // 主机ID,用于权限验证
+			Uid:    limit.Uid,    // 用户ID,用于权限验证
+		}); err != nil {
+			allErrors = multierror.Append(allErrors, fmt.Errorf("删除UDP转发规则失败: %w", err))
+			s.Logger.Error("删除UDP转发规则失败", zap.Int("hostId", limit.HostId), zap.Error(err))
+		} else {
+			s.Logger.Info("成功删除UDP转发规则", zap.Int("hostId", limit.HostId), zap.Int("count", len(udpIds)))
+		}
+	} else {
+		s.Logger.Debug("该主机没有UDP转发规则需要删除", zap.Int("hostId", limit.HostId))
+	}
+
+	// 步骤4: 删除Web转发规则
+	// Web转发规则用于HTTP/HTTPS网站服务,是最常见的转发类型
+	// 包括域名解析、SSL证书、负载均衡等复杂配置
+	webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, limit.HostId)
+	if err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("获取Web转发规则失败: %w", err))
+		s.Logger.Warn("获取Web转发规则失败", zap.Int("hostId", limit.HostId), zap.Error(err))
+	} else if len(webIds) > 0 {
+		s.Logger.Info("开始删除Web转发规则", zap.Int("hostId", limit.HostId), zap.Int("count", len(webIds)))
+		if err := s.web.DeleteWebForwarding(ctx, v1.DeleteWebForwardingRequest{
+			Ids:    webIds,       // 需要删除的Web规则ID列表
+			HostId: limit.HostId, // 主机ID,用于权限验证
+			Uid:    limit.Uid,    // 用户ID,用于权限验证
+		}); err != nil {
+			allErrors = multierror.Append(allErrors, fmt.Errorf("删除Web转发规则失败: %w", err))
+			s.Logger.Error("删除Web转发规则失败", zap.Int("hostId", limit.HostId), zap.Error(err))
+		} else {
+			s.Logger.Info("成功删除Web转发规则", zap.Int("hostId", limit.HostId), zap.Int("count", len(webIds)))
+		}
+	} else {
+		s.Logger.Debug("该主机没有Web转发规则需要删除", zap.Int("hostId", limit.HostId))
+	}
+
+	// 步骤5: 重置BGP防护设置
+	// 将防护等级重置为默认值(10),这通常是最低的防护级别
+	// BGP防护是网络层面的DDoS防护,重置后将停止高级防护功能
+	s.Logger.Info("开始重置BGP防护设置", zap.Int64("hostId", hostId), zap.Int("defenseLevel", 10))
+	if err := s.zzyBgp.SetDefense(ctx, hostId, 10); err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("重置BGP防护设置失败: %w", err))
+		s.Logger.Error("重置BGP防护设置失败", zap.Int64("hostId", hostId), zap.Error(err))
+	} else {
+		s.Logger.Info("成功重置BGP防护设置", zap.Int64("hostId", hostId))
+	}
+
+	// 步骤6: 清除小防火墙带宽限制
+	// 移除奥盾防护系统中设置的带宽限制配置
+	// "del" 操作表示删除该主机的所有带宽限制规则
+	s.Logger.Info("开始清除小防火墙带宽限制", zap.Int64("hostId", hostId))
+	if err := s.buildAoDun.Bandwidth(ctx, hostId, "del"); err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("清除带宽限制失败: %w", err))
+		s.Logger.Error("清除带宽限制失败", zap.Int64("hostId", hostId), zap.Error(err))
+	} else {
+		s.Logger.Info("成功清除小防火墙带宽限制", zap.Int64("hostId", hostId))
+	}
+
+	// 步骤7: 执行最终清理操作(仅在前面步骤都成功时执行)
+	// 这是一个关键的设计决策:只有在所有资源清理都成功的情况下,
+	// 才执行网关IP清理和状态标记,确保数据一致性
+	if allErrors.ErrorOrNil() == nil {
+		s.Logger.Info("前置清理步骤全部成功,开始执行最终清理操作", zap.Int64("hostId", hostId))
+		
+		// 步骤7a: 清理网关IP配置
+		// 删除该主机在网关系统中的所有IP配置,断开网络连接
+		if err := s.gatewayIpRep.CleanIPByHostId(ctx, []int64{hostId}); err != nil {
+			allErrors = multierror.Append(allErrors, fmt.Errorf("清理网关IP失败: %w", err))
+			s.Logger.Error("清理网关IP失败", zap.Int64("hostId", hostId), zap.Error(err))
+		} else {
+			s.Logger.Info("成功清理网关IP", zap.Int64("hostId", hostId))
+		}
+
+		// 步骤7b: 将套餐标记为"已清理"状态
+		// 添加到ExpiringSoonPlansList(已清理列表),表示清理流程完成
+		// 这个标记用于防止重复清理和状态跟踪
+		if err := s.expiredRep.AddPlans(ctx, repository.ExpiringSoonPlansList, hostId); err != nil {
+			allErrors = multierror.Append(allErrors, fmt.Errorf("标记为已清理失败: %w", err))
+			s.Logger.Error("标记为已清理失败", zap.Int64("hostId", hostId), zap.Error(err))
+		} else {
+			s.Logger.Info("成功标记套餐为已清理状态", zap.Int64("hostId", hostId))
+		}
+	} else {
+		// 如果前面的步骤有失败,记录警告日志,不执行最终清理
+		s.Logger.Warn("由于前置清理步骤存在错误,跳过最终清理操作", 
+			zap.Int64("hostId", hostId), 
+			zap.Error(allErrors.ErrorOrNil()))
+	}
+
+	// 记录最终的清理结果
+	if allErrors.ErrorOrNil() != nil {
+		s.Logger.Error("清理套餐资源失败", 
+			zap.Int("hostId", limit.HostId), 
+			zap.Int("uid", limit.Uid),
+			zap.Error(allErrors.ErrorOrNil()))
+	} else {
+		s.Logger.Info("成功清理套餐资源", 
+			zap.Int("hostId", limit.HostId),
+			zap.Int("uid", limit.Uid),
+			zap.String("status", "completed"))
+	}
+
+	return allErrors.ErrorOrNil()
+}
+
+// RecoverPlans 批量恢复套餐服务
+func (s *wafOperationsService) RecoverPlans(ctx context.Context, limits []model.GlobalLimit, redisListKey repository.PlanListType) error {
+	if len(limits) == 0 {
+		return nil
+	}
+
+	// 1. 检查哪些套餐需要恢复(已续费且未过期)
+	var hostIdsToCheck []int
+	for _, limit := range limits {
+		hostIdsToCheck = append(hostIdsToCheck, limit.HostId)
+	}
+
+	// 2. 获取最新的主机到期时间
+	hostExpirations, err := s.hostRep.GetExpireTimeByHostId(ctx, hostIdsToCheck)
+	if err != nil {
+		return fmt.Errorf("获取主机到期时间失败: %w", err)
+	}
+
+	hostExpiredMap := make(map[int]int64, len(hostExpirations))
+	for _, h := range hostExpirations {
+		hostExpiredMap[h.HostId] = h.ExpiredAt
+	}
+
+	// 3. 筛选出需要恢复的套餐
+	var renewalRequests []RenewalRequest
+	var hostIdsToRecover []int
+	now := time.Now().Unix()
+
+	for _, limit := range limits {
+		if hostTime, ok := hostExpiredMap[limit.HostId]; ok && hostTime > now {
+			renewalRequests = append(renewalRequests, RenewalRequest{
+				HostId:    limit.HostId,
+				ExpiredAt: hostTime,
+			})
+			hostIdsToRecover = append(hostIdsToRecover, limit.HostId)
+		}
+	}
+
+	if len(renewalRequests) == 0 {
+		s.Logger.Info("没有需要恢复的套餐")
+		return nil
+	}
+
+	s.Logger.Info("开始恢复已续费的WAF服务", 
+		zap.Int("数量", len(renewalRequests)), 
+		zap.Any("套餐内容", renewalRequests))
+
+	var allErrors *multierror.Error
+
+	// 4. 启用CDN服务
+	webIds, err := s.GetForwardingRuleIds(ctx, hostIdsToRecover)
+	if err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("获取转发规则ID失败: %w", err))
+	} else {
+		if err := s.SetCdnWebsitesState(ctx, webIds, true); err != nil {
+			allErrors = multierror.Append(allErrors, fmt.Errorf("启用CDN服务失败: %w", err))
+		}
+	}
+
+	// 5. 同步续费信息到数据库
+	if err := s.ExecuteRenewalActions(ctx, renewalRequests); err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("同步续费信息失败: %w", err))
+	}
+
+	// 步骤7: 状态清理 - 从Redis相关列表中移除停止/清理标记
+	// 将hostId转换为int64类型以符合Redis操作接口要求
+	planIdsToRecover := make([]int64, len(hostIdsToRecover))
+	for i, id := range hostIdsToRecover {
+		planIdsToRecover[i] = int64(id)
+	}
+	
+	s.Logger.Info("开始从Redis列表移除状态标记", 
+		zap.String("listKey", string(redisListKey)),
+		zap.Int("套餐数量", len(planIdsToRecover)))
+	if err := s.expiredRep.RemovePlans(ctx, redisListKey, planIdsToRecover...); err != nil {
+		allErrors = multierror.Append(allErrors, fmt.Errorf("从Redis列表移除标记失败: %w", err))
+		s.Logger.Error("从Redis列表移除标记失败", zap.Error(err))
+	} else {
+		s.Logger.Info("成功从Redis列表移除状态标记")
+	}
+
+	// 记录最终的恢复结果
+	if allErrors.ErrorOrNil() != nil {
+		s.Logger.Error("恢复套餐服务部分失败", 
+			zap.Int("成功数量", len(renewalRequests)),
+			zap.Error(allErrors.ErrorOrNil()))
+	} else {
+		s.Logger.Info("成功恢复套餐服务", 
+			zap.Int("恢复数量", len(renewalRequests)),
+			zap.String("status", "completed"))
+	}
+
+	return allErrors.ErrorOrNil()
+}

+ 22 - 207
internal/task/waf.go

@@ -3,7 +3,6 @@ package task
 import (
 	"context"
 	"fmt"
-	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 	waf2 "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
@@ -48,6 +47,7 @@ func NewWafTask(
 	web waf.WebForwardingService,
 	buildAoDun waf.BuildAudunService,
 	zzyBgp waf.ZzybgpService,
+	wafOps waf.WafOperationsService,
 ) WafTask {
 	return &wafTask{
 		Task:              task,
@@ -63,7 +63,8 @@ func NewWafTask(
 		udp:               udp,
 		web:               web,
 		buildAoDun:        buildAoDun,
-		zzyBgp :           zzyBgp,
+		zzyBgp:            zzyBgp,
+		wafOps:            wafOps,
 	}
 }
 
@@ -78,10 +79,11 @@ type wafTask struct {
 	expiredRep       repository.ExpiredRepository
 	gatewayIpRep     waf2.GatewayipRepository
 	tcp              waf.TcpforwardingService
-	udp waf.UdpForWardingService
-	web waf.WebForwardingService
-	buildAoDun waf.BuildAudunService
-	zzyBgp waf.ZzybgpService
+	udp              waf.UdpForWardingService
+	web              waf.WebForwardingService
+	buildAoDun       waf.BuildAudunService
+	zzyBgp           waf.ZzybgpService
+	wafOps           waf.WafOperationsService
 }
 
 const (
@@ -89,10 +91,8 @@ const (
 	SevenDaysInSeconds = 7 * 24 * 60 * 60
 )
 
-type RenewalRequest struct {
-	HostId    int
-	ExpiredAt int64
-}
+// RenewalRequest 现在使用service层的定义
+type RenewalRequest = waf.RenewalRequest
 
 // =================================================================
 // =================== 核心辅助函数 (Core Helpers) =================
@@ -106,87 +106,19 @@ func (t *wafTask) wrapTaskError(taskName, step string, err error) error {
 	return fmt.Errorf("执行[%s]-%s失败: %w", taskName, step, err)
 }
 
-// getCdnWebIdsByHostIds (原GetCdnWebId) 根据hostId列表获取所有关联的转发规则ID
+// getCdnWebIdsByHostIds 委托给service层处理
 func (t *wafTask) getCdnWebIdsByHostIds(ctx context.Context, hostIds []int) ([]int, error) {
-	if len(hostIds) == 0 {
-		return nil, nil
-	}
-	var ids []int
-	var result *multierror.Error
-
-	tcpIds, err := t.tcpforwardingRep.GetTcpAll(ctx, hostIds)
-	if err != nil {
-		result = multierror.Append(result, err)
-	}
-	ids = append(ids, tcpIds...)
-
-	udpIds, err := t.udpForWardingRep.GetUdpAll(ctx, hostIds)
-	if err != nil {
-		result = multierror.Append(result, err)
-	}
-	ids = append(ids, udpIds...)
-
-	webIds, err := t.webForWardingRep.GetWebAll(ctx, hostIds)
-	if err != nil {
-		result = multierror.Append(result, err)
-	}
-	ids = append(ids, webIds...)
-
-	return ids, result.ErrorOrNil()
+	return t.wafOps.GetForwardingRuleIds(ctx, hostIds)
 }
 
-// setCdnWebsitesState (原BanServer) 启用或禁用一组CDN网站 (并发执行)
+// setCdnWebsitesState 委托给service层处理
 func (t *wafTask) setCdnWebsitesState(ctx context.Context, ids []int, enable bool) error {
-	if len(ids) == 0 {
-		return nil
-	}
-	var wg sync.WaitGroup
-	errChan := make(chan error, len(ids))
-	wg.Add(len(ids))
-	for _, id := range ids {
-		go func(id int) {
-			defer wg.Done()
-			// cdn.EditWebIsOn 的第二个参数 isBan, false=启用, true=禁用
-			// 所以 enable=true 对应 isBan=false
-			if err := t.cdn.EditWebIsOn(ctx, int64(id), enable); err != nil {
-				errChan <- err
-			}
-		}(id)
-	}
-	wg.Wait()
-	close(errChan)
-	var result *multierror.Error
-	for err := range errChan {
-		result = multierror.Append(result, err)
-	}
-	return result.ErrorOrNil()
+	return t.wafOps.SetCdnWebsitesState(ctx, ids, enable)
 }
 
-// executeRenewalActions (原EditExpired) 执行续费操作,包括更新DB和调用CDN API
+// executeRenewalActions 委托给service层处理
 func (t *wafTask) executeRenewalActions(ctx context.Context, reqs []RenewalRequest) error {
-	if len(reqs) == 0 {
-		return nil
-	}
-	var allErrors *multierror.Error
-	var wg sync.WaitGroup
-	wg.Add(len(reqs))
-	var mu sync.Mutex
-	for _, req := range reqs {
-		go func(r RenewalRequest) {
-			defer wg.Done()
-			// 更新数据库状态
-			err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{HostId: r.HostId, ExpiredAt: r.ExpiredAt, State: true})
-			if err != nil {
-				mu.Lock() // 在修改前加锁
-				allErrors = multierror.Append(allErrors, err)
-				mu.Unlock() // 修改后解锁
-				return // 如果DB更新失败,不继续调用CDN API
-			}
-		}(req)
-	}
-
-	wg.Wait()
-	return allErrors.ErrorOrNil()
+	return t.wafOps.ExecuteRenewalActions(ctx, reqs)
 }
 
 // =================================================================
@@ -347,60 +279,12 @@ func (t *wafTask) StopPlan(ctx context.Context) error {
 	return t.wrapTaskError(taskName, "执行停止", allErrors.ErrorOrNil())
 }
 
-// _recoverPlans 是一个统一的、可重用的套餐恢复流程
+// _recoverPlans 委托给service层处理套餐恢复流程
 func (t *wafTask) _recoverPlans(ctx context.Context, limitsToCheck []model.GlobalLimit, taskName string, redisListKey repository.PlanListType) error {
-	if len(limitsToCheck) == 0 {
-		return nil
-	}
-
-	requestsToSync, err := t.findPlansNeedingSync(ctx, limitsToCheck)
-	if err != nil {
-		return t.wrapTaskError(taskName, "决策检查续费状态", err)
-	}
-
-	var finalRecoveryRequests []RenewalRequest
-	for _, req := range requestsToSync {
-		if req.ExpiredAt > time.Now().Unix() {
-			finalRecoveryRequests = append(finalRecoveryRequests, req)
-		}
-	}
-
-	if len(finalRecoveryRequests) == 0 {
-		t.logger.Info("在检查范围内未发现已续费的套餐", zap.String("task", taskName))
-		return nil
-	}
-
-	t.logger.Info("开始恢复已续费的WAF服务", zap.String("task", taskName), zap.Int("数量", len(finalRecoveryRequests)), zap.Any("套餐内容", finalRecoveryRequests))
-
-	var hostIdsToRecover []int
-	for _, req := range finalRecoveryRequests {
-		hostIdsToRecover = append(hostIdsToRecover, req.HostId)
-	}
-
-	var allErrors *multierror.Error
-	webIds, err := t.getCdnWebIdsByHostIds(ctx, hostIdsToRecover)
-	if err != nil {
-		allErrors = multierror.Append(allErrors, fmt.Errorf("获取webId失败: %w", err))
-	} else {
-		if err := t.setCdnWebsitesState(ctx, webIds, true); err != nil { // enable=true
-			allErrors = multierror.Append(allErrors, fmt.Errorf("启用web服务失败: %w", err))
-		}
-	}
-
-	if err := t.executeRenewalActions(ctx, finalRecoveryRequests); err != nil {
-		allErrors = multierror.Append(allErrors, fmt.Errorf("同步续费信息失败: %w", err))
-	}
-
-	planIdsToRecover := make([]int64, len(hostIdsToRecover))
-	for i, id := range hostIdsToRecover {
-		planIdsToRecover[i] = int64(id)
-	}
-	// 从指定的Redis列表中移除标记 (ClosedPlansList 或 ExpiringSoonPlansList)
-	if err := t.expiredRep.RemovePlans(ctx, redisListKey, planIdsToRecover...); err != nil {
-		allErrors = multierror.Append(allErrors, fmt.Errorf("从Redis列表 '%s' 移除标记失败: %w", redisListKey, err))
+	if err := t.wafOps.RecoverPlans(ctx, limitsToCheck, redisListKey); err != nil {
+		return t.wrapTaskError(taskName, "执行恢复", err)
 	}
-
-	return t.wrapTaskError(taskName, "执行恢复", allErrors.ErrorOrNil())
+	return nil
 }
 
 // 3. RecoverRecentPlan 恢复7天内续费的套餐
@@ -499,78 +383,9 @@ func (t *wafTask) CleanUpStaleRecords(ctx context.Context) error {
 	return t.wrapTaskError(taskName, "执行清理", allErrors.ErrorOrNil())
 }
 
-// executeSinglePlanCleanup 执行对单个套餐的完整清理操作,方便并发调用
+// executeSinglePlanCleanup 委托给service层处理单个套餐清理
 func (t *wafTask) executeSinglePlanCleanup(ctx context.Context, limit model.GlobalLimit) error {
-	var allErrors *multierror.Error
-	hostId := int64(limit.HostId)
-
-	// 从“停止列表”中移除,因为它即将被归档到“已清理列表”
-	if err := t.expiredRep.RemovePlans(ctx, repository.ClosedPlansList, hostId); err != nil {
-		allErrors = multierror.Append(allErrors, err)
-	}
-
-	// 删除关联的转发规则...
-	tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, limit.HostId)
-	if err != nil {
-		allErrors = multierror.Append(allErrors, err)
-	} else if len(tcpIds) > 0 {
-		if err := t.tcp.DeleteTcpForwarding(ctx, v1.DeleteTcpForwardingRequest{Ids: tcpIds, HostId: limit.HostId,Uid: limit.Uid}); err != nil {
-			allErrors = multierror.Append(allErrors, err)
-		}
-	}
-
-
-
-	udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, limit.HostId)
-	if err != nil {
-		allErrors = multierror.Append(allErrors, err)
-	} else if len(udpIds) > 0 {
-		if err := t.udp.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{Ids: udpIds, HostId: limit.HostId,Uid: limit.Uid}); err != nil {
-			allErrors = multierror.Append(allErrors, err)
-		}
-	}
-
-
-	webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, limit.HostId)
-	if err != nil {
-		allErrors = multierror.Append(allErrors, err)
-	} else if len(webIds) > 0 {
-		if err := t.web.DeleteWebForwarding(ctx, v1.DeleteWebForwardingRequest{Ids: webIds, HostId: limit.HostId,Uid: limit.Uid}); err != nil {
-			allErrors = multierror.Append(allErrors, err)
-		}
-	}
-
-
-	// 重置防护
-	err = t.zzyBgp.SetDefense(ctx, hostId, 10)
-	if err != nil {
-		return err
-	}
-
-	// 清除小防火墙带宽限制
-	if err := t.buildAoDun.Bandwidth(ctx, hostId, "del"); err != nil {
-		allErrors = multierror.Append(allErrors, err)
-	}
-
-
-
-
-
-	// 只有在上述所有步骤都没有出错的情况下,才执行最终的数据库更新和Redis标记
-	if allErrors.ErrorOrNil() == nil {
-		err := t.gatewayIpRep.CleanIPByHostId(ctx, []int64{hostId})
-		if err != nil {
-			allErrors = multierror.Append(allErrors, err)
-		}
-
-		// [CORRECTION] 幂等性保障:将此hostId标记为“已清理”,即添加到 `ExpiringSoonPlansList`
-		if err := t.expiredRep.AddPlans(ctx, repository.ExpiringSoonPlansList, hostId); err != nil {
-			allErrors = multierror.Append(allErrors, fmt.Errorf("将hostId %d标记为已清理失败: %w", hostId, err))
-		}
-	}
-
-
-	return allErrors.ErrorOrNil()
+	return t.wafOps.CleanupPlan(ctx, limit)
 }
 
 // 5. RecoverStalePlan 恢复超过7天后才续费的套餐