Переглянути джерело

refactor(task): 重构 WafTask 以支持 CDN 相关操作

- 在 WafTask 中添加 TCP、UDP 和 Web转发服务的依赖
- 实现 GetCdnWebId 方法以获取 CDN Web ID
- 修改 BanServer 方法以支持批量操作
- 优化清理流程,添加过期标记移除逻辑
- 删除计划清理时增加对 CDN 相关资源的处理
fusu 3 тижнів тому
батько
коміт
6b7092de47
3 змінених файлів з 107 додано та 12 видалено
  1. 8 1
      cmd/task/wire/wire.go
  2. 13 6
      cmd/task/wire/wire_gen.go
  3. 86 5
      internal/task/waf.go

+ 8 - 1
cmd/task/wire/wire.go

@@ -44,6 +44,8 @@ var repositorySet = wire.NewSet(
 	repository.NewGateWayGroupIpRepository,
 	repository.NewCdnRepository,
 	repository.NewExpiredRepository,
+	repository.NewProxyRepository,
+
 )
 
 var taskSet = wire.NewSet(
@@ -80,7 +82,12 @@ var serviceSet = wire.NewSet(
 	service.NewWafFormatterService,
 	service.NewCdnService,
 	service.NewRequestService,
-
+	service.NewTcpforwardingService,
+	service.NewUdpForWardingService,
+	service.NewWebForwardingService,
+	service.NewProxyService,
+	service.NewSslCertService,
+	service.NewWebsocketService,
 )
 
 // build App

+ 13 - 6
cmd/task/wire/wire_gen.go

@@ -63,13 +63,20 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	globalLimitRepository := repository.NewGlobalLimitRepository(repositoryRepository)
 	expiredRepository := repository.NewExpiredRepository(repositoryRepository)
 	gateWayGroupIpRepository := repository.NewGateWayGroupIpRepository(repositoryRepository)
-	wafTask := task.NewWafTask(webForwardingRepository, tcpforwardingRepository, udpForWardingRepository, cdnService, hostRepository, globalLimitRepository, expiredRepository, taskTask, gateWayGroupIpRepository)
+	gatewayGroupRepository := repository.NewGatewayGroupRepository(repositoryRepository)
+	wafFormatterService := service.NewWafFormatterService(serviceService, globalLimitRepository, hostRepository, requiredService, parserService, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, rabbitMQ, hostService, gatewayGroupRepository, gateWayGroupIpRepository, cdnService)
+	proxyRepository := repository.NewProxyRepository(repositoryRepository)
+	proxyService := service.NewProxyService(serviceService, proxyRepository, cdnService)
+	tcpforwardingService := service.NewTcpforwardingService(serviceService, tcpforwardingRepository, parserService, requiredService, crawlerService, globalLimitRepository, hostRepository, wafFormatterService, cdnService, proxyService)
+	udpForWardingService := service.NewUdpForWardingService(serviceService, udpForWardingRepository, requiredService, parserService, crawlerService, globalLimitRepository, hostRepository, wafFormatterService, cdnService, proxyService)
+	aoDunService := service.NewAoDunService(serviceService, viperViper)
+	sslCertService := service.NewSslCertService(serviceService, webForwardingRepository, cdnService)
+	websocketService := service.NewWebsocketService(serviceService, cdnService, webForwardingRepository)
+	webForwardingService := service.NewWebForwardingService(serviceService, requiredService, webForwardingRepository, crawlerService, parserService, wafFormatterService, aoDunService, rabbitMQ, gateWayGroupIpRepository, gatewayGroupRepository, globalLimitRepository, cdnService, proxyService, sslCertService, websocketService)
+	wafTask := task.NewWafTask(webForwardingRepository, tcpforwardingRepository, udpForWardingRepository, cdnService, hostRepository, globalLimitRepository, expiredRepository, taskTask, gateWayGroupIpRepository, tcpforwardingService, udpForWardingService, webForwardingService)
 	taskServer := server.NewTaskServer(logger, userTask, gameShieldTask, wafTask)
 	jobJob := job.NewJob(transaction, logger, sidSid, rabbitMQ)
 	userJob := job.NewUserJob(jobJob, userRepository)
-	aoDunService := service.NewAoDunService(serviceService, viperViper)
-	gatewayGroupRepository := repository.NewGatewayGroupRepository(repositoryRepository)
-	wafFormatterService := service.NewWafFormatterService(serviceService, globalLimitRepository, hostRepository, requiredService, parserService, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, rabbitMQ, hostService, gatewayGroupRepository, gateWayGroupIpRepository, cdnService)
 	whitelistJob := job.NewWhitelistJob(jobJob, aoDunService, wafFormatterService)
 	jobServer := server.NewJobServer(logger, userJob, whitelistJob)
 	appApp := newApp(taskServer, jobServer)
@@ -80,7 +87,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 // wire.go:
 
-var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewMongoClient, repository.NewCasbinEnforcer, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldPublicIpRepository, repository.NewHostRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldSdkIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, repository.NewCdnRepository, repository.NewExpiredRepository)
+var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewMongoClient, repository.NewCasbinEnforcer, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldPublicIpRepository, repository.NewHostRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldSdkIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, repository.NewCdnRepository, repository.NewExpiredRepository, repository.NewProxyRepository)
 
 var taskSet = wire.NewSet(task.NewTask, task.NewUserTask, task.NewGameShieldTask, task.NewWafTask)
 
