Browse Source

feat(aodun): 同时调用大网和小网的 API

- 新增了同时调用大网和小网 API 的逻辑
- 重构了 AoDunService 接口和实现类,增加了 isSmall 参数来区分大网和小网
- 优化了错误处理和日志记录- 使用 sync.WaitGroup 和 channel来同步并发请求
fusu 1 month ago
parent
commit
048b7a49c6

+ 53 - 10
internal/job/whitelist.go

@@ -10,6 +10,7 @@ import (
 	"go.uber.org/zap"
 	"strconv"
 	"strings"
+	"sync"
 )
 
 // taskHandler 定义了处理单个消息的函数签名
@@ -181,23 +182,65 @@ func (j *whitelistJob) handleIpMessage(ctx context.Context, logger *zap.Logger,
 			// 如果附加IP失败,记录错误并终止
 			processingErr = fmt.Errorf("为WAF准备IP列表失败: %w", err)
 		} else {
-			processingErr = j.aoDunService.AddWhiteStaticList(ctx, ips)
+			var wg sync.WaitGroup
+			errChan := make(chan error, 2)
+
+			wg.Add(2)
+			go func() {
+				defer wg.Done()
+				if err := j.aoDunService.AddWhiteStaticList(ctx, false, ips); err != nil {
+					errChan <- err
+				}
+			}()
+			go func() {
+				defer wg.Done()
+				if err := j.aoDunService.AddWhiteStaticList(ctx, true, ips); err != nil {
+					errChan <- err
+				}
+			}()
+
+			wg.Wait()
+			close(errChan)
+
+			var errs []string
+			for err := range errChan {
+				errs = append(errs, err.Error())
+			}
+			if len(errs) > 0 {
+				processingErr = fmt.Errorf("添加IP到白名单时发生错误: %s", strings.Join(errs, "; "))
+			}
 		}
 
 	case "del":
