waf.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. package task
  2. import (
  3. "context"
  4. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  5. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  6. "github.com/go-nunu/nunu-layout-advanced/internal/service"
  7. "github.com/hashicorp/go-multierror"
  8. "sync"
  9. )
  10. type WafTask interface {
  11. }
  12. func NewWafTask (
  13. webForWardingRep repository.WebForwardingRepository,
  14. tcpforwardingRep repository.TcpforwardingRepository,
  15. udpForWardingRep repository.UdpForWardingRepository,
  16. cdn service.CdnService,
  17. hostRep repository.HostRepository,
  18. task *Task,
  19. ) WafTask{
  20. return &wafTask{
  21. Task: task,
  22. webForWardingRep: webForWardingRep,
  23. tcpforwardingRep: tcpforwardingRep,
  24. udpForWardingRep: udpForWardingRep,
  25. cdn: cdn,
  26. hostRep: hostRep,
  27. }
  28. }
  29. type wafTask struct {
  30. *Task
  31. webForWardingRep repository.WebForwardingRepository
  32. tcpforwardingRep repository.TcpforwardingRepository
  33. udpForWardingRep repository.UdpForWardingRepository
  34. cdn service.CdnService
  35. hostRep repository.HostRepository
  36. }
  37. func (t wafTask) CheckExpiredTask(ctx context.Context) error {
  38. return nil
  39. }
  40. // 获取cdn web id
  41. func (t wafTask) GetCdnWebId(ctx context.Context,hostId int) ([]int, error) {
  42. tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, hostId)
  43. if err != nil {
  44. return nil, err
  45. }
  46. udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, hostId)
  47. if err != nil {
  48. return nil, err
  49. }
  50. webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, hostId)
  51. if err != nil {
  52. return nil, err
  53. }
  54. var ids []int
  55. ids = append(ids, tcpIds...)
  56. ids = append(ids, udpIds...)
  57. ids = append(ids, webIds...)
  58. return ids, nil
  59. }
  60. // 禁用网站
  61. func (t wafTask) BanServer(ctx context.Context, ids []int, isBan bool) error {
  62. var wg sync.WaitGroup
  63. errChan := make(chan error, len(ids))
  64. // 修正1:为每个 goroutine 增加 WaitGroup 的计数
  65. wg.Add(len(ids))
  66. for _, id := range ids {
  67. go func(id int) {
  68. // 修正2:确保每个 goroutine 在退出时都调用 Done()
  69. defer wg.Done()
  70. err := t.cdn.EditWebIsOn(ctx, int64(id), isBan)
  71. if err != nil {
  72. errChan <- err
  73. // 这里不需要 return,因为 defer wg.Done() 会在函数退出时执行
  74. }
  75. }(id)
  76. }
  77. // 现在 wg.Wait() 会正确地阻塞,直到所有 goroutine 都调用了 Done()
  78. wg.Wait()
  79. // 在所有 goroutine 都结束后,安全地关闭 channel
  80. close(errChan)
  81. var result error
  82. for err := range errChan {
  83. result = multierror.Append(result, err) // 将多个 error 对象合并成一个单一的 error 对象
  84. }
  85. // 修正3:返回收集到的错误,而不是 nil
  86. return result
  87. }
  88. // 获取到期时间
  89. func (t wafTask) GetAlmostExpiring(ctx context.Context,hostIds []int) ([]v1.GetAlmostExpireHostResponse,error) {
  90. // 3 天
  91. res, err := t.hostRep.GetAlmostExpired(ctx, hostIds, 259200)
  92. if err != nil {
  93. return nil,err
  94. }
  95. return res, nil
  96. }