فهرست منبع

feat(allowAndDenyIp): 实现 IP 黑白名单功能

- 新增 AllowAndDenyIp 相关的 API 接口和处理逻辑
- 实现 IP 数量统计和筛选功能- 优化 IP 白名单发布机制,只处理需要删除的 IP
- 新增域名数量统计功能
fusu 1 ماه پیش
والد
کامیت
8dfc5cda3c

+ 5 - 0
api/v1/allowAndDenyIp.go

@@ -13,3 +13,8 @@ type DelAllowAndDenyIpRequest struct {
 	Uid    int `json:"uid" form:"uid" validate:"required"`
 	Ids []int `json:"ids" form:"ids" validate:"required,min=1,dive,required"`
 }
+
+type IpCountResult struct {
+	Ip    string `bson:"_id"`   // MongoDB $group 的结果会放在 _id 字段
+	Count int    `bson:"count"`
+}

+ 8 - 5
cmd/server/wire/wire_gen.go

@@ -93,9 +93,12 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	adminService := service.NewAdminService(serviceService, adminRepository)
 	adminHandler := handler.NewAdminHandler(handlerHandler, adminService)
 	gatewayGroupHandler := handler.NewGatewayGroupHandler(handlerHandler, gatewayGroupService)
-	gateWayGroupIpService := service.NewGateWayGroupIpService(serviceService, gateWayGroupIpRepository, requestService)
+	gateWayGroupIpService := service.NewGateWayGroupIpService(serviceService, gateWayGroupIpRepository, gatewayGroupRepository, requestService)
 	gateWayGroupIpHandler := handler.NewGateWayGroupIpHandler(handlerHandler, gateWayGroupIpService)
-	httpServer := server.NewHTTPServer(logger, viperViper, jwtJWT, syncedEnforcer, limiterLimiter, handlerFunc, userHandler, gameShieldHandler, gameShieldBackendHandler, webForwardingHandler, webLimitHandler, tcpforwardingHandler, udpForWardingHandler, tcpLimitHandler, udpLimitHandler, globalLimitHandler, adminHandler, gatewayGroupHandler, gateWayGroupIpHandler)
+	allowAndDenyIpRepository := repository.NewAllowAndDenyIpRepository(repositoryRepository)
+	allowAndDenyIpService := service.NewAllowAndDenyIpService(serviceService, allowAndDenyIpRepository, gateWayGroupIpService, wafFormatterService)
+	allowAndDenyIpHandler := handler.NewAllowAndDenyIpHandler(handlerHandler, allowAndDenyIpService)
+	httpServer := server.NewHTTPServer(logger, viperViper, jwtJWT, syncedEnforcer, limiterLimiter, handlerFunc, userHandler, gameShieldHandler, gameShieldBackendHandler, webForwardingHandler, webLimitHandler, tcpforwardingHandler, udpForWardingHandler, tcpLimitHandler, udpLimitHandler, globalLimitHandler, adminHandler, gatewayGroupHandler, gateWayGroupIpHandler, allowAndDenyIpHandler)
 	appApp := newApp(httpServer)
 	return appApp, func() {
 		cleanup()
@@ -104,11 +107,11 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 // wire.go:
 
-var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewCasbinEnforcer, repository.NewMongoClient, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewAdminRepository, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldPublicIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewGameShieldUserIpRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldSdkIpRepository, repository.NewHostRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, repository.NewCdnRepository)
+var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewCasbinEnforcer, repository.NewMongoClient, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewAdminRepository, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldPublicIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewGameShieldUserIpRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldSdkIpRepository, repository.NewHostRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, repository.NewCdnRepository, repository.NewAllowAndDenyIpRepository)
 
-var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, service.NewUserService, service.NewAdminService, service.NewGameShieldService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewCrawlerService, service.NewWebForwardingService, service.NewTcpforwardingService, service.NewUdpForWardingService, service.NewGameShieldUserIpService, service.NewWebLimitService, service.NewTcpLimitService, service.NewUdpLimitService, service.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewHostService, service.NewGlobalLimitService, service.NewGatewayGroupService, service.NewWafFormatterService, service.NewGateWayGroupIpService, service.NewRequestService, service.NewCdnService)
+var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, service.NewUserService, service.NewAdminService, service.NewGameShieldService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewCrawlerService, service.NewWebForwardingService, service.NewTcpforwardingService, service.NewUdpForWardingService, service.NewGameShieldUserIpService, service.NewWebLimitService, service.NewTcpLimitService, service.NewUdpLimitService, service.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewHostService, service.NewGlobalLimitService, service.NewGatewayGroupService, service.NewWafFormatterService, service.NewGateWayGroupIpService, service.NewRequestService, service.NewCdnService, service.NewAllowAndDenyIpService)
 
