waf.go 15 KB


  1. package task
  2. import (
  3. "context"
  4. "fmt"
  5. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  6. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/service"
  9. "github.com/hashicorp/go-multierror"
  10. "go.uber.org/zap"
  11. "sync"
  12. "time"
  13. )
  14. // WafTask 定义了WAF相关的五个独立定时任务接口
  15. type WafTask interface {
  16. // 1. 同步即将到期(1天内)的套餐时间
  17. SynchronizationTime(ctx context.Context) error
  18. // 2. 停止所有已到期的套餐
  19. StopPlan(ctx context.Context) error
  20. // 3. 恢复7天内续费的套餐
  21. RecoverRecentPlan(ctx context.Context) error
  22. // 4. 清理过期超过7天且仍未续费的记录
  23. CleanUpStaleRecords(ctx context.Context) error
  24. // 5. 恢复超过7天后才续费的套餐
  25. RecoverStalePlan(ctx context.Context) error
  26. }
  27. func NewWafTask(
  28. webForWardingRep repository.WebForwardingRepository,
  29. tcpforwardingRep repository.TcpforwardingRepository,
  30. udpForWardingRep repository.UdpForWardingRepository,
  31. cdn service.CdnService,
  32. hostRep repository.HostRepository,
  33. globalLimitRep repository.GlobalLimitRepository,
  34. expiredRep repository.ExpiredRepository,
  35. task *Task,
  36. gatewayGroupIpRep repository.GateWayGroupIpRepository,
  37. ) WafTask {
  38. return &wafTask{
  39. Task: task,
  40. webForWardingRep: webForWardingRep,
  41. tcpforwardingRep: tcpforwardingRep,
  42. udpForWardingRep: udpForWardingRep,
  43. cdn: cdn,
  44. hostRep: hostRep,
  45. globalLimitRep: globalLimitRep,
  46. expiredRep: expiredRep,
  47. gatewayGroupIpRep: gatewayGroupIpRep,
  48. }
  49. }
  50. type wafTask struct {
  51. *Task
  52. webForWardingRep repository.WebForwardingRepository
  53. tcpforwardingRep repository.TcpforwardingRepository
  54. udpForWardingRep repository.UdpForWardingRepository
  55. cdn service.CdnService
  56. hostRep repository.HostRepository
  57. globalLimitRep repository.GlobalLimitRepository
  58. expiredRep repository.ExpiredRepository
  59. gatewayGroupIpRep repository.GateWayGroupIpRepository
  60. }
  61. const (
  62. // 1天对应的秒数
  63. OneDaysInSeconds = 1 * 24 * 60 * 60
  64. // 7天对应的秒数
  65. SevenDaysInSeconds = 7 * 24 * 60 * 60
  66. )
  67. // RenewalRequest 续费操作请求结构体
  68. type RenewalRequest struct {
  69. HostId int
  70. PlanId int
  71. ExpiredAt int64
  72. }
  73. // =================================================================
  74. // =================== 原始辅助函数 (Helpers) =====================
  75. // =================================================================
  76. // BanServer 启用/禁用 网站 (并发执行)
  77. func (t wafTask) BanServer(ctx context.Context, ids []int, isBan bool) error {
  78. if len(ids) == 0 { return nil }
  79. var wg sync.WaitGroup
  80. errChan := make(chan error, len(ids))
  81. wg.Add(len(ids))
  82. for _, id := range ids {
  83. go func(id int) {
  84. defer wg.Done()
  85. if err := t.cdn.EditWebIsOn(ctx, int64(id), isBan); err != nil {
  86. errChan <- err
  87. }
  88. }(id)
  89. }
  90. wg.Wait()
  91. close(errChan)
  92. var result error
  93. for err := range errChan {
  94. result = multierror.Append(result, err)
  95. }
  96. return result
  97. }
  98. // EditExpired 统一的续费操作入口
  99. func (t wafTask) EditExpired(ctx context.Context, reqs []RenewalRequest) error {
  100. if len(reqs) == 0 { return nil }
  101. var globalLimitUpdates []struct { hostId int; expiredAt int64 }
  102. var planRenewals []struct { planId int; expiredAt int64 }
  103. for _, req := range reqs {
  104. globalLimitUpdates = append(globalLimitUpdates, struct{ hostId int; expiredAt int64 }{req.HostId, req.ExpiredAt})
  105. planRenewals = append(planRenewals, struct{ planId int; expiredAt int64 }{req.PlanId, req.ExpiredAt})
  106. }
  107. var result *multierror.Error
  108. if err := t.editGlobalLimitState(ctx, globalLimitUpdates, true); err != nil {
  109. result = multierror.Append(result, err)
  110. }
  111. if err := t.renewCdnPlan(ctx, planRenewals); err != nil {
  112. result = multierror.Append(result, err)
  113. }
  114. return result.ErrorOrNil()
  115. }
  116. // editGlobalLimitState 内部函数,用于更新数据库中的状态和时间
  117. func (t wafTask) editGlobalLimitState(ctx context.Context, req []struct { hostId int; expiredAt int64 }, state bool) error {
  118. var result *multierror.Error
  119. for _, v := range req {
  120. err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{HostId: v.hostId, ExpiredAt: v.expiredAt, State: state})
  121. if err != nil { result = multierror.Append(result, err) }
  122. }
  123. return result.ErrorOrNil()
  124. }
  125. // renewCdnPlan 内部函数,用于调用CDN服务进行续费
  126. func (t wafTask) renewCdnPlan(ctx context.Context, req []struct { planId int; expiredAt int64 }) error {
  127. var result *multierror.Error
  128. for _, v := range req {
  129. err := t.cdn.RenewPlan(ctx, v1.RenewalPlan{
  130. UserPlanId: int64(v.planId), IsFree: true, DayTo: time.Unix(v.expiredAt, 0).Format("2006-01-02"),
  131. Period: "monthly", CountPeriod: 1, PeriodDayTo: time.Unix(v.expiredAt, 0).Format("2006-01-02"),
  132. })
  133. if err != nil { result = multierror.Append(result, err) }
  134. }
  135. return result.ErrorOrNil()
  136. }
  137. // =================================================================
  138. // =================== 1. 数据查找层 (Finders) =====================
  139. // =================================================================
  140. // findMismatchedExpirations 检查 WAF 和 Host 的到期时间差异。这是决策的核心。
  141. func (t *wafTask) findMismatchedExpirations(ctx context.Context, wafLimits []model.GlobalLimit) ([]RenewalRequest, error) {
  142. if len(wafLimits) == 0 { return nil, nil }
  143. wafExpiredMap := make(map[int]int64, len(wafLimits))
  144. wafPlanMap := make(map[int]int, len(wafLimits))
  145. var hostIds []int
  146. for _, limit := range wafLimits {
  147. hostIds = append(hostIds, limit.HostId)
  148. wafExpiredMap[limit.HostId] = limit.ExpiredAt
  149. wafPlanMap[limit.HostId] = limit.RuleId
  150. }
  151. hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, hostIds)
  152. if err != nil { return nil, fmt.Errorf("获取主机到期时间失败: %w", err) }
  153. hostExpiredMap := make(map[int]int64, len(hostExpirations))
  154. for _, h := range hostExpirations { hostExpiredMap[h.HostId] = h.ExpiredAt }
  155. var renewalRequests []RenewalRequest
  156. for hostId, wafExpiredTime := range wafExpiredMap {
  157. hostTime, ok := hostExpiredMap[hostId]
  158. if !ok || hostTime != wafExpiredTime {
  159. planId, planOk := wafPlanMap[hostId]
  160. if !planOk {
  161. t.logger.Warn("数据不一致:在waf_limits中找不到hostId对应的套餐ID", zap.Int("hostId", hostId))
  162. continue
  163. }
  164. renewalRequests = append(renewalRequests, RenewalRequest{HostId: hostId, ExpiredAt: hostTime, PlanId: planId})
  165. }
  166. }
  167. return renewalRequests, nil
  168. }
  169. // findAllCurrentlyExpiredPlans 查找所有当前时间点已经到期的WAF记录。
  170. func (t *wafTask) findAllCurrentlyExpiredPlans(ctx context.Context) ([]model.GlobalLimit, error) {
  171. return t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, 0)
  172. }
  173. // findRecentlyExpiredPlans (精确查找) 查找在过去7天内到期的WAF记录。
  174. func (t *wafTask) findRecentlyExpiredPlans(ctx context.Context) ([]model.GlobalLimit, error) {
  175. sevenDaysAgo := time.Now().Add(-7 * 24 * time.Hour).Unix()
  176. now := time.Now().Unix()
  177. return t.globalLimitRep.GetGlobalLimitsByExpirationRange(ctx, sevenDaysAgo, now)
  178. }
  179. // findStaleExpiredPlans (精确查找) 查找7天前或更早就已到期的WAF记录。
  180. func (t *wafTask) findStaleExpiredPlans(ctx context.Context) ([]model.GlobalLimit, error) {
  181. sevenDaysAgoOffset := int64(-1 * SevenDaysInSeconds)
  182. return t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, sevenDaysAgoOffset)
  183. }
  184. // =================================================================
  185. // =================== 2. 业务决策层 (Filters) =====================
  186. // =================================================================
  187. // filterCleanablePlans (精确决策) 从长期过期的列表中,筛选出确认未续费且需要被清理的记录。
  188. func (t *wafTask) filterCleanablePlans(ctx context.Context, staleLimits []model.GlobalLimit) ([]model.GlobalLimit, error) {
  189. renewedStalePlans, err := t.findMismatchedExpirations(ctx, staleLimits)
  190. if err != nil {
  191. return nil, fmt.Errorf("决策[清理]: 检查续费状态失败: %w", err)
  192. }
  193. renewedHostIds := make(map[int]struct{}, len(renewedStalePlans))
  194. for _, req := range renewedStalePlans {
  195. renewedHostIds[req.HostId] = struct{}{}
  196. }
  197. var plansToClean []model.GlobalLimit
  198. for _, limit := range staleLimits {
  199. if _, found := renewedHostIds[limit.HostId]; !found {
  200. plansToClean = append(plansToClean, limit)
  201. }
  202. }
  203. return plansToClean, nil
  204. }
  205. // =================================================================
  206. // ============== 3. 业务执行层 (Executors & Public API) =============
  207. // =================================================================
  208. // executePlanRecovery (可重用) 负责恢复套餐的所有步骤。
  209. func (t *wafTask) executePlanRecovery(ctx context.Context, renewalRequests []RenewalRequest, taskName string,key repository.PlanListType) error {
  210. t.logger.Info(fmt.Sprintf("开始执行[%s]套餐恢复流程", taskName), zap.Int("数量", len(renewalRequests)))
  211. var hostIds []int
  212. for _, req := range renewalRequests {
  213. hostIds = append(hostIds, req.HostId)
  214. }
  215. if err := t.BanServer(ctx, hostIds, true); err != nil {
  216. return fmt.Errorf("执行[%s]-启用服务失败: %w", taskName, err)
  217. }
  218. var allErrors *multierror.Error
  219. if err := t.EditExpired(ctx, renewalRequests); err != nil {
  220. allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-同步续费信息失败: %w", taskName, err))
  221. }
  222. planIdsToRecover := make([]int64, len(hostIds))
  223. for i, id := range hostIds { planIdsToRecover[i] = int64(id) }
  224. if err := t.expiredRep.RemovePlans(ctx, key, planIdsToRecover...); err != nil {
  225. allErrors = multierror.Append(allErrors, fmt.Errorf("执行[%s]-移除Redis关闭标记失败: %w", taskName, err))
  226. }
  227. return allErrors.ErrorOrNil()
  228. }
  229. // 1. SynchronizationTime 同步即将到期(1天内)的套餐时间
  230. func (t *wafTask) SynchronizationTime(ctx context.Context) error {
  231. wafLimits, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, OneDaysInSeconds)
  232. if err != nil { return fmt.Errorf("执行[同步]-查找失败: %w", err) }
  233. renewalRequests, err := t.findMismatchedExpirations(ctx, wafLimits)
  234. if err != nil { return fmt.Errorf("执行[同步]-决策失败: %w", err) }
  235. if len(renewalRequests) > 0 {
  236. t.logger.Info("发现记录需要同步到期时间。", zap.Int("数量", len(renewalRequests)))
  237. return t.EditExpired(ctx, renewalRequests)
  238. }
  239. return nil
  240. }
  241. // 2. StopPlan (已优化) 停止所有已到期的套餐
  242. func (t *wafTask) StopPlan(ctx context.Context) error {
  243. // 1. 查找所有理论上已到期的记录
  244. expiredLimits, err := t.findAllCurrentlyExpiredPlans(ctx)
  245. if err != nil { return fmt.Errorf("执行[停止]-查找失败: %w", err) }
  246. if len(expiredLimits) == 0 { return nil }
  247. // 2. 决策 - 第1步:检查这些记录中是否已有续费但未同步的
  248. renewalRequests, err := t.findMismatchedExpirations(ctx, expiredLimits)
  249. if err != nil { return fmt.Errorf("执行[停止]-决策检查续费失败: %w", err) }
  250. renewedHostIds := make(map[int]struct{}, len(renewalRequests))
  251. for _, req := range renewalRequests {
  252. renewedHostIds[req.HostId] = struct{}{}
  253. }
  254. // 2. 决策 - 第2步:筛选出真正需要停止的记录
  255. var plansToClose []model.GlobalLimit
  256. for _, limit := range expiredLimits {
  257. if _, found := renewedHostIds[limit.HostId]; found {
  258. t.logger.Info("发现已到期但刚续费的套餐,跳过停止操作", zap.Int("hostId", limit.HostId))
  259. continue
  260. }
  261. isClosed, err := t.expiredRep.IsPlanInList(ctx, repository.ClosedPlansList, int64(limit.HostId))
  262. if err != nil {
  263. t.logger.Error("决策[停止]:检查套餐是否已关闭失败", zap.Int("hostId", limit.HostId), zap.Error(err))
  264. continue
  265. }
  266. if !isClosed {
  267. plansToClose = append(plansToClose, limit)
  268. }
  269. }
  270. if len(plansToClose) == 0 {
  271. t.logger.Info("没有需要停止的套餐(可能均已续费或已关闭)")
  272. return nil
  273. }
  274. // 3. 执行停止操作
  275. t.logger.Info("开始关闭到期的WAF服务", zap.Int("数量", len(plansToClose)))
  276. var hostIds []int
  277. for _, limit := range plansToClose {
  278. hostIds = append(hostIds, limit.HostId)
  279. }
  280. if err := t.BanServer(ctx, hostIds, false); err != nil {
  281. return fmt.Errorf("执行[停止]-禁用服务失败: %w", err)
  282. }
  283. closedPlanIds := make([]int64, len(hostIds))
  284. for i, id := range hostIds { closedPlanIds[i] = int64(id) }
  285. if err := t.expiredRep.AddPlans(ctx, repository.ClosedPlansList, closedPlanIds...); err != nil {
  286. return fmt.Errorf("执行[停止]-标记为已关闭失败: %w", err)
  287. }
  288. return nil
  289. }
  290. // 3. RecoverRecentPlan 恢复7天内续费的套餐
  291. func (t *wafTask) RecoverRecentPlan(ctx context.Context) error {
  292. recentlyExpiredLimits, err := t.findRecentlyExpiredPlans(ctx)
  293. if err != nil { return fmt.Errorf("执行[近期恢复]-查找失败: %w", err) }
  294. if len(recentlyExpiredLimits) == 0 { return nil }
  295. renewalRequests, err := t.findMismatchedExpirations(ctx, recentlyExpiredLimits)
  296. if err != nil { return fmt.Errorf("执行[近期恢复]-决策失败: %w", err) }
  297. if len(renewalRequests) == 0 {
  298. t.logger.Info("在近期过期的套餐中,没有发现已续费的")
  299. return nil
  300. }
  301. return t.executePlanRecovery(ctx, renewalRequests, "近期恢复",repository.ClosedPlansList)
  302. }
  303. // 4. CleanUpStaleRecords 清理过期超过7天且仍未续费的记录
  304. func (t *wafTask) CleanUpStaleRecords(ctx context.Context) error {
  305. staleLimits, err := t.findStaleExpiredPlans(ctx)
  306. if err != nil { return fmt.Errorf("执行[清理]-查找失败: %w", err) }
  307. if len(staleLimits) == 0 { return nil }
  308. plansToClean, err := t.filterCleanablePlans(ctx, staleLimits)
  309. if err != nil { return fmt.Errorf("执行[清理]-决策失败: %w", err) }
  310. if len(plansToClean) == 0 {
  311. t.logger.Info("没有长期未续费的记录需要清理")
  312. return nil
  313. }
  314. t.logger.Info("开始清理长期未续费的WAF记录", zap.Int("数量", len(plansToClean)))
  315. var planIdsToClean []int64
  316. for _, limit := range plansToClean {
  317. planIdsToClean = append(planIdsToClean, int64(limit.HostId))
  318. }
  319. if err := t.expiredRep.RemovePlans(ctx, repository.ClosedPlansList, planIdsToClean...); err != nil {
  320. return fmt.Errorf("执行[清理]-从Redis移除关闭标记失败: %w", err)
  321. }
  322. // 在这里可以添加从数据库删除或调用CDN API彻底删除的逻辑
  323. for _, limit := range plansToClean {
  324. err = t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  325. HostId: limit.HostId,
  326. GatewayGroupId: limit.GatewayGroupId,
  327. State: true,
  328. })
  329. if err != nil {
  330. return fmt.Errorf("执行[清理]-更新套餐状态失败: %w", err)
  331. }
  332. }
  333. return nil
  334. }
  335. // 5. RecoverStalePlan 恢复超过7天后才续费的套餐
  336. func (t *wafTask) RecoverStalePlan(ctx context.Context) error {
  337. staleLimits, err := t.findStaleExpiredPlans(ctx)
  338. if err != nil { return fmt.Errorf("执行[长期恢复]-查找失败: %w", err) }
  339. if len(staleLimits) == 0 { return nil }
  340. renewalRequests, err := t.findMismatchedExpirations(ctx, staleLimits)
  341. if err != nil { return fmt.Errorf("执行[长期恢复]-决策失败: %w", err) }
  342. if len(renewalRequests) == 0 {
  343. t.logger.Info("在长期过期的套餐中,没有发现已续费的")
  344. return nil
  345. }
  346. return t.executePlanRecovery(ctx, renewalRequests, "长期恢复",repository.ExpiringSoonPlansList)
  347. }