udpforwarding.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. package waf
  2. import (
  3. "context"
  4. "fmt"
  5. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  6. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
  9. "github.com/go-nunu/nunu-layout-advanced/internal/service"
  10. "github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn"
  11. "golang.org/x/sync/errgroup"
  12. "sort"
  13. )
  14. type UdpForWardingService interface {
  15. GetUdpForWarding(ctx context.Context,req v1.GetForwardingRequest) (v1.UdpForwardingDataRequest, error)
  16. AddUdpForwarding(ctx context.Context, req *v1.UdpForwardingRequest) (int, error)
  17. EditUdpForwarding(ctx context.Context, req *v1.UdpForwardingRequest) error
  18. DeleteUdpForwarding(ctx context.Context, req v1.DeleteUdpForwardingRequest) error
  19. GetUdpForwardingWafUdpAllIps(ctx context.Context, req v1.GetForwardingRequest) ([]v1.UdpForwardingDataRequest, error)
  20. }
  21. func NewUdpForWardingService(
  22. service *service.Service,
  23. udpForWardingRepository waf.UdpForWardingRepository,
  24. required service.RequiredService,
  25. parser service.ParserService,
  26. crawler service.CrawlerService,
  27. globalRep waf.GlobalLimitRepository,
  28. hostRep repository.HostRepository,
  29. wafformatter WafFormatterService,
  30. cdn flexCdn.CdnService,
  31. proxy flexCdn.ProxyService,
  32. aidedUdp AidedUdpService,
  33. ) UdpForWardingService {
  34. return &udpForWardingService{
  35. Service: service,
  36. udpForWardingRepository: udpForWardingRepository,
  37. required: required,
  38. parser: parser,
  39. crawler: crawler,
  40. globalRep: globalRep,
  41. hostRep: hostRep,
  42. wafformatter: wafformatter,
  43. cdn: cdn,
  44. proxy: proxy,
  45. aidedUdp: aidedUdp,
  46. }
  47. }
  48. type udpForWardingService struct {
  49. *service.Service
  50. udpForWardingRepository waf.UdpForWardingRepository
  51. required service.RequiredService
  52. parser service.ParserService
  53. crawler service.CrawlerService
  54. globalRep waf.GlobalLimitRepository
  55. hostRep repository.HostRepository
  56. wafformatter WafFormatterService
  57. cdn flexCdn.CdnService
  58. proxy flexCdn.ProxyService
  59. aidedUdp AidedUdpService
  60. }
  61. // GetUdpForWarding 获取单个UDP转发配置详情
  62. // 该函数根据ID同时查询主记录和规则记录,并合并返回完整的配置信息
  63. func (s *udpForWardingService) GetUdpForWarding(ctx context.Context, req v1.GetForwardingRequest) (v1.UdpForwardingDataRequest, error) {
  64. // 参数验证
  65. if req.Id <= 0 {
  66. return v1.UdpForwardingDataRequest{}, fmt.Errorf("非法的ID参数: %d", req.Id)
  67. }
  68. var udpForWarding model.UdpForWarding
  69. var backend model.UdpForwardingRule
  70. var err error
  71. // 并发查询主记录和规则记录以提高性能
  72. g, gCtx := errgroup.WithContext(ctx)
  73. g.Go(func() error {
  74. res, e := s.udpForWardingRepository.GetUdpForWarding(gCtx, int64(req.Id))
  75. if e != nil {
  76. return fmt.Errorf("查询UDP转发主记录失败 ID:%d, %w", req.Id, e)
  77. }
  78. if res != nil {
  79. udpForWarding = *res
  80. }
  81. return nil
  82. })
  83. g.Go(func() error {
  84. res, e := s.udpForWardingRepository.GetUdpForwardingIpsByID(gCtx, req.Id)
  85. if e != nil {
  86. return fmt.Errorf("查询UDP转发规则记录失败 ID:%d, %w", req.Id, e)
  87. }
  88. if res != nil {
  89. backend = *res
  90. }
  91. return nil
  92. })
  93. if err = g.Wait(); err != nil {
  94. return v1.UdpForwardingDataRequest{}, err
  95. }
  96. // 检查是否找到主记录
  97. if udpForWarding.Id == 0 {
  98. return v1.UdpForwardingDataRequest{}, fmt.Errorf("UDP转发配置不存在 ID:%d", req.Id)
  99. }
  100. return v1.UdpForwardingDataRequest{
  101. Id: udpForWarding.Id,
  102. Port: udpForWarding.Port,
  103. BackendList: backend.BackendList,
  104. Comment: udpForWarding.Comment,
  105. Proxy: udpForWarding.Proxy,
  106. }, nil
  107. }
  108. // AddUdpForwarding 添加 UDP 转发配置
  109. // 该函数完成 UDP 转发的完整创建流程:验证、创建 CDN、添加源站、配置代理、保存数据、处理异步任务
  110. func (s *udpForWardingService) AddUdpForwarding(ctx context.Context, req *v1.UdpForwardingRequest) (int, error) {
  111. // 1. 数据准备和验证
  112. require, formData, err := s.aidedUdp.PrepareWafData(ctx, req)
  113. if err != nil {
  114. return 0, err
  115. }
  116. if err := s.aidedUdp.ValidateAddRequest(ctx, req, require); err != nil {
  117. return 0, err
  118. }
  119. // 2. 创建CDN网站
  120. udpId, err := s.aidedUdp.CreateCdnWebsite(ctx, formData)
  121. if err != nil {
  122. return 0, err
  123. }
  124. // 3. 添加源站
  125. cdnOriginIds, err := s.aidedUdp.AddOriginsToWebsite(ctx, req, udpId)
  126. if err != nil {
  127. return 0, err
  128. }
  129. // 4. 配置代理协议
  130. if err := s.aidedUdp.ConfigureProxyProtocol(ctx, req, udpId); err != nil {
  131. return 0, err
  132. }
  133. // 5. 保存到数据库
  134. id, err := s.aidedUdp.SaveToDatabase(ctx, req, require, udpId, cdnOriginIds)
  135. if err != nil {
  136. return 0, err
  137. }
  138. // 6. 处理异步任务
  139. s.aidedUdp.ProcessAsyncTasks(req)
  140. return id, nil
  141. }
  142. // EditUdpForwarding 编辑 UDP 转发配置
  143. // 该函数完成 UDP 转发的完整编辑流程:验证、更新 CDN、处理IP白名单、更新源站、更新数据库
  144. func (s *udpForWardingService) EditUdpForwarding(ctx context.Context, req *v1.UdpForwardingRequest) error {
  145. // 1. 数据准备和验证
  146. require, formData, err := s.aidedUdp.PrepareWafData(ctx, req)
  147. if err != nil {
  148. return err
  149. }
  150. oldData, err := s.udpForWardingRepository.GetUdpForWarding(ctx, int64(req.UdpForwardingData.Id))
  151. if err != nil {
  152. return fmt.Errorf("获取原始数据失败: %w", err)
  153. }
  154. if err := s.aidedUdp.ValidateEditRequest(ctx, req, require, oldData); err != nil {
  155. return err
  156. }
  157. // 2. 更新CDN配置
  158. if err := s.aidedUdp.UpdateCdnConfiguration(ctx, req, oldData, require, formData); err != nil {
  159. return err
  160. }
  161. // 3. 获取IP数据并处理白名单
  162. ipData, err := s.udpForWardingRepository.GetUdpForwardingIpsByID(ctx, req.UdpForwardingData.Id)
  163. if err != nil {
  164. return fmt.Errorf("获取IP数据失败: %w", err)
  165. }
  166. if err := s.aidedUdp.ProcessIpWhitelistChanges(ctx, req, ipData); err != nil {
  167. return err
  168. }
  169. // 4. 更新源站配置
  170. if err := s.aidedUdp.UpdateOriginServers(ctx, req, oldData, ipData); err != nil {
  171. return err
  172. }
  173. // 5. 更新数据库记录
  174. if err := s.aidedUdp.UpdateDatabaseRecords(ctx, req, oldData, require, ipData); err != nil {
  175. return err
  176. }
  177. return nil
  178. }
  179. // DeleteUdpForwarding 批量删除 UDP 转发配置
  180. // 该函数支持批量删除多个 UDP 转发配置,对每个配置都执行完整的删除流程
  181. func (s *udpForWardingService) DeleteUdpForwarding(ctx context.Context, req v1.DeleteUdpForwardingRequest) error {
  182. // 批量删除处理
  183. for _, id := range req.Ids {
  184. if err := s.deleteSingleUdpForwarding(ctx, id, req.HostId); err != nil {
  185. return fmt.Errorf("删除UDP转发配置失败 ID:%d, %w", id, err)
  186. }
  187. }
  188. return nil
  189. }
  190. // deleteSingleUdpForwarding 删除单个 UDP 转发配置
  191. // 该函数完成单个配置的完整删除流程:权限验证、删除 CDN、清理IP白名单、清理数据库
  192. func (s *udpForWardingService) deleteSingleUdpForwarding(ctx context.Context, id int, hostId int) error {
  193. // 1. 获取原始数据并验证权限
  194. oldData, err := s.udpForWardingRepository.GetUdpForWarding(ctx, int64(id))
  195. if err != nil {
  196. return fmt.Errorf("获取UDP转发数据失败: %w", err)
  197. }
  198. if err := s.aidedUdp.ValidateDeletePermission(oldData, hostId); err != nil {
  199. return err
  200. }
  201. // 2. 删除CDN服务器
  202. if err := s.aidedUdp.DeleteCdnServer(ctx, oldData.CdnWebId); err != nil {
  203. return err
  204. }
  205. // 3. 处理IP白名单清理
  206. if err := s.aidedUdp.ProcessDeleteIpWhitelist(ctx, id); err != nil {
  207. return err
  208. }
  209. // 4. 清理数据库记录
  210. if err := s.aidedUdp.CleanupDatabaseRecords(ctx, id); err != nil {
  211. return err
  212. }
  213. return nil
  214. }
  215. // GetUdpForwardingWafUdpAllIps 获取指定主机的所有 UDP 转发配置列表
  216. // 该函数使用并发查询优化性能,同时获取多个配置的详细信息并按ID降序排列
  217. func (s *udpForWardingService) GetUdpForwardingWafUdpAllIps(ctx context.Context, req v1.GetForwardingRequest) ([]v1.UdpForwardingDataRequest, error) {
  218. type CombinedResult struct {
  219. Id int
  220. Forwarding *model.UdpForWarding
  221. BackendRule *model.UdpForwardingRule
  222. Err error
  223. }
  224. g, gCtx := errgroup.WithContext(ctx)
  225. ids, err := s.udpForWardingRepository.GetUdpForwardingWafUdpAllIds(gCtx, req.HostId)
  226. if err != nil {
  227. return nil, fmt.Errorf("获取UDP转发ID列表失败: %w", err)
  228. }
  229. if len(ids) == 0 {
  230. return nil, nil
  231. }
  232. resChan := make(chan CombinedResult, len(ids))
  233. g.Go(func() error {
  234. for _, idVal := range ids {
  235. currentID := idVal
  236. g.Go(func() error {
  237. var wf *model.UdpForWarding
  238. var bk *model.UdpForwardingRule
  239. var localErr error
  240. wf, localErr = s.udpForWardingRepository.GetUdpForWarding(gCtx, int64(currentID))
  241. if localErr != nil {
  242. resChan <- CombinedResult{Id: currentID, Err: localErr}
  243. return localErr
  244. }
  245. bk, localErr = s.udpForWardingRepository.GetUdpForwardingIpsByID(gCtx, currentID)
  246. if localErr != nil {
  247. resChan <- CombinedResult{Id: currentID, Err: localErr}
  248. return localErr
  249. }
  250. resChan <- CombinedResult{Id: currentID, Forwarding: wf, BackendRule: bk}
  251. return nil
  252. })
  253. }
  254. return nil
  255. })
  256. groupErr := g.Wait()
  257. close(resChan)
  258. if groupErr != nil {
  259. return nil, groupErr
  260. }
  261. res := make([]v1.UdpForwardingDataRequest, 0, len(ids))
  262. for r := range resChan {
  263. if r.Err != nil {
  264. return nil, fmt.Errorf("处理ID %d 时出错: %w", r.Id, r.Err)
  265. }
  266. if r.Forwarding == nil {
  267. return nil, fmt.Errorf("ID %d 对应的转发配置为空", r.Id)
  268. }
  269. dataReq := v1.UdpForwardingDataRequest{
  270. Id: r.Forwarding.Id,
  271. Port: r.Forwarding.Port,
  272. Comment: r.Forwarding.Comment,
  273. Proxy: r.Forwarding.Proxy,
  274. }
  275. if r.BackendRule != nil {
  276. dataReq.BackendList = r.BackendRule.BackendList
  277. }
  278. res = append(res, dataReq)
  279. }
  280. sort.Slice(res, func(i, j int) bool {
  281. return res[i].Id > res[j].Id
  282. })
  283. return res, nil
  284. }