-		var errs []string
-		for _, ip := range payload.Ips {
-			id, err := j.aoDunService.GetWhiteStaticList(ctx, ip)
+		var wg sync.WaitGroup
+		errChan := make(chan error, len(payload.Ips)*2)
+
+		deleteFromWall := func(isSmall bool, ip string) {
+			defer wg.Done()
+			id, err := j.aoDunService.GetWhiteStaticList(ctx, isSmall, ip)
 			if err != nil {
-				logger.Error("获取IP白名单ID失败", zap.Error(err), zap.String("ip", ip))
-				errs = append(errs, fmt.Sprintf("获取IP '%s' 失败: %v", ip, err))
-				continue
+				errChan <- fmt.Errorf("获取IP '%s' (isSmall: %t) ID失败: %w", ip, isSmall, err)
+				return
 			}
-			if err := j.aoDunService.DelWhiteStaticList(ctx, strconv.Itoa(id)); err != nil {
-				logger.Error("删除IP白名单失败", zap.Error(err), zap.String("ip", ip))
-				errs = append(errs, fmt.Sprintf("删除IP '%s' 失败: %v", ip, err))
+			if err := j.aoDunService.DelWhiteStaticList(ctx, isSmall, strconv.Itoa(id)); err != nil {
+				errChan <- fmt.Errorf("删除IP '%s' (isSmall: %t, id: %d) 失败: %w", ip, isSmall, id, err)
 			}
 		}
+
+		for _, ip := range payload.Ips {
+			wg.Add(2)
+			go deleteFromWall(false, ip)
+			go deleteFromWall(true, ip)
+		}
+
+		wg.Wait()
+		close(errChan)
+
+		var errs []string
+		for err := range errChan {
+			logger.Error("删除IP白名单过程中发生错误", zap.Error(err))
+			errs = append(errs, err.Error())
+		}
 		if len(errs) > 0 {
 			processingErr = fmt.Errorf("删除IP任务中发生错误: %s", strings.Join(errs, "; "))
 		}

+ 171 - 164
internal/service/aodun.go

@@ -6,219 +6,194 @@ import (
 	"crypto/tls"
 	"encoding/json"
 	"fmt"
-	"github.com/davecgh/go-spew/spew"
 	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/spf13/viper"
+	"go.uber.org/zap"
 	"io"
 	"net/http"
 	"net/url"
+	"strconv"
 	"strings"
 	"time"
 )
 
+// AoDunService 定义了与傲盾 API 交互的服务接口
 type AoDunService interface {
 	DomainWhiteList(ctx context.Context, domain string, ip string, apiType string) error
-	AddWhiteStaticList(ctx context.Context, req []v1.IpInfo) error
-	DelWhiteStaticList(ctx context.Context, id string) error
-	GetWhiteStaticList(ctx context.Context,ip string) (int,error)
+	AddWhiteStaticList(ctx context.Context, isSmall bool, req []v1.IpInfo) error
+	DelWhiteStaticList(ctx context.Context, isSmall bool, id string) error
+	GetWhiteStaticList(ctx context.Context, isSmall bool, ip string) (int, error)
 }
-func NewAoDunService(
-	service *Service,
-	conf *viper.Viper,
-) AoDunService {
-	// 1. 创建一个可复用的 Transport,并配置好 TLS 和其他参数
+
+// aoDunService 是 AoDunService 接口的实现
+type aoDunService struct {
+	*Service
+	cfg        *aoDunConfig
+	httpClient *http.Client
+}
+
+// aoDunConfig 用于整合来自 viper 的所有配置
+type aoDunConfig struct {
+	Url            string
+	ClientID       string
+	Username       string
+	Password       string
+	SmallUrl       string
+	SmallClientID  string
+	DomainUsername string
+	DomainPassword string
+}
+
+// NewAoDunService 创建一个新的 AoDunService 实例
+func NewAoDunService(service *Service, conf *viper.Viper) AoDunService {
+	cfg := &aoDunConfig{
+		Url:            conf.GetString("aodun.Url"),
+		ClientID:       conf.GetString("aodun.clientID"),
+		Username:       conf.GetString("aodun.username"),
+		Password:       conf.GetString("aodun.password"),
+		SmallUrl:       conf.GetString("aodunSmall.Url"),
+		SmallClientID:  conf.GetString("aodunSmall.clientID"),
+		DomainUsername: conf.GetString("domainWhite.username"),
+		DomainPassword: conf.GetString("domainWhite.password"),
+	}
+
 	tr := &http.Transport{
-		TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // 忽略 SSL 验证
-		MaxIdleConns:    100,                               // 最大空闲连接数
-		IdleConnTimeout: 90 * time.Second,                // 空闲连接超时时间
+		TLSClientConfig:   &tls.Config{InsecureSkipVerify: true},
+		MaxIdleConns:      100,
+		IdleConnTimeout:   90 * time.Second,
+		ForceAttemptHTTP2: true,
 	}
 
-	// 2. 基于该 Transport 创建一个可复用的 http.Client
 	client := &http.Client{
 		Transport: tr,
-		Timeout:   15 * time.Second, // 设置所有请求的默认超时时间
+		Timeout:   15 * time.Second,
 	}
 
 	return &aoDunService{
-		Service:        service,
-		Url:            conf.GetString("aodun.Url"),
-		clientID:       conf.GetString("aodun.clientID"),
-		username:       conf.GetString("aodun.username"),
-		password:       conf.GetString("aodun.password"),
-		IPusername:     conf.GetString("aodunIp.username"),
-		IPpassword:     conf.GetString("aodunIp.password"),
-		domainUserName: conf.GetString("domainWhite.username"),
-		domainPassword: conf.GetString("domainWhite.password"),
-		httpClient:     client, // 存储共享的 client
+		Service:    service,
+		cfg:        cfg,
+		httpClient: client,
 	}
 }
 
-type aoDunService struct {
-	*Service
-	Url            string
-	clientID       string
-	username       string
-	password       string
-	IPusername     string
-	IPpassword     string
-	domainUserName string
-	domainPassword string
-	httpClient     *http.Client // <--- 新增 http client 字段
+// getApiUrl 根据 isSmall 标志返回正确的 API 基础 URL
+func (s *aoDunService) getApiUrl(isSmall bool) string {
+	if isSmall {
+		return s.cfg.SmallUrl
+	}
+	return s.cfg.Url
 }
 
+// getClientID 根据 isSmall 标志返回正确的 ClientID
+func (s *aoDunService) getClientID(isSmall bool) string {
+	if isSmall {
+		return s.cfg.SmallClientID
+	}
+	return s.cfg.ClientID
+}
 
-func (s *aoDunService) sendFormData(ctx context.Context, apiUrl string, tokenType string, token string, formData map[string]interface{}) ([]byte, error) {
-	URL := s.Url + apiUrl
-	jsonData, err := json.Marshal(formData)
+// executeRequest 封装了发送 HTTP POST 请求、读取响应和 JSON 解码的通用逻辑
+func (s *aoDunService) executeRequest(ctx context.Context, url, tokenType, token string, requestBody, responsePayload interface{}, isSmall bool) error {
+	jsonData, err := json.Marshal(requestBody)
 	if err != nil {
-		return nil, fmt.Errorf("序列化请求数据失败: %w", err)
+		return fmt.Errorf("序列化请求数据失败 (isSmall: %t): %w", isSmall, err)
 	}
 
-	// 使用带有 context 的请求,以便上游可以控制请求的取消
-	req, err := http.NewRequestWithContext(ctx, "POST", URL, bytes.NewBuffer(jsonData))
+	req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
 	if err != nil {
-		return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
+		return fmt.Errorf("创建 HTTP 请求失败 (isSmall: %t): %w", isSmall, err)
 	}
 
-	// 设置请求头
 	req.Header.Set("Content-Type", "application/json")
-	// 修正逻辑:当 token 不为空时才设置 Authorization
 	if token != "" {
 		req.Header.Set("Authorization", tokenType+" "+token)
 	}
 
-	// 使用结构体中共享的 httpClient 实例发送请求
 	resp, err := s.httpClient.Do(req)
 	if err != nil {
-		return nil, fmt.Errorf("发送 HTTP 请求失败: %w", err)
+		return fmt.Errorf("发送 HTTP 请求失败 (isSmall: %t): %w", isSmall, err)
 	}
 	defer resp.Body.Close()
 
-	// 6. 读取响应体内容
 	body, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return nil, fmt.Errorf("读取响应体失败: %w", err)
-	}
-	return body, nil
-}
-
-
-func (s *aoDunService) sendDomainFormData(ctx context.Context, domain string, ip string, apiType string) ([]byte, error) {
-	var URL string
-	if apiType == "add" {
-		URL = "http://zapi.zzybgp.com/api/user/do_main"
-	} else {
-		URL = "http://zapi.zzybgp.com/api/user/do_main/delete"
+		return fmt.Errorf("读取响应体失败 (isSmall: %t): %w", isSmall, err)
 	}
-	formData := url.Values{}
-	formData.Set("username", s.domainUserName)
-	formData.Set("password", s.domainPassword)
-	formData.Add("do_main_list[name][]", domain)
-	formData.Add("do_main_list[ip]", ip)
-	encodedData := formData.Encode()
 
-	// 使用带有 context 的请求
-	req, err := http.NewRequestWithContext(ctx, "POST", URL, bytes.NewBuffer([]byte(encodedData)))
-	if err != nil {
-		return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
+	if resp.StatusCode != http.StatusOK {
+		return fmt.Errorf("HTTP 错误 (isSmall: %t): 状态码 %d, 响应: %s", isSmall, resp.StatusCode, string(body))
 	}
 
-	// 设置请求头
-	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
-
-	// 使用共享的 httpClient 实例发送请求
-	resp, err := s.httpClient.Do(req)
-	if err != nil {
-		return nil, fmt.Errorf("发送 HTTP 请求失败: %w", err)
+	if err := json.Unmarshal(body, responsePayload); err != nil {
+		return fmt.Errorf("反序列化响应 JSON 失败 (isSmall: %t, 内容: %s): %w", isSmall, string(body), err)
 	}
-	defer resp.Body.Close()
 
-	// 6. 读取响应体内容
-	body, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return nil, fmt.Errorf("读取响应体失败: %w", err)
-	}
-	return body, nil
+	return nil
 }
 
-
-// sendAuthenticatedRequest 封装了需要认证的API请求的通用流程:获取token -> 发送请求。
-func (s *aoDunService) sendAuthenticatedRequest(ctx context.Context, apiPath string, formData map[string]interface{}) ([]byte, error) {
-	tokenType, token, err := s.GetToken(ctx)
+// sendAuthenticatedRequest 封装了需要认证的 API 请求的通用流程
+func (s *aoDunService) sendAuthenticatedRequest(ctx context.Context, isSmall bool, apiPath string, requestBody, responsePayload interface{}) error {
+	tokenType, token, err := s.GetToken(ctx, isSmall)
 	if err != nil {
-		// 如果获取token失败,直接返回错误
-		return nil, err
+		return err
 	}
 
-	// 使用获取到的token发送请求
-	return s.sendFormData(ctx, apiPath, tokenType, token, formData)
+	apiURL := s.getApiUrl(isSmall) + apiPath
+	return s.executeRequest(ctx, apiURL, tokenType, token, requestBody, responsePayload, isSmall)
 }
 
-func (s *aoDunService) GetToken(ctx context.Context) (string, string, error) {
-
+// GetToken 获取认证令牌
+func (s *aoDunService) GetToken(ctx context.Context, isSmall bool) (string, string, error) {
 	formData := map[string]interface{}{
-		"ClientID":  s.clientID,
+		"ClientID":  s.getClientID(isSmall),
 		"GrantType": "password",
-		"Username":  s.IPusername,
-		"Password":  s.IPpassword,
+		"Username":  s.cfg.Username,
+		"Password":  s.cfg.Password,
 	}
 
-	resBody, err := s.sendFormData(ctx,"/oauth/token","","",formData)
-	if err != nil {
+	apiURL := s.getApiUrl(isSmall) + "/oauth/token"
+	var res v1.GetTokenRespone
+	if err := s.executeRequest(ctx, apiURL, "", "", formData, &res, isSmall); err != nil {
 		return "", "", err
 	}
-	// 7. 将响应体 JSON 数据反序列化到 ResponsePayload 结构体
-	var responsePayload v1.GetTokenRespone
-	if err := json.Unmarshal(resBody, &responsePayload); err != nil {
-		// 如果反序列化失败,可能是响应格式不符合预期
-		return "", "", fmt.Errorf("反序列化响应 JSON 失败 ( 内容: %s): %w", string(resBody), err)
-	}
 
-	// 8. 检查 API 返回的操作结果代码
-	if responsePayload.Code != 0 {
-		return "", "", fmt.Errorf("API 错误: code %d, msg '%s', remote_ip '%s'",
-			responsePayload.Code, responsePayload.Msg, responsePayload.RemoteIP)
+	if res.Code != 0 {
+		return "", "", fmt.Errorf("API 错误 (isSmall: %t): code %d, msg '%s'", isSmall, res.Code, res.Msg)
 	}
-
-	// 9. 成功:返回 access_token
-	if responsePayload.AccessToken == "" {
-		// 理论上 code 为 0 时应该有 access_token,这是一个额外的健壮性检查
-		return "", "", fmt.Errorf("API 成功 (code 0) 但 access_token 为空")
+	if res.AccessToken == "" {
+		return "", "", fmt.Errorf("API 成功 (isSmall: %t, code 0) 但 access_token 为空", isSmall)
 	}
-	return responsePayload.TokenType,responsePayload.AccessToken, nil
+
+	return res.TokenType, res.AccessToken, nil
 }
 
-func (s *aoDunService) AddWhiteStaticList(ctx context.Context, req []v1.IpInfo) error {
+// AddWhiteStaticList 添加 IP 到静态白名单
+func (s *aoDunService) AddWhiteStaticList(ctx context.Context, isSmall bool, req []v1.IpInfo) error {
 	formData := map[string]interface{}{
 		"action":         "add",
 		"bwflag":         "white",
 		"insert_bw_list": req,
 	}
 
-	// 使用封装好的方法发送认证请求
-	resBody, err := s.sendAuthenticatedRequest(ctx, "/v1.0/firewall/static_bw_list", formData)
+	var res v1.IpResponse
+	err := s.sendAuthenticatedRequest(ctx, isSmall, "/v1.0/firewall/static_bw_list", formData, &res)
 	if err != nil {
 		return err
 	}
-	// 7. 将响应体 JSON 数据反序列化到 ResponsePayload 结构体
-	var res v1.IpResponse
-	if err := json.Unmarshal(resBody, &res); err != nil {
-		// 如果反序列化失败,可能是响应格式不符合预期
-		return  fmt.Errorf("反序列化响应 JSON 失败 ( 内容: %s): %w", string(resBody), err)
-	}
+
 	if res.Code != 0 {
-		if strings.Contains(res.Msg,"操作部分成功,重复IP如下") {
-			s.logger.Info(res.Msg)
+		if strings.Contains(res.Msg, "操作部分成功,重复IP如下") {
+			s.logger.Info(res.Msg, zap.String("isSmall", strconv.FormatBool(isSmall)))
 			return nil
 		}
-		return  fmt.Errorf("API 错误: code %d, msg '%s'",
-			res.Code, res.Msg)
+		return fmt.Errorf("API 错误 (isSmall: %t): code %d, msg '%s'", isSmall, res.Code, res.Msg)
 	}
 
 	return nil
 }
 
-func (s *aoDunService) GetWhiteStaticList(ctx context.Context, ip string) (int, error) {
+// GetWhiteStaticList 查询白名单 IP 并返回其 ID
+func (s *aoDunService) GetWhiteStaticList(ctx context.Context, isSmall bool, ip string) (int, error) {
 	formData := map[string]interface{}{
 		"action": "get",
 		"bwflag": "white",
@@ -226,38 +201,24 @@ func (s *aoDunService) GetWhiteStaticList(ctx context.Context, ip string) (int,
 		"ids":    ip,
 	}
 
-	// 使用封装好的方法发送认证请求
-	resBody, err := s.sendAuthenticatedRequest(ctx, "/v1.0/firewall/static_bw_list", formData)
+	var res v1.IpGetResponse
+	err := s.sendAuthenticatedRequest(ctx, isSmall, "/v1.0/firewall/static_bw_list", formData, &res)
 	if err != nil {
 		return 0, err
 	}
-	// 7. 将响应体 JSON 数据反序列化到 ResponsePayload 结构体
-	var res v1.IpGetResponse // 使用我们定义的 IpResponse 结构体
-	if err := json.Unmarshal(resBody, &res); err != nil {
-		// 如果反序列化失败,说明响应格式不符合预期
-		return 0, fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
-	}
 
-	// 2. 检查 API 返回的 code,这是处理业务失败的关键
 	if res.Code != 0 {
-		// API 返回了错误码,例如 IP 不存在、参数错误等
-		return 0, fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Msg)
+		return 0, fmt.Errorf("API 错误 (isSmall: %t): code %d, msg '%s'", isSmall, res.Code, res.Msg)
 	}
-
-	// 3. 检查 data 数组是否为空
-	// 即使 code 为 0,也可能因为没有匹配的数据而返回一个空数组
 	if len(res.Data) == 0 {
-		return 0, fmt.Errorf("API 调用成功,但未找到 IP '%s' 相关的记录", ip)
+		return 0, fmt.Errorf("未找到 IP '%s' 相关的白名单记录 (isSmall: %t)", ip, isSmall)
 	}
 
-	// 4. 获取 ID 并返回
-	// 假设我们总是取返回结果中的第一个元素的 ID
-	id := res.Data[0].ID
-	spew.Dump(id)
-	return id, nil // 成功!返回获取到的 id 和 nil 错误
+	return res.Data[0].ID, nil
 }
 
-func (s *aoDunService) DelWhiteStaticList(ctx context.Context, id string) error {
+// DelWhiteStaticList 根据 ID 从白名单中删除 IP
+func (s *aoDunService) DelWhiteStaticList(ctx context.Context, isSmall bool, id string) error {
 	formData := map[string]interface{}{
 		"action": "del",
 		"bwflag": "white",
@@ -265,36 +226,82 @@ func (s *aoDunService) DelWhiteStaticList(ctx context.Context, id string) error
 		"ids":    id,
 	}
 
-	// 使用封装好的方法发送认证请求
-	resBody, err := s.sendAuthenticatedRequest(ctx, "/v1.0/firewall/static_bw_list", formData)
+	var res v1.IpResponse
+	err := s.sendAuthenticatedRequest(ctx, isSmall, "/v1.0/firewall/static_bw_list", formData, &res)
 	if err != nil {
 		return err
 	}
-	var res v1.IpResponse
-	if err := json.Unmarshal(resBody, &res); err != nil {
-		return fmt.Errorf("反序列化响应 JSON 失败 ( 内容: %s): %w", string(resBody), err)
-	}
+
 	if res.Code != 0 {
-		return fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Msg)
+		return fmt.Errorf("API 错误 (isSmall: %t): code %d, msg '%s'", isSmall, res.Code, res.Msg)
 	}
 	return nil
 }
 
-func (s *aoDunService) DomainWhiteList(ctx context.Context, domain string, ip string, apiType string) error {
-	resBody, err := s.sendDomainFormData(ctx,domain,ip,apiType)
+// sendDomainFormData 处理域名白名单的 application/x-www-form-urlencoded 请求
+func (s *aoDunService) sendDomainFormData(ctx context.Context, domain, ip, apiType string) ([]byte, error) {
+	var apiURL string
+	switch apiType {
+	case "add":
+		apiURL = "http://zapi.zzybgp.com/api/user/do_main"
+	case "del":
+		apiURL = "http://zapi.zzybgp.com/api/user/do_main/delete"
+	default:
+		return nil, fmt.Errorf("无效的 apiType: %s", apiType)
+	}
+
+	formData := url.Values{}
+	formData.Set("username", s.cfg.DomainUsername)
+	formData.Set("password", s.cfg.DomainPassword)
+	formData.Add("do_main_list[name][]", domain)
+	formData.Add("do_main_list[ip]", ip)
+
+	req, err := http.NewRequestWithContext(ctx, "POST", apiURL, strings.NewReader(formData.Encode()))
+	if err != nil {
+		return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
+	}
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+	resp, err := s.httpClient.Do(req)
+	if err != nil {
+		return nil, fmt.Errorf("发送 HTTP 请求失败: %w", err)
+	}
+	defer resp.Body.Close()
+
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return nil, fmt.Errorf("读取响应体失败: %w", err)
+	}
+
+	if resp.StatusCode != http.StatusOK {
+		return nil, fmt.Errorf("HTTP 错误: 状态码 %d, 响应: %s", resp.StatusCode, string(body))
+	}
+
+	return body, nil
+}
+
+// DomainWhiteList 添加或删除域名白名单
+func (s *aoDunService) DomainWhiteList(ctx context.Context, domain, ip, apiType string) error {
+	resBody, err := s.sendDomainFormData(ctx, domain, ip, apiType)
 	if err != nil {
 		return err
 	}
+
 	var res v1.DomainResponse
 	if err := json.Unmarshal(resBody, &res); err != nil {
-		return fmt.Errorf("反序列化响应 JSON 失败 ( 内容: %s): %w", string(resBody), err)
+		return fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
 	}
 
-	if res.Code != 200 && apiType == "add" {
-		return fmt.Errorf("API 错误: code %d, msg '%s', data '%s", res.Code, res.Msg, res.Info)
-	}
-	if res.Code != 600 && apiType == "del" {
-		return fmt.Errorf("API 错误: code %d, msg '%s', data '%s", res.Code, res.Msg, res.Info)
+	switch apiType {
+	case "add":
+		if res.Code != 200 {
+			return fmt.Errorf("API 错误: code %d, msg '%s', info '%s'", res.Code, res.Msg, res.Info)
+		}
+	case "del":
+		if res.Code != 600 {
+			return fmt.Errorf("API 错误: code %d, msg '%s', info '%s'", res.Code, res.Msg, res.Info)
+		}
 	}
+
 	return nil
 }

+ 6 - 33
internal/service/tcpforwarding.go

@@ -266,34 +266,14 @@ func (s *tcpforwardingService) EditTcpForwarding(ctx context.Context, req *v1.Tc
 	}
 
 	// 异步任务:将IP添加到白名单
-	var oldIps []string
-	var newIps []string
 	ipData, err := s.tcpforwardingRepository.GetTcpForwardingIpsByID(ctx, req.TcpForwardingData.Id)
 	if err != nil {
 		return err
 	}
-	for _, v := range ipData.BackendList {
-		ip, _, err := net.SplitHostPort(v)
-		if err != nil {
-			return err
-		}
-		oldIps = append(oldIps, ip)
-	}
-	if ipData.AllowIpList != nil {
-		oldIps = append(oldIps, ipData.AllowIpList...)
-	}
-	if req.TcpForwardingData.BackendList != nil {
-		for _, v := range req.TcpForwardingData.BackendList {
-			ip, _, err := net.SplitHostPort(v)
-			if err != nil {
-				return err
-			}
-			newIps = append(newIps, ip)
-		}
-		newIps = append(newIps, req.TcpForwardingData.AllowIpList...)
+	addedIps, removedIps, err := s.wafformatter.WashEditWafIp(ctx,req.TcpForwardingData.BackendList,req.TcpForwardingData.AllowIpList,ipData.BackendList,ipData.AllowIpList)
+	if err != nil {
+		return err
 	}
-	addedIps, removedIps := s.wafformatter.findIpDifferences(oldIps, newIps)
-
 	if len(addedIps) > 0 {
 		go s.wafformatter.PublishIpWhitelistTask(addedIps, "add")
 	}
@@ -332,16 +312,9 @@ func (s *tcpforwardingService) DeleteTcpForwarding(ctx context.Context, req v1.D
 		if err != nil {
 			return err
 		}
-
-		if ipData.BackendList != nil {
-			for _, v := range ipData.BackendList {
-				ip, _, err := net.SplitHostPort(v)
-				if err != nil {
-					return err
-				}
-				ips = append(ips, ip)
-			}
-			ips = append(ips, ipData.AllowIpList...)
+		ips, err = s.wafformatter.WashDeleteWafIp(ctx, ipData.BackendList, ipData.AllowIpList)
+		if err != nil {
+			return err
 		}
 		if len(ips) > 0 {
 			go s.wafformatter.PublishIpWhitelistTask(ips, "del")

+ 6 - 33
internal/service/udpforwarding.go

@@ -284,35 +284,14 @@ func (s *udpForWardingService) EditUdpForwarding(ctx context.Context, req *v1.Ud
 	}
 
 	// 异步任务:将IP添加到白名单
-	var oldIps []string
-	var newIps []string
 	ipData, err := s.udpForWardingRepository.GetUdpForwardingIpsByID(ctx, req.UdpForwardingData.Id)
 	if err != nil {
 		return err
 	}
-	for _, v := range ipData.BackendList {
-		ip, _, err := net.SplitHostPort(v)
-		if err != nil {
-			return err
-		}
-		oldIps = append(oldIps, ip)
-	}
-	if ipData.AllowIpList != nil {
-		oldIps = append(oldIps, ipData.AllowIpList...)
-	}
-
-	if req.UdpForwardingData.BackendList != nil {
-		for _, v := range req.UdpForwardingData.BackendList {
-			ip, _, err := net.SplitHostPort(v)
-			if err != nil {
-				return err
-			}
-			newIps = append(newIps, ip)
-		}
-		newIps = append(newIps, req.UdpForwardingData.AllowIpList...)
+	addedIps, removedIps, err := s.wafformatter.WashEditWafIp(ctx,req.UdpForwardingData.BackendList,req.UdpForwardingData.AllowIpList,ipData.BackendList,ipData.AllowIpList)
+	if err != nil {
+		return err
 	}
-	addedIps, removedIps := s.wafformatter.findIpDifferences(oldIps, newIps)
-
 	if len(addedIps) > 0 {
 		go s.wafformatter.PublishIpWhitelistTask(addedIps, "add")
 	}
@@ -351,15 +330,9 @@ func (s *udpForWardingService) DeleteUdpForwarding(ctx context.Context, Ids []in
 			return err
 		}
 		var ips []string
-		if ipData.BackendList != nil {
-			for _, v := range ipData.BackendList {
-				ip, _, err := net.SplitHostPort(v)
-				if err != nil {
-					return err
-				}
-				ips = append(ips, ip)
-			}
-			ips = append(ips, ipData.AllowIpList...)
+		ips, err = s.wafformatter.WashDeleteWafIp(ctx, ipData.BackendList, ipData.AllowIpList)
+		if err != nil {
+			return err
 		}
 		if len(ips) > 0 {
 			go s.wafformatter.PublishIpWhitelistTask(ips, "del")

+ 43 - 0
internal/service/wafformatter.go

@@ -27,6 +27,8 @@ type WafFormatterService interface {
 	PublishIpWhitelistTask(ips []string, action string)
 	PublishDomainWhitelistTask(domain, ip, action string)
 	findIpDifferences(oldIps, newIps []string) ([]string, []string)
+	WashDeleteWafIp(ctx context.Context, backendList []string,allowIpList []string) ([]string, error)
+	WashEditWafIp(ctx context.Context, newBackendList []string,newAllowIpList []string,oldBackendList []string,oldAllowIpList []string) ([]string, []string, error)
 
 }
 func NewWafFormatterService(
@@ -343,4 +345,45 @@ func (s *wafFormatterService) findIpDifferences(oldIps, newIps []string) ([]stri
 	}
 
 	return addedIps, removedIps
+}
+
+func (s *wafFormatterService) WashDeleteWafIp(ctx context.Context, backendList []string,allowIpList []string) ([]string, error) {
+	var res []string
+	for _, v := range backendList {
+		ip, _, err := net.SplitHostPort(v)
+		if err != nil {
+			return nil, err
+		}
+		res = append(res, ip)
+	}
+	res = append(res, allowIpList...)
+	return res, nil
+}
+
+func (s *wafFormatterService) WashEditWafIp(ctx context.Context, newBackendList []string,newAllowIpList []string,oldBackendList []string,oldAllowIpList []string) ([]string, []string, error) {
+	var oldIps []string
+	var newIps []string
+	for _, v := range oldBackendList {
+		ip, _, err := net.SplitHostPort(v)
+		if err != nil {
+			return nil, nil, err
+		}
+		oldIps = append(oldIps, ip)
+	}
+	if oldAllowIpList != nil {
+		oldIps = append(oldIps, oldAllowIpList...)
+	}
+	if newBackendList != nil {
+		for _, v := range newBackendList {
+			ip, _, err := net.SplitHostPort(v)
+			if err != nil {
+				return nil, nil, err
+			}
+			newIps = append(newIps, ip)
+		}
+		newIps = append(newIps, newAllowIpList...)
+	}
+	addedIps, removedIps := s.findIpDifferences(oldIps, newIps)
+
+	return addedIps, removedIps , nil
 }