-var handlerSet = wire.NewSet(handler.NewHandler, handler.NewUserHandler, handler.NewAdminHandler, handler.NewGameShieldHandler, handler.NewGameShieldPublicIpHandler, handler.NewWebForwardingHandler, handler.NewTcpforwardingHandler, handler.NewUdpForWardingHandler, handler.NewGameShieldUserIpHandler, handler.NewWebLimitHandler, handler.NewTcpLimitHandler, handler.NewUdpLimitHandler, handler.NewGameShieldBackendHandler, handler.NewGameShieldSdkIpHandler, handler.NewHostHandler, handler.NewGlobalLimitHandler, handler.NewGatewayGroupHandler, handler.NewGateWayGroupIpHandler)
+var handlerSet = wire.NewSet(handler.NewHandler, handler.NewUserHandler, handler.NewAdminHandler, handler.NewGameShieldHandler, handler.NewGameShieldPublicIpHandler, handler.NewWebForwardingHandler, handler.NewTcpforwardingHandler, handler.NewUdpForWardingHandler, handler.NewGameShieldUserIpHandler, handler.NewWebLimitHandler, handler.NewTcpLimitHandler, handler.NewUdpLimitHandler, handler.NewGameShieldBackendHandler, handler.NewGameShieldSdkIpHandler, handler.NewHostHandler, handler.NewGlobalLimitHandler, handler.NewGatewayGroupHandler, handler.NewGateWayGroupIpHandler, handler.NewAllowAndDenyIpHandler)
 
 // 限流器依赖集
 var limiterSet = wire.NewSet(limiter.NewLimiter, middleware.NewRateLimitMiddleware)

+ 1 - 0
internal/model/webforwarding.go

@@ -37,6 +37,7 @@ type WebForwardingRule struct {
 	UpdatedAt   time.Time          `bson:"updated_at" json:"updated_at"`
 }
 
+
 func (m *WebForwardingRule) CollectionName() string {
 	return "web_forwarding_rules"
 }

+ 39 - 0
internal/repository/tcpforwarding.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"errors"
 	"fmt"
+	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"go.mongodb.org/mongo-driver/bson"
 	"go.mongodb.org/mongo-driver/bson/primitive"
@@ -23,6 +24,8 @@ type TcpforwardingRepository interface {
 	EditTcpforwardingIps(ctx context.Context, req model.TcpForwardingRule) error
 	GetTcpForwardingIpsByID(ctx context.Context, tcpId int) (*model.TcpForwardingRule, error)
 	DeleteTcpForwardingIpsById(ctx context.Context, tcpId int) error
+	// 获取IP数量等于1的IP
+	GetIpCountByIp(ctx context.Context,ips []string) ([]v1.IpCountResult, error)
 }
 
 func NewTcpforwardingRepository(
@@ -181,4 +184,40 @@ func (r *tcpforwardingRepository) DeleteTcpForwardingIpsById(ctx context.Context
 		return fmt.Errorf("删除MongoDB文档失败: %w", err)
 	}
 	return nil
+}
+
+
+// 获取IP数量等于1的IP
+func (r *tcpforwardingRepository) GetIpCountByIp(ctx context.Context,ips []string) ([]v1.IpCountResult, error) {
+	if len(ips) == 0 {
+		return []v1.IpCountResult{}, nil
+	}
+	pipeline := []bson.M{
+		{
+			"$match": bson.M{
+				"ip": bson.M{"$in": ips},
+			},
+		},
+		{
+			"$group": bson.M{
+				"_id":   "$ip",
+				"count": bson.M{"$sum": 1},
+			},
+		},
+		{
+			"$project": bson.M{
+				"_id":   0,       // 不输出默认的_id
+				"ip":    "$_id",  // 将分组的_id字段重命名为ip
+				"count": 1,       // 保留count字段
+			},
+		},
+	}
+
+	var results []v1.IpCountResult
+	// 使用 qmgo 执行聚合查询
+	err := r.mongoDB.Collection("tcp_forwarding_rules").Aggregate(ctx, pipeline).All(&results)
+	if err != nil {
+		return nil, err
+	}
+	return results, nil
 }

