globallimit.go 7.3 KB


  1. package service
  2. import (
  3. "context"
  4. "fmt"
  5. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  6. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  8. "github.com/spf13/cast"
  9. "github.com/spf13/viper"
  10. "strconv"
  11. "sync"
  12. "github.com/sourcegraph/conc"
  13. )
  14. type GlobalLimitService interface {
  15. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  16. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  17. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  18. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  19. }
  20. func NewGlobalLimitService(
  21. service *Service,
  22. globalLimitRepository repository.GlobalLimitRepository,
  23. duedate DuedateService,
  24. crawler CrawlerService,
  25. conf *viper.Viper,
  26. required RequiredService,
  27. parser ParserService,
  28. host HostService,
  29. tcpLimit TcpLimitService,
  30. udpLimit UdpLimitService,
  31. webLimit WebLimitService,
  32. gateWayGroup GatewayGroupService,
  33. ) GlobalLimitService {
  34. return &globalLimitService{
  35. Service: service,
  36. globalLimitRepository: globalLimitRepository,
  37. duedate: duedate,
  38. crawler: crawler,
  39. Url: conf.GetString("crawler.Url"),
  40. required: required,
  41. parser: parser,
  42. host: host,
  43. tcpLimit: tcpLimit,
  44. udpLimit: udpLimit,
  45. webLimit: webLimit,
  46. gateWayGroup: gateWayGroup,
  47. }
  48. }
  49. type globalLimitService struct {
  50. *Service
  51. globalLimitRepository repository.GlobalLimitRepository
  52. duedate DuedateService
  53. crawler CrawlerService
  54. Url string
  55. required RequiredService
  56. parser ParserService
  57. host HostService
  58. tcpLimit TcpLimitService
  59. udpLimit UdpLimitService
  60. webLimit WebLimitService
  61. gateWayGroup GatewayGroupService
  62. }
  63. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  64. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  65. if err != nil {
  66. return v1.GlobalLimitRequireResponse{}, err
  67. }
  68. if isExist {
  69. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("配置限制已存在")
  70. }
  71. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  72. if err != nil {
  73. return v1.GlobalLimitRequireResponse{}, err
  74. }
  75. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  76. if err != nil {
  77. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  78. }
  79. res.Bps = configCount.Bps
  80. res.MaxBytesMonth = configCount.MaxBytesMonth
  81. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + strconv.Itoa(req.HostId) + "_" + req.Domain
  82. return res, nil
  83. }
  84. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  85. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  86. }
  87. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  88. require, err := s.GlobalLimitRequire(ctx, req)
  89. if err != nil {
  90. return err
  91. }
  92. formData := map[string]interface{}{
  93. "tag": require.GlobalLimitName,
  94. "bps": require.Bps,
  95. "max_bytes_month": require.MaxBytesMonth,
  96. "expired_at": require.ExpiredAt,
  97. }
  98. respBody, err := s.required.SendForm(ctx, "admin/info/waf_common_limit/new", "admin/new/waf_common_limit", formData)
  99. if err != nil {
  100. return err
  101. }
  102. ruleIdBase, err := s.parser.GetRuleIdByColumnName(ctx, respBody, require.GlobalLimitName)
  103. if err != nil {
  104. return err
  105. }
  106. if ruleIdBase == "" {
  107. res, err := s.parser.ParseAlert(string(respBody))
  108. if err != nil {
  109. return err
  110. }
  111. return fmt.Errorf(res)
  112. }
  113. ruleId, err := cast.ToIntE(ruleIdBase)
  114. if err != nil {
  115. return err
  116. }
  117. // 使用conc库并发执行API调用
  118. var tcpLimitRuleId, udpLimitRuleId, webLimitRuleId, gateWayGroupId int
  119. var mu sync.Mutex // 用于保护共享变量
  120. // 为每个并发调用创建独立的请求参数(深拷贝)
  121. // 避免共享同一个指针可能导致的数据竞争
  122. // 创建网关组请求参数
  123. gateWayReq := v1.AddGateWayGroupRequest{
  124. Name: require.GlobalLimitName,
  125. Comment: req.Comment,
  126. }
  127. // 创建一个WaitGroup来协调多个并发任务
  128. wg := conc.NewWaitGroup()
  129. // 启动tcpLimit调用 - 使用独立的请求参数副本
  130. wg.Go(func() {
  131. // 为该goroutine创建独立的请求参数副本
  132. tcpLimitReq := &v1.GeneralLimitRequireRequest{
  133. Tag: require.GlobalLimitName,
  134. HostId: req.HostId,
  135. RuleId: ruleId,
  136. Uid: req.Uid,
  137. }
  138. result, e := s.tcpLimit.AddTcpLimit(ctx, tcpLimitReq)
  139. if e != nil {
  140. // 只在修改共享的错误变量时加锁
  141. mu.Lock()
  142. err = e
  143. mu.Unlock()
  144. } else {
  145. // 不需要加锁,因为tcpLimitRuleId只被这一个goroutine修改
  146. tcpLimitRuleId = result
  147. }
  148. })
  149. // 启动udpLimit调用 - 使用独立的请求参数副本
  150. wg.Go(func() {
  151. // 为该goroutine创建独立的请求参数副本
  152. udpLimitReq := &v1.GeneralLimitRequireRequest{
  153. Tag: require.GlobalLimitName,
  154. HostId: req.HostId,
  155. RuleId: ruleId,
  156. Uid: req.Uid,
  157. }
  158. result, e := s.udpLimit.AddUdpLimit(ctx, udpLimitReq)
  159. if e != nil {
  160. // 只在修改共享的错误变量时加锁
  161. mu.Lock()
  162. err = e
  163. mu.Unlock()
  164. } else {
  165. // 不需要加锁,因为udpLimitRuleId只被这一个goroutine修改
  166. udpLimitRuleId = result
  167. }
  168. })
  169. // 启动webLimit调用 - 使用独立的请求参数副本
  170. wg.Go(func() {
  171. // 为该goroutine创建独立的请求参数副本
  172. webLimitReq := &v1.GeneralLimitRequireRequest{
  173. Tag: require.GlobalLimitName,
  174. HostId: req.HostId,
  175. RuleId: ruleId,
  176. Uid: req.Uid,
  177. }
  178. result, e := s.webLimit.AddWebLimit(ctx, webLimitReq)
  179. if e != nil {
  180. // 只在修改共享的错误变量时加锁
  181. mu.Lock()
  182. err = e
  183. mu.Unlock()
  184. } else {
  185. // 不需要加锁,因为webLimitRuleId只被这一个goroutine修改
  186. webLimitRuleId = result
  187. }
  188. })
  189. // 启动gatewayGroup调用
  190. wg.Go(func() {
  191. result, e := s.gateWayGroup.AddGatewayGroup(ctx, gateWayReq)
  192. if e != nil {
  193. // 只在修改共享的错误变量时加锁
  194. mu.Lock()
  195. err = e
  196. mu.Unlock()
  197. } else {
  198. // 不需要加锁,因为gateWayGroupId只被这一个goroutine修改
  199. gateWayGroupId = result
  200. }
  201. })
  202. // 等待所有调用完成
  203. wg.Wait()
  204. // 检查是否有错误发生
  205. if err != nil {
  206. return err
  207. }
  208. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  209. HostId: req.HostId,
  210. RuleId: cast.ToInt(ruleId),
  211. GlobalLimitName: require.GlobalLimitName,
  212. Comment: req.Comment,
  213. TcpLimitRuleId: tcpLimitRuleId,
  214. UdpLimitRuleId: udpLimitRuleId,
  215. WebLimitRuleId: webLimitRuleId,
  216. GatewayGroupId: gateWayGroupId,
  217. })
  218. if err != nil {
  219. return err
  220. }
  221. return nil
  222. }
  223. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  224. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  225. HostId: req.HostId,
  226. Comment: req.Comment,
  227. }); err != nil {
  228. return err
  229. }
  230. return nil
  231. }
  232. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  233. if err := s.globalLimitRepository.DeleteGlobalLimitByHostId(ctx, int64(req.HostId)); err != nil {
  234. return err
  235. }
  236. return nil
  237. }