wafformatter.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. package service
  2. import (
  3. "context"
  4. "fmt"
  5. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  6. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  7. "github.com/spf13/cast"
  8. "slices"
  9. "strconv"
  10. )
  11. type WafFormatterService interface {
  12. require(ctx context.Context, req v1.GlobalRequire, category string) (v1.GlobalRequire, error)
  13. sendFormData(ctx context.Context,addTokenUrl string,addSendUrl string,formData map[string]interface{}) (int, error)
  14. validateWafPortCount(ctx context.Context, hostId int) error
  15. validateWafDomainCount(ctx context.Context, req v1.GlobalRequire) error
  16. }
  17. func NewWafFormatterService(
  18. service *Service,
  19. globalRep repository.GlobalLimitRepository,
  20. hostRep repository.HostRepository,
  21. required RequiredService,
  22. parser ParserService,
  23. tcpforwardingRep repository.TcpforwardingRepository,
  24. udpForWardingRep repository.UdpForWardingRepository,
  25. webForwardingRep repository.WebForwardingRepository,
  26. host HostService,
  27. ) WafFormatterService {
  28. return &wafFormatterService{
  29. Service: service,
  30. globalRep: globalRep,
  31. hostRep: hostRep,
  32. required: required,
  33. parser: parser,
  34. tcpforwardingRep: tcpforwardingRep,
  35. udpForWardingRep: udpForWardingRep,
  36. webForwardingRep: webForwardingRep,
  37. host : host,
  38. }
  39. }
  40. type wafFormatterService struct {
  41. *Service
  42. globalRep repository.GlobalLimitRepository
  43. hostRep repository.HostRepository
  44. required RequiredService
  45. parser ParserService
  46. tcpforwardingRep repository.TcpforwardingRepository
  47. udpForWardingRep repository.UdpForWardingRepository
  48. webForwardingRep repository.WebForwardingRepository
  49. host HostService
  50. }
  51. func (s *wafFormatterService) require(ctx context.Context,req v1.GlobalRequire,category string) (v1.GlobalRequire, error) {
  52. RuleIds, err := s.globalRep.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  53. if err != nil {
  54. return v1.GlobalRequire{}, err
  55. }
  56. req.WafGatewayGroupId = RuleIds.GatewayGroupId
  57. switch category {
  58. case "tcp":
  59. req.LimitRuleId = RuleIds.TcpLimitRuleId
  60. case "udp":
  61. req.LimitRuleId = RuleIds.UdpLimitRuleId
  62. case "web":
  63. req.LimitRuleId = RuleIds.WebLimitRuleId
  64. }
  65. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  66. if err != nil {
  67. return v1.GlobalRequire{}, err
  68. }
  69. req.Tag = strconv.Itoa(req.Uid) + "_" + strconv.Itoa(req.HostId) + "_" + domain + "_" + req.Comment
  70. return req, nil
  71. }
  72. func (s *wafFormatterService) sendFormData(ctx context.Context,addTokenUrl string,addSendUrl string,formData map[string]interface{}) (int, error) {
  73. respBody, err := s.required.SendForm(ctx, addTokenUrl, addSendUrl, formData)
  74. if err != nil {
  75. return 0, err
  76. }
  77. // 解析响应内容中的 alert 消息
  78. res, err := s.parser.ParseAlert(string(respBody))
  79. if err != nil {
  80. return 0,err
  81. }
  82. if res != "" {
  83. return 0,fmt.Errorf(res)
  84. }
  85. ruleIdStr, err := s.parser.GetRuleIdByColumnName(ctx, respBody,formData["tag"].(string))
  86. if err != nil {
  87. return 0, err
  88. }
  89. ruleId, err := cast.ToIntE(ruleIdStr)
  90. if err != nil {
  91. return 0,err
  92. }
  93. return ruleId, nil
  94. }
  95. func (s *wafFormatterService) validateWafPortCount(ctx context.Context, hostId int) error {
  96. congfig, err := s.host.GetGlobalLimitConfig(ctx, hostId)
  97. if err != nil {
  98. return err
  99. }
  100. tcpCount, err := s.tcpforwardingRep.GetTcpForwardingPortCountByHostId(ctx, hostId)
  101. if err != nil {
  102. return err
  103. }
  104. udpCount, err := s.udpForWardingRep.GetUdpForwardingPortCountByHostId(ctx, hostId)
  105. if err != nil {
  106. return err
  107. }
  108. webCount, err := s.webForwardingRep.GetWebForwardingPortCountByHostId(ctx, hostId)
  109. if err != nil {
  110. return err
  111. }
  112. if int64(congfig.PortCount) > tcpCount + udpCount + webCount {
  113. return nil
  114. }
  115. return fmt.Errorf("端口数量超出套餐限制,已配置%d个端口,套餐限制为%d个端口", tcpCount+udpCount+webCount, congfig.PortCount)
  116. }
  117. func (s *wafFormatterService) validateWafDomainCount(ctx context.Context, req v1.GlobalRequire) error {
  118. congfig, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  119. if err != nil {
  120. return err
  121. }
  122. domainCount, domainSlice, err := s.webForwardingRep.GetWebForwardingDomainCountByHostId(ctx, req.HostId)
  123. if err != nil {
  124. return err
  125. }
  126. if req.Domain != "" {
  127. if !slices.Contains(domainSlice, req.Domain) {
  128. domainCount += 1
  129. if domainCount > int64(congfig.DomainCount) {
  130. return fmt.Errorf("域名数量已达到上限,已配置%d个域名,套餐限制为%d个域名", domainCount, congfig.DomainCount)
  131. }
  132. }
  133. }
  134. return nil
  135. }