tcpforwarding.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. package service
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. "golang.org/x/sync/errgroup"
  10. "maps"
  11. "net"
  12. "sort"
  13. )
  14. type TcpforwardingService interface {
  15. GetTcpforwarding(ctx context.Context, req v1.GetForwardingRequest) (v1.TcpForwardingDataRequest, error)
  16. AddTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest) error
  17. EditTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest) error
  18. DeleteTcpForwarding(ctx context.Context, req v1.DeleteTcpForwardingRequest) error
  19. GetTcpForwardingAllIpsByHostId(ctx context.Context, req v1.GetForwardingRequest) ([]v1.TcpForwardingDataRequest, error)
  20. }
  21. func NewTcpforwardingService(
  22. service *Service,
  23. tcpforwardingRepository repository.TcpforwardingRepository,
  24. parser ParserService,
  25. required RequiredService,
  26. crawler CrawlerService,
  27. globalRep repository.GlobalLimitRepository,
  28. hostRep repository.HostRepository,
  29. wafformatter WafFormatterService,
  30. cdn CdnService,
  31. proxy ProxyService,
  32. gatewayIpRep repository.GatewayipRepository,
  33. ) TcpforwardingService {
  34. return &tcpforwardingService{
  35. Service: service,
  36. tcpforwardingRepository: tcpforwardingRepository,
  37. parser: parser,
  38. required: required,
  39. crawler: crawler,
  40. globalRep: globalRep,
  41. hostRep: hostRep,
  42. wafformatter: wafformatter,
  43. cdn: cdn,
  44. proxy: proxy,
  45. gatewayIpRep: gatewayIpRep,
  46. }
  47. }
  48. type tcpforwardingService struct {
  49. *Service
  50. tcpforwardingRepository repository.TcpforwardingRepository
  51. parser ParserService
  52. required RequiredService
  53. crawler CrawlerService
  54. globalRep repository.GlobalLimitRepository
  55. hostRep repository.HostRepository
  56. wafformatter WafFormatterService
  57. cdn CdnService
  58. proxy ProxyService
  59. gatewayIpRep repository.GatewayipRepository
  60. }
  61. func (s *tcpforwardingService) GetTcpforwarding(ctx context.Context, req v1.GetForwardingRequest) (v1.TcpForwardingDataRequest, error) {
  62. var tcpForwarding model.Tcpforwarding
  63. var backend model.TcpForwardingRule
  64. var err error
  65. g, gCtx := errgroup.WithContext(ctx)
  66. g.Go(func() error {
  67. res, e := s.tcpforwardingRepository.GetTcpforwarding(gCtx, int64(req.Id))
  68. if e != nil {
  69. return fmt.Errorf("GetTcpforwarding failed: %w", e)
  70. }
  71. if res != nil {
  72. tcpForwarding = *res
  73. }
  74. return nil
  75. })
  76. g.Go(func() error {
  77. res, e := s.tcpforwardingRepository.GetTcpForwardingIpsByID(gCtx, req.Id)
  78. if e != nil {
  79. return fmt.Errorf("GetTcpforwardingIps failed: %w", e)
  80. }
  81. if res != nil {
  82. backend = *res
  83. }
  84. return nil
  85. })
  86. if err = g.Wait(); err != nil {
  87. return v1.TcpForwardingDataRequest{}, err
  88. }
  89. return v1.TcpForwardingDataRequest{
  90. Id: tcpForwarding.Id,
  91. Port: tcpForwarding.Port,
  92. Comment: tcpForwarding.Comment,
  93. Proxy: tcpForwarding.Proxy,
  94. BackendList: backend.BackendList,
  95. }, nil
  96. }
  97. func (s *tcpforwardingService) buildTcpForwardingModel(req *v1.TcpForwardingDataRequest, ruleId int, require RequireResponse) *model.Tcpforwarding {
  98. return &model.Tcpforwarding{
  99. HostId: require.HostId,
  100. CdnWebId: ruleId,
  101. Port: req.Port,
  102. Comment: req.Comment,
  103. Proxy: req.Proxy,
  104. }
  105. }
  106. func (s *tcpforwardingService) buildTcpRuleModel(reqData *v1.TcpForwardingDataRequest, require RequireResponse, localDbId int, cdnOriginIds map[string]int64) *model.TcpForwardingRule {
  107. return &model.TcpForwardingRule{
  108. Uid: require.Uid,
  109. HostId: require.HostId,
  110. TcpId: localDbId, // 关联到本地数据库的主记录 ID
  111. CdnOriginIds: cdnOriginIds,
  112. BackendList: reqData.BackendList,
  113. }
  114. }
  115. func (s *tcpforwardingService) prepareWafData(ctx context.Context, req *v1.TcpForwardingRequest) (RequireResponse, v1.WebsiteSend, error) {
  116. // 1. 获取必要的全局信息
  117. require, err := s.wafformatter.Require(ctx, v1.GlobalRequire{
  118. HostId: req.HostId,
  119. Uid: req.Uid,
  120. Comment: req.TcpForwardingData.Comment,
  121. })
  122. if err != nil {
  123. return RequireResponse{}, v1.WebsiteSend{}, err
  124. }
  125. gatewayIps, err := s.gatewayIpRep.GetGatewayipOnlyIpByHostIdAll(ctx, int64(req.HostId))
  126. if err != nil {
  127. return RequireResponse{}, v1.WebsiteSend{}, err
  128. }
  129. require.GatewayIps = gatewayIps
  130. if require.GatewayIps == nil || require.Uid == 0 {
  131. return RequireResponse{}, v1.WebsiteSend{}, fmt.Errorf("请先配置实例")
  132. }
  133. var jsonData v1.TypeJSON
  134. jsonData.IsOn = true
  135. for _, v := range require.GatewayIps {
  136. jsonData.Listen = append(jsonData.Listen, v1.Listen{
  137. Protocol: "tcp",
  138. Host: v,
  139. Port: req.TcpForwardingData.Port,
  140. })
  141. }
  142. byteData, err := json.Marshal(jsonData)
  143. if err != nil {
  144. return RequireResponse{}, v1.WebsiteSend{}, err
  145. }
  146. formData := v1.WebsiteSend{
  147. UserId: int64(require.CdnUid),
  148. Type: "tcpProxy",
  149. Name: require.Tag,
  150. Description: req.TcpForwardingData.Comment,
  151. TcpJSON: byteData,
  152. ServerGroupIds: []int64{int64(require.GroupId)},
  153. UserPlanId: int64(require.RuleId),
  154. NodeClusterId: 1,
  155. }
  156. return require, formData, nil
  157. }
  158. func (s *tcpforwardingService) AddTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest) error {
  159. require, formData, err := s.prepareWafData(ctx, req)
  160. if err != nil {
  161. return err
  162. }
  163. err = s.wafformatter.validateWafPortCount(ctx, require.HostId)
  164. if err != nil {
  165. return err
  166. }
  167. // 验证端口重复
  168. err = s.wafformatter.VerifyPort(ctx, "tcp", int64(req.TcpForwardingData.Id),req.TcpForwardingData.Port, int64(require.HostId), "")
  169. if err != nil {
  170. return err
  171. }
  172. tcpId, err := s.cdn.CreateWebsite(ctx, formData)
  173. if err != nil {
  174. return err
  175. }
  176. // 添加源站
  177. cdnOriginIds := make(map[string]int64)
  178. for _, v := range req.TcpForwardingData.BackendList{
  179. id, err := s.wafformatter.AddOrigin(ctx, v1.WebJson{
  180. ApiType: "tcp",
  181. BackendList: v,
  182. Comment: req.TcpForwardingData.Comment,
  183. })
  184. if err != nil {
  185. return err
  186. }
  187. cdnOriginIds[v] = id
  188. }
  189. // 添加源站到网站
  190. for _, v := range cdnOriginIds {
  191. err = s.cdn.AddServerOrigin(ctx, tcpId, v)
  192. if err != nil {
  193. return err
  194. }
  195. }
  196. // 开启proxy
  197. if req.TcpForwardingData.Proxy {
  198. err = s.proxy.EditProxy(ctx,tcpId, v1.ProxyProtocolJSON{
  199. IsOn: true,
  200. Version: 1,
  201. })
  202. if err != nil {
  203. return err
  204. }
  205. }
  206. tcpModel := s.buildTcpForwardingModel(&req.TcpForwardingData, int(tcpId), require)
  207. id, err := s.tcpforwardingRepository.AddTcpforwarding(ctx, tcpModel)
  208. if err != nil {
  209. return err
  210. }
  211. TcpRuleModel := s.buildTcpRuleModel(&req.TcpForwardingData, require, id, cdnOriginIds)
  212. if _, err = s.tcpforwardingRepository.AddTcpforwardingIps(ctx, *TcpRuleModel); err != nil {
  213. return err
  214. }
  215. // 异步任务:将源站IP添加到白名单
  216. var ips []string
  217. if req.TcpForwardingData.BackendList != nil {
  218. for _, v := range req.TcpForwardingData.BackendList {
  219. ip, _, err := net.SplitHostPort(v)
  220. if err != nil {
  221. return err
  222. }
  223. ips = append(ips, ip)
  224. }
  225. go s.wafformatter.PublishIpWhitelistTask(ips, "add","","white")
  226. }
  227. return nil
  228. }
  229. func (s *tcpforwardingService) EditTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest) error {
  230. require, formData, err := s.prepareWafData(ctx, req)
  231. if err != nil {
  232. return err
  233. }
  234. oldData, err := s.tcpforwardingRepository.GetTcpforwarding(ctx, int64(req.TcpForwardingData.Id))
  235. if err != nil {
  236. return err
  237. }
  238. // 验证端口重复
  239. if oldData.Port != req.TcpForwardingData.Port {
  240. err = s.wafformatter.VerifyPort(ctx, "tcp", int64(req.TcpForwardingData.Id), req.TcpForwardingData.Port, int64(require.HostId), "")
  241. if err != nil {
  242. return err
  243. }
  244. }
  245. //修改网站端口
  246. if oldData.Port != req.TcpForwardingData.Port {
  247. err = s.cdn.EditServerType(ctx, v1.EditWebsite{
  248. Id: int64(oldData.CdnWebId),
  249. TypeJSON: formData.TcpJSON,
  250. }, "tcp")
  251. if err != nil {
  252. return err
  253. }
  254. }
  255. //修改网站名字
  256. if oldData.Comment != req.TcpForwardingData.Comment {
  257. nodeId, err := s.globalRep.GetNodeId(ctx, oldData.CdnWebId)
  258. err = s.cdn.EditServerBasic(ctx, int64(oldData.CdnWebId), require.Tag, nodeId)
  259. if err != nil {
  260. return err
  261. }
  262. }
  263. //修改Proxy
  264. if oldData.Proxy != req.TcpForwardingData.Proxy {
  265. err = s.proxy.EditProxy(ctx, int64(oldData.CdnWebId), v1.ProxyProtocolJSON{
  266. IsOn: req.TcpForwardingData.Proxy,
  267. Version: 1,
  268. })
  269. if err != nil {
  270. return err
  271. }
  272. }
  273. // 异步任务:将IP添加到白名单
  274. ipData, err := s.tcpforwardingRepository.GetTcpForwardingIpsByID(ctx, req.TcpForwardingData.Id)
  275. if err != nil {
  276. return err
  277. }
  278. addedIps, removedIps, err := s.wafformatter.WashEditWafIp(ctx,req.TcpForwardingData.BackendList, ipData.BackendList)
  279. if err != nil {
  280. return err
  281. }
  282. if len(addedIps) > 0 {
  283. go s.wafformatter.PublishIpWhitelistTask(addedIps, "add","","white")
  284. }
  285. if len(removedIps) > 0 {
  286. ipsToDelist, err := s.wafformatter.WashDelIps(ctx, removedIps)
  287. if err != nil {
  288. return err
  289. }
  290. // 4. 如果有需要处理的IP,则批量发布一次任务
  291. if len(ipsToDelist) > 0 {
  292. go s.wafformatter.PublishIpWhitelistTask(ipsToDelist, "del", "0", "white")
  293. }
  294. }
  295. //修改源站
  296. addOrigins, delOrigins := s.wafformatter.findIpDifferences(ipData.BackendList, req.TcpForwardingData.BackendList)
  297. addedIds := make(map[string]int64)
  298. for _, v := range addOrigins {
  299. id, err := s.wafformatter.AddOrigin(ctx,v1.WebJson{
  300. ApiType: "tcp",
  301. BackendList: v,
  302. Comment: req.TcpForwardingData.Comment,
  303. })
  304. if err != nil {
  305. return err
  306. }
  307. addedIds[v] = id
  308. }
  309. for _, v := range addedIds {
  310. err = s.cdn.AddServerOrigin(ctx, int64(oldData.CdnWebId), v)
  311. if err != nil {
  312. return err
  313. }
  314. }
  315. maps.Copy(ipData.CdnOriginIds, addedIds)
  316. for k, v := range ipData.CdnOriginIds {
  317. for _, ip := range delOrigins {
  318. if k == ip {
  319. err = s.cdn.DelServerOrigin(ctx, int64(oldData.CdnWebId), v)
  320. if err != nil {
  321. return err
  322. }
  323. delete(ipData.CdnOriginIds, k)
  324. }
  325. }
  326. }
  327. tcpModel := s.buildTcpForwardingModel(&req.TcpForwardingData,oldData.CdnWebId, require)
  328. tcpModel.Id = req.TcpForwardingData.Id
  329. if err = s.tcpforwardingRepository.EditTcpforwarding(ctx, tcpModel); err != nil {
  330. return err
  331. }
  332. TcpRuleModel := s.buildTcpRuleModel(&req.TcpForwardingData, require, req.TcpForwardingData.Id, ipData.CdnOriginIds)
  333. if err = s.tcpforwardingRepository.EditTcpforwardingIps(ctx, *TcpRuleModel); err != nil {
  334. return err
  335. }
  336. return nil
  337. }
  338. func (s *tcpforwardingService) DeleteTcpForwarding(ctx context.Context, req v1.DeleteTcpForwardingRequest) error {
  339. for _, Id := range req.Ids {
  340. oldData, err := s.tcpforwardingRepository.GetTcpforwarding(ctx, int64(Id))
  341. if err != nil {
  342. return err
  343. }
  344. err = s.cdn.DelServer(ctx, int64(oldData.CdnWebId))
  345. if err != nil {
  346. return err
  347. }
  348. // 删除白名单
  349. var ips []string
  350. ipData, err := s.tcpforwardingRepository.GetTcpForwardingIpsByID(ctx, Id)
  351. if err != nil {
  352. return err
  353. }
  354. ips, err = s.wafformatter.WashDeleteWafIp(ctx, ipData.BackendList)
  355. if err != nil {
  356. return err
  357. }
  358. if len(ips) > 0 {
  359. ipsToDelist, err := s.wafformatter.WashDelIps(ctx, ips)
  360. if err != nil {
  361. return err
  362. }
  363. // 4. 如果有需要处理的IP,则批量发布一次任务
  364. if len(ipsToDelist) > 0 {
  365. go s.wafformatter.PublishIpWhitelistTask(ipsToDelist, "del", "0", "white")
  366. }
  367. }
  368. if err = s.tcpforwardingRepository.DeleteTcpforwarding(ctx, int64(Id)); err != nil {
  369. return err
  370. }
  371. if err = s.tcpforwardingRepository.DeleteTcpForwardingIpsById(ctx, Id); err != nil {
  372. return err
  373. }
  374. }
  375. return nil
  376. }
  377. func (s *tcpforwardingService) GetTcpForwardingAllIpsByHostId(ctx context.Context, req v1.GetForwardingRequest) ([]v1.TcpForwardingDataRequest, error) {
  378. type CombinedResult struct {
  379. Id int
  380. Forwarding *model.Tcpforwarding
  381. BackendRule *model.TcpForwardingRule
  382. Err error // 如果此ID的处理出错,则携带错误
  383. }
  384. g,gCtx := errgroup.WithContext(ctx)
  385. ids, err := s.tcpforwardingRepository.GetTcpForwardingAllIdsByID(gCtx, req.HostId)
  386. if err != nil {
  387. return nil, fmt.Errorf("GetTcpForwardingAllIds failed: %w", err)
  388. }
  389. if len(ids) == 0 {
  390. return nil, nil
  391. }
  392. resChan := make(chan CombinedResult, len(ids))
  393. g.Go(func() error {
  394. for _, idVal := range ids {
  395. currentID := idVal
  396. g.Go(func() error {
  397. var wf *model.Tcpforwarding
  398. var bk *model.TcpForwardingRule
  399. var localErr error
  400. wf, localErr = s.tcpforwardingRepository.GetTcpforwarding(gCtx, int64(currentID))
  401. if localErr != nil {
  402. resChan <- CombinedResult{Id: currentID, Err: localErr}
  403. return localErr
  404. }
  405. bk, localErr = s.tcpforwardingRepository.GetTcpForwardingIpsByID(gCtx, currentID)
  406. if localErr != nil {
  407. resChan <- CombinedResult{Id: currentID, Err: localErr}
  408. return localErr
  409. }
  410. resChan <- CombinedResult{Id: currentID, Forwarding: wf, BackendRule: bk}
  411. return nil
  412. })
  413. }
  414. return nil
  415. })
  416. groupErr := g.Wait()
  417. close(resChan)
  418. if groupErr != nil {
  419. return nil, groupErr
  420. }
  421. res := make([]v1.TcpForwardingDataRequest, 0, len(ids))
  422. for r := range resChan {
  423. if r.Err != nil {
  424. return nil, fmt.Errorf("received error from goroutine for ID %d: %w", r.Id, r.Err)
  425. }
  426. if r.Forwarding == nil {
  427. return nil,fmt.Errorf("received nil forwarding from goroutine for ID %d", r.Id)
  428. }
  429. dataReq := v1.TcpForwardingDataRequest{
  430. Id: r.Forwarding.Id,
  431. Port: r.Forwarding.Port,
  432. Comment: r.Forwarding.Comment,
  433. Proxy: r.Forwarding.Proxy,
  434. }
  435. if r.BackendRule != nil {
  436. dataReq.BackendList = r.BackendRule.BackendList
  437. }
  438. res = append(res, dataReq)
  439. }
  440. sort.Slice(res, func(i, j int) bool {
  441. return res[i].Id > res[j].Id
  442. })
  443. return res, nil
  444. }