package service import ( "bytes" "context" "crypto/tls" "encoding/json" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/spf13/viper" "go.uber.org/zap" "io" "net/http" "net/url" "strconv" "strings" "time" ) // AoDunService 定义了与傲盾 API 交互的服务接口 type AoDunService interface { DomainWhiteList(ctx context.Context, domain string, ip string, apiType string) error AddWhiteStaticList(ctx context.Context, isSmall bool, req []v1.IpInfo, color string) error DelWhiteStaticList(ctx context.Context, isSmall bool, id string, color string) error GetWhiteStaticList(ctx context.Context, isSmall bool, ip string,serverIp string, color string) (int, error) AddBandwidthLimit(ctx context.Context, req v1.Bandwidth) error DelBandwidthLimit(ctx context.Context, req v1.Bandwidth) error } // aoDunService 是 AoDunService 接口的实现 type aoDunService struct { *Service cfg *aoDunConfig httpClient *http.Client } // aoDunConfig 用于整合来自 viper 的所有配置 type aoDunConfig struct { Url string ClientID string Username string Password string SmallUrl string SmallClientID string DomainUsername string DomainPassword string } // NewAoDunService 创建一个新的 AoDunService 实例 func NewAoDunService(service *Service, conf *viper.Viper) AoDunService { cfg := &aoDunConfig{ Url: conf.GetString("aodun.Url"), ClientID: conf.GetString("aodun.clientID"), Username: conf.GetString("aodun.username"), Password: conf.GetString("aodun.password"), SmallUrl: conf.GetString("aodunSmall.Url"), SmallClientID: conf.GetString("aodunSmall.clientID"), DomainUsername: conf.GetString("domainWhite.username"), DomainPassword: conf.GetString("domainWhite.password"), } tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, ForceAttemptHTTP2: true, } client := &http.Client{ Transport: tr, Timeout: 15 * time.Second, } return &aoDunService{ Service: service, cfg: cfg, httpClient: client, } } // getApiUrl 根据 isSmall 标志返回正确的 API 基础 URL func (s *aoDunService) getApiUrl(isSmall bool) string { if isSmall { return s.cfg.SmallUrl } return s.cfg.Url } // getClientID 根据 isSmall 标志返回正确的 ClientID func (s *aoDunService) getClientID(isSmall bool) string { if isSmall { return s.cfg.SmallClientID } return s.cfg.ClientID } // executeRequest 封装了发送 HTTP POST 请求、读取响应和 JSON 解码的通用逻辑 func (s *aoDunService) executeRequest(ctx context.Context, url, tokenType, token string, requestBody, responsePayload interface{}, isSmall bool) error { jsonData, err := json.Marshal(requestBody) if err != nil { return fmt.Errorf("序列化请求数据失败 (isSmall: %t): %w", isSmall, err) } req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) if err != nil { return fmt.Errorf("创建 HTTP 请求失败 (isSmall: %t): %w", isSmall, err) } req.Header.Set("Content-Type", "application/json") if token != "" { req.Header.Set("Authorization", tokenType+" "+token) } resp, err := s.httpClient.Do(req) if err != nil { return fmt.Errorf("发送 HTTP 请求失败 (isSmall: %t): %w", isSmall, err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("读取响应体失败 (isSmall: %t): %w", isSmall, err) } if resp.StatusCode != http.StatusOK { return fmt.Errorf("HTTP 错误 (isSmall: %t): 状态码 %d, 响应: %s", isSmall, resp.StatusCode, string(body)) } if err := json.Unmarshal(body, responsePayload); err != nil { return fmt.Errorf("反序列化响应 JSON 失败 (isSmall: %t, 内容: %s): %w", isSmall, string(body), err) } return nil } // sendAuthenticatedRequest 封装了需要认证的 API 请求的通用流程 func (s *aoDunService) sendAuthenticatedRequest(ctx context.Context, isSmall bool, apiPath string, requestBody, responsePayload interface{}) error { tokenType, token, err := s.GetToken(ctx, isSmall) if err != nil { return err } apiURL := s.getApiUrl(isSmall) + apiPath return s.executeRequest(ctx, apiURL, tokenType, token, requestBody, responsePayload, isSmall) } // GetToken 获取认证令牌 func (s *aoDunService) GetToken(ctx context.Context, isSmall bool) (string, string, error) { formData := map[string]interface{}{ "ClientID": s.getClientID(isSmall), "GrantType": "password", "Username": s.cfg.Username, "Password": s.cfg.Password, } apiURL := s.getApiUrl(isSmall) + "/oauth/token" var res v1.GetTokenRespone if err := s.executeRequest(ctx, apiURL, "", "", formData, &res, isSmall); err != nil { return "", "", err } if res.Code != 0 { return "", "", fmt.Errorf("API 错误 (isSmall: %t): code %d, msg '%s'", isSmall, res.Code, res.Msg) } if res.AccessToken == "" { return "", "", fmt.Errorf("API 成功 (isSmall: %t, code 0) 但 access_token 为空", isSmall) } return res.TokenType, res.AccessToken, nil } // AddWhiteStaticList 添加 IP 到静态白名单 func (s *aoDunService) AddWhiteStaticList(ctx context.Context, isSmall bool, req []v1.IpInfo,color string) error { formData := map[string]interface{}{ "action": "add", "bwflag": color, "insert_bw_list": req, } var res v1.IpResponse err := s.sendAuthenticatedRequest(ctx, isSmall, "/v1.0/firewall/static_bw_list", formData, &res) if err != nil { return err } if res.Code != 0 { if strings.Contains(res.Msg, "操作部分成功,重复IP如下") { s.Logger.Info(res.Msg, zap.String("isSmall", strconv.FormatBool(isSmall))) return nil } return fmt.Errorf("API 错误 (isSmall: %t): color %s,code %d, msg '%s'", isSmall, color, res.Code, res.Msg) } return nil } // GetWhiteStaticList 查询白名单 IP 并返回其 ID func (s *aoDunService) GetWhiteStaticList(ctx context.Context, isSmall bool, ip string,serverIp string, color string) (int, error) { // 使用一个无限循环,直到API返回空数据页才停止 for i := 0; ; i++ { // i++ 会持续请求下一页 formData := map[string]interface{}{ "action": "get", "bwflag": color, "page": i, "ip": ip, } var res v1.IpGetResponse err := s.sendAuthenticatedRequest(ctx, isSmall, "/v1.0/firewall/static_bw_list", formData, &res) if err != nil { return 0, err // 网络或请求本身出错,直接返回 } if res.Code != 0 { // API返回了业务错误,直接返回 return 0, fmt.Errorf("API 错误 (isSmall: %t): color %s,code %d, msg '%s'", isSmall, color, res.Code, res.Msg) } // 如果当前页的数据为空,说明已经没有更多数据了,可以跳出循环。 // 这是分页查询结束的正确信号。 if len(res.Data) == 0 { break } // 在当前页的数据中查找目标记录 for _, v := range res.Data { if v.Remark == "宁波高防IP过白" && v.ServerIP == serverIp { // 找到了,立即返回ID return v.ID, nil } } // 可选:为了防止无限循环,可以加一个最大页数限制 if i > 50 { // 比如最多查100页 break } } // 如果循环正常结束(所有页都查完了),说明没有找到符合条件的记录 return 0, fmt.Errorf("未找到 IP '%s' 相关的 '%s'名单记录 (备注: 宁波高防IP过白) (isSmall: %t)", ip, color, isSmall) } // DelWhiteStaticList 根据 ID 从白名单中删除 IP func (s *aoDunService) DelWhiteStaticList(ctx context.Context, isSmall bool, id string, color string) error { formData := map[string]interface{}{ "action": "del", "bwflag": color, "flag": 0, "ids": id, } var res v1.IpResponse err := s.sendAuthenticatedRequest(ctx, isSmall, "/v1.0/firewall/static_bw_list", formData, &res) if err != nil { return err } if res.Code != 0 { return fmt.Errorf("API 错误 (isSmall: %t): color %s,code %d, msg '%s'", isSmall, color, res.Code, res.Msg) } return nil } // sendDomainFormData 处理域名白名单的 application/x-www-form-urlencoded 请求 func (s *aoDunService) sendDomainFormData(ctx context.Context, domain, ip, apiType string) ([]byte, error) { var apiURL string switch apiType { case "add": apiURL = "http://zapi.zzybgp.com/api/user/do_main" case "del": apiURL = "http://zapi.zzybgp.com/api/user/do_main/delete" default: return nil, fmt.Errorf("无效的 apiType: %s", apiType) } formData := url.Values{} formData.Set("username", s.cfg.DomainUsername) formData.Set("password", s.cfg.DomainPassword) formData.Add("do_main_list[name][]", domain) formData.Add("do_main_list[ip]", ip) req, err := http.NewRequestWithContext(ctx, "POST", apiURL, strings.NewReader(formData.Encode())) if err != nil { return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := s.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("发送 HTTP 请求失败: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("读取响应体失败: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("HTTP 错误: 状态码 %d, 响应: %s", resp.StatusCode, string(body)) } return body, nil } // DomainWhiteList 添加或删除域名白名单 func (s *aoDunService) DomainWhiteList(ctx context.Context, domain, ip, apiType string) error { resBody, err := s.sendDomainFormData(ctx, domain, ip, apiType) if err != nil { return err } var res v1.DomainResponse if err := json.Unmarshal(resBody, &res); err != nil { return fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err) } switch apiType { case "add": if res.Code != 200 { return fmt.Errorf("API 错误: code %d, msg '%s', info '%s'", res.Code, res.Msg, res.Info) } case "del": if res.Code != 600 { return fmt.Errorf("API 错误: code %d, msg '%s', info '%s'", res.Code, res.Msg, res.Info) } } return nil } // AddBandwidthLimit 添加带宽限制 func (s *aoDunService) AddBandwidthLimit(ctx context.Context, req v1.Bandwidth) error { var res v1.BandwidthResponse formData := map[string]interface{}{ "server_ip_type": req.ServerIPType, "server_ip_start": req.ServerIPStart, "name": req.Name, "speedlimit_out": req.SpeedlimitOut, "client_ip_type": req.ClientIPType, "action": req.Action, "direction": req.Direction, "protocol": req.Protocol, } err := s.sendAuthenticatedRequest(ctx, true, "/v1.0/firewall/add_filter_rule", formData, &res) if err != nil { return err } if res.Err != 0 { return fmt.Errorf("API 错误: code %d, msg '%s'", res.Err, res.Msg) } if res.Msg != "操作成功" { return fmt.Errorf("API 错误: code %d, msg '%s'", res.Err, res.Msg) } return nil } // DelBandwidthLimit 删除带宽限制 func (s *aoDunService) DelBandwidthLimit(ctx context.Context, req v1.Bandwidth) error { var res v1.BandwidthResponse formData := map[string]interface{}{ "name": req.Name, } err := s.sendAuthenticatedRequest(ctx, true, "/v1.0/firewall/delete_filter_rule", formData, &res) if err != nil { return err } if res.Err != 0 { return fmt.Errorf("API 错误: code %d, msg '%s'", res.Err, res.Msg) } if res.Msg != "操作成功" { return fmt.Errorf("API 错误: code %d, msg '%s'", res.Err, res.Msg) } return nil }