Sfoglia il codice sorgente

feat(task): 实现 WAF套餐到期处理功能

- 新增 WafTask 接口和实现,包含同步到期时间、停止套餐和恢复停止的套餐等功能
- 在任务服务器中集成 WafTask,并添加相应的定时任务- 新增 ExpiredRepository 接口和实现,用于管理已到期的套餐信息
- 更新相关的依赖注入和配置
fusu 3 settimane fa
parent
commit
33bcadcf5c

+ 1 - 0
cmd/server/wire/wire.go

@@ -49,6 +49,7 @@ var repositorySet = wire.NewSet(
 	repository.NewAllowAndDenyIpRepository,
 	repository.NewProxyRepository,
 	repository.NewCcRepository,
+	repository.NewExpiredRepository,
 
 )
 

+ 1 - 1
cmd/server/wire/wire_gen.go

@@ -114,7 +114,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 // wire.go:
 
-var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewCasbinEnforcer, repository.NewMongoClient, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewAdminRepository, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldPublicIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewGameShieldUserIpRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldSdkIpRepository, repository.NewHostRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, repository.NewCdnRepository, repository.NewAllowAndDenyIpRepository, repository.NewProxyRepository, repository.NewCcRepository)
+var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewCasbinEnforcer, repository.NewMongoClient, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewAdminRepository, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldPublicIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewGameShieldUserIpRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldSdkIpRepository, repository.NewHostRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, repository.NewCdnRepository, repository.NewAllowAndDenyIpRepository, repository.NewProxyRepository, repository.NewCcRepository, repository.NewExpiredRepository)
 
 var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, service.NewUserService, service.NewAdminService, service.NewGameShieldService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewCrawlerService, service.NewWebForwardingService, service.NewTcpforwardingService, service.NewUdpForWardingService, service.NewGameShieldUserIpService, service.NewWebLimitService, service.NewTcpLimitService, service.NewUdpLimitService, service.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewHostService, service.NewGlobalLimitService, service.NewGatewayGroupService, service.NewWafFormatterService, service.NewGateWayGroupIpService, service.NewRequestService, service.NewCdnService, service.NewAllowAndDenyIpService, service.NewProxyService, service.NewSslCertService, service.NewWebsocketService, service.NewCcService)
 

+ 2 - 0
cmd/task/wire/wire.go

