Browse Source

refactor(waf): 优化 IP 过滤性能并移除未使用的方法

- 移除了未使用的 ValidateWebForwardingRequest 方法- 优化了 WashDifferentIp 方法,添加了并发 IP 过滤功能
- 新增了 filterValidIpsConcurrently 和 filterValidIpsSequentially 方法用于 IP 过滤
fusu 1 day ago
parent
commit
62d1b5d9e2
1 changed files with 62 additions and 40 deletions
  1. 62 40
      internal/service/api/waf/aidedweb.go

+ 62 - 40
internal/service/api/waf/aidedweb.go

@@ -74,7 +74,6 @@ type AidedWebService interface {
 	
 	// 废弃的方法(保持向后兼容)
 	Require(ctx context.Context, req v1.GlobalRequire) (v1.GlobalRequire, error)
-	ValidateWebForwardingRequest(ctx context.Context, req *v1.WebForwardingRequest, require RequireResponse) error
 	CreateOriginServers(ctx context.Context, req *v1.WebForwardingRequest) (map[string]int64, error)
 }
 
@@ -368,27 +367,75 @@ func (s *aidedWebService) FindDifferenceList(oldList, newList []v1.BackendList)
 	return added, removed
 }
 
-// WashDifferentIp 清洗IP差异
+// WashDifferentIp 清洗IP差异 - 并发版本
 func (s *aidedWebService) WashDifferentIp(newIpList []string, oldIpList []string) (addedDenyIps []string, removedDenyIps []string) {
-	var newAllowIps []string
-	var oldAllowIps []string
+	// 并发验证并过滤有效IP
+	oldAllowIps := s.filterValidIpsConcurrently(oldIpList)
+	newAllowIps := s.filterValidIpsConcurrently(newIpList)
 
-	// 获取旧IP列表
-	for _, v := range oldIpList {
-		if net.ParseIP(v) != nil {
-			oldAllowIps = append(oldAllowIps, v)
-		}
+	addedDenyIps, removedDenyIps = s.wafformatter.findIpDifferences(oldAllowIps, newAllowIps)
+	return addedDenyIps, removedDenyIps
+}
+
+// filterValidIpsConcurrently 并发过滤有效IP地址
+func (s *aidedWebService) filterValidIpsConcurrently(ipList []string) []string {
+	if len(ipList) == 0 {
+		return nil
 	}
 
-	// 获取新IP列表
-	for _, v := range newIpList {
-		if net.ParseIP(v) != nil {
-			newAllowIps = append(newAllowIps, v)
+	// 小于10个IP时不使用并发,避免overhead
+	if len(ipList) < 10 {
+		return s.filterValidIpsSequentially(ipList)
+	}
+
+	type ipResult struct {
+		ip    string
+		valid bool
+		index int
+	}
+
+	resultChan := make(chan ipResult, len(ipList))
+	semaphore := make(chan struct{}, 20) // 限制并发数为20
+
+	// 启动goroutine验证IP
+	for i, ip := range ipList {
+		go func(ip string, index int) {
+			semaphore <- struct{}{} // 获取信号量
+			defer func() { <-semaphore }() // 释放信号量
+
+			valid := net.ParseIP(ip) != nil
+			resultChan <- ipResult{ip: ip, valid: valid, index: index}
+		}(ip, i)
+	}
+
+	// 收集结果并保持原始顺序
+	results := make([]ipResult, len(ipList))
+	for i := 0; i < len(ipList); i++ {
+		result := <-resultChan
+		results[result.index] = result
+	}
+	close(resultChan)
+
+	// 按原始顺序提取有效IP
+	var validIps []string
+	for _, result := range results {
+		if result.valid {
+			validIps = append(validIps, result.ip)
 		}
 	}
 
-	addedDenyIps, removedDenyIps = s.wafformatter.findIpDifferences(oldAllowIps, newAllowIps)
-	return addedDenyIps, removedDenyIps
+	return validIps
+}
+
+// filterValidIpsSequentially 顺序过滤有效IP地址(用于小数据集)
+func (s *aidedWebService) filterValidIpsSequentially(ipList []string) []string {
+	var validIps []string
+	for _, ip := range ipList {
+		if net.ParseIP(ip) != nil {
+			validIps = append(validIps, ip)
+		}
+	}
+	return validIps
 }
 
 // EditLog 修改日志配置
@@ -448,31 +495,6 @@ func (s *aidedWebService) BulidFormData(ctx context.Context, formData v1.Website
 	return formDataSend, nil
 }
 
-// ValidateWebForwardingRequest 验证Web转发请求
-func (s *aidedWebService) ValidateWebForwardingRequest(ctx context.Context, req *v1.WebForwardingRequest, require RequireResponse) error {
-	// 验证域名限制
-	if err := s.wafformatter.validateWafDomainCount(ctx, v1.GlobalRequire{
-		HostId:  req.HostId,
-		Domain:  req.WebForwardingData.Domain,
-		Comment: req.WebForwardingData.Comment,
-		Uid:     req.Uid,
-	}); err != nil {
-		return fmt.Errorf("域名数量验证失败: %w", err)
-	}
-
-	// 验证端口数量限制
-	if err := s.wafformatter.validateWafPortCount(ctx, require.HostId); err != nil {
-		return fmt.Errorf("端口数量验证失败: %w", err)
-	}
-
-	// 验证端口重复
-	protocol := s.GetProtocolType(req.WebForwardingData.IsHttps)
-	if err := s.wafformatter.VerifyPort(ctx, protocol, int64(req.WebForwardingData.Id), req.WebForwardingData.Port, int64(require.HostId), req.WebForwardingData.Domain); err != nil {
-		return fmt.Errorf("端口 %d 验证失败: %w", req.WebForwardingData.Port, err)
-	}
-
-	return nil
-}
 
 // ProcessSSLCertificate 处理SSL证书
 func (s *aidedWebService) ProcessSSLCertificate(ctx context.Context, req *v1.WebForwardingRequest, cdnUid int) error {