|
@@ -11,6 +11,7 @@ import (
|
|
|
amqp "github.com/rabbitmq/amqp091-go"
|
|
|
"go.uber.org/zap"
|
|
|
"golang.org/x/net/publicsuffix"
|
|
|
+ "golang.org/x/sync/errgroup"
|
|
|
"net"
|
|
|
"slices"
|
|
|
"strconv"
|
|
@@ -31,7 +32,7 @@ type WafFormatterService interface {
|
|
|
//cdn添加网站
|
|
|
AddOrigin(ctx context.Context, req v1.WebJson) (int64, error)
|
|
|
// 获取ip数量等于1的源站过白ip
|
|
|
- WashDelIps(ctx context.Context, ips []string,apiType string) ([]string, error)
|
|
|
+ WashDelIps(ctx context.Context, ips []string) ([]string, error)
|
|
|
}
|
|
|
func NewWafFormatterService(
|
|
|
service *Service,
|
|
@@ -412,44 +413,61 @@ func (s *wafFormatterService) AddOrigin(ctx context.Context, req v1.WebJson) (in
|
|
|
}
|
|
|
|
|
|
// 获取ip数量等于1的源站过白ip
|
|
|
-func (s *wafFormatterService) WashDelIps(ctx context.Context, ips []string,apiType string) ([]string, error) {
|
|
|
- var ipCounts []v1.IpCountResult
|
|
|
- var err error
|
|
|
- switch apiType {
|
|
|
- case "udp":
|
|
|
- ipCounts, err = s.udpForWardingRep.GetIpCountByIp(ctx, ips)
|
|
|
+func (s *wafFormatterService) WashDelIps(ctx context.Context, ips []string) ([]string, error) {
|
|
|
+ var udpIpCounts, tcpIpCounts, webIpCounts []v1.IpCountResult
|
|
|
+ g, gCtx := errgroup.WithContext(ctx)
|
|
|
+ // 1. 查询 IP 的数量
|
|
|
+ g.Go(func() error {
|
|
|
+ var err error
|
|
|
+ udpIpCounts, err = s.udpForWardingRep.GetIpCountByIp(gCtx, ips)
|
|
|
if err != nil {
|
|
|
- return nil, err // 数据库查询失败,直接返回错误
|
|
|
+ return fmt.Errorf("in udp repository: %w", err)
|
|
|
}
|
|
|
- case "tcp":
|
|
|
- ipCounts, err = s.tcpforwardingRep.GetIpCountByIp(ctx, ips)
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+
|
|
|
+ g.Go(func() error {
|
|
|
+ var err error
|
|
|
+ tcpIpCounts, err = s.tcpforwardingRep.GetIpCountByIp(gCtx, ips)
|
|
|
if err != nil {
|
|
|
- return nil, err // 数据库查询失败,直接返回错误
|
|
|
+ return fmt.Errorf("in tcp repository: %w", err)
|
|
|
}
|
|
|
- case "web":
|
|
|
- ipCounts, err = s.webForwardingRep.GetIpCountByIp(ctx, ips)
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+
|
|
|
+ g.Go(func() error {
|
|
|
+ var err error
|
|
|
+ webIpCounts, err = s.webForwardingRep.GetIpCountByIp(gCtx, ips)
|
|
|
if err != nil {
|
|
|
- return nil, err // 数据库查询失败,直接返回错误
|
|
|
+ return fmt.Errorf("in web repository: %w", err)
|
|
|
}
|
|
|
- return ips, nil
|
|
|
- default:
|
|
|
- return nil, fmt.Errorf("invalid api type: %s", apiType)
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+
|
|
|
+ if err := g.Wait(); err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
|
|
|
- // 2. 将聚合结果转换为 map,方便快速查找
|
|
|
- countMap := make(map[string]int, len(ipCounts))
|
|
|
- for _, result := range ipCounts {
|
|
|
- countMap[result.Ip] = result.Count
|
|
|
+ // 2. 汇总所有计数结果
|
|
|
+ totalCountMap := make(map[string]int)
|
|
|
+ // 将多个 for 循环合并到一个函数中,可以显得更整洁(可选)
|
|
|
+ accumulateCounts := func(counts []v1.IpCountResult) {
|
|
|
+ for _, result := range counts {
|
|
|
+ totalCountMap[result.Ip] += result.Count
|
|
|
+ }
|
|
|
}
|
|
|
+ accumulateCounts(udpIpCounts)
|
|
|
+ accumulateCounts(tcpIpCounts)
|
|
|
+ accumulateCounts(webIpCounts)
|
|
|
|
|
|
- // 3. 筛选出需要被移除的IP
|
|
|
+ // 3. 筛选出总引用数小于 2 的 IP
|
|
|
var ipsToDelist []string
|
|
|
for _, ip := range ips {
|
|
|
- // 如果IP在map中存在且count < 2,或者IP根本不在map中(意味着count为0),则需要处理
|
|
|
- if count, ok := countMap[ip]; !ok || count < 2 {
|
|
|
+ if totalCountMap[ip] < 2 {
|
|
|
ipsToDelist = append(ipsToDelist, ip)
|
|
|
}
|
|
|
}
|
|
|
- return ipsToDelist, nil
|
|
|
+
|
|
|
+ return ipsToDelist, nil
|
|
|
}
|