+ 37 - 0
internal/repository/udpforwarding.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"errors"
 	"fmt"
+	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"go.mongodb.org/mongo-driver/bson"
 	"go.mongodb.org/mongo-driver/bson/primitive"
@@ -23,6 +24,8 @@ type UdpForWardingRepository interface {
 	EditUdpForwardingIps(ctx context.Context, req model.UdpForwardingRule) error
 	GetUdpForwardingIpsByID(ctx context.Context, udpId int) (*model.UdpForwardingRule, error)
 	DeleteUdpForwardingIpsById(ctx context.Context, udpId int) error
+	// 获取ip数量等于1的ip
+	GetIpCountByIp(ctx context.Context, ips []string) ([]v1.IpCountResult, error)
 }
 
 func NewUdpForWardingRepository(
@@ -179,3 +182,37 @@ func (r *udpForWardingRepository) DeleteUdpForwardingIpsById(ctx context.Context
 
 }
 
+// 获取IP数量等于1的IP
+func (r *udpForWardingRepository) GetIpCountByIp(ctx context.Context, ips []string) ([]v1.IpCountResult, error) {
+	if len(ips) == 0 {
+		return []v1.IpCountResult{}, nil
+	}
+	pipeline := []bson.M{
+		{
+			"$match": bson.M{
+				"ip": bson.M{"$in": ips},
+			},
+		},
+		{
+			"$group": bson.M{
+				"_id":   "$ip",
+				"count": bson.M{"$sum": 1},
+			},
+		},
+		{
+			"$project": bson.M{
+				"_id":   0,       // 不输出默认的_id
+				"ip":    "$_id",  // 将分组的_id字段重命名为ip
+				"count": 1,       // 保留count字段
+			},
+		},
+	}
+
+	var results []v1.IpCountResult
+	// 使用 qmgo 执行聚合查询
+	err := r.mongoDB.Collection("udp_forwarding_rules").Aggregate(ctx, pipeline).All(&results)
+	if err != nil {
+		return nil, err
+	}
+	return results, nil
+}

+ 51 - 1
internal/repository/webforwarding.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"errors"
 	"fmt"
+	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/qiniu/qmgo"
 	"go.mongodb.org/mongo-driver/bson"
@@ -25,6 +26,10 @@ type WebForwardingRepository interface {
 	EditWebForwardingIps(ctx context.Context, req model.WebForwardingRule) error
 	GetWebForwardingIpsByID(ctx context.Context, webId int) (*model.WebForwardingRule, error)
 	DeleteWebForwardingIpsById(ctx context.Context, webId int) error
+	// 获取域名数量
+	GetDomainCount(ctx context.Context, hostId int,domain string) (int, error)
+	// 获取IP数量等于1的IP
+	GetIpCountByIp(ctx context.Context, ips []string) ([]v1.IpCountResult, error)
 }
 
 func NewWebForwardingRepository(
@@ -202,4 +207,49 @@ func (r *webForwardingRepository) DeleteWebForwardingIpsById(ctx context.Context
 		return fmt.Errorf("删除MongoDB文档失败: %w", err)
 	}
 	return nil
-}
+}
+
+// 获取域名数量
+func (r *webForwardingRepository) GetDomainCount(ctx context.Context, hostId int,domain string) (int, error) {
+	var count int64
+	if err := r.db.Model(&model.WebForwarding{}).WithContext(ctx).Where("host_id = ? AND domain = ?", hostId,domain).Count(&count).Error; err != nil {
+		return 0, err
+	}
+	return int(count), nil
+}
+
+// 获取IP数量等于1的IP
+func (r *webForwardingRepository) GetIpCountByIp(ctx context.Context, ips []string) ([]v1.IpCountResult, error) {
+	if len(ips) == 0 {
+		return []v1.IpCountResult{}, nil
+	}
+	pipeline := []bson.M{
+		{
+			"$match": bson.M{
+				"ip": bson.M{"$in": ips},
+			},
+		},
+		{
+			"$group": bson.M{
+				"_id":   "$ip",
+				"count": bson.M{"$sum": 1},
+			},
+		},
+		{
+			"$project": bson.M{
+				"_id":   0,       // 不输出默认的_id
+				"ip":    "$_id",  // 将分组的_id字段重命名为ip
+				"count": 1,       // 保留count字段
+			},
+		},
+	}
+
+	var results []v1.IpCountResult
+	// 使用 qmgo 执行聚合查询
+	err := r.mongoDB.Collection("web_forwarding_rules").Aggregate(ctx, pipeline).All(&results)
+	if err != nil {
+		return nil, fmt.Errorf("聚合查询失败: %w", err)
+	}
+
+	return results, nil
+}

