globallimit.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. package service
  2. import (
  3. "context"
  4. "fmt"
  5. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  6. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  8. "github.com/spf13/cast"
  9. "github.com/spf13/viper"
  10. "golang.org/x/sync/errgroup"
  11. "strconv"
  12. )
  13. type GlobalLimitService interface {
  14. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  15. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  16. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  17. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  18. }
  19. func NewGlobalLimitService(
  20. service *Service,
  21. globalLimitRepository repository.GlobalLimitRepository,
  22. duedate DuedateService,
  23. crawler CrawlerService,
  24. conf *viper.Viper,
  25. required RequiredService,
  26. parser ParserService,
  27. host HostService,
  28. tcpLimit TcpLimitService,
  29. udpLimit UdpLimitService,
  30. webLimit WebLimitService,
  31. gateWayGroup GatewayGroupService,
  32. hostRep repository.HostRepository,
  33. gateWayGroupRep repository.GatewayGroupRepository,
  34. ) GlobalLimitService {
  35. return &globalLimitService{
  36. Service: service,
  37. globalLimitRepository: globalLimitRepository,
  38. duedate: duedate,
  39. crawler: crawler,
  40. Url: conf.GetString("crawler.Url"),
  41. required: required,
  42. parser: parser,
  43. host: host,
  44. tcpLimit: tcpLimit,
  45. udpLimit: udpLimit,
  46. webLimit: webLimit,
  47. gateWayGroup: gateWayGroup,
  48. hostRep: hostRep,
  49. gateWayGroupRep: gateWayGroupRep,
  50. }
  51. }
  52. type globalLimitService struct {
  53. *Service
  54. globalLimitRepository repository.GlobalLimitRepository
  55. duedate DuedateService
  56. crawler CrawlerService
  57. Url string
  58. required RequiredService
  59. parser ParserService
  60. host HostService
  61. tcpLimit TcpLimitService
  62. udpLimit UdpLimitService
  63. webLimit WebLimitService
  64. gateWayGroup GatewayGroupService
  65. hostRep repository.HostRepository
  66. gateWayGroupRep repository.GatewayGroupRepository
  67. }
  68. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  69. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  70. if err != nil {
  71. return v1.GlobalLimitRequireResponse{}, err
  72. }
  73. if isExist {
  74. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("配置限制已存在")
  75. }
  76. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  77. if err != nil {
  78. return v1.GlobalLimitRequireResponse{}, err
  79. }
  80. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  81. if err != nil {
  82. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  83. }
  84. res.Bps = configCount.Bps
  85. res.MaxBytesMonth = configCount.MaxBytesMonth
  86. res.Operator = configCount.Operator
  87. res.IpCount = configCount.IpCount
  88. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  89. if err != nil {
  90. return v1.GlobalLimitRequireResponse{}, err
  91. }
  92. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + strconv.Itoa(req.HostId) + "_" + domain
  93. return res, nil
  94. }
  95. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  96. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  97. }
  98. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  99. require, err := s.GlobalLimitRequire(ctx, req)
  100. if err != nil {
  101. return err
  102. }
  103. gatewayGroupId, err := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(ctx, require.Operator, require.IpCount)
  104. if err != nil {
  105. return err
  106. }
  107. formData := map[string]interface{}{
  108. "tag": require.GlobalLimitName,
  109. "bps": require.Bps,
  110. "max_bytes_month": require.MaxBytesMonth,
  111. "expired_at": require.ExpiredAt,
  112. }
  113. respBody, err := s.required.SendForm(ctx, "admin/info/waf_common_limit/new", "admin/new/waf_common_limit", formData)
  114. if err != nil {
  115. return err
  116. }
  117. ruleIdBase, err := s.parser.GetRuleIdByColumnName(ctx, respBody, require.GlobalLimitName)
  118. if err != nil {
  119. return err
  120. }
  121. if ruleIdBase == "" {
  122. res, err := s.parser.ParseAlert(string(respBody))
  123. if err != nil {
  124. return err
  125. }
  126. return fmt.Errorf(res)
  127. }
  128. ruleId, err := cast.ToIntE(ruleIdBase)
  129. if err != nil {
  130. return err
  131. }
  132. var tcpLimitRuleId, udpLimitRuleId, webLimitRuleId int
  133. g, gCtx := errgroup.WithContext(ctx)
  134. // 启动tcpLimit调用 - 使用独立的请求参数副本
  135. g.Go(func() error {
  136. tcpLimitReq := &v1.GeneralLimitRequireRequest{
  137. Tag: require.GlobalLimitName,
  138. HostId: req.HostId,
  139. RuleId: ruleId,
  140. Uid: req.Uid,
  141. }
  142. result, e := s.tcpLimit.AddTcpLimit(gCtx, tcpLimitReq)
  143. if e != nil {
  144. return fmt.Errorf("tcpLimit调用失败: %w", e)
  145. }
  146. if result != 0 {
  147. tcpLimitRuleId = result
  148. return nil
  149. }
  150. return fmt.Errorf("tcpLimit调用失败,Id为 %d", result)
  151. })
  152. // 启动udpLimit调用 - 使用独立的请求参数副本
  153. g.Go(func() error {
  154. udpLimitReq := &v1.GeneralLimitRequireRequest{
  155. Tag: require.GlobalLimitName,
  156. HostId: req.HostId,
  157. RuleId: ruleId,
  158. Uid: req.Uid,
  159. }
  160. result, e := s.udpLimit.AddUdpLimit(gCtx, udpLimitReq)
  161. if e != nil {
  162. return fmt.Errorf("udpLimit调用失败: %w", e)
  163. }
  164. if result != 0 {
  165. udpLimitRuleId = result
  166. return nil
  167. }
  168. return fmt.Errorf("udpLimit调用失败,Id为 %d", result)
  169. })
  170. // 启动webLimit调用 - 使用独立的请求参数副本
  171. g.Go(func() error {
  172. webLimitReq := &v1.GeneralLimitRequireRequest{
  173. Tag: require.GlobalLimitName,
  174. HostId: req.HostId,
  175. RuleId: ruleId,
  176. Uid: req.Uid,
  177. }
  178. result, e := s.webLimit.AddWebLimit(gCtx, webLimitReq)
  179. if e != nil {
  180. return fmt.Errorf("webLimit调用失败: %w", e)
  181. }
  182. if result != 0 {
  183. webLimitRuleId = result
  184. return nil
  185. }
  186. return fmt.Errorf("webLimit调用失败,Id为 %d", result)
  187. })
  188. if err := g.Wait(); err != nil {
  189. return err
  190. }
  191. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  192. HostId: req.HostId,
  193. RuleId: cast.ToInt(ruleId),
  194. GlobalLimitName: require.GlobalLimitName,
  195. Comment: req.Comment,
  196. TcpLimitRuleId: tcpLimitRuleId,
  197. UdpLimitRuleId: udpLimitRuleId,
  198. WebLimitRuleId: webLimitRuleId,
  199. GatewayGroupId: gatewayGroupId,
  200. })
  201. if err != nil {
  202. return err
  203. }
  204. err = s.gateWayGroupRep.EditGatewayGroup(ctx, &model.GatewayGroup{
  205. RuleId: gatewayGroupId,
  206. HostId: req.HostId,
  207. })
  208. return nil
  209. }
  210. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  211. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  212. HostId: req.HostId,
  213. Comment: req.Comment,
  214. }); err != nil {
  215. return err
  216. }
  217. return nil
  218. }
  219. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  220. if err := s.globalLimitRepository.DeleteGlobalLimitByHostId(ctx, int64(req.HostId)); err != nil {
  221. return err
  222. }
  223. return nil
  224. }