@@ -43,12 +43,14 @@ var repositorySet = wire.NewSet(
 	repository.NewGatewayGroupRepository,
 	repository.NewGateWayGroupIpRepository,
 	repository.NewCdnRepository,
+	repository.NewExpiredRepository,
 )
 
 var taskSet = wire.NewSet(
 	task.NewTask,
 	task.NewUserTask,
 	task.NewGameShieldTask,
+	task.NewWafTask,
 )
 
 var jobSet = wire.NewSet(

+ 12 - 10
cmd/task/wire/wire_gen.go

@@ -54,19 +54,21 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	gameShieldService := service.NewGameShieldService(serviceService, gameShieldRepository, crawlerService, gameShieldPublicIpService, duedateService, formatterService, parserService, requiredService, viperViper, gameShieldSdkIpService)
 	gameShieldBackendService := service.NewGameShieldBackendService(serviceService, gameShieldBackendRepository, gameShieldRepository, crawlerService, gameShieldPublicIpService, duedateService, formatterService, parserService, requiredService, viperViper, gameShieldService, hostService)
 	gameShieldTask := task.NewGameShieldTask(taskTask, gameShieldRepository, gameShieldBackendService)
-	taskServer := server.NewTaskServer(logger, userTask, gameShieldTask)
-	jobJob := job.NewJob(transaction, logger, sidSid, rabbitMQ)
-	userJob := job.NewUserJob(jobJob, userRepository)
-	aoDunService := service.NewAoDunService(serviceService, viperViper)
-	globalLimitRepository := repository.NewGlobalLimitRepository(repositoryRepository)
+	webForwardingRepository := repository.NewWebForwardingRepository(repositoryRepository)
 	tcpforwardingRepository := repository.NewTcpforwardingRepository(repositoryRepository)
 	udpForWardingRepository := repository.NewUdpForWardingRepository(repositoryRepository)
-	webForwardingRepository := repository.NewWebForwardingRepository(repositoryRepository)
-	gatewayGroupRepository := repository.NewGatewayGroupRepository(repositoryRepository)
-	gateWayGroupIpRepository := repository.NewGateWayGroupIpRepository(repositoryRepository)
 	requestService := service.NewRequestService(serviceService)
 	cdnRepository := repository.NewCdnRepository(repositoryRepository)
 	cdnService := service.NewCdnService(serviceService, viperViper, requestService, cdnRepository)
+	globalLimitRepository := repository.NewGlobalLimitRepository(repositoryRepository)
+	expiredRepository := repository.NewExpiredRepository(repositoryRepository)
+	wafTask := task.NewWafTask(webForwardingRepository, tcpforwardingRepository, udpForWardingRepository, cdnService, hostRepository, globalLimitRepository, expiredRepository, taskTask)
+	taskServer := server.NewTaskServer(logger, userTask, gameShieldTask, wafTask)
+	jobJob := job.NewJob(transaction, logger, sidSid, rabbitMQ)
+	userJob := job.NewUserJob(jobJob, userRepository)
+	aoDunService := service.NewAoDunService(serviceService, viperViper)
+	gatewayGroupRepository := repository.NewGatewayGroupRepository(repositoryRepository)
+	gateWayGroupIpRepository := repository.NewGateWayGroupIpRepository(repositoryRepository)
 	wafFormatterService := service.NewWafFormatterService(serviceService, globalLimitRepository, hostRepository, requiredService, parserService, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, rabbitMQ, hostService, gatewayGroupRepository, gateWayGroupIpRepository, cdnService)
 	whitelistJob := job.NewWhitelistJob(jobJob, aoDunService, wafFormatterService)
 	jobServer := server.NewJobServer(logger, userJob, whitelistJob)
@@ -78,9 +80,9 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 // wire.go:
 
-var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewMongoClient, repository.NewCasbinEnforcer, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldPublicIpRepository, repository.NewHostRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldSdkIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, repository.NewCdnRepository)
+var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewMongoClient, repository.NewCasbinEnforcer, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldPublicIpRepository, repository.NewHostRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldSdkIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, repository.NewCdnRepository, repository.NewExpiredRepository)
 
-var taskSet = wire.NewSet(task.NewTask, task.NewUserTask, task.NewGameShieldTask)
+var taskSet = wire.NewSet(task.NewTask, task.NewUserTask, task.NewGameShieldTask, task.NewWafTask)
 
 var jobSet = wire.NewSet(job.NewJob, job.NewUserJob, job.NewWhitelistJob)
 

+ 168 - 0
internal/repository/expired.go

