瀏覽代碼

refactor(service): 重构 WashDelIps 方法并优化相关逻辑- 移除了 tcpforwarding.go、udpforwarding.go 和 webforwarding.go 中的冗余常量定义
- 修改了 WashDelIps 方法签名,移除了不必要的 apiType 参数
- 采用 errgroup 实现并发查询不同协议的 IP 计数,提高效率
- 统一处理所有协议的 IP 计数逻辑,简化代码结构

fusu 1 月之前
父節點
當前提交
84af8b7985

+ 2 - 6
internal/service/tcpforwarding.go

@@ -45,10 +45,6 @@ func NewTcpforwardingService(
 	}
 }
 
-const (
-	tcp = "tcp"
-)
-
 type tcpforwardingService struct {
 	*Service
 	tcpforwardingRepository repository.TcpforwardingRepository
@@ -276,7 +272,7 @@ func (s *tcpforwardingService) EditTcpForwarding(ctx context.Context, req *v1.Tc
 
 
 	if len(removedIps) > 0 {
-		ipsToDelist, err := s.wafformatter.WashDelIps(ctx, removedIps,tcp)
+		ipsToDelist, err := s.wafformatter.WashDelIps(ctx, removedIps)
 		if err != nil {
 			return err
 		}
@@ -361,7 +357,7 @@ func (s *tcpforwardingService) DeleteTcpForwarding(ctx context.Context, req v1.D
 			return err
 		}
 		if len(ips) > 0 {
-			ipsToDelist, err := s.wafformatter.WashDelIps(ctx, ips,tcp)
+			ipsToDelist, err := s.wafformatter.WashDelIps(ctx, ips)
 			if err != nil {
 				return err
 			}

+ 2 - 5
internal/service/udpforwarding.go

@@ -45,9 +45,6 @@ func NewUdpForWardingService(
 	}
 }
 
-const (
-	udp = "udp"
-)
 
 type udpForWardingService struct {
 	*Service
@@ -274,7 +271,7 @@ func (s *udpForWardingService) EditUdpForwarding(ctx context.Context, req *v1.Ud
 
 
 	if len(removedIps) > 0 {
-		ipsToDelist, err := s.wafformatter.WashDelIps(ctx, removedIps,udp)
+		ipsToDelist, err := s.wafformatter.WashDelIps(ctx, removedIps)
 		if err != nil {
 			return err
 		}
@@ -360,7 +357,7 @@ func (s *udpForWardingService) DeleteUdpForwarding(ctx context.Context, Ids []in
 
 
 		if len(ips) > 0 {
-			ipsToDelist, err := s.wafformatter.WashDelIps(ctx, ips,udp)
+			ipsToDelist, err := s.wafformatter.WashDelIps(ctx, ips)
 			if err != nil {
 				return err
 			}

+ 43 - 25
internal/service/wafformatter.go

@@ -11,6 +11,7 @@ import (
 	amqp "github.com/rabbitmq/amqp091-go"
 	"go.uber.org/zap"
 	"golang.org/x/net/publicsuffix"
+	"golang.org/x/sync/errgroup"
 	"net"
 	"slices"
 	"strconv"
@@ -31,7 +32,7 @@ type WafFormatterService interface {
 	//cdn添加网站
 	AddOrigin(ctx context.Context, req v1.WebJson) (int64, error)
 	// 获取ip数量等于1的源站过白ip
-	WashDelIps(ctx context.Context, ips []string,apiType string) ([]string, error)
+	WashDelIps(ctx context.Context, ips []string) ([]string, error)
 }
 func NewWafFormatterService(
     service *Service,
@@ -412,44 +413,61 @@ func (s *wafFormatterService) AddOrigin(ctx context.Context, req v1.WebJson) (in
 }
 
 // 获取ip数量等于1的源站过白ip
-func (s *wafFormatterService) WashDelIps(ctx context.Context, ips []string,apiType string) ([]string, error) {
-	var ipCounts []v1.IpCountResult
-	var err error
-	switch apiType {
-	case "udp":
-		ipCounts, err = s.udpForWardingRep.GetIpCountByIp(ctx, ips)
+func (s *wafFormatterService) WashDelIps(ctx context.Context, ips []string) ([]string, error) {
+	var udpIpCounts, tcpIpCounts, webIpCounts []v1.IpCountResult
+	g, gCtx := errgroup.WithContext(ctx)
+	// 1. 查询 IP 的数量
+	g.Go(func() error {
+		var err error
+		udpIpCounts, err = s.udpForWardingRep.GetIpCountByIp(gCtx, ips)
 		if err != nil {
-			return nil, err // 数据库查询失败,直接返回错误
+			return fmt.Errorf("in udp repository: %w", err)
 		}
-	case "tcp":
-		ipCounts, err = s.tcpforwardingRep.GetIpCountByIp(ctx, ips)
+		return nil
+	})
+
+	g.Go(func() error {
+		var err error
+		tcpIpCounts, err = s.tcpforwardingRep.GetIpCountByIp(gCtx, ips)
 		if err != nil {
-			return nil, err // 数据库查询失败,直接返回错误
+			return fmt.Errorf("in tcp repository: %w", err)
 		}
-	case "web":
-		ipCounts, err = s.webForwardingRep.GetIpCountByIp(ctx, ips)
+		return nil
+	})
+
+	g.Go(func() error {
+		var err error
+		webIpCounts, err = s.webForwardingRep.GetIpCountByIp(gCtx, ips)
 		if err != nil {
-			return nil, err // 数据库查询失败,直接返回错误
+			return fmt.Errorf("in web repository: %w", err)
 		}
-		return ips, nil
-	default:
-		return nil, fmt.Errorf("invalid api type: %s", apiType)
+		return nil
+	})
+
+	if err := g.Wait(); err != nil {
+		return nil, err
 	}
 
 
-	// 2. 将聚合结果转换为 map,方便快速查找
-	countMap := make(map[string]int, len(ipCounts))
-	for _, result := range ipCounts {
-		countMap[result.Ip] = result.Count
+	// 2. 汇总所有计数结果
+	totalCountMap := make(map[string]int)
+	// 将多个 for 循环合并到一个函数中,可以显得更整洁(可选)
+	accumulateCounts := func(counts []v1.IpCountResult) {
+		for _, result := range counts {
+			totalCountMap[result.Ip] += result.Count
+		}
 	}
+	accumulateCounts(udpIpCounts)
+	accumulateCounts(tcpIpCounts)
+	accumulateCounts(webIpCounts)
 
-	// 3. 筛选出需要被移除的IP
+	// 3. 筛选出总引用数小于 2 的 IP
 	var ipsToDelist []string
 	for _, ip := range ips {
-		// 如果IP在map中存在且count < 2,或者IP根本不在map中(意味着count为0),则需要处理
-		if count, ok := countMap[ip]; !ok || count < 2 {
+		if totalCountMap[ip] < 2 {
 			ipsToDelist = append(ipsToDelist, ip)
 		}
 	}
-	return  ipsToDelist, nil
+
+	return ipsToDelist, nil
 }

+ 2 - 3
internal/service/webforwarding.go

@@ -55,7 +55,6 @@ const (
 	protocolHttps        = "https"
 	protocolHttp         = "http"
 	defaultNodeClusterId = 1
-	web                 = "web"
 )
 
 type webForwardingService struct {
@@ -557,7 +556,7 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 	// IP过白
 	if len(removedIps) > 0 {
 		// 1. 一次性获取所有相关IP的数量
-		ipsToDelist, err := s.wafformatter.WashDelIps(ctx, removedIps,web)
+		ipsToDelist, err := s.wafformatter.WashDelIps(ctx, removedIps)
 		if err != nil {
 			return err
 		}
@@ -670,7 +669,7 @@ func (s *webForwardingService) DeleteWebForwarding(ctx context.Context, Ids []in
 			}
 		}
 		if len(ips) > 0 {
-			ipsToDelist, err := s.wafformatter.WashDelIps(ctx, ips,web)
+			ipsToDelist, err := s.wafformatter.WashDelIps(ctx, ips)
 			if err != nil {
 				return err
 			}