waf.go 5.5 KB


  1. package task
  2. import (
  3. "context"
  4. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  5. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  6. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/service"
  8. "github.com/hashicorp/go-multierror"
  9. "sync"
  10. "time"
  11. )
  12. type WafTask interface {
  13. }
  14. func NewWafTask (
  15. webForWardingRep repository.WebForwardingRepository,
  16. tcpforwardingRep repository.TcpforwardingRepository,
  17. udpForWardingRep repository.UdpForWardingRepository,
  18. cdn service.CdnService,
  19. hostRep repository.HostRepository,
  20. globalLimitRep repository.GlobalLimitRepository,
  21. task *Task,
  22. ) WafTask{
  23. return &wafTask{
  24. Task: task,
  25. webForWardingRep: webForWardingRep,
  26. tcpforwardingRep: tcpforwardingRep,
  27. udpForWardingRep: udpForWardingRep,
  28. cdn: cdn,
  29. hostRep: hostRep,
  30. globalLimitRep: globalLimitRep,
  31. }
  32. }
  33. type wafTask struct {
  34. *Task
  35. webForWardingRep repository.WebForwardingRepository
  36. tcpforwardingRep repository.TcpforwardingRepository
  37. udpForWardingRep repository.UdpForWardingRepository
  38. cdn service.CdnService
  39. hostRep repository.HostRepository
  40. globalLimitRep repository.GlobalLimitRepository
  41. }
  42. func (t wafTask) CheckExpiredTask(ctx context.Context) error {
  43. return nil
  44. }
  45. // 获取cdn web id
  46. func (t wafTask) GetCdnWebId(ctx context.Context,hostId int) ([]int, error) {
  47. tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, hostId)
  48. if err != nil {
  49. return nil, err
  50. }
  51. udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, hostId)
  52. if err != nil {
  53. return nil, err
  54. }
  55. webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, hostId)
  56. if err != nil {
  57. return nil, err
  58. }
  59. var ids []int
  60. ids = append(ids, tcpIds...)
  61. ids = append(ids, udpIds...)
  62. ids = append(ids, webIds...)
  63. return ids, nil
  64. }
  65. // 启用/禁用 网站
  66. func (t wafTask) BanServer(ctx context.Context, ids []int, isBan bool) error {
  67. var wg sync.WaitGroup
  68. errChan := make(chan error, len(ids))
  69. // 修正1:为每个 goroutine 增加 WaitGroup 的计数
  70. wg.Add(len(ids))
  71. for _, id := range ids {
  72. go func(id int) {
  73. // 修正2:确保每个 goroutine 在退出时都调用 Done()
  74. defer wg.Done()
  75. err := t.cdn.EditWebIsOn(ctx, int64(id), isBan)
  76. if err != nil {
  77. errChan <- err
  78. // 这里不需要 return,因为 defer wg.Done() 会在函数退出时执行
  79. }
  80. }(id)
  81. }
  82. // 现在 wg.Wait() 会正确地阻塞,直到所有 goroutine 都调用了 Done()
  83. wg.Wait()
  84. // 在所有 goroutine 都结束后,安全地关闭 channel
  85. close(errChan)
  86. var result error
  87. for err := range errChan {
  88. result = multierror.Append(result, err) // 将多个 error 对象合并成一个单一的 error 对象
  89. }
  90. // 修正3:返回收集到的错误,而不是 nil
  91. return result
  92. }
  93. // 获取指定到期时间
  94. func (t wafTask) GetAlmostExpiring(ctx context.Context,hostIds []int,addTime int64) ([]v1.GetAlmostExpireHostResponse,error) {
  95. // 3 天
  96. res, err := t.hostRep.GetAlmostExpired(ctx, hostIds, addTime)
  97. if err != nil {
  98. return nil,err
  99. }
  100. return res, nil
  101. }
  102. // 获取全局到期时间
  103. func (t wafTask) GetGlobalAlmostExpiring(ctx context.Context,addTime int64) ([]model.GlobalLimit,error) {
  104. res, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, addTime)
  105. if err != nil {
  106. return nil, err
  107. }
  108. return res, nil
  109. }
  110. // 获取cdn web id
  111. func (t wafTask) GetGlobalAllHostId(ctx context.Context,addTime int64) (map[int]int64, error) {
  112. globalData, err := t.GetGlobalAlmostExpiring(ctx,addTime)
  113. if err != nil {
  114. return nil, err
  115. }
  116. var hostIds []int
  117. for _, v := range globalData {
  118. hostIds = append(hostIds, v.HostId)
  119. }
  120. globalDataMap := make(map[int]int64, len(globalData))
  121. planMap := make(map[int]int64, len(globalData))
  122. for _, v := range globalData {
  123. globalDataMap[v.HostId] = v.ExpiredAt
  124. planMap[v.HostId] = int64(v.RuleId)
  125. }
  126. hostData,err := t.GetAlmostExpiring(ctx,hostIds,addTime)
  127. if err != nil {
  128. return nil, err
  129. }
  130. hostDataMap := make(map[int]int64, len(hostData))
  131. for _, v := range hostData {
  132. hostDataMap[v.HostId] = v.ExpiredAt
  133. }
  134. editMap := make(map[int]int64)
  135. for k, v := range globalDataMap {
  136. if hostDataMap[k] != v {
  137. editMap[k] = v
  138. }
  139. }
  140. planExpireMap := make(map[int]int64)
  141. for k, v := range planMap {
  142. if _, ok := editMap[k]; ok {
  143. planExpireMap[k] = v
  144. }
  145. }
  146. return editMap, nil
  147. }
  148. // 修改全局续费
  149. func (t wafTask) EditGlobalExpired(ctx context.Context,req []struct{
  150. hostId int
  151. expiredAt int64
  152. },state bool) error {
  153. for _, v := range req {
  154. err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  155. HostId: v.hostId,
  156. ExpiredAt: v.expiredAt,
  157. State: state,
  158. })
  159. if err != nil {
  160. return err
  161. }
  162. }
  163. return nil
  164. }
  165. // 续费套餐
  166. func (t wafTask) EnablePlan(ctx context.Context,req []struct{
  167. planId int
  168. expiredAt int64
  169. }) error {
  170. for _, v := range req {
  171. err := t.cdn.RenewPlan(ctx, v1.RenewalPlan{
  172. UserPlanId: int64(v.planId),
  173. IsFree: true,
  174. DayTo: time.Unix(v.expiredAt,0).Format("2006-01-02"),
  175. Period: "monthly",
  176. CountPeriod: 1,
  177. PeriodDayTo: time.Unix(v.expiredAt,0).Format("2006-01-02"),
  178. })
  179. if err != nil {
  180. return err
  181. }
  182. }
  183. return nil
  184. }
  185. // 续费操作
  186. func (t wafTask) EditExpired(ctx context.Context,req []struct {
  187. hostId int
  188. expiredAt int64
  189. planId int
  190. }) error {
  191. var sendData []struct {
  192. hostId int
  193. expiredAt int64
  194. }
  195. for _, v := range req {
  196. sendData = append(sendData, struct {
  197. hostId int
  198. expiredAt int64
  199. }{
  200. hostId: v.hostId,
  201. expiredAt: v.expiredAt,
  202. })
  203. }
  204. if err := t.EditGlobalExpired(ctx,sendData,true); err != nil {
  205. return err
  206. }
  207. return nil
  208. }