webforwarding.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. package service
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. "golang.org/x/sync/errgroup"
  10. "strconv"
  11. "strings"
  12. )
  13. type WebForwardingService interface {
  14. GetWebForwarding(ctx context.Context, req v1.GetForwardingRequest) (v1.WebForwardingDataRequest, error)
  15. AddWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) error
  16. EditWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) error
  17. DeleteWebForwarding(ctx context.Context, Ids []int) error
  18. }
  19. func NewWebForwardingService(
  20. service *Service,
  21. required RequiredService,
  22. webForwardingRepository repository.WebForwardingRepository,
  23. crawler CrawlerService,
  24. parser ParserService,
  25. wafformatter WafFormatterService,
  26. ) WebForwardingService {
  27. return &webForwardingService{
  28. Service: service,
  29. webForwardingRepository: webForwardingRepository,
  30. required: required,
  31. parser: parser,
  32. crawler: crawler,
  33. wafformatter: wafformatter,
  34. }
  35. }
  36. type webForwardingService struct {
  37. *Service
  38. webForwardingRepository repository.WebForwardingRepository
  39. required RequiredService
  40. parser ParserService
  41. crawler CrawlerService
  42. wafformatter WafFormatterService
  43. }
  44. func (s *webForwardingService) require(ctx context.Context,req v1.GlobalRequire) (v1.GlobalRequire, error) {
  45. var err error
  46. var res v1.GlobalRequire
  47. g, gCtx := errgroup.WithContext(ctx)
  48. g.Go(func() error {
  49. result, e := s.wafformatter.require(gCtx, req, "web")
  50. if e != nil {
  51. return e
  52. }
  53. res = result
  54. return nil
  55. })
  56. g.Go(func() error {
  57. e := s.wafformatter.validateWafDomainCount(gCtx, req)
  58. if e != nil {
  59. return e
  60. }
  61. return nil
  62. })
  63. if err = g.Wait(); err != nil {
  64. return v1.GlobalRequire{}, err
  65. }
  66. return res, nil
  67. }
  68. func (s *webForwardingService) GetWebForwarding(ctx context.Context, req v1.GetForwardingRequest) (v1.WebForwardingDataRequest, error) {
  69. var webForwarding model.WebForwarding
  70. var backend model.WebForwardingRule
  71. g, gCtx := errgroup.WithContext(ctx)
  72. g.Go(func() error {
  73. res, e := s.webForwardingRepository.GetWebForwarding(gCtx, int64(req.Id))
  74. if e != nil {
  75. // 直接返回错误,errgroup 会捕获它
  76. return fmt.Errorf("GetWebForwarding failed: %w", e)
  77. }
  78. if res != nil {
  79. webForwarding = *res
  80. }
  81. return nil
  82. })
  83. g.Go(func() error {
  84. res, e := s.webForwardingRepository.GetWebForwardingIpsByID(ctx, req.Id)
  85. if e != nil {
  86. return fmt.Errorf("GetWebForwardingByID failed: %w", e)
  87. }
  88. if res != nil {
  89. backend = *res
  90. }
  91. return nil
  92. })
  93. if err := g.Wait(); err != nil {
  94. return v1.WebForwardingDataRequest{}, err
  95. }
  96. return v1.WebForwardingDataRequest{
  97. Id: webForwarding.Id,
  98. WafWebId: webForwarding.WafWebId,
  99. Tag: webForwarding.Tag,
  100. Port: webForwarding.Port,
  101. Domain: webForwarding.Domain,
  102. CustomHost: webForwarding.CustomHost,
  103. WafWebLimitId: webForwarding.WebLimitRuleId,
  104. WafGatewayGroupId: webForwarding.WafGatewayGroupId,
  105. CcCount: webForwarding.CcCount,
  106. CcDuration: webForwarding.CcDuration,
  107. CcBlockCount: webForwarding.CcBlockCount,
  108. CcBlockDuration: webForwarding.CcBlockDuration,
  109. Cc4xxCount: webForwarding.Cc4xxCount,
  110. Cc4xxDuration: webForwarding.Cc4xxDuration,
  111. Cc4xxBlockCount: webForwarding.Cc4xxBlockCount,
  112. Cc4xxBlockDuration: webForwarding.Cc4xxBlockDuration,
  113. Cc5xxCount: webForwarding.Cc5xxCount,
  114. Cc5xxDuration: webForwarding.Cc5xxDuration,
  115. Cc5xxBlockCount: webForwarding.Cc5xxBlockCount,
  116. Cc5xxBlockDuration: webForwarding.Cc5xxBlockDuration,
  117. IsHttps: webForwarding.IsHttps,
  118. Comment: webForwarding.Comment,
  119. BackendList: backend.BackendList,
  120. AllowIpList: backend.AllowIpList,
  121. DenyIpList: backend.DenyIpList,
  122. AccessRule: backend.AccessRule,
  123. HttpsKey: webForwarding.HttpsKey,
  124. HttpsCert: webForwarding.HttpsCert,
  125. }, nil
  126. }
  127. // buildWafFormData 辅助函数,用于构建通用的 formData
  128. func (s *webForwardingService) buildWafFormData(req *v1.WebForwardingDataSend, require v1.GlobalRequire) map[string]interface{} {
  129. // 将BackendList序列化为JSON字符串
  130. backendJSON, err := json.MarshalIndent(req.BackendList, "", " ")
  131. var backendStr interface{}
  132. if err != nil {
  133. // 如果序列化失败,使用空数组
  134. backendStr = "[]"
  135. } else {
  136. // 成功序列化后,使用JSON字符串
  137. backendStr = string(backendJSON)
  138. }
  139. return map[string]interface{}{
  140. "waf_web_id": req.WafWebId,
  141. "tag": require.Tag,
  142. "port": req.Port,
  143. "domain": req.Domain,
  144. "custom_host": req.CustomHost,
  145. "waf_gateway_group_id": require.WafGatewayGroupId,
  146. "waf_web_limit_id": require.LimitRuleId,
  147. "cc_count": req.CcCount,
  148. "cc_duration": req.CcDuration,
  149. "cc_block_count": req.CcBlockCount,
  150. "cc_block_duration": req.CcBlockDuration,
  151. "cc_4xx_count": req.Cc4xxCount,
  152. "cc_4xx_duration": req.Cc4xxDuration,
  153. "cc_4xx_block_count": req.Cc4xxBlockCount,
  154. "cc_4xx_block_duration": req.Cc4xxBlockDuration,
  155. "cc_5xx_count": req.Cc5xxCount,
  156. "cc_5xx_duration": req.Cc5xxDuration,
  157. "cc_5xx_block_count": req.Cc5xxBlockCount,
  158. "cc_5xx_block_duration": req.Cc5xxBlockDuration,
  159. "backend": backendStr,
  160. "allow_ip_list": req.AllowIpList,
  161. "deny_ip_list": req.DenyIpList,
  162. "access_rule": req.AccessRule,
  163. "is_https": req.IsHttps,
  164. "comment": req.Comment,
  165. "https_cert": req.HttpsCert,
  166. "https_key": req.HttpsKey,
  167. }
  168. }
  169. // buildWebForwardingModel 辅助函数,用于构建通用的 WebForwarding 模型
  170. // ruleId 是从 WAF 系统获取的 ID
  171. func (s *webForwardingService) buildWebForwardingModel(req *v1.WebForwardingDataRequest,ruleId int, require v1.GlobalRequire) *model.WebForwarding {
  172. return &model.WebForwarding{
  173. HostId: require.HostId,
  174. WafWebId: ruleId,
  175. Tag: require.Tag,
  176. Port: req.Port,
  177. Domain: req.Domain,
  178. CustomHost: req.CustomHost,
  179. WafGatewayGroupId: require.WafGatewayGroupId,
  180. WebLimitRuleId: require.LimitRuleId,
  181. CcCount: req.CcCount,
  182. CcDuration: req.CcDuration,
  183. CcBlockCount: req.CcBlockCount,
  184. CcBlockDuration: req.CcBlockDuration,
  185. Cc4xxCount: req.Cc4xxCount,
  186. Cc4xxDuration: req.Cc4xxDuration,
  187. Cc4xxBlockCount: req.Cc4xxBlockCount,
  188. Cc4xxBlockDuration: req.Cc4xxBlockDuration,
  189. Cc5xxCount: req.Cc5xxCount,
  190. Cc5xxDuration: req.Cc5xxDuration,
  191. Cc5xxBlockCount: req.Cc5xxBlockCount,
  192. Cc5xxBlockDuration: req.Cc5xxBlockDuration,
  193. IsHttps: req.IsHttps,
  194. Comment: req.Comment,
  195. HttpsCert: req.HttpsCert,
  196. HttpsKey: req.HttpsKey,
  197. }
  198. }
  199. func (s *webForwardingService) buildWebRuleModel(reqData *v1.WebForwardingDataRequest, require v1.GlobalRequire, localDbId int) *model.WebForwardingRule {
  200. return &model.WebForwardingRule{
  201. Uid: require.Uid,
  202. HostId: require.HostId,
  203. WebId: localDbId, // 关联到本地数据库的主记录 ID
  204. BackendList: reqData.BackendList,
  205. AllowIpList: reqData.AllowIpList,
  206. DenyIpList: reqData.DenyIpList,
  207. AccessRule: reqData.AccessRule,
  208. }
  209. }
  210. func (s *webForwardingService) prepareWafData(ctx context.Context, req *v1.WebForwardingRequest) (v1.GlobalRequire, map[string]interface{}, error) {
  211. // 1. 获取必要的全局信息
  212. require, err := s.require(ctx, v1.GlobalRequire{
  213. HostId: req.HostId,
  214. Uid: req.Uid,
  215. Comment: req.WebForwardingData.Comment,
  216. Domain: req.WebForwardingData.Domain,
  217. })
  218. if err != nil {
  219. return v1.GlobalRequire{}, nil, err
  220. }
  221. if require.WafGatewayGroupId == 0 || require.LimitRuleId == 0 {
  222. return v1.GlobalRequire{}, nil, fmt.Errorf("请先配置实例")
  223. }
  224. // 2. 将字符串切片拼接成字符串,用于 WAF API
  225. allowIpListStr := strings.Join(req.WebForwardingData.AllowIpList, "\n")
  226. denyIpListStr := strings.Join(req.WebForwardingData.DenyIpList, "\n")
  227. PortInt, err := strconv.Atoi(req.WebForwardingData.Port)
  228. if err != nil {
  229. return v1.GlobalRequire{}, nil, err
  230. }
  231. // 3. 创建用于构建 WAF 表单的数据结构
  232. formDataBase := v1.WebForwardingDataSend{
  233. Tag: require.Tag,
  234. WafWebId: req.WebForwardingData.WafWebId,
  235. WafGatewayGroupId: require.WafGatewayGroupId,
  236. WafWebLimitId: require.LimitRuleId,
  237. Port: PortInt,
  238. Domain: req.WebForwardingData.Domain,
  239. CustomHost: req.WebForwardingData.CustomHost,
  240. CcCount: req.WebForwardingData.CcCount,
  241. CcDuration: req.WebForwardingData.CcDuration,
  242. CcBlockCount: req.WebForwardingData.CcBlockCount,
  243. CcBlockDuration: req.WebForwardingData.CcBlockDuration,
  244. Cc4xxCount: req.WebForwardingData.Cc4xxCount,
  245. Cc4xxDuration: req.WebForwardingData.Cc4xxDuration,
  246. Cc4xxBlockCount: req.WebForwardingData.Cc4xxBlockCount,
  247. Cc4xxBlockDuration: req.WebForwardingData.Cc4xxBlockDuration,
  248. Cc5xxCount: req.WebForwardingData.Cc5xxCount,
  249. Cc5xxDuration: req.WebForwardingData.Cc5xxDuration,
  250. Cc5xxBlockCount: req.WebForwardingData.Cc5xxBlockCount,
  251. Cc5xxBlockDuration: req.WebForwardingData.Cc5xxBlockDuration,
  252. IsHttps: req.WebForwardingData.IsHttps,
  253. BackendList: req.WebForwardingData.BackendList,
  254. AllowIpList: allowIpListStr,
  255. DenyIpList: denyIpListStr,
  256. AccessRule: req.WebForwardingData.AccessRule,
  257. Comment: req.WebForwardingData.Comment,
  258. HttpsCert: req.WebForwardingData.HttpsCert,
  259. HttpsKey: req.WebForwardingData.HttpsKey,
  260. }
  261. // 4. 构建 WAF 表单数据映射
  262. formData := s.buildWafFormData(&formDataBase, require)
  263. return require, formData, nil
  264. }
  265. func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) error {
  266. require, formData, err := s.prepareWafData(ctx, req)
  267. if err != nil {
  268. return err
  269. }
  270. err = s.wafformatter.validateWafPortCount(ctx, require.HostId)
  271. if err != nil {
  272. return err
  273. }
  274. wafWebId, err := s.wafformatter.sendFormData(ctx, "admin/info/waf_web/new", "admin/new/waf_web", formData)
  275. if err != nil {
  276. return err
  277. }
  278. webModel := s.buildWebForwardingModel(&req.WebForwardingData, wafWebId, require)
  279. id, err := s.webForwardingRepository.AddWebForwarding(ctx, webModel)
  280. if err != nil {
  281. return err
  282. }
  283. webRuleModel := s.buildWebRuleModel(&req.WebForwardingData, require, id)
  284. if _, err = s.webForwardingRepository.AddWebForwardingIps(ctx, *webRuleModel); err != nil {
  285. return err
  286. }
  287. return nil
  288. }
  289. func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) error {
  290. WafWebId, err := s.webForwardingRepository.GetWebForwardingWafWebIdById(ctx, req.WebForwardingData.Id)
  291. if err != nil {
  292. return err
  293. }
  294. req.WebForwardingData.WafWebId = WafWebId
  295. require, formData, err := s.prepareWafData(ctx, req)
  296. if err != nil {
  297. return err
  298. }
  299. _, err = s.wafformatter.sendFormData(ctx, "admin/info/waf_web/edit?&__goadmin_edit_pk="+strconv.Itoa(req.WebForwardingData.WafWebId), "admin/edit/waf_web", formData)
  300. if err != nil {
  301. return err
  302. }
  303. webModel := s.buildWebForwardingModel(&req.WebForwardingData, req.WebForwardingData.WafWebId, require)
  304. webModel.Id = req.WebForwardingData.Id
  305. if err = s.webForwardingRepository.EditWebForwarding(ctx, webModel); err != nil {
  306. return err
  307. }
  308. webRuleModel := s.buildWebRuleModel(&req.WebForwardingData, require, req.WebForwardingData.Id)
  309. if err = s.webForwardingRepository.EditWebForwardingIps(ctx, *webRuleModel); err != nil {
  310. return err
  311. }
  312. return nil
  313. }
  314. func (s *webForwardingService) DeleteWebForwarding(ctx context.Context, Ids []int) error {
  315. for _, Id := range Ids {
  316. wafWebId, err := s.webForwardingRepository.GetWebForwardingWafWebIdById(ctx, Id)
  317. if err != nil {
  318. return err
  319. }
  320. _, err = s.crawler.DeleteRule(ctx, wafWebId, "admin/delete/waf_web?page=1&__pageSize=10&__sort=waf_web_id&__sort_type=desc")
  321. if err != nil {
  322. return err
  323. }
  324. if err = s.webForwardingRepository.DeleteWebForwarding(ctx, int64(Id)); err != nil {
  325. return err
  326. }
  327. if err = s.webForwardingRepository.DeleteWebForwardingIpsById(ctx, Id); err != nil {
  328. return err
  329. }
  330. }
  331. return nil
  332. }