@@ -0,0 +1,168 @@
+package repository
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"strconv"
+	"strings"
+	"time"
+)
+
+// ExpiredInfo 包含了关闭套餐的详细信息
+type ExpiredInfo struct {
+	HostID int64     `json:"host_id"` // hostid, a.k.a. planid
+	Expiry time.Time `json:"expiry"` // 过期时间
+}
+
+// ExpiredRepository 定义了与过期套餐相关的操作接口
+type ExpiredRepository interface {
+	// AddClosePlans 批量添加要关闭的套餐信息,并设置过期时间
+	AddClosePlans(ctx context.Context, infos ...ExpiredInfo) error
+	// GetClosePlanInfo 获取单个已关闭套餐的详细信息
+	GetClosePlanInfo(ctx context.Context, planId int64) (*ExpiredInfo, error)
+	// RemoveClosePlanIds 批量移除已关闭的套餐
+	RemoveClosePlanIds(ctx context.Context, planIds ...int64) error
+	// GetAllClosePlanIds 获取所有当前套餐ID
+	GetAllClosePlanIds(ctx context.Context) ([]int64, error)
+	// IsPlanClosed 检查一个套餐是否被标记为关闭
+	IsPlanClosed(ctx context.Context, planId int64) (bool, error)
+}
+
+func NewExpiredRepository(
+	repository *Repository,
+) ExpiredRepository {
+	return &expiredRepository{
+		Repository: repository,
+	}
+}
+
+type expiredRepository struct {
+	*Repository
+}
+
+// Key的前缀,用于标识所有已关闭套餐的Key
+const closePlanIdKeyPrefix = "waf:closed_plan:"
+
+// 辅助函数:根据 planId 生成对应的 Redis Key
+func (r *expiredRepository) getPlanKey(planId int64) string {
+	return fmt.Sprintf("%s%d", closePlanIdKeyPrefix, planId)
+}
+
+// AddClosePlans 为每个套餐创建一个独立的 key,并将详细信息作为 value 存储
+func (r *expiredRepository) AddClosePlans(ctx context.Context, infos ...ExpiredInfo) error {
+	if len(infos) == 0 {
+		return nil
+	}
+
+	pipe := r.rdb.Pipeline()
+	for _, info := range infos {
+		key := r.getPlanKey(info.HostID)
+		
+		// 将结构体序列化为 JSON 字符串
+		value, err := json.Marshal(info)
+		if err != nil {
+			// 在实际应用中,这里应该记录错误日志
+			// log.Printf("Error marshalling ExpiredInfo for plan %d: %v", info.HostID, err)
+			continue // 跳过这个错误的数据
+		}
+
+		// 设置一个固定的7天过期时间,用于记录关闭状态
+		// 这样可以确保在7天内,该套餐的状态是“已关闭”,可以被恢复
+		// 7天后,该key会自动过期
+		const sevenDays = 7 * 24 * time.Hour
+		pipe.Set(ctx, key, value, sevenDays)
+	}
+
+	_, err := pipe.Exec(ctx)
+	return err
+}
+
+// GetClosePlanInfo 获取并解析单个套餐的信息
+func (r *expiredRepository) GetClosePlanInfo(ctx context.Context, planId int64) (*ExpiredInfo, error) {
+	key := r.getPlanKey(planId)
+	value, err := r.rdb.Get(ctx, key).Result()
+	if err != nil {
+		return nil, err // 包括 redis.Nil 的情况
+	}
+
+	var info ExpiredInfo
+	if err := json.Unmarshal([]byte(value), &info); err != nil {
+		return nil, fmt.Errorf("failed to unmarshal plan info for key %s: %w", key, err)
+	}
+
+	return &info, nil
+}
+
+// RemoveClosePlanIds 删除每个 planId 对应的 key
+func (r *expiredRepository) RemoveClosePlanIds(ctx context.Context, planIds ...int64) error {
+	if len(planIds) == 0 {
+		return nil
+	}
+
+	// 生成所有需要删除的 key
+	keys := make([]string, len(planIds))
+	for i, id := range planIds {
+		keys[i] = r.getPlanKey(id)
+	}
+
+	// DEL 命令可以一次性删除多个 key
+	return r.rdb.Del(ctx, keys...).Err()
+}
+
+// GetAllClosePlanIds 使用 SCAN 遍历所有匹配的 key 并解析出 planId
+func (r *expiredRepository) GetAllClosePlanIds(ctx context.Context) ([]int64, error) {
+	var cursor uint64
+	var allKeys []string
+
+	// 使用 SCAN 命令来安全地遍历大量的 key,避免阻塞 Redis
+	// KEYS 命令会导致性能问题,在生产环境中严禁使用
+	scanPattern := closePlanIdKeyPrefix + "*"
+
+	for {
+		var keys []string
+		var err error
+		keys, cursor, err = r.rdb.Scan(ctx, cursor, scanPattern, 100).Result() // 每次扫描100个
+		if err != nil {
+			return nil, err
+		}
+
+		allKeys = append(allKeys, keys...)
+
+		// 如果 cursor 回到 0,表示遍历完成
+		if cursor == 0 {
+			break
+		}
+	}
+
+	// 从 key 的字符串中解析出 planId
+	planIds := make([]int64, 0, len(allKeys))
+	for _, key := range allKeys {
+		// key 的格式是 "waf:closed_plan:12345"
+		// 我们需要移除前缀 "waf:closed_plan:" 来获取ID部分
+		idStr := strings.TrimPrefix(key, closePlanIdKeyPrefix)
+		id, err := strconv.ParseInt(idStr, 10, 64)
+		if err != nil {
+			// 如果有无法解析的key,最好记录日志并跳过
+			// log.Printf("Warning: could not parse planId from key '%s': %v", key, err)
+			continue
+		}
+		planIds = append(planIds, id)
+	}
+
+	return planIds, nil
+}
+
+// IsPlanClosed 检查 planId 对应的 key 是否存在
+func (r *expiredRepository) IsPlanClosed(ctx context.Context, planId int64) (bool, error) {
+	key := r.getPlanKey(planId)
+
+	// EXISTS 命令是 O(1) 的高效操作,返回存在的key的数量
+	count, err := r.rdb.Exists(ctx, key).Result()
+	if err != nil {
+		return false, err
+	}
+
+	// 如果 count > 0,说明key存在,即套餐已关闭
+	return count > 0, nil
+}

