waf.go 12 KB


  1. package task
  2. import (
  3. "context"
  4. "fmt"
  5. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  6. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/service"
  9. "github.com/hashicorp/go-multierror"
  10. "go.uber.org/zap"
  11. "sync"
  12. "time"
  13. )
  14. // WafTask 定义了WAF相关的五个独立定时任务接口
  15. type WafTask interface {
  16. // 1. 同步即将到期(1天内)的套餐时间
  17. SynchronizationTime(ctx context.Context) error
  18. // 2. 停止所有已到期的套餐
  19. StopPlan(ctx context.Context) error
  20. // 3. 恢复7天内续费的套餐
  21. RecoverRecentPlan(ctx context.Context) error
  22. // 4. 清理过期超过7天且仍未续费的记录
  23. CleanUpStaleRecords(ctx context.Context) error
  24. // 5. 恢复超过7天后才续费的套餐
  25. RecoverStalePlan(ctx context.Context) error
  26. }
  27. // =================================================================
  28. // =================== 结构体与构造函数 ==========================
  29. // =================================================================
  30. func NewWafTask(
  31. webForWardingRep repository.WebForwardingRepository,
  32. tcpforwardingRep repository.TcpforwardingRepository,
  33. udpForWardingRep repository.UdpForWardingRepository,
  34. cdn service.CdnService,
  35. hostRep repository.HostRepository,
  36. globalLimitRep repository.GlobalLimitRepository,
  37. expiredRep repository.ExpiredRepository,
  38. task *Task,
  39. gatewayGroupIpRep repository.GateWayGroupIpRepository,
  40. tcp service.TcpforwardingService,
  41. udp service.UdpForWardingService,
  42. web service.WebForwardingService,
  43. ) WafTask {
  44. return &wafTask{
  45. Task: task,
  46. webForWardingRep: webForWardingRep,
  47. tcpforwardingRep: tcpforwardingRep,
  48. udpForWardingRep: udpForWardingRep,
  49. cdn: cdn,
  50. hostRep: hostRep,
  51. globalLimitRep: globalLimitRep,
  52. expiredRep: expiredRep,
  53. gatewayGroupIpRep: gatewayGroupIpRep,
  54. tcp: tcp,
  55. udp: udp,
  56. web: web,
  57. }
  58. }
  59. type wafTask struct {
  60. *Task
  61. webForWardingRep repository.WebForwardingRepository
  62. tcpforwardingRep repository.TcpforwardingRepository
  63. udpForWardingRep repository.UdpForWardingRepository
  64. cdn service.CdnService
  65. hostRep repository.HostRepository
  66. globalLimitRep repository.GlobalLimitRepository
  67. expiredRep repository.ExpiredRepository
  68. gatewayGroupIpRep repository.GateWayGroupIpRepository
  69. tcp service.TcpforwardingService
  70. udp service.UdpForWardingService
  71. web service.WebForwardingService
  72. }
  73. const (
  74. OneDaysInSeconds = 1 * 24 * 60 * 60
  75. SevenDaysInSeconds = 7 * 24 * 60 * 60
  76. )
  77. type RenewalRequest struct {
  78. HostId int
  79. PlanId int
  80. ExpiredAt int64
  81. }
  82. // =================================================================
  83. // =================== 核心辅助函数 (Core Helpers) =================
  84. // =================================================================
  85. // (wrapTaskError, getCdnWebIdsByHostIds, setCdnWebsitesState, executeRenewalActions 保持不变)
  86. // ...
  87. func (t *wafTask) wrapTaskError(taskName, step string, err error) error { /* ... */ return nil }
  88. func (t *wafTask) getCdnWebIdsByHostIds(ctx context.Context, hostIds []int) ([]int, error) { /* ... */ return nil, nil }
  89. func (t *wafTask) setCdnWebsitesState(ctx context.Context, ids []int, enable bool) error { /* ... */ return nil }
  90. func (t *wafTask) executeRenewalActions(ctx context.Context, reqs []RenewalRequest) error { /* ... */ return nil }
  91. // =================================================================
  92. // =================== 1. 数据查找与决策层 ==========================
  93. // =================================================================
  94. // (findPlansNeedingSync, findAllCurrentlyExpiredWAFPlans, findRecentlyExpiredWAFPlans, findStaleWAFPlans 保持不变)
  95. // ...
  96. func (t *wafTask) findPlansNeedingSync(ctx context.Context, wafLimits []model.GlobalLimit) ([]RenewalRequest, error) { /* ... */ return nil, nil }
  97. func (t *wafTask) findAllCurrentlyExpiredWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { /* ... */ return nil, nil }
  98. func (t *wafTask) findRecentlyExpiredWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { /* ... */ return nil, nil }
  99. func (t *wafTask) findStaleWAFPlans(ctx context.Context) ([]model.GlobalLimit, error) { /* ... */ return nil, nil }
  100. // =================================================================
  101. // ============== 2. 业务执行与公共API层 ===========================
  102. // =================================================================
  103. // (SynchronizationTime, StopPlan 保持不变)
  104. // ...
  105. func (t *wafTask) SynchronizationTime(ctx context.Context) error { /* ... */ return nil }
  106. func (t *wafTask) StopPlan(ctx context.Context) error { /* ... */ return nil }
  107. // _recoverPlans 是一个统一的、可重用的套餐恢复流程
  108. func (t *wafTask) _recoverPlans(ctx context.Context, limitsToCheck []model.GlobalLimit, taskName string, redisListKey repository.PlanListType) error {
  109. if len(limitsToCheck) == 0 {
  110. return nil
  111. }
  112. requestsToSync, err := t.findPlansNeedingSync(ctx, limitsToCheck)
  113. if err != nil {
  114. return t.wrapTaskError(taskName, "决策检查续费状态", err)
  115. }
  116. var finalRecoveryRequests []RenewalRequest
  117. for _, req := range requestsToSync {
  118. if req.ExpiredAt > time.Now().Unix() {
  119. finalRecoveryRequests = append(finalRecoveryRequests, req)
  120. }
  121. }
  122. if len(finalRecoveryRequests) == 0 {
  123. t.logger.Info("在检查范围内未发现已续费的套餐", zap.String("task", taskName))
  124. return nil
  125. }
  126. t.logger.Info("开始恢复已续费的WAF服务", zap.String("task", taskName), zap.Int("数量", len(finalRecoveryRequests)))
  127. var hostIdsToRecover []int
  128. for _, req := range finalRecoveryRequests {
  129. hostIdsToRecover = append(hostIdsToRecover, req.HostId)
  130. }
  131. var allErrors *multierror.Error
  132. webIds, err := t.getCdnWebIdsByHostIds(ctx, hostIdsToRecover)
  133. if err != nil {
  134. allErrors = multierror.Append(allErrors, fmt.Errorf("获取webId失败: %w", err))
  135. } else {
  136. if err := t.setCdnWebsitesState(ctx, webIds, true); err != nil {
  137. allErrors = multierror.Append(allErrors, fmt.Errorf("启用web服务失败: %w", err))
  138. }
  139. }
  140. if err := t.executeRenewalActions(ctx, finalRecoveryRequests); err != nil {
  141. allErrors = multierror.Append(allErrors, fmt.Errorf("同步续费信息失败: %w", err))
  142. }
  143. planIdsToRecover := make([]int64, len(hostIdsToRecover))
  144. for i, id := range hostIdsToRecover {
  145. planIdsToRecover[i] = int64(id)
  146. }
  147. // 从指定的Redis列表中移除标记 (ClosedPlansList 或 ExpiringSoonPlansList)
  148. if err := t.expiredRep.RemovePlans(ctx, redisListKey, planIdsToRecover...); err != nil {
  149. allErrors = multierror.Append(allErrors, fmt.Errorf("从Redis列表 '%s' 移除标记失败: %w", redisListKey, err))
  150. }
  151. return t.wrapTaskError(taskName, "执行恢复", allErrors.ErrorOrNil())
  152. }
  153. // 3. RecoverRecentPlan 恢复7天内续费的套餐
  154. func (t *wafTask) RecoverRecentPlan(ctx context.Context) error {
  155. taskName := "恢复近期到期套餐"
  156. recentlyExpiredLimits, err := t.findRecentlyExpiredWAFPlans(ctx)
  157. if err != nil {
  158. return t.wrapTaskError(taskName, "查找近期到期记录", err)
  159. }
  160. return t._recoverPlans(ctx, recentlyExpiredLimits, taskName, repository.ClosedPlansList)
  161. }
  162. // 4. CleanUpStaleRecords 清理过期超过7天且仍未续费的记录
  163. func (t *wafTask) CleanUpStaleRecords(ctx context.Context) error {
  164. taskName := "清理陈旧记录"
  165. // 1. 从数据库查找所有陈旧的记录作为候选
  166. candidateLimits, err := t.findStaleWAFPlans(ctx)
  167. if err != nil {
  168. return t.wrapTaskError(taskName, "查找陈旧记录", err)
  169. }
  170. if len(candidateLimits) == 0 {
  171. return nil
  172. }
  173. // 2. [CORRECTION] 幂等性检查: 过滤掉那些已经被标记为“已清理”的记录
  174. // 根据您的定义,`ExpiringSoonPlansList` 就是已清理列表。
  175. var uncleanedLimits []model.GlobalLimit
  176. for _, limit := range candidateLimits {
  177. isAlreadyCleaned, err := t.expiredRep.IsPlanInList(ctx, repository.ExpiringSoonPlansList, int64(limit.HostId))
  178. if err != nil {
  179. t.logger.Error("检查Redis清理状态失败,跳过", zap.String("task", taskName), zap.Int("hostId", limit.HostId), zap.Error(err))
  180. continue
  181. }
  182. if !isAlreadyCleaned {
  183. uncleanedLimits = append(uncleanedLimits, limit)
  184. }
  185. }
  186. if len(uncleanedLimits) == 0 {
  187. t.logger.Info("没有需要清理的新套餐(可能均已清理)", zap.String("task", taskName))
  188. return nil
  189. }
  190. // 3. [性能优化] 批量获取未清理记录的真实到期时间
  191. uncleanedHostIds := make([]int, len(uncleanedLimits))
  192. for i, limit := range uncleanedLimits {
  193. uncleanedHostIds[i] = limit.HostId
  194. }
  195. hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, uncleanedHostIds)
  196. if err != nil {
  197. return t.wrapTaskError(taskName, "批量获取主机到期时间", err)
  198. }
  199. hostExpiredMap := make(map[int]int64, len(hostExpirations))
  200. for _, h := range hostExpirations {
  201. hostExpiredMap[h.HostId] = h.ExpiredAt
  202. }
  203. // 4. 决策:筛选出最终需要清理的记录(未在最后一刻续费)
  204. var plansToClean []model.GlobalLimit
  205. now := time.Now().Unix()
  206. for _, limit := range uncleanedLimits {
  207. hostExpiredTime, ok := hostExpiredMap[limit.HostId]
  208. if !ok || hostExpiredTime <= now {
  209. plansToClean = append(plansToClean, limit)
  210. }
  211. }
  212. if len(plansToClean) == 0 {
  213. t.logger.Info("没有长期未续费的记录需要清理(可能均已续费)", zap.String("task", taskName))
  214. return nil
  215. }
  216. // 5. 并发执行清理操作
  217. t.logger.Info("开始清理长期未续费的WAF记录", zap.String("task", taskName), zap.Int("数量", len(plansToClean)))
  218. var wg sync.WaitGroup
  219. errChan := make(chan error, len(plansToClean))
  220. wg.Add(len(plansToClean))
  221. for _, limit := range plansToClean {
  222. go func(l model.GlobalLimit) {
  223. defer wg.Done()
  224. if err := t.executeSinglePlanCleanup(ctx, l); err != nil {
  225. errChan <- fmt.Errorf("清理hostId %d 失败: %w", l.HostId, err)
  226. }
  227. }(limit)
  228. }
  229. wg.Wait()
  230. close(errChan)
  231. var allErrors *multierror.Error
  232. for err := range errChan {
  233. allErrors = multierror.Append(allErrors, err)
  234. }
  235. return t.wrapTaskError(taskName, "执行清理", allErrors.ErrorOrNil())
  236. }
  237. // executeSinglePlanCleanup 执行对单个套餐的完整清理操作,方便并发调用
  238. func (t *wafTask) executeSinglePlanCleanup(ctx context.Context, limit model.GlobalLimit) error {
  239. var allErrors *multierror.Error
  240. hostId := int64(limit.HostId)
  241. // 从“停止列表”中移除,因为它即将被归档到“已清理列表”
  242. if err := t.expiredRep.RemovePlans(ctx, repository.ClosedPlansList, hostId); err != nil {
  243. allErrors = multierror.Append(allErrors, err)
  244. }
  245. // 删除关联的转发规则...
  246. tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, limit.HostId)
  247. if err != nil {
  248. allErrors = multierror.Append(allErrors, err)
  249. } else if len(tcpIds) > 0 {
  250. if err := t.tcp.DeleteTcpForwarding(ctx, v1.DeleteTcpForwardingRequest{Ids: tcpIds, HostId: limit.HostId}); err != nil {
  251. allErrors = multierror.Append(allErrors, err)
  252. }
  253. }
  254. // ... 删除 UDP 和 Web 规则的逻辑保持不变
  255. // 只有在上述所有步骤都没有出错的情况下,才执行最终的数据库更新和Redis标记
  256. if allErrors.ErrorOrNil() == nil {
  257. // 执行您指定的数据库“重置”操作
  258. err = t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  259. HostId: limit.HostId,
  260. State: true,
  261. })
  262. if err != nil {
  263. allErrors = multierror.Append(allErrors, err)
  264. }
  265. // [CORRECTION] 幂等性保障:将此hostId标记为“已清理”,即添加到 `ExpiringSoonPlansList`
  266. if err := t.expiredRep.AddPlans(ctx, repository.ExpiringSoonPlansList, hostId); err != nil {
  267. allErrors = multierror.Append(allErrors, fmt.Errorf("将hostId %d标记为已清理失败: %w", hostId, err))
  268. }
  269. }
  270. return allErrors.ErrorOrNil()
  271. }
  272. // 5. RecoverStalePlan 恢复超过7天后才续费的套餐
  273. func (t *wafTask) RecoverStalePlan(ctx context.Context) error {
  274. taskName := "恢复长期到期套餐"
  275. staleLimits, err := t.findStaleWAFPlans(ctx)
  276. if err != nil {
  277. return t.wrapTaskError(taskName, "查找陈旧记录", err)
  278. }
  279. // [CORRECTION] 当恢复一个被清理过的套餐时,需要从“已清理列表”(`ExpiringSoonPlansList`)中移除它。
  280. return t._recoverPlans(ctx, staleLimits, taskName, repository.ExpiringSoonPlansList)
  281. }