+ 7 - 1
internal/server/http.go

@@ -36,7 +36,7 @@ func NewHTTPServer(
 	adminHandler *handler.AdminHandler,
 	gatewayHandler *handler.GatewayGroupHandler,
 	gatewayIpHandler *handler.GateWayGroupIpHandler,
-	allowAnddenyHandler *handler.AllowAndDenyHandler,
+	allowAnddenyHandler *handler.AllowAndDenyIpHandler,
 ) *http.Server {
 	gin.SetMode(gin.DebugMode)
 	s := http.NewServer(
@@ -128,6 +128,12 @@ func NewHTTPServer(
 			noAuthRouter.POST("/globalLimit/edit", ipAllowlistMiddleware, globalLimitHandler.EditGlobalLimit)
 			noAuthRouter.POST("/globalLimit/delete", ipAllowlistMiddleware, globalLimitHandler.DeleteGlobalLimit)
 
+			noAuthRouter.POST("/allowAndDeny/get", ipAllowlistMiddleware, allowAnddenyHandler.GetAllowAndDenyIp)
+			noAuthRouter.POST("/allowAndDeny/getList", ipAllowlistMiddleware, allowAnddenyHandler.GetAllowAndDenyIpList)
+			noAuthRouter.POST("/allowAndDeny/add", ipAllowlistMiddleware, allowAnddenyHandler.AddAllowAndDenyIp)
+			noAuthRouter.POST("/allowAndDeny/edit", ipAllowlistMiddleware, allowAnddenyHandler.EditAllowAndDenyIp)
+			noAuthRouter.POST("/allowAndDeny/delete", ipAllowlistMiddleware, allowAnddenyHandler.DeleteAllowAndDenyIp)
+
 		}
 		// Non-strict permission routing group
 		//noStrictAuthRouter := v1.Group("/").Use(middleware.NoStrictAuth(jwt, logger))

+ 2 - 2
internal/service/allowanddenyip.go

@@ -17,7 +17,7 @@ type AllowAndDenyIpService interface {
 func NewAllowAndDenyIpService(
     service *Service,
     allowAndDenyIpRepository repository.AllowAndDenyIpRepository,
-	gatewayGroupIp gateWayGroupIpService,
+	gatewayGroupIp GateWayGroupIpService,
 	wafformatter WafFormatterService,
 
 ) AllowAndDenyIpService {
@@ -32,7 +32,7 @@ func NewAllowAndDenyIpService(
 type allowAndDenyIpService struct {
 	*Service
 	allowAndDenyIpRepository repository.AllowAndDenyIpRepository
-	gatewayGroupIp gateWayGroupIpService
+	gatewayGroupIp GateWayGroupIpService
 	wafformatter WafFormatterService
 }
 

+ 26 - 2
internal/service/tcpforwarding.go

@@ -45,6 +45,10 @@ func NewTcpforwardingService(
 	}
 }
 
+const (
+	tcp = "tcp"
+)
+
 type tcpforwardingService struct {
 	*Service
 	tcpforwardingRepository repository.TcpforwardingRepository
@@ -57,6 +61,8 @@ type tcpforwardingService struct {
 	cdn CdnService
 }
 
+
+
 func (s *tcpforwardingService) GetTcpforwarding(ctx context.Context, req v1.GetForwardingRequest) (v1.TcpForwardingDataRequest, error) {
 	var tcpForwarding model.Tcpforwarding
 	var backend model.TcpForwardingRule
@@ -267,8 +273,18 @@ func (s *tcpforwardingService) EditTcpForwarding(ctx context.Context, req *v1.Tc
 	if len(addedIps) > 0 {
 		go s.wafformatter.PublishIpWhitelistTask(addedIps, "add","","white")
 	}
+
+
 	if len(removedIps) > 0 {
-		go s.wafformatter.PublishIpWhitelistTask(removedIps, "del","0","white")
+		ipsToDelist, err := s.wafformatter.WashDelIps(ctx, removedIps,tcp)
+		if err != nil {
+			return err
+		}
+
+		// 4. 如果有需要处理的IP,则批量发布一次任务
+		if len(ipsToDelist) > 0 {
+			go s.wafformatter.PublishIpWhitelistTask(ipsToDelist, "del", "0", "white")
+		}
 	}
 
 
@@ -345,7 +361,15 @@ func (s *tcpforwardingService) DeleteTcpForwarding(ctx context.Context, req v1.D
 			return err
 		}
 		if len(ips) > 0 {
-			go s.wafformatter.PublishIpWhitelistTask(ips, "del","0","white")
+			ipsToDelist, err := s.wafformatter.WashDelIps(ctx, ips,tcp)
+			if err != nil {
+				return err
+			}
+
+			// 4. 如果有需要处理的IP,则批量发布一次任务
+			if len(ipsToDelist) > 0 {
+				go s.wafformatter.PublishIpWhitelistTask(ipsToDelist, "del", "0", "white")
+			}
 		}
 
 

+ 30 - 3
internal/service/udpforwarding.go

@@ -45,6 +45,10 @@ func NewUdpForWardingService(
 	}
 }
 
+const (
+	udp = "udp"
+)
+
 type udpForWardingService struct {
 	*Service
 	udpForWardingRepository repository.UdpForWardingRepository
@@ -58,6 +62,7 @@ type udpForWardingService struct {
 }
 
 
+
 func (s *udpForWardingService) GetUdpForWarding(ctx context.Context,req v1.GetForwardingRequest) (v1.UdpForwardingDataRequest, error) {
 	var udpForWarding model.UdpForWarding
 	var backend model.UdpForwardingRule
@@ -265,8 +270,19 @@ func (s *udpForWardingService) EditUdpForwarding(ctx context.Context, req *v1.Ud
 	if len(addedIps) > 0 {
 		go s.wafformatter.PublishIpWhitelistTask(addedIps, "add","","white")
 	}
+
+
+
 	if len(removedIps) > 0 {
-		go s.wafformatter.PublishIpWhitelistTask(removedIps, "del","0","white")
+		ipsToDelist, err := s.wafformatter.WashDelIps(ctx, removedIps,udp)
+		if err != nil {
+			return err
+		}
+
+		// 4. 如果有需要处理的IP,则批量发布一次任务
+		if len(ipsToDelist) > 0 {
+			go s.wafformatter.PublishIpWhitelistTask(ipsToDelist, "del", "0", "white")
+		}
 	}
 
 
@@ -341,8 +357,18 @@ func (s *udpForWardingService) DeleteUdpForwarding(ctx context.Context, Ids []in
 		if err != nil {
 			return err
 		}
+
+
 		if len(ips) > 0 {
-			go s.wafformatter.PublishIpWhitelistTask(ips, "del","0","white")
+			ipsToDelist, err := s.wafformatter.WashDelIps(ctx, ips,udp)
+			if err != nil {
+				return err
+			}
+
+			// 4. 如果有需要处理的IP,则批量发布一次任务
+			if len(ipsToDelist) > 0 {
+				go s.wafformatter.PublishIpWhitelistTask(ipsToDelist, "del", "0", "white")
+			}
 		}
 
 
@@ -427,4 +453,5 @@ func (s *udpForWardingService) GetUdpForwardingWafUdpAllIps(ctx context.Context,
 	})
 
 	return res, nil
-}
+}
+

+ 44 - 0
internal/service/wafformatter.go

@@ -30,6 +30,8 @@ type WafFormatterService interface {
 	WashEditWafIp(ctx context.Context, newBackendList []string,oldBackendList []string) ([]string, []string, error)
 	//cdn添加网站
 	AddOrigin(ctx context.Context, req v1.WebJson) (int64, error)
+	// 获取ip数量等于1的源站过白ip
+	WashDelIps(ctx context.Context, ips []string,apiType string) ([]string, error)
 }
 func NewWafFormatterService(
     service *Service,
@@ -409,3 +411,45 @@ func (s *wafFormatterService) AddOrigin(ctx context.Context, req v1.WebJson) (in
 	return id, nil
 }
 
+// 获取ip数量等于1的源站过白ip
+func (s *wafFormatterService) WashDelIps(ctx context.Context, ips []string,apiType string) ([]string, error) {
+	var ipCounts []v1.IpCountResult
+	var err error
+	switch apiType {
+	case "udp":
+		ipCounts, err = s.udpForWardingRep.GetIpCountByIp(ctx, ips)
+		if err != nil {
+			return nil, err // 数据库查询失败,直接返回错误
+		}
+	case "tcp":
+		ipCounts, err = s.tcpforwardingRep.GetIpCountByIp(ctx, ips)
+		if err != nil {
+			return nil, err // 数据库查询失败,直接返回错误
+		}
+	case "web":
+		ipCounts, err = s.webForwardingRep.GetIpCountByIp(ctx, ips)
+		if err != nil {
+			return nil, err // 数据库查询失败,直接返回错误
+		}
+		return ips, nil
+	default:
+		return nil, fmt.Errorf("invalid api type: %s", apiType)
+	}
+
+
+	// 2. 将聚合结果转换为 map,方便快速查找
+	countMap := make(map[string]int, len(ipCounts))
+	for _, result := range ipCounts {
+		countMap[result.Ip] = result.Count
+	}
+
+	// 3. 筛选出需要被移除的IP
+	var ipsToDelist []string
+	for _, ip := range ips {
+		// 如果IP在map中存在且count < 2,或者IP根本不在map中(意味着count为0),则需要处理
+		if count, ok := countMap[ip]; !ok || count < 2 {
+			ipsToDelist = append(ipsToDelist, ip)
+		}
+	}
+	return  ipsToDelist, nil
+}

+ 39 - 11
internal/service/webforwarding.go

@@ -50,6 +50,14 @@ func NewWebForwardingService(
 	}
 }
 
+const (
+	isHttps         = 1
+	protocolHttps        = "https"
+	protocolHttp         = "http"
+	defaultNodeClusterId = 1
+	web                 = "web"
+)
+
 type webForwardingService struct {
 	*Service
 	webForwardingRepository repository.WebForwardingRepository
@@ -66,12 +74,7 @@ type webForwardingService struct {
 
 
 
-const (
-	isHttps         = 1
-	protocolHttps        = "https"
-	protocolHttp         = "http"
-	defaultNodeClusterId = 1
-)
+
 func (s *webForwardingService) require(ctx context.Context,req v1.GlobalRequire) (v1.GlobalRequire, error) {
 	var err error
 	var res v1.GlobalRequire
@@ -512,7 +515,15 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 		if len(require.GatewayIps) == 0 {
 			return fmt.Errorf("网关组不存在")
 		}
-		go s.wafformatter.PublishDomainWhitelistTask(oldDomain, firstIp, "del")
+
+		// 查找域名数量,如果数量小于2,删除旧域名
+		count, err := s.webForwardingRepository.GetDomainCount(ctx, req.HostId, webData.Domain)
+		if err != nil {
+			return err
+		}
+		if count < 2 {
+			go s.wafformatter.PublishDomainWhitelistTask(oldDomain, firstIp, "del")
+		}
 		go s.wafformatter.PublishDomainWhitelistTask(doMain, firstIp, "add")
 	}
 
@@ -542,11 +553,20 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 	if len(addedIps) > 0 {
 		go s.wafformatter.PublishIpWhitelistTask(addedIps, "add","","white")
 	}
-	if len(removedIps) > 0 {
-		go s.wafformatter.PublishIpWhitelistTask(removedIps, "del","0","white")
-	}
 
+	// IP过白
+	if len(removedIps) > 0 {
+		// 1. 一次性获取所有相关IP的数量
+		ipsToDelist, err := s.wafformatter.WashDelIps(ctx, removedIps,web)
+		if err != nil {
+			return err
+		}
 
+		// 4. 如果有需要处理的IP,则批量发布一次任务
+		if len(ipsToDelist) > 0 {
+			go s.wafformatter.PublishIpWhitelistTask(ipsToDelist, "del", "0", "white")
+		}
+	}
 
 
 
@@ -650,7 +670,15 @@ func (s *webForwardingService) DeleteWebForwarding(ctx context.Context, Ids []in
 			}
 		}
 		if len(ips) > 0 {
-			go s.wafformatter.PublishIpWhitelistTask(ips, "del","0","white")
+			ipsToDelist, err := s.wafformatter.WashDelIps(ctx, ips,web)
+			if err != nil {
+				return err
+			}
+
+			// 4. 如果有需要处理的IP,则批量发布一次任务
+			if len(ipsToDelist) > 0 {
+				go s.wafformatter.PublishIpWhitelistTask(ipsToDelist, "del", "0", "white")
+			}
 		}