Browse Source

refactor(whitelist): 重构白名单任务处理逻辑

-优化了域名和 IP 白名单任务的处理流程
- 新增通用的 consume 函数封装消费者逻辑
- 引入 taskHandler 类型统一任务处理函数签名
- 改进了错误处理和日志记录机制
- 调整了 AoDunService 接口,简化了 IP 白名单相关方法
fusu 1 month ago
parent
commit
d3e2d6aa93

+ 11 - 1
cmd/task/wire/wire.go

@@ -4,9 +4,9 @@
 package wire
 
 import (
+	"github.com/go-nunu/nunu-layout-advanced/internal/job"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 	"github.com/go-nunu/nunu-layout-advanced/internal/server"
-	"github.com/go-nunu/nunu-layout-advanced/internal/job"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service"
 	"github.com/go-nunu/nunu-layout-advanced/internal/task"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/app"
@@ -32,6 +32,15 @@ var repositorySet = wire.NewSet(
 	repository.NewHostRepository,
 	repository.NewGameShieldUserIpRepository,
 	repository.NewGameShieldSdkIpRepository,
+	repository.NewWebForwardingRepository,
+	repository.NewTcpforwardingRepository,
+	repository.NewUdpForWardingRepository,
+	repository.NewWebLimitRepository,
+	repository.NewTcpLimitRepository,
+	repository.NewUdpLimitRepository,
+	repository.NewGlobalLimitRepository,
+	repository.NewGatewayGroupRepository,
+	repository.NewGateWayGroupIpRepository,
 )
 
 var taskSet = wire.NewSet(
@@ -64,6 +73,7 @@ var serviceSet = wire.NewSet(
 	service.NewGameShieldBackendService,
 	service.NewGameShieldSdkIpService,
 	service.NewGameShieldUserIpService,
+	service.NewWafFormatterService,
 )
 
 // build App

+ 8 - 3
cmd/task/wire/wire_gen.go

@@ -56,7 +56,12 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	jobJob := job.NewJob(transaction, logger, sidSid, rabbitMQ)
 	userJob := job.NewUserJob(jobJob, userRepository)
 	aoDunService := service.NewAoDunService(serviceService, viperViper)
-	whitelistJob := job.NewWhitelistJob(jobJob, aoDunService)
+	globalLimitRepository := repository.NewGlobalLimitRepository(repositoryRepository)
+	tcpforwardingRepository := repository.NewTcpforwardingRepository(repositoryRepository)
+	udpForWardingRepository := repository.NewUdpForWardingRepository(repositoryRepository)
+	webForwardingRepository := repository.NewWebForwardingRepository(repositoryRepository)
+	wafFormatterService := service.NewWafFormatterService(serviceService, globalLimitRepository, hostRepository, requiredService, parserService, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, rabbitMQ, hostService)
+	whitelistJob := job.NewWhitelistJob(jobJob, aoDunService, wafFormatterService)
 	jobServer := server.NewJobServer(logger, userJob, whitelistJob)
 	appApp := newApp(taskServer, jobServer)
 	return appApp, func() {
@@ -66,7 +71,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 // wire.go:
 
-var repositorySet = wire.NewSet(repository.NewDB, repository.NewMongoClient, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldPublicIpRepository, repository.NewHostRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldSdkIpRepository)
+var repositorySet = wire.NewSet(repository.NewDB, repository.NewMongoClient, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldPublicIpRepository, repository.NewHostRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldSdkIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository)
 
 var taskSet = wire.NewSet(task.NewTask, task.NewUserTask, task.NewGameShieldTask)
 
@@ -74,7 +79,7 @@ var jobSet = wire.NewSet(job.NewJob, job.NewUserJob, job.NewWhitelistJob)
 
 var serverSet = wire.NewSet(server.NewTaskServer, server.NewJobServer)
 
-var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, service.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewHostService, service.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewGameShieldUserIpService)
+var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, service.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewHostService, service.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewGameShieldUserIpService, service.NewWafFormatterService)
 
 // build App
 func newApp(task2 *server.TaskServer,

+ 2 - 3
config/prod.yml

@@ -110,10 +110,9 @@ rabbitmq:
   tasks:
     # IP白名单更新任务
     ip_white:
-      exchange: "tasks_direct_exchange_test" # 改为使用 Topic 交换机,与域名任务保持一致
-      exchange_type: "topic"              # 显式指定交换机类型
+      exchange: "tasks_direct_exchange" # 使用一个统一的direct交换机
       queue: "ip_white_queue"
-      routing_key: "task.ip_white.*"      # 使用通配符路由键,匹配 task.ip_white.add, task.ip_white.del 等
+      routing_key: "task.ip_white.update"
       consumer_count: 2
       prefetch_count: 1
 

+ 125 - 111
internal/job/whitelist.go

@@ -4,11 +4,18 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
-	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service"
+	"github.com/google/uuid"
+	"github.com/rabbitmq/amqp091-go"
 	"go.uber.org/zap"
+	"strconv"
+	"strings"
 )
 
+// taskHandler 定义了处理单个消息的函数签名
+// 它负责业务逻辑的执行,并返回一个 error 来告知调用者处理是否成功。
+type taskHandler func(ctx context.Context, logger *zap.Logger, delivery amqp091.Delivery) error
+
 // WhitelistJob 定义了处理白名单相关任务的接口
 type WhitelistJob interface {
 	// DomainConsumer 启动消费者,处理域名白名单任务
@@ -19,29 +26,42 @@ type WhitelistJob interface {
 }
 
 // NewWhitelistJob 创建一个新的 WhitelistJob
-func NewWhitelistJob(job *Job, aoDunService service.AoDunService) WhitelistJob {
+func NewWhitelistJob(job *Job,
+	aoDunService service.AoDunService,
+	wafForMatter service.WafFormatterService,
+) WhitelistJob {
 	return &whitelistJob{
 		Job:          job,
 		aoDunService: aoDunService,
+		wafForMatter: wafForMatter,
 	}
 }
 
 type whitelistJob struct {
 	*Job
 	aoDunService service.AoDunService
+	wafForMatter service.WafFormatterService
 }
 
-// DomainConsumer 是处理域名白名单任务的消费者
+// DomainConsumer 启动域名白名单消费者
 func (j *whitelistJob) DomainConsumer(ctx context.Context) {
-	taskName := "domain_whitelist"
+	j.consume(ctx, "domain_whitelist", "domain_whitelist_consumer", j.handleDomainMessage)
+}
+
+// IpConsumer 启动IP白名单消费者
+func (j *whitelistJob) IpConsumer(ctx context.Context) {
+	j.consume(ctx, "ip_white", "ip_white_consumer", j.handleIpMessage)
+}
+
+// consume 是一个通用的 RabbitMQ 消费者方法,封装了重复的逻辑
+func (j *whitelistJob) consume(ctx context.Context, taskName, consumerName string, handler taskHandler) {
 	taskCfg, ok := j.Rabbitmq.GetTaskConfig(taskName)
 	if !ok {
 		j.logger.Error(fmt.Sprintf("未找到任务 '%s' 的配置", taskName))
 		return
 	}
 
-	consumerName := "domain_whitelist_consumer"
-	j.logger.Info("正在启动域名白名单消费者...",
+	j.logger.Info("正在启动消费者...",
 		zap.String("task", taskName),
 		zap.String("queue", taskCfg.Queue),
 		zap.String("consumer", consumerName),
@@ -49,155 +69,149 @@ func (j *whitelistJob) DomainConsumer(ctx context.Context) {
 
 	msgs, err := j.Rabbitmq.Consume(taskCfg.Queue, consumerName, taskCfg.PrefetchCount)
 	if err != nil {
-		j.logger.Error("启动域名白名单消费者失败", zap.Error(err))
+		j.logger.Error("启动消费者失败", zap.String("task", taskName), zap.Error(err))
 		return
 	}
 
-	// Define the message payload structure, now including an action field
-	type domainTaskPayload struct {
-		Domain string `json:"domain"`
-		Ip     string `json:"ip"`
-		Action string `json:"action"` // "add" or "del"
-	}
-
 	for {
 		select {
 		case <-ctx.Done():
-			j.logger.Info("域名白名单消费者正在关闭...", zap.String("task", taskName))
+			j.logger.Info("消费者正在关闭...", zap.String("task", taskName))
 			return
 		case d, ok := <-msgs:
 			if !ok {
-				j.logger.Warn("消息通道已关闭,域名白名单消费者退出。", zap.String("task", taskName))
+				j.logger.Warn("消息通道已关闭,消费者退出。", zap.String("task", taskName))
 				return
 			}
 
-			// 解析消息
-			var payload domainTaskPayload
-			if err := json.Unmarshal(d.Body, &payload); err != nil {
-				j.logger.Error("解析域名白名单消息失败", zap.Error(err), zap.ByteString("body", d.Body))
-				// 消息格式错误,直接拒绝且不重新入队
-				_ = d.Nack(false, false)
-				continue
+			// 尝试从消息头获取 trace_id,如果不存在则生成一个新的
+			traceID, ok := d.Headers["trace_id"].(string)
+			if !ok || traceID == "" {
+				traceID = uuid.New().String()
 			}
 
-			j.logger.Info("收到域名白名单任务",
-				zap.String("domain", payload.Domain),
-				zap.String("routing_key", d.RoutingKey),
-			)
-
-			// Call business logic based on the action
-			var processingErr error
-			switch payload.Action {
-			case "add", "del":
-				processingErr = j.aoDunService.DomainWhiteList(ctx, payload.Domain, payload.Ip, payload.Action)
-			default:
-				processingErr = fmt.Errorf("unknown action: %s", payload.Action)
-				j.logger.Warn("在 域名 白名单任务中收到未知操作", zap.String("action", payload.Action), zap.String("domain", payload.Domain))
-			}
+			// 创建一个带有 trace_id 的 logger,用于本次任务的所有日志记录
+			scopedLogger := j.logger.With(zap.String("trace_id", traceID))
 
-			if processingErr == nil {
-				j.logger.Info("已成功处理域名白名单任务", zap.String("action", payload.Action), zap.String("domain", payload.Domain))
-			}
+			// 创建一个带有 trace_id 的 context,用于传递给下游服务
+			ctxWithTrace := context.WithValue(ctx, "trace_id", traceID)
 
-			// 在循环的最后,根据 processingErr 的状态统一处理 Ack/Nack
+			// 调用具体的业务处理器
+			processingErr := handler(ctxWithTrace, scopedLogger, d)
+
+			// 根据处理结果统一进行 Ack/Nack
 			if processingErr != nil {
-				j.logger.Error("处理域名白名单任务失败", zap.Error(processingErr), zap.String("domain", payload.Domain))
 				// 业务失败,拒绝消息并不重新入队
 				if err := d.Nack(false, false); err != nil {
-					j.logger.Error("消息 Nack 失败", zap.Error(err))
+					scopedLogger.Error("消息 Nack 失败", zap.Error(err), zap.String("task", taskName))
 				}
 			} else {
 				// 业务处理成功,手动发送确认
 				if err := d.Ack(false); err != nil {
-					j.logger.Error("域名白名单任务消息确认失败", zap.Error(err))
+					scopedLogger.Error("消息 Ack 失败", zap.Error(err), zap.String("task", taskName))
 				}
 			}
-
 		}
 	}
 }
 
+// handleDomainMessage 是域名白名单任务的具体处理器
+func (j *whitelistJob) handleDomainMessage(ctx context.Context, logger *zap.Logger, d amqp091.Delivery) error {
+	type domainTaskPayload struct {
+		Domain string `json:"domain"`
+		Ip     string `json:"ip"`
+		Action string `json:"action"` // "add" or "del"
+	}
 
-func (j *whitelistJob) IpConsumer(ctx context.Context) {
-	taskName := "ip_white"
-	taskCfg, ok := j.Rabbitmq.GetTaskConfig(taskName)
-	if !ok {
-		j.logger.Error(fmt.Sprintf("未找到任务 '%s' 的配置", taskName))
-		return
+	var payload domainTaskPayload
+	if err := json.Unmarshal(d.Body, &payload); err != nil {
+		logger.Error("解析域名白名单消息失败", zap.Error(err), zap.ByteString("body", d.Body))
+		return nil // 返回 nil 以避免消息重入队列,因为这是一个格式错误
 	}
 
-	consumerName := "ip_white_consumer"
-	j.logger.Info("正在启动IP白名单消费者...",
-		zap.String("task", taskName),
-		zap.String("queue", taskCfg.Queue),
-		zap.String("consumer", consumerName),
+	logger.Info("收到域名白名单任务",
+		zap.String("action", payload.Action),
+		zap.String("domain", payload.Domain),
+		zap.String("ip", payload.Ip),
+		zap.String("routing_key", d.RoutingKey),
 	)
 
-	msgs, err := j.Rabbitmq.Consume(taskCfg.Queue, consumerName, taskCfg.PrefetchCount)
-	if err != nil {
-		j.logger.Error("启动IP白名单消费者失败", zap.Error(err))
-		return
+	var processingErr error
+	switch payload.Action {
+	case "add", "del":
+		processingErr = j.aoDunService.DomainWhiteList(ctx, payload.Domain, payload.Ip, payload.Action)
+	default:
+		processingErr = fmt.Errorf("unknown action: %s", payload.Action)
+		logger.Warn("在域名白名单任务中收到未知操作", zap.String("action", payload.Action), zap.String("domain", payload.Domain))
 	}
 
-	// Define the message payload structure, now including an action field
+	if processingErr != nil {
+		logger.Error("处理域名白名单任务失败", zap.Error(processingErr), zap.String("domain", payload.Domain))
+	} else {
+		logger.Info("已成功处理域名白名单任务", zap.String("action", payload.Action), zap.String("domain", payload.Domain))
+	}
+
+	return processingErr
+}
+
+// handleIpMessage 是 IP 白名单任务的具体处理器
+func (j *whitelistJob) handleIpMessage(ctx context.Context, logger *zap.Logger, d amqp091.Delivery) error {
 	type ipTaskPayload struct {
-		Ips     []v1.IpInfo `json:"ips"`
-		Action string `json:"action"`
+		Ips    []string `json:"ips"`
+		Action string   `json:"action"`
 	}
 
-	for {
-		select {
-		case <-ctx.Done():
-			j.logger.Info("IP白名单消费者正在关闭...", zap.String("task", taskName))
-			return
-		case d, ok := <-msgs:
-			if !ok {
-				j.logger.Warn("消息通道已关闭,IP白名单消费者退出。", zap.String("task", taskName))
-				return
-			}
+	var payload ipTaskPayload
+	if err := json.Unmarshal(d.Body, &payload); err != nil {
+		logger.Error("解析IP白名单消息失败", zap.Error(err), zap.ByteString("body", d.Body))
+		return nil // 消息格式错误,不应重试
+	}
 
-			// 解析消息
-			var payload ipTaskPayload
-			if err := json.Unmarshal(d.Body, &payload); err != nil {
-				j.logger.Error("解析IP白名单消息失败", zap.Error(err), zap.ByteString("body", d.Body))
-				// 消息格式错误,直接拒绝且不重新入队
-				_ = d.Nack(false, false)
-				continue
-			}
+	logger.Info("收到IP白名单任务",
+		zap.String("action", payload.Action),
+		zap.Any("ips", payload.Ips),
+		zap.String("routing_key", d.RoutingKey),
+	)
 
-			j.logger.Info("收到IP白名单任务",
-				zap.Any("IP", payload.Ips),
-				zap.String("routing_key", d.RoutingKey),
-			)
-
-			// Call business logic based on the action
-			var processingErr error
-			switch payload.Action {
-			case "add":
-				processingErr = j.aoDunService.AddWhiteStaticList(ctx, payload.Ips)
-			default:
-				processingErr = fmt.Errorf("unknown action: %s", payload.Action)
-				j.logger.Warn("在 IP 白名单任务中收到未知操作", zap.String("action", payload.Action), zap.Any("IP", payload.Ips))
-			}
+	var processingErr error
+	switch payload.Action {
+	case "add":
+		ips, err := j.wafForMatter.AppendWafIp(ctx, payload.Ips)
+		if err != nil {
+			// 如果附加IP失败,记录错误并终止
+			processingErr = fmt.Errorf("为WAF准备IP列表失败: %w", err)
+		} else {
+			processingErr = j.aoDunService.AddWhiteStaticList(ctx, ips)
+		}
 
-			if processingErr == nil {
-				j.logger.Info("已成功处理IP白名单任务", zap.String("action", payload.Action), zap.Any("IP", payload.Ips))
+	case "del":
+		var errs []string
+		for _, ip := range payload.Ips {
+			id, err := j.aoDunService.GetWhiteStaticList(ctx, ip)
+			if err != nil {
+				logger.Error("获取IP白名单ID失败", zap.Error(err), zap.String("ip", ip))
+				errs = append(errs, fmt.Sprintf("获取IP '%s' 失败: %v", ip, err))
+				continue
 			}
-
-			// 在循环的最后,根据 processingErr 的状态统一处理 Ack/Nack
-			if processingErr != nil {
-				j.logger.Error("处理域名白名单任务失败", zap.Error(processingErr), zap.Any("domain", payload.Ips))
-				// 业务失败,拒绝消息并不重新入队
-				if err := d.Nack(false, false); err != nil {
-					j.logger.Error("消息 Nack 失败", zap.Error(err))
-				}
-			} else {
-				// 业务处理成功,手动发送确认
-				if err := d.Ack(false); err != nil {
-					j.logger.Error("域名白名单任务消息确认失败", zap.Error(err))
-				}
+			if err := j.aoDunService.DelWhiteStaticList(ctx, strconv.Itoa(id)); err != nil {
+				logger.Error("删除IP白名单失败", zap.Error(err), zap.String("ip", ip))
+				errs = append(errs, fmt.Sprintf("删除IP '%s' 失败: %v", ip, err))
 			}
-
 		}
+		if len(errs) > 0 {
+			processingErr = fmt.Errorf("删除IP任务中发生错误: %s", strings.Join(errs, "; "))
+		}
+
+	default:
+		processingErr = fmt.Errorf("unknown action: %s", payload.Action)
+		logger.Warn("在IP白名单任务中收到未知操作", zap.String("action", payload.Action), zap.Any("ips", payload.Ips))
+	}
+
+	if processingErr != nil {
+		logger.Error("处理IP白名单任务失败", zap.Error(processingErr), zap.Any("ips", payload.Ips))
+	} else {
+		logger.Info("已成功处理IP白名单任务", zap.String("action", payload.Action), zap.Any("ips", payload.Ips))
 	}
+
+	return processingErr
 }

+ 6 - 3
internal/service/aodun.go

@@ -6,6 +6,7 @@ import (
 	"crypto/tls"
 	"encoding/json"
 	"fmt"
+	"github.com/davecgh/go-spew/spew"
 	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/spf13/viper"
 	"io"
@@ -18,7 +19,8 @@ import (
 type AoDunService interface {
 	DomainWhiteList(ctx context.Context, domain string, ip string, apiType string) error
 	AddWhiteStaticList(ctx context.Context, req []v1.IpInfo) error
-	DelWhiteStaticList(ctx context.Context, req v1.DeleteIp) error
+	DelWhiteStaticList(ctx context.Context, id string) error
+	GetWhiteStaticList(ctx context.Context,ip string) (int,error)
 }
 func NewAoDunService(
     service *Service,
@@ -249,10 +251,11 @@ func (s *aoDunService) GetWhiteStaticList(ctx context.Context,ip string) (int,er
 	// 4. 获取 ID 并返回
 	// 假设我们总是取返回结果中的第一个元素的 ID
 	id := res.Data[0].ID
+	spew.Dump(id)
 	return id, nil // 成功!返回获取到的 id 和 nil 错误
 }
 
-func (s *aoDunService) DelWhiteStaticList(ctx context.Context, req v1.DeleteIp) error {
+func (s *aoDunService) DelWhiteStaticList(ctx context.Context, id string) error {
 	tokenType, token, err := s.GetToken(ctx)
 	if err != nil {
 		return err
@@ -262,7 +265,7 @@ func (s *aoDunService) DelWhiteStaticList(ctx context.Context, req v1.DeleteIp)
 		"action": "del",
 		"bwflag": "white",
 		"flag":   0,
-		"ids":    req.Ids,
+		"ids":   id,
 	}
 
 	resBody, err := s.sendFormData(ctx, "/v1.0/firewall/static_bw_list", tokenType, token, formData)

+ 45 - 4
internal/service/wafformatter.go

@@ -23,9 +23,11 @@ type WafFormatterService interface {
 	validateWafDomainCount(ctx context.Context, req v1.GlobalRequire) error
 	ConvertToWildcardDomain(ctx context.Context,domain string) (string, error)
 	AppendWafIp(ctx context.Context, req []string) ([]v1.IpInfo, error)
-	AppendWafIpByRemovePort(ctx context.Context, req []string) ([]v1.IpInfo, error)
-	PublishIpWhitelistTask(ips []v1.IpInfo, action string)
+	WashIps(ctx context.Context, req []string) ([]string, error)
+	PublishIpWhitelistTask(ips []string, action string)
 	PublishDomainWhitelistTask(domain, ip, action string)
+	findIpDifferences(oldIps, newIps []string) ([]string, []string)
+
 }
 func NewWafFormatterService(
     service *Service,
@@ -210,6 +212,13 @@ func (s *wafFormatterService) AppendWafIpByRemovePort(ctx context.Context, req [
 
 }
 
+func (s *wafFormatterService) WashIps(ctx context.Context, req []string) ([]string, error) {
+	var res []string
+	for _, v := range req {
+		res = append(res,v)
+	}
+	return res, nil
+}
 
 // publishDomainWhitelistTask is a helper function to publish domain whitelist tasks to RabbitMQ.
 // It can handle different actions like "add" or "del".
@@ -260,10 +269,10 @@ func (s *wafFormatterService) PublishDomainWhitelistTask(domain, ip, action stri
 }
 
 
-func (s *wafFormatterService) PublishIpWhitelistTask(ips []v1.IpInfo, action string) {
+func (s *wafFormatterService) PublishIpWhitelistTask(ips []string, action string) {
 	// Define message payload, including the action
 	type ipTaskPayload struct {
-		Ips     []v1.IpInfo `json:"ips"`
+		Ips     []string `json:"ips"`
 		Action string `json:"action"`
 	}
 	payload := ipTaskPayload{
@@ -302,4 +311,36 @@ func (s *wafFormatterService) PublishIpWhitelistTask(ips []v1.IpInfo, action str
 	} else {
 		s.logger.Info("成功将 IP 白名单任务发布到 MQ", zap.String("action", action))
 	}
+}
+
+
+func (s *wafFormatterService) findIpDifferences(oldIps, newIps []string) ([]string, []string) {
+	// 使用 map 实现 set,用于快速查找
+	oldIpsSet := make(map[string]struct{}, len(oldIps))
+	for _, ip := range oldIps {
+		oldIpsSet[ip] = struct{}{}
+	}
+
+	newIpsSet := make(map[string]struct{}, len(newIps))
+	for _, ip := range newIps {
+		newIpsSet[ip] = struct{}{}
+	}
+
+	var addedIps []string
+	// 查找新增的 IP:存在于 newIpsSet 但不存在于 oldIpsSet
+	for ip := range newIpsSet {
+		if _, found := oldIpsSet[ip]; !found {
+			addedIps = append(addedIps, ip)
+		}
+	}
+
+	var removedIps []string
+	// 查找移除的 IP:存在于 oldIpsSet 但不存在于 newIpsSet
+	for ip := range oldIpsSet {
+		if _, found := newIpsSet[ip]; !found {
+			removedIps = append(removedIps, ip)
+		}
+	}
+
+	return addedIps, removedIps
 }

+ 65 - 13
internal/service/webforwarding.go

@@ -338,26 +338,16 @@ func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.Web
 		go s.wafformatter.PublishDomainWhitelistTask(doMain,ip, "add")
 	}
 	// IP过白
-	var ips []v1.IpInfo
+	var ips []string
 	if req.WebForwardingData.BackendList != nil {
 		for _, v := range req.WebForwardingData.BackendList {
 			ip, _, err := net.SplitHostPort(v.Addr)
 			if err != nil {
 				return err
 			}
-			ips = append(ips, v1.IpInfo{
-				FType:      "0",
-				FStartIp:   ip,
-				FEndIp:     ip,
-				FRemark:    "宁波高防IP过白",
-				FServerIp:  "",
-			})
+			ips = append(ips,ip)
 		}
-		allowIps, err := s.wafformatter.AppendWafIp(ctx, req.WebForwardingData.AllowIpList)
-		if err != nil {
-			return err
-		}
-		ips = append(ips, allowIps...)
+		ips = append(ips, req.WebForwardingData.AllowIpList...)
 		go s.wafformatter.PublishIpWhitelistTask(ips, "add")
 
 	}
@@ -411,6 +401,44 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 		go s.wafformatter.PublishDomainWhitelistTask(doMain, Ip, "add")
 	}
 
+	// IP过白
+	ipData, err := s.webForwardingRepository.GetWebForwardingIpsByID(ctx, req.WebForwardingData.Id)
+	if err != nil {
+		return err
+	}
+	var oldIps []string
+	var newIps []string
+	for _, v := range ipData.BackendList {
+		ip, _, err := net.SplitHostPort(v.Addr)
+		if err != nil {
+			return err
+		}
+		oldIps = append(oldIps, ip)
+	}
+	if len(ipData.AllowIpList) > 0 {
+		oldIps = append(oldIps, ipData.AllowIpList...)
+	}
+	for _, v := range req.WebForwardingData.BackendList {
+		ip, _, err := net.SplitHostPort(v.Addr)
+		if err != nil {
+			return err
+		}
+		newIps = append(newIps, ip)
+	}
+	if len(req.WebForwardingData.AllowIpList) > 0 {
+		newIps = append(newIps, req.WebForwardingData.AllowIpList...)
+	}
+
+
+	addedIps, removedIps := s.wafformatter.findIpDifferences(oldIps, newIps)
+	if len(addedIps) > 0 {
+		go s.wafformatter.PublishIpWhitelistTask(addedIps, "add")
+	}
+	if len(removedIps) > 0 {
+		go s.wafformatter.PublishIpWhitelistTask(removedIps, "del")
+	}
+
+
 
 	webModel := s.buildWebForwardingModel(&req.WebForwardingData, req.WebForwardingData.WafWebId, require)
 	webModel.Id = req.WebForwardingData.Id
@@ -434,6 +462,7 @@ func (s *webForwardingService) DeleteWebForwarding(ctx context.Context, Ids []in
 		if err != nil {
 			return err
 		}
+		// 异步任务:将域名添加到白名单
 		if webData.Domain != "" {
 			ip , err := s.GetIp(ctx, webData.WafGatewayGroupId)
 			if err != nil {
@@ -445,6 +474,29 @@ func (s *webForwardingService) DeleteWebForwarding(ctx context.Context, Ids []in
 			}
 			go s.wafformatter.PublishDomainWhitelistTask(doMain,ip, "del")
 		}
+		// IP过白
+		ipData, err := s.webForwardingRepository.GetWebForwardingIpsByID(ctx, Id)
+		if err != nil {
+			return err
+		}
+		var ips []string
+		if len(ipData.BackendList) > 0 {
+			for _, v := range ipData.BackendList {
+				ip, _, err := net.SplitHostPort(v.Addr)
+				if err != nil {
+					return err
+				}
+				ips = append(ips, ip)
+			}
+		}
+		if len(ipData.AllowIpList) > 0 {
+			ips = append(ips, ipData.AllowIpList...)
+		}
+		if len(ips) > 0 {
+			go s.wafformatter.PublishIpWhitelistTask(ips, "del")
+		}
+
+
 		_, err = s.crawler.DeleteRule(ctx, wafWebId, "admin/delete/waf_web?page=1&__pageSize=10&__sort=waf_web_id&__sort_type=desc")
 		if err != nil {
 			return err