@@ -88,7 +95,7 @@ var jobSet = wire.NewSet(job.NewJob, job.NewUserJob, job.NewWhitelistJob)
 
 var serverSet = wire.NewSet(server.NewTaskServer, server.NewJobServer)
 
-var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, service.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewHostService, service.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewGameShieldUserIpService, service.NewWafFormatterService, service.NewCdnService, service.NewRequestService)
+var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, service.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewHostService, service.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewGameShieldUserIpService, service.NewWafFormatterService, service.NewCdnService, service.NewRequestService, service.NewTcpforwardingService, service.NewUdpForWardingService, service.NewWebForwardingService, service.NewProxyService, service.NewSslCertService, service.NewWebsocketService)
 
 // build App
 func newApp(task2 *server.TaskServer,

+ 86 - 5
internal/task/waf.go

@@ -37,6 +37,9 @@ func NewWafTask(
 	expiredRep repository.ExpiredRepository,
 	task *Task,
 	gatewayGroupIpRep repository.GateWayGroupIpRepository,
+	tcp service.TcpforwardingService,
+	udp service.UdpForWardingService,
+	web service.WebForwardingService,
 ) WafTask {
 	return &wafTask{
 		Task:             task,
@@ -48,6 +51,9 @@ func NewWafTask(
 		globalLimitRep:   globalLimitRep,
 		expiredRep:       expiredRep,
 		gatewayGroupIpRep: gatewayGroupIpRep,
+		tcp:              tcp,
+		udp:              udp,
+		web:              web,
 	}
 }
 
@@ -61,6 +67,9 @@ type wafTask struct {
 	globalLimitRep   repository.GlobalLimitRepository
 	expiredRep       repository.ExpiredRepository
 	gatewayGroupIpRep repository.GateWayGroupIpRepository
+	tcp              service.TcpforwardingService
+	udp              service.UdpForWardingService
+	web              service.WebForwardingService
 }
 
 const (
@@ -81,6 +90,27 @@ type RenewalRequest struct {
 // =================== 原始辅助函数 (Helpers) =====================
 // =================================================================
 
+// 获取cdn web id
+func (t wafTask) GetCdnWebId(ctx context.Context,hostId int) ([]int, error) {
+	tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, hostId)
+	if err != nil {
+		return nil, err
+	}
+	udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, hostId)
+	if err != nil {
+		return nil, err
+	}
+	webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, hostId)
+	if err != nil {
+		return nil, err
+	}
+	var ids []int
+	ids = append(ids, tcpIds...)
+	ids = append(ids, udpIds...)
+	ids = append(ids, webIds...)
+	return ids, nil
+}
+
 // BanServer 启用/禁用 网站 (并发执行)
 func (t wafTask) BanServer(ctx context.Context, ids []int, isBan bool) error {
 	if len(ids) == 0 { return nil }
@@ -234,11 +264,21 @@ func (t *wafTask) executePlanRecovery(ctx context.Context, renewalRequests []Ren
 		hostIds = append(hostIds, req.HostId)
 	}
 
-	if err := t.BanServer(ctx, hostIds, true); err != nil {
-		return fmt.Errorf("执行[%s]-启用服务失败: %w", taskName, err)
-	}
 
 	var allErrors *multierror.Error
+
+	for _, v := range renewalRequests {
+		webIds, err := t.GetCdnWebId(ctx, v.HostId)
+		if err != nil {
+			allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-获取webId失败: %w", taskName, err))
+		}
+		if err := t.BanServer(ctx, webIds, true); err != nil {
+			allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-封禁webId失败: %w", taskName, err))
+		}
+	}
+
+
+
 	if err := t.EditExpired(ctx, renewalRequests); err != nil {
 		allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-同步续费信息失败: %w", taskName, err))
 	}
@@ -308,9 +348,16 @@ func (t *wafTask) StopPlan(ctx context.Context) error {
 	for _, limit := range plansToClose {
 		hostIds = append(hostIds, limit.HostId)
 	}
-	if err := t.BanServer(ctx, hostIds, false); err != nil {
-		return fmt.Errorf("执行[停止]-禁用服务失败: %w", err)
+
+	for _, hostId := range hostIds {
+		webIds, err := t.GetCdnWebId(ctx, hostId)
+		if err != nil { return fmt.Errorf("执行[停止]-获取cdn_web_id失败: %w", err) }
+		if err := t.BanServer(ctx, webIds, false); err != nil {
+			return fmt.Errorf("执行[停止]-禁用服务失败: %w", err)
+		}
 	}
+
+
 	closedPlanIds := make([]int64, len(hostIds))
 	for i, id := range hostIds { closedPlanIds[i] = int64(id) }
 	if err := t.expiredRep.AddPlans(ctx, repository.ClosedPlansList, closedPlanIds...); err != nil {
@@ -357,6 +404,9 @@ func (t *wafTask) CleanUpStaleRecords(ctx context.Context) error {
 	if err := t.expiredRep.RemovePlans(ctx, repository.ClosedPlansList, planIdsToClean...); err != nil {
 		return fmt.Errorf("执行[清理]-从Redis移除关闭标记失败: %w", err)
 	}
+	if err := t.expiredRep.AddPlans(ctx, repository.ExpiringSoonPlansList, planIdsToClean...); err != nil {
+		return fmt.Errorf("执行[清理]-从Redis移除过期标记失败: %w", err)
+	}
 	// 在这里可以添加从数据库删除或调用CDN API彻底删除的逻辑
 	for _, limit := range plansToClean {
 		err = t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
@@ -370,6 +420,37 @@ func (t *wafTask) CleanUpStaleRecords(ctx context.Context) error {
 
 
 
+		tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, limit.HostId)
+		if err != nil {
+			return err
+		}
+		udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, limit.HostId)
+		if err != nil {
+			return err
+		}
+		webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, limit.HostId)
+		if err != nil {
+			return err
+		}
+
+		err = t.tcp.DeleteTcpForwarding(ctx, v1.DeleteTcpForwardingRequest{
+			Ids: tcpIds,
+			Uid: 0,
+			HostId: limit.HostId,
+		})
+		if err != nil {
+			return err
+		}
+		err = t.udp.DeleteUdpForwarding(ctx, udpIds)
+		if err != nil {
+			return err
+		}
+		err = t.web.DeleteWebForwarding(ctx, webIds)
+		if err != nil {
+			return err
+		}
+
+
 	}
 
 	return nil