wafformatter.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. package service
  2. import (
  3. "context"
  4. "fmt"
  5. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  6. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  7. "github.com/spf13/cast"
  8. "golang.org/x/net/publicsuffix"
  9. "slices"
  10. "strconv"
  11. )
  12. type WafFormatterService interface {
  13. require(ctx context.Context, req v1.GlobalRequire, category string) (v1.GlobalRequire, error)
  14. sendFormData(ctx context.Context,addTokenUrl string,addSendUrl string,formData map[string]interface{}) (int, error)
  15. validateWafPortCount(ctx context.Context, hostId int) error
  16. validateWafDomainCount(ctx context.Context, req v1.GlobalRequire) error
  17. ConvertToWildcardDomain(ctx context.Context,domain string) (string, error)
  18. }
  19. func NewWafFormatterService(
  20. service *Service,
  21. globalRep repository.GlobalLimitRepository,
  22. hostRep repository.HostRepository,
  23. required RequiredService,
  24. parser ParserService,
  25. tcpforwardingRep repository.TcpforwardingRepository,
  26. udpForWardingRep repository.UdpForWardingRepository,
  27. webForwardingRep repository.WebForwardingRepository,
  28. host HostService,
  29. ) WafFormatterService {
  30. return &wafFormatterService{
  31. Service: service,
  32. globalRep: globalRep,
  33. hostRep: hostRep,
  34. required: required,
  35. parser: parser,
  36. tcpforwardingRep: tcpforwardingRep,
  37. udpForWardingRep: udpForWardingRep,
  38. webForwardingRep: webForwardingRep,
  39. host : host,
  40. }
  41. }
  42. type wafFormatterService struct {
  43. *Service
  44. globalRep repository.GlobalLimitRepository
  45. hostRep repository.HostRepository
  46. required RequiredService
  47. parser ParserService
  48. tcpforwardingRep repository.TcpforwardingRepository
  49. udpForWardingRep repository.UdpForWardingRepository
  50. webForwardingRep repository.WebForwardingRepository
  51. host HostService
  52. }
  53. func (s *wafFormatterService) require(ctx context.Context,req v1.GlobalRequire,category string) (v1.GlobalRequire, error) {
  54. RuleIds, err := s.globalRep.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  55. if err != nil {
  56. return v1.GlobalRequire{}, err
  57. }
  58. req.WafGatewayGroupId = RuleIds.GatewayGroupId
  59. switch category {
  60. case "tcp":
  61. req.LimitRuleId = RuleIds.TcpLimitRuleId
  62. case "udp":
  63. req.LimitRuleId = RuleIds.UdpLimitRuleId
  64. case "web":
  65. req.LimitRuleId = RuleIds.WebLimitRuleId
  66. }
  67. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  68. if err != nil {
  69. return v1.GlobalRequire{}, err
  70. }
  71. req.Tag = strconv.Itoa(req.Uid) + "_" + strconv.Itoa(req.HostId) + "_" + domain + "_" + req.Comment
  72. return req, nil
  73. }
  74. func (s *wafFormatterService) sendFormData(ctx context.Context,addTokenUrl string,addSendUrl string,formData map[string]interface{}) (int, error) {
  75. respBody, err := s.required.SendForm(ctx, addTokenUrl, addSendUrl, formData)
  76. if err != nil {
  77. return 0, err
  78. }
  79. // 解析响应内容中的 alert 消息
  80. res, err := s.parser.ParseAlert(string(respBody))
  81. if err != nil {
  82. return 0,err
  83. }
  84. if res != "" {
  85. return 0,fmt.Errorf(res)
  86. }
  87. ruleIdStr, err := s.parser.GetRuleIdByColumnName(ctx, respBody,formData["tag"].(string))
  88. if err != nil {
  89. return 0, err
  90. }
  91. ruleId, err := cast.ToIntE(ruleIdStr)
  92. if err != nil {
  93. return 0,err
  94. }
  95. return ruleId, nil
  96. }
  97. func (s *wafFormatterService) validateWafPortCount(ctx context.Context, hostId int) error {
  98. congfig, err := s.host.GetGlobalLimitConfig(ctx, hostId)
  99. if err != nil {
  100. return err
  101. }
  102. tcpCount, err := s.tcpforwardingRep.GetTcpForwardingPortCountByHostId(ctx, hostId)
  103. if err != nil {
  104. return err
  105. }
  106. udpCount, err := s.udpForWardingRep.GetUdpForwardingPortCountByHostId(ctx, hostId)
  107. if err != nil {
  108. return err
  109. }
  110. webCount, err := s.webForwardingRep.GetWebForwardingPortCountByHostId(ctx, hostId)
  111. if err != nil {
  112. return err
  113. }
  114. if int64(congfig.PortCount) > tcpCount + udpCount + webCount {
  115. return nil
  116. }
  117. return fmt.Errorf("端口数量超出套餐限制,已配置%d个端口,套餐限制为%d个端口", tcpCount+udpCount+webCount, congfig.PortCount)
  118. }
  119. func (s *wafFormatterService) validateWafDomainCount(ctx context.Context, req v1.GlobalRequire) error {
  120. congfig, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  121. if err != nil {
  122. return err
  123. }
  124. domainCount, domainSlice, err := s.webForwardingRep.GetWebForwardingDomainCountByHostId(ctx, req.HostId)
  125. if err != nil {
  126. return err
  127. }
  128. if req.Domain != "" {
  129. if !slices.Contains(domainSlice, req.Domain) {
  130. domainCount += 1
  131. if domainCount > int64(congfig.DomainCount) {
  132. return fmt.Errorf("域名数量已达到上限,已配置%d个域名,套餐限制为%d个域名", domainCount, congfig.DomainCount)
  133. }
  134. }
  135. }
  136. return nil
  137. }
  138. func (s *wafFormatterService) ConvertToWildcardDomain(ctx context.Context, domain string) (string, error) {
  139. // 1. 使用 EffectiveTLDPlusOne 获取可注册域名部分。
  140. // 例如,对于 "www.google.com",这将返回 "google.com"。
  141. // 对于 "a.b.c.tokyo.jp",这将返回 "c.tokyo.jp"。
  142. registrableDomain, err := publicsuffix.EffectiveTLDPlusOne(domain)
  143. if err != nil {
  144. // 如果域名无效(如 IP 地址、localhost),则返回错误。
  145. return "", fmt.Errorf("无法处理 '%s': %w", domain, err)
  146. }
  147. // 2. 比较原始域名和可注册域名。
  148. // 如果它们不相等,说明原始域名包含子域名。
  149. if domain != registrableDomain {
  150. // 3. 如果存在子域名,则用 "*." 加上可注册域名来构造通配符域名。
  151. return registrableDomain, nil
  152. }
  153. // 4. 如果原始域名和可注册域名相同(例如,输入就是 "google.com"),
  154. // 则说明没有子域名可替换,直接返回原始域名。
  155. return domain, nil
  156. }