+ 1 - 1
internal/repository/globallimit.go

@@ -204,4 +204,4 @@ func (r *globalLimitRepository) GetGlobalLimitAlmostExpired(ctx context.Context,
 	}
 	return res, nil
 
-}
+}

+ 52 - 9
internal/server/task.go

@@ -14,17 +14,20 @@ type TaskServer struct {
 	scheduler      *gocron.Scheduler
 	userTask       task.UserTask
 	gameShieldTask task.GameShieldTask
+	wafTask            task.WafTask
 }
 
 func NewTaskServer(
 	log *log.Logger,
 	userTask task.UserTask,
 	gameShieldTask task.GameShieldTask,
+	wafTask task.WafTask,
 ) *TaskServer {
 	return &TaskServer{
 		log:            log,
 		userTask:       userTask,
 		gameShieldTask: gameShieldTask,
+		wafTask:        wafTask,
 	}
 }
 func (t *TaskServer) Start(ctx context.Context) error {
@@ -49,26 +52,66 @@ func (t *TaskServer) Start(ctx context.Context) error {
 	//}
 
 	// 添加游戏盾检查任务 - 每1小时执行
-	_, err := t.scheduler.Cron("0 */1 * * *").Do(func() {
-		err := t.gameShieldTask.CheckGameShield(ctx)
+	//_, err := t.scheduler.Cron("0 */1 * * *").Do(func() {
+	//	err := t.gameShieldTask.CheckGameShield(ctx)
+	//	if err != nil {
+	//		t.log.Error("CheckGameShield error", zap.Error(err))
+	//	}
+	//})
+	//if err != nil {
+	//	t.log.Error("Register CheckGameShield task error", zap.Error(err))
+	//}
+	//
+	//// 添加游戏盾数据同步任务 - 每天凌晨3点执行
+	//_, err = t.scheduler.Cron("0 3 * * *").Do(func() {
+	//	err := t.gameShieldTask.SyncAllExpireTimeFromHost(ctx)
+	//	if err != nil {
+	//		t.log.Error("SyncAllExpireTimeFromHost error", zap.Error(err))
+	//	}
+	//})
+	//if err != nil {
+	//	t.log.Error("Register SyncAllExpireTimeFromHost task error", zap.Error(err))
+	//}
+
+
+
+
+	_, err := t.scheduler.Cron("* * * * *").Do(func() {
+		err := t.wafTask.SynchronizationTime(ctx)
 		if err != nil {
-			t.log.Error("CheckGameShield error", zap.Error(err))
+			t.log.Error("同步到期时间失败", zap.Error(err))
 		}
 	})
 	if err != nil {
-		t.log.Error("Register CheckGameShield task error", zap.Error(err))
+		t.log.Error("同步到期时间注册任务失败", zap.Error(err))
 	}
 
-	// 添加游戏盾数据同步任务 - 每天凌晨3点执行
-	_, err = t.scheduler.Cron("0 3 * * *").Do(func() {
-		err := t.gameShieldTask.SyncAllExpireTimeFromHost(ctx)
+	_, err = t.scheduler.Cron("* * * * *").Do(func() {
+		err := t.wafTask.StopPlan(ctx)
 		if err != nil {
-			t.log.Error("SyncAllExpireTimeFromHost error", zap.Error(err))
+			t.log.Error("停止套餐失败", zap.Error(err))
 		}
 	})
 	if err != nil {
-		t.log.Error("Register SyncAllExpireTimeFromHost task error", zap.Error(err))
+		t.log.Error("停止套餐注册任务失败", zap.Error(err))
 	}
+
+
+	_, err = t.scheduler.Cron("* * * * *").Do(func() {
+		err := t.wafTask.RecoverStopPlan(ctx)
+		if err != nil {
+			t.log.Error("续费失败", zap.Error(err))
+		}
+	})
+	if err != nil {
+		t.log.Error("续费注册任务失败", zap.Error(err))
+	}
+
+
+
+
+
+
 	// 使用非阻塞方式启动调度器,添加一条日志表明任务服务已启动
 	t.scheduler.StartAsync()
 	t.log.Info("task server starting...")

+ 57 - 18
internal/task/waf.go

@@ -16,6 +16,8 @@ import (
 type WafTask interface {
 	//获取到期时间小于3天的同步时间
 	SynchronizationTime(ctx context.Context) error
+	StopPlan(ctx context.Context) error
+	RecoverStopPlan(ctx context.Context) error
 }
 
 func NewWafTask (
@@ -25,6 +27,7 @@ func NewWafTask (
 	cdn service.CdnService,
 	hostRep repository.HostRepository,
 	globalLimitRep repository.GlobalLimitRepository,
+	expiredRep repository.ExpiredRepository,
 	task *Task,
 	) WafTask{
 	return &wafTask{
@@ -35,6 +38,7 @@ func NewWafTask (
 		cdn: cdn,
 		hostRep: hostRep,
 		globalLimitRep: globalLimitRep,
+		expiredRep: expiredRep,
 	}
 }
 type wafTask struct {
@@ -45,6 +49,7 @@ type wafTask struct {
 	cdn service.CdnService
 	hostRep repository.HostRepository
 	globalLimitRep repository.GlobalLimitRepository
+	expiredRep repository.ExpiredRepository
 }
 
 
@@ -291,7 +296,7 @@ func (t *wafTask) findMismatchedExpirations(ctx context.Context, wafLimits []mod
 }
 
 
-//获取到期时间小于1天的同步时间
+//获取同步到期时间小于1天的套餐
 
 func (t *wafTask) SynchronizationTime(ctx context.Context) error {
 	// 1. 获取 WAF 全局配置中即将到期(小于3天)的数据
@@ -317,7 +322,6 @@ func (t *wafTask) SynchronizationTime(ctx context.Context) error {
 
 
 
-//获取到期的进行关闭套餐操作
 // 获取到期的进行关闭套餐操作
 func (t *wafTask) StopPlan(ctx context.Context) error {
 	// 1. 获取 WAF 全局配置中已经到期的数据
@@ -343,27 +347,56 @@ func (t *wafTask) StopPlan(ctx context.Context) error {
 		}
 	}
 
-	// 3. 关闭所有已经到期的套餐
-	t.logger.Info("开始关闭已到期的WAF服务", zap.Int("数量", len(wafLimits)))
+	// 3. 筛选出尚未被关闭的套餐
+	var plansToClose []model.GlobalLimit
+	for _, limit := range wafLimits {
+		isClosed, err := t.expiredRep.IsPlanClosed(ctx, int64(limit.HostId))
+		if err != nil {
+			t.logger.Error("检查Redis中套餐关闭状态失败", zap.Int("hostId", limit.HostId), zap.Error(err))
+			continue // 跳过这个,处理下一个
+		}
+		if !isClosed {
+			plansToClose = append(plansToClose, limit)
+		}
+	}
+
+	if len(plansToClose) == 0 {
+		t.logger.Info("没有新的到期套餐需要关闭")
+		return nil
+	}
+
+	// 4. 对筛选出的套餐执行关闭操作
+	t.logger.Info("开始关闭新的到期WAF服务", zap.Int("数量", len(plansToClose)))
 	var allErrors *multierror.Error
 
 	var webIds []int
-	for _, limit := range wafLimits {
+	for _, limit := range plansToClose {
 		webIds = append(webIds, limit.HostId)
 	}
 
-
 	if err := t.BanServer(ctx, webIds, false); err != nil {
-		allErrors = multierror.Append(allErrors, fmt.Errorf("关闭hostId %d 的服务失败: %w", webIds, err))
-	}
-
+		allErrors = multierror.Append(allErrors, fmt.Errorf("关闭hostId %v 的服务失败: %w", webIds, err))
+	} else {
+		// 服务关闭成功后,将这些套餐信息添加到 Redis
+		var expiredInfos []repository.ExpiredInfo
+		for _, limit := range plansToClose {
+			expiredInfos = append(expiredInfos, repository.ExpiredInfo{
+				HostID: int64(limit.HostId),
+				Expiry: time.Unix(limit.ExpiredAt, 0),
+			})
+		}
 
+		if len(expiredInfos) > 0 {
+			if err := t.expiredRep.AddClosePlans(ctx, expiredInfos...); err != nil {
+				allErrors = multierror.Append(allErrors, fmt.Errorf("添加已关闭套餐信息到Redis失败: %w", err))
+			}
+		}
+	}
 
 	return allErrors.ErrorOrNil()
 }
 //对于到期7天内续费的产品需要进行恢复操作
 
-// RecoverStopPlan 对于到期7天内续费的产品进行恢复操作
 func (t *wafTask) RecoverStopPlan(ctx context.Context) error {
 	// 1. 获取所有已过期(expired_at < now)但状态仍为 true 的 WAF 记录
 	// StopPlan 任务会禁用这些服务,但不会改变它们的 state
@@ -392,22 +425,28 @@ func (t *wafTask) RecoverStopPlan(ctx context.Context) error {
 	t.logger.Info("发现已续费、需要恢复的WAF服务", zap.Int("数量", len(renewalRequests)))
 	var allErrors *multierror.Error
 
-
 	var webIds []int
 	for _, req := range renewalRequests {
 		webIds = append(webIds, req.HostId)
 	}
 
-
 	if err := t.BanServer(ctx, webIds, true); err != nil {
-		allErrors = multierror.Append(allErrors, fmt.Errorf("恢复hostId %d: 启用服务失败: %w", webIds, err))
+		allErrors = multierror.Append(allErrors, fmt.Errorf("恢复hostId %v: 启用服务失败: %w", webIds, err))
+	} else {
+		// 服务恢复成功后,从 Redis 中移除这些套餐的关闭记录
+		planIds := make([]int64, len(webIds))
+		for i, id := range webIds {
+			planIds[i] = int64(id)
+		}
+		if err := t.expiredRep.RemoveClosePlanIds(ctx, planIds...); err != nil {
+			allErrors = multierror.Append(allErrors, fmt.Errorf("从Redis移除已恢复的套餐失败: %w", err))
+		}
 	}
 
-
-
-	for _, req := range renewalRequests {
-		if err := t.EditExpired(ctx, []RenewalRequest{req}); err != nil {
-			allErrors = multierror.Append(allErrors, fmt.Errorf("恢复hostId %d: 更新数据库状态失败: %w", req.HostId, err))
+	if len(renewalRequests) > 0 {
+		// 统一执行续费和数据库更新操作
+		if err := t.EditExpired(ctx, renewalRequests); err != nil {
+			allErrors = multierror.Append(allErrors, fmt.Errorf("批量更新已恢复服务的数据库状态失败: %w", err))
 		}
 	}