aodun.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. package service
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/tls"
  6. "encoding/json"
  7. "fmt"
  8. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  9. "github.com/spf13/viper"
  10. "go.uber.org/zap"
  11. "io"
  12. "net/http"
  13. "net/url"
  14. "strconv"
  15. "strings"
  16. "time"
  17. )
  18. // AoDunService 定义了与傲盾 API 交互的服务接口
  19. type AoDunService interface {
  20. // 添加域名到白名单
  21. DomainWhiteList(ctx context.Context, domain string, ip string, apiType string) error
  22. // 添加 IP 到静态白名单
  23. AddWhiteStaticList(ctx context.Context, isSmall bool, req []v1.IpInfo, color string) error
  24. // 根据 ID 从白名单中删除 IP
  25. DelWhiteStaticList(ctx context.Context, isSmall bool, id string, color string) error
  26. // 查询白名单 IP
  27. GetWhiteStaticList(ctx context.Context, isSmall bool, ip string,serverIp string, color string) (int, error)
  28. // 添加带宽限制
  29. AddBandwidthLimit(ctx context.Context, req v1.Bandwidth) error
  30. // 删除带宽限制
  31. DelBandwidthLimit(ctx context.Context, req v1.Bandwidth) error
  32. // 设置防御带宽
  33. SetDefense(ctx context.Context, req v1.SetDefense) error
  34. }
  35. // aoDunService 是 AoDunService 接口的实现
  36. type aoDunService struct {
  37. *Service
  38. cfg *aoDunConfig
  39. httpClient *http.Client
  40. request RequestService
  41. }
  42. // aoDunConfig 用于整合来自 viper 的所有配置
  43. type aoDunConfig struct {
  44. Url string
  45. ClientID string
  46. Username string
  47. Password string
  48. SmallUrl string
  49. SmallClientID string
  50. DomainUsername string
  51. DomainPassword string
  52. }
  53. // NewAoDunService 创建一个新的 AoDunService 实例
  54. func NewAoDunService(
  55. service *Service,
  56. conf *viper.Viper,
  57. request RequestService,
  58. ) AoDunService {
  59. cfg := &aoDunConfig{
  60. Url: conf.GetString("aodun.Url"),
  61. ClientID: conf.GetString("aodun.clientID"),
  62. Username: conf.GetString("aodun.username"),
  63. Password: conf.GetString("aodun.password"),
  64. SmallUrl: conf.GetString("aodunSmall.Url"),
  65. SmallClientID: conf.GetString("aodunSmall.clientID"),
  66. DomainUsername: conf.GetString("domainWhite.username"),
  67. DomainPassword: conf.GetString("domainWhite.password"),
  68. }
  69. tr := &http.Transport{
  70. TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
  71. MaxIdleConns: 100,
  72. IdleConnTimeout: 90 * time.Second,
  73. ForceAttemptHTTP2: true,
  74. }
  75. client := &http.Client{
  76. Transport: tr,
  77. Timeout: 15 * time.Second,
  78. }
  79. return &aoDunService{
  80. Service: service,
  81. cfg: cfg,
  82. httpClient: client,
  83. request: request,
  84. }
  85. }
  86. // getApiUrl 根据 isSmall 标志返回正确的 API 基础 URL
  87. func (s *aoDunService) getApiUrl(isSmall bool) string {
  88. if isSmall {
  89. return s.cfg.SmallUrl
  90. }
  91. return s.cfg.Url
  92. }
  93. // getClientID 根据 isSmall 标志返回正确的 ClientID
  94. func (s *aoDunService) getClientID(isSmall bool) string {
  95. if isSmall {
  96. return s.cfg.SmallClientID
  97. }
  98. return s.cfg.ClientID
  99. }
  100. // executeRequest 封装了发送 HTTP POST 请求、读取响应和 JSON 解码的通用逻辑
  101. func (s *aoDunService) executeRequest(ctx context.Context, url, tokenType, token string, requestBody, responsePayload interface{}, isSmall bool) error {
  102. jsonData, err := json.Marshal(requestBody)
  103. if err != nil {
  104. return fmt.Errorf("序列化请求数据失败 (isSmall: %t): %w", isSmall, err)
  105. }
  106. req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
  107. if err != nil {
  108. return fmt.Errorf("创建 HTTP 请求失败 (isSmall: %t): %w", isSmall, err)
  109. }
  110. req.Header.Set("Content-Type", "application/json")
  111. if token != "" {
  112. req.Header.Set("Authorization", tokenType+" "+token)
  113. }
  114. resp, err := s.httpClient.Do(req)
  115. if err != nil {
  116. return fmt.Errorf("发送 HTTP 请求失败 (isSmall: %t): %w", isSmall, err)
  117. }
  118. defer resp.Body.Close()
  119. body, err := io.ReadAll(resp.Body)
  120. if err != nil {
  121. return fmt.Errorf("读取响应体失败 (isSmall: %t): %w", isSmall, err)
  122. }
  123. if resp.StatusCode != http.StatusOK {
  124. return fmt.Errorf("HTTP 错误 (isSmall: %t): 状态码 %d, 响应: %s", isSmall, resp.StatusCode, string(body))
  125. }
  126. if err := json.Unmarshal(body, responsePayload); err != nil {
  127. return fmt.Errorf("反序列化响应 JSON 失败 (isSmall: %t, 内容: %s): %w", isSmall, string(body), err)
  128. }
  129. return nil
  130. }
  131. // sendAuthenticatedRequest 封装了需要认证的 API 请求的通用流程
  132. func (s *aoDunService) sendAuthenticatedRequest(ctx context.Context, isSmall bool, apiPath string, requestBody, responsePayload interface{}) error {
  133. tokenType, token, err := s.GetToken(ctx, isSmall)
  134. if err != nil {
  135. return err
  136. }
  137. apiURL := s.getApiUrl(isSmall) + apiPath
  138. return s.executeRequest(ctx, apiURL, tokenType, token, requestBody, responsePayload, isSmall)
  139. }
  140. // GetToken 获取认证令牌
  141. func (s *aoDunService) GetToken(ctx context.Context, isSmall bool) (string, string, error) {
  142. formData := map[string]interface{}{
  143. "ClientID": s.getClientID(isSmall),
  144. "GrantType": "password",
  145. "Username": s.cfg.Username,
  146. "Password": s.cfg.Password,
  147. }
  148. apiURL := s.getApiUrl(isSmall) + "/oauth/token"
  149. var res v1.GetTokenRespone
  150. if err := s.executeRequest(ctx, apiURL, "", "", formData, &res, isSmall); err != nil {
  151. return "", "", err
  152. }
  153. if res.Code != 0 {
  154. return "", "", fmt.Errorf("API 错误 (isSmall: %t): code %d, msg '%s'", isSmall, res.Code, res.Msg)
  155. }
  156. if res.AccessToken == "" {
  157. return "", "", fmt.Errorf("API 成功 (isSmall: %t, code 0) 但 access_token 为空", isSmall)
  158. }
  159. return res.TokenType, res.AccessToken, nil
  160. }
  161. // AddWhiteStaticList 添加 IP 到静态白名单
  162. func (s *aoDunService) AddWhiteStaticList(ctx context.Context, isSmall bool, req []v1.IpInfo,color string) error {
  163. formData := map[string]interface{}{
  164. "action": "add",
  165. "bwflag": color,
  166. "insert_bw_list": req,
  167. }
  168. var res v1.IpResponse
  169. err := s.sendAuthenticatedRequest(ctx, isSmall, "/v1.0/firewall/static_bw_list", formData, &res)
  170. if err != nil {
  171. return err
  172. }
  173. if res.Code != 0 {
  174. if strings.Contains(res.Msg, "操作部分成功,重复IP如下") {
  175. s.Logger.Info(res.Msg, zap.String("isSmall", strconv.FormatBool(isSmall)))
  176. return nil
  177. }
  178. return fmt.Errorf("API 错误 (isSmall: %t): color %s,code %d, msg '%s'", isSmall, color, res.Code, res.Msg)
  179. }
  180. return nil
  181. }
  182. // GetWhiteStaticList 查询白名单 IP 并返回其 ID
  183. func (s *aoDunService) GetWhiteStaticList(ctx context.Context, isSmall bool, ip string,serverIp string, color string) (int, error) {
  184. // 使用一个无限循环,直到API返回空数据页才停止
  185. for i := 0; ; i++ { // i++ 会持续请求下一页
  186. formData := map[string]interface{}{
  187. "action": "get",
  188. "bwflag": color,
  189. "page": i,
  190. "ip": ip,
  191. }
  192. var res v1.IpGetResponse
  193. err := s.sendAuthenticatedRequest(ctx, isSmall, "/v1.0/firewall/static_bw_list", formData, &res)
  194. if err != nil {
  195. return 0, err // 网络或请求本身出错,直接返回
  196. }
  197. if res.Code != 0 {
  198. // API返回了业务错误,直接返回
  199. return 0, fmt.Errorf("API 错误 (isSmall: %t): color %s,code %d, msg '%s'", isSmall, color, res.Code, res.Msg)
  200. }
  201. // 如果当前页的数据为空,说明已经没有更多数据了,可以跳出循环。
  202. // 这是分页查询结束的正确信号。
  203. if len(res.Data) == 0 {
  204. break
  205. }
  206. // 在当前页的数据中查找目标记录
  207. for _, v := range res.Data {
  208. if v.Remark == "宁波高防IP过白" && v.ServerIP == serverIp {
  209. // 找到了,立即返回ID
  210. return v.ID, nil
  211. }
  212. }
  213. // 可选:为了防止无限循环,可以加一个最大页数限制
  214. if i > 50 { // 比如最多查100页
  215. break
  216. }
  217. }
  218. // 如果循环正常结束(所有页都查完了),说明没有找到符合条件的记录
  219. return 0, fmt.Errorf("未找到 IP '%s' 相关的 '%s'名单记录 (备注: 宁波高防IP过白) (isSmall: %t)", ip, color, isSmall)
  220. }
  221. // DelWhiteStaticList 根据 ID 从白名单中删除 IP
  222. func (s *aoDunService) DelWhiteStaticList(ctx context.Context, isSmall bool, id string, color string) error {
  223. formData := map[string]interface{}{
  224. "action": "del",
  225. "bwflag": color,
  226. "flag": 0,
  227. "ids": id,
  228. }
  229. var res v1.IpResponse
  230. err := s.sendAuthenticatedRequest(ctx, isSmall, "/v1.0/firewall/static_bw_list", formData, &res)
  231. if err != nil {
  232. return err
  233. }
  234. if res.Code != 0 {
  235. return fmt.Errorf("API 错误 (isSmall: %t): color %s,code %d, msg '%s'", isSmall, color, res.Code, res.Msg)
  236. }
  237. return nil
  238. }
  239. // sendDomainFormData 处理域名白名单的 application/x-www-form-urlencoded 请求
  240. func (s *aoDunService) sendDomainFormData(ctx context.Context, domain, ip, apiType string) ([]byte, error) {
  241. var apiURL string
  242. switch apiType {
  243. case "add":
  244. apiURL = "http://zapi.zzybgp.com/api/user/do_main"
  245. case "del":
  246. apiURL = "http://zapi.zzybgp.com/api/user/do_main/delete"
  247. default:
  248. return nil, fmt.Errorf("无效的 apiType: %s", apiType)
  249. }
  250. formData := url.Values{}
  251. formData.Set("username", s.cfg.DomainUsername)
  252. formData.Set("password", s.cfg.DomainPassword)
  253. formData.Add("do_main_list[name][]", domain)
  254. formData.Add("do_main_list[ip]", ip)
  255. req, err := http.NewRequestWithContext(ctx, "POST", apiURL, strings.NewReader(formData.Encode()))
  256. if err != nil {
  257. return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
  258. }
  259. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  260. resp, err := s.httpClient.Do(req)
  261. if err != nil {
  262. return nil, fmt.Errorf("发送 HTTP 请求失败: %w", err)
  263. }
  264. defer resp.Body.Close()
  265. body, err := io.ReadAll(resp.Body)
  266. if err != nil {
  267. return nil, fmt.Errorf("读取响应体失败: %w", err)
  268. }
  269. if resp.StatusCode != http.StatusOK {
  270. return nil, fmt.Errorf("HTTP 错误: 状态码 %d, 响应: %s", resp.StatusCode, string(body))
  271. }
  272. return body, nil
  273. }
  274. // DomainWhiteList 添加或删除域名白名单
  275. func (s *aoDunService) DomainWhiteList(ctx context.Context, domain, ip, apiType string) error {
  276. resBody, err := s.sendDomainFormData(ctx, domain, ip, apiType)
  277. if err != nil {
  278. return err
  279. }
  280. var res v1.DomainResponse
  281. if err := json.Unmarshal(resBody, &res); err != nil {
  282. return fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
  283. }
  284. switch apiType {
  285. case "add":
  286. if res.Code != 200 {
  287. return fmt.Errorf("API 错误: code %d, msg '%s', info '%s'", res.Code, res.Msg, res.Info)
  288. }
  289. case "del":
  290. if res.Code != 600 {
  291. return fmt.Errorf("API 错误: code %d, msg '%s', info '%s'", res.Code, res.Msg, res.Info)
  292. }
  293. }
  294. return nil
  295. }
  296. // AddBandwidthLimit 添加带宽限制
  297. func (s *aoDunService) AddBandwidthLimit(ctx context.Context, req v1.Bandwidth) error {
  298. var res v1.BandwidthResponse
  299. formData := map[string]interface{}{
  300. "server_ip_type": req.ServerIPType,
  301. "server_ip_start": req.ServerIPStart,
  302. "name": req.Name,
  303. "speedlimit_out": req.SpeedlimitOut,
  304. "client_ip_type": req.ClientIPType,
  305. "action": req.Action,
  306. "direction": req.Direction,
  307. "protocol": req.Protocol,
  308. }
  309. err := s.sendAuthenticatedRequest(ctx, true, "/v1.0/firewall/add_filter_rule", formData, &res)
  310. if err != nil {
  311. return err
  312. }
  313. if res.Err != 0 {
  314. return fmt.Errorf("API 错误: code %d, msg '%s'", res.Err, res.Msg)
  315. }
  316. if res.Msg != "操作成功" {
  317. return fmt.Errorf("API 错误: code %d, msg '%s'", res.Err, res.Msg)
  318. }
  319. return nil
  320. }
  321. // DelBandwidthLimit 删除带宽限制
  322. func (s *aoDunService) DelBandwidthLimit(ctx context.Context, req v1.Bandwidth) error {
  323. var res v1.BandwidthResponse
  324. formData := map[string]interface{}{
  325. "name": req.Name,
  326. }
  327. err := s.sendAuthenticatedRequest(ctx, true, "/v1.0/firewall/delete_filter_rule", formData, &res)
  328. if err != nil {
  329. return err
  330. }
  331. if res.Err != 0 {
  332. return fmt.Errorf("API 错误: code %d, msg '%s'", res.Err, res.Msg)
  333. }
  334. if res.Msg != "操作成功" {
  335. return fmt.Errorf("API 错误: code %d, msg '%s'", res.Err, res.Msg)
  336. }
  337. return nil
  338. }
  339. // 设置防御带宽
  340. func (s *aoDunService) SetDefense(ctx context.Context, req v1.SetDefense) error {
  341. formData := map[string]interface{}{
  342. "ip_addr": req.IpAddr,
  343. "defense": req.Defense,
  344. "username": s.cfg.DomainUsername,
  345. "password": s.cfg.DomainPassword,
  346. }
  347. resBody, err := s.request.Request(ctx,formData, "http://zapi.zzybgp.com/api/set_defense", "", "")
  348. if err != nil {
  349. return err
  350. }
  351. var res struct {
  352. Code int `json:"code"`
  353. Msg string `json:"msg"`
  354. }
  355. if err := json.Unmarshal(resBody, &res); err != nil {
  356. return fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
  357. }
  358. if res.Msg == "当前ip已是此防御值" {
  359. return nil
  360. }
  361. if res.Code != 200 {
  362. return fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Msg)
  363. }
  364. return nil
  365. }