waf.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. package task
  2. import (
  3. "context"
  4. "fmt"
  5. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  6. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/service"
  9. "github.com/hashicorp/go-multierror"
  10. "go.uber.org/zap"
  11. "sync"
  12. "time"
  13. )
  14. type WafTask interface {
  15. //获取到期时间小于3天的同步时间
  16. SynchronizationTime(ctx context.Context) error
  17. }
  18. func NewWafTask (
  19. webForWardingRep repository.WebForwardingRepository,
  20. tcpforwardingRep repository.TcpforwardingRepository,
  21. udpForWardingRep repository.UdpForWardingRepository,
  22. cdn service.CdnService,
  23. hostRep repository.HostRepository,
  24. globalLimitRep repository.GlobalLimitRepository,
  25. task *Task,
  26. ) WafTask{
  27. return &wafTask{
  28. Task: task,
  29. webForWardingRep: webForWardingRep,
  30. tcpforwardingRep: tcpforwardingRep,
  31. udpForWardingRep: udpForWardingRep,
  32. cdn: cdn,
  33. hostRep: hostRep,
  34. globalLimitRep: globalLimitRep,
  35. }
  36. }
  37. type wafTask struct {
  38. *Task
  39. webForWardingRep repository.WebForwardingRepository
  40. tcpforwardingRep repository.TcpforwardingRepository
  41. udpForWardingRep repository.UdpForWardingRepository
  42. cdn service.CdnService
  43. hostRep repository.HostRepository
  44. globalLimitRep repository.GlobalLimitRepository
  45. }
  46. const (
  47. // 3天后秒数
  48. OneDaysInSeconds = 3 * 24 * 60 * 60
  49. // 7天前秒数
  50. SevenDaysInSeconds = 7 * 24 * 60 * 60 * -1
  51. )
  52. // 获取cdn web id
  53. func (t wafTask) GetCdnWebId(ctx context.Context,hostId int) ([]int, error) {
  54. tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, hostId)
  55. if err != nil {
  56. return nil, err
  57. }
  58. udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, hostId)
  59. if err != nil {
  60. return nil, err
  61. }
  62. webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, hostId)
  63. if err != nil {
  64. return nil, err
  65. }
  66. var ids []int
  67. ids = append(ids, tcpIds...)
  68. ids = append(ids, udpIds...)
  69. ids = append(ids, webIds...)
  70. return ids, nil
  71. }
  72. // 启用/禁用 网站
  73. func (t wafTask) BanServer(ctx context.Context, ids []int, isBan bool) error {
  74. var wg sync.WaitGroup
  75. errChan := make(chan error, len(ids))
  76. // 修正1:为每个 goroutine 增加 WaitGroup 的计数
  77. wg.Add(len(ids))
  78. for _, id := range ids {
  79. go func(id int) {
  80. // 修正2:确保每个 goroutine 在退出时都调用 Done()
  81. defer wg.Done()
  82. err := t.cdn.EditWebIsOn(ctx, int64(id), isBan)
  83. if err != nil {
  84. errChan <- err
  85. // 这里不需要 return,因为 defer wg.Done() 会在函数退出时执行
  86. }
  87. }(id)
  88. }
  89. // 现在 wg.Wait() 会正确地阻塞,直到所有 goroutine 都调用了 Done()
  90. wg.Wait()
  91. // 在所有 goroutine 都结束后,安全地关闭 channel
  92. close(errChan)
  93. var result error
  94. for err := range errChan {
  95. result = multierror.Append(result, err) // 将多个 error 对象合并成一个单一的 error 对象
  96. }
  97. // 修正3:返回收集到的错误,而不是 nil
  98. return result
  99. }
  100. // 获取指定到期时间
  101. func (t wafTask) GetAlmostExpiring(ctx context.Context,hostIds []int,addTime int64) ([]v1.GetAlmostExpireHostResponse,error) {
  102. // 3 天
  103. res, err := t.hostRep.GetAlmostExpired(ctx, hostIds, addTime)
  104. if err != nil {
  105. return nil,err
  106. }
  107. return res, nil
  108. }
  109. // 获取waf全局到期时间
  110. func (t wafTask) GetGlobalAlmostExpiring(ctx context.Context,addTime int64) ([]model.GlobalLimit,error) {
  111. res, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, addTime)
  112. if err != nil {
  113. return nil, err
  114. }
  115. return res, nil
  116. }
  117. // 修改全局续费
  118. func (t wafTask) EditGlobalExpired(ctx context.Context, req []struct{
  119. hostId int
  120. expiredAt int64
  121. }, state bool) error {
  122. var result *multierror.Error // 使用 multierror
  123. for _, v := range req {
  124. err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  125. HostId: v.hostId,
  126. ExpiredAt: v.expiredAt,
  127. State: state,
  128. })
  129. if err != nil {
  130. // 收集错误,而不是直接返回
  131. result = multierror.Append(result, err)
  132. }
  133. }
  134. // 返回所有收集到的错误
  135. return result.ErrorOrNil()
  136. }
  137. // 续费套餐
  138. func (t wafTask) EnablePlan(ctx context.Context, req []struct{
  139. planId int
  140. expiredAt int64
  141. }) error {
  142. var result *multierror.Error
  143. for _, v := range req {
  144. err := t.cdn.RenewPlan(ctx, v1.RenewalPlan{
  145. UserPlanId: int64(v.planId),
  146. IsFree: true,
  147. DayTo: time.Unix(v.expiredAt, 0).Format("2006-01-02"),
  148. Period: "monthly",
  149. CountPeriod: 1,
  150. PeriodDayTo: time.Unix(v.expiredAt, 0).Format("2006-01-02"),
  151. })
  152. if err != nil {
  153. result = multierror.Append(result, err)
  154. }
  155. }
  156. return result.ErrorOrNil()
  157. }
  158. // 续费操作
  159. type RenewalRequest struct {
  160. HostId int
  161. PlanId int
  162. ExpiredAt int64
  163. }
  164. // 续费操作
  165. func (t wafTask) EditExpired(ctx context.Context, reqs []RenewalRequest) error {
  166. // 如果请求为空,直接返回
  167. if len(reqs) == 0 {
  168. return nil
  169. }
  170. // 1. 准备用于更新 GlobalLimit 的数据
  171. var globalLimitUpdates []struct {
  172. hostId int
  173. expiredAt int64
  174. }
  175. for _, req := range reqs {
  176. globalLimitUpdates = append(globalLimitUpdates, struct {
  177. hostId int
  178. expiredAt int64
  179. }{req.HostId, req.ExpiredAt})
  180. }
  181. // 2. 准备用于续费套餐的数据
  182. var planRenewals []struct {
  183. planId int
  184. expiredAt int64
  185. }
  186. for _, req := range reqs {
  187. planRenewals = append(planRenewals, struct {
  188. planId int
  189. expiredAt int64
  190. }{req.PlanId, req.ExpiredAt})
  191. }
  192. var result *multierror.Error
  193. // 3. 执行更新,并收集错误
  194. if err := t.EditGlobalExpired(ctx, globalLimitUpdates, true); err != nil {
  195. result = multierror.Append(result, err)
  196. }
  197. if err := t.EnablePlan(ctx, planRenewals); err != nil {
  198. result = multierror.Append(result, err)
  199. }
  200. return result.ErrorOrNil()
  201. }
  202. // findMismatchedExpirations 检查 WAF 和 Host 的到期时间差异,并返回需要同步的请求。
  203. func (t *wafTask) findMismatchedExpirations(ctx context.Context, wafLimits []model.GlobalLimit) ([]RenewalRequest, error) {
  204. if len(wafLimits) == 0 {
  205. return nil, nil
  206. }
  207. // 2. 将 WAF 数据组织成 Map
  208. wafExpiredMap := make(map[int]int64, len(wafLimits))
  209. wafPlanMap := make(map[int]int, len(wafLimits))
  210. var hostIds []int
  211. for _, limit := range wafLimits {
  212. hostIds = append(hostIds, limit.HostId)
  213. wafExpiredMap[limit.HostId] = limit.ExpiredAt
  214. wafPlanMap[limit.HostId] = limit.RuleId
  215. }
  216. // 3. 获取对应 Host 的到期时间
  217. hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, hostIds)
  218. if err != nil {
  219. return nil, fmt.Errorf("获取主机到期时间失败: %w", err)
  220. }
  221. hostExpiredMap := make(map[int]int64, len(hostExpirations))
  222. for _, h := range hostExpirations {
  223. hostExpiredMap[h.HostId] = h.ExpiredAt
  224. }
  225. // 4. 找出时间不一致的记录
  226. var renewalRequests []RenewalRequest
  227. for hostId, wafExpiredTime := range wafExpiredMap {
  228. hostTime, ok := hostExpiredMap[hostId]
  229. // 如果 Host 时间与 WAF 时间不一致,则需要同步
  230. if !ok || hostTime != wafExpiredTime {
  231. planId, planOk := wafPlanMap[hostId]
  232. if !planOk {
  233. t.logger.Warn("数据不一致:在waf_limits中找不到hostId对应的套餐ID", zap.Int("hostId", hostId))
  234. continue
  235. }
  236. renewalRequests = append(renewalRequests, RenewalRequest{
  237. HostId: hostId,
  238. ExpiredAt: hostTime, // 以 WAF 表的时间为准
  239. PlanId: planId,
  240. })
  241. }
  242. }
  243. return renewalRequests, nil
  244. }
  245. //获取到期时间小于3天的同步时间
  246. func (t *wafTask) SynchronizationTime(ctx context.Context) error {
  247. // 1. 获取 WAF 全局配置中即将到期(小于3天)的数据
  248. wafLimits, err := t.GetGlobalAlmostExpiring(ctx, OneDaysInSeconds)
  249. if err != nil {
  250. return fmt.Errorf("获取全局到期配置失败: %w", err)
  251. }
  252. // 2. 找出需要同步的数据
  253. renewalRequests, err := t.findMismatchedExpirations(ctx, wafLimits)
  254. if err != nil {
  255. return err // 错误已在辅助函数中包装
  256. }
  257. // 3. 如果有需要同步的数据,执行续费操作
  258. if len(renewalRequests) > 0 {
  259. t.logger.Info("发现记录需要同步到期时间。", zap.Int("数量", len(renewalRequests)))
  260. return t.EditExpired(ctx, renewalRequests)
  261. }
  262. return nil
  263. }
  264. //获取到期的进行关闭套餐操作
  265. // 获取到期的进行关闭套餐操作
  266. func (t *wafTask) StopPlan(ctx context.Context) error {
  267. // 1. 获取 WAF 全局配置中已经到期的数据
  268. // 使用 time.Now().Unix() 表示获取所有 expired_at <= 当前时间的记录
  269. wafLimits, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, time.Now().Unix())
  270. if err != nil {
  271. return fmt.Errorf("获取全局到期配置失败: %w", err)
  272. }
  273. if len(wafLimits) == 0 {
  274. return nil // 没有到期的,任务完成
  275. }
  276. // 2. (可选,但推荐)先同步任何时间不一致的数据,确保状态准确
  277. renewalRequests, err := t.findMismatchedExpirations(ctx, wafLimits)
  278. if err != nil {
  279. t.logger.Error("在关闭套餐前,同步时间失败", zap.Error(err))
  280. // 根据业务决定是否要继续,这里我们选择继续,但记录错误
  281. }
  282. if len(renewalRequests) > 0 {
  283. t.logger.Info("关闭套餐前,发现并同步不一致的时间记录", zap.Int("数量", len(renewalRequests)))
  284. if err := t.EditExpired(ctx, renewalRequests); err != nil {
  285. t.logger.Error("同步不一致的时间记录失败", zap.Error(err))
  286. }
  287. }
  288. // 3. 关闭所有已经到期的套餐
  289. t.logger.Info("开始关闭已到期的WAF服务", zap.Int("数量", len(wafLimits)))
  290. var allErrors *multierror.Error
  291. for _, limit := range wafLimits {
  292. webIds, err := t.GetCdnWebId(ctx, limit.HostId)
  293. if err != nil {
  294. allErrors = multierror.Append(allErrors, fmt.Errorf("获取hostId %d 的webId失败: %w", limit.HostId, err))
  295. continue // 继续处理下一个
  296. }
  297. if err := t.BanServer(ctx, webIds, false); err != nil {
  298. allErrors = multierror.Append(allErrors, fmt.Errorf("关闭hostId %d 的服务失败: %w", limit.HostId, err))
  299. }
  300. }
  301. return allErrors.ErrorOrNil()
  302. }
  303. //对于到期7天内续费的产品需要进行恢复操作
  304. // RecoverStopPlan 对于到期7天内续费的产品进行恢复操作
  305. func (t *wafTask) RecoverStopPlan(ctx context.Context) error {
  306. // 1. 查找在过去7天内到期,并且当前状态为“已关闭”的 WAF 记录
  307. // 这可能需要一个新的 repository 方法,例如: GetRecentlyClosedLimits
  308. // 我们先假设有这样一个方法,它返回 state=false 且 expired_at 在 (now-7天, now] 之间的记录
  309. since := time.Now().Add(-7 * 24 * time.Hour).Unix()
  310. // 假设你有一个方法 `GetClosedLimitsSince(ctx, sinceTime)`
  311. // closedLimits, err := t.globalLimitRep.GetClosedLimitsSince(ctx, since)
  312. // 为简化,我们先获取所有7天内到期的,再在逻辑里判断
  313. // 简单的实现:获取7天内到期的所有记录
  314. wafLimits, err := t.globalLimitRep.GetLimitsExpiredSince(ctx, since) // 假设有这个方法
  315. if err != nil {
  316. return fmt.Errorf("获取近期到期配置失败: %w", err)
  317. }
  318. if len(wafLimits) == 0 {
  319. return nil
  320. }
  321. // 提取 hostIds 并过滤出已关闭的记录
  322. var hostIds []int
  323. closedLimitsMap := make(map[int]model.GlobalLimit)
  324. for _, limit := range wafLimits {
  325. if !limit.State { // 只处理状态为“已关闭”的
  326. hostIds = append(hostIds, limit.HostId)
  327. closedLimitsMap[limit.HostId] = limit
  328. }
  329. }
  330. if len(hostIds) == 0 {
  331. return nil // 没有已关闭的记录需要检查
  332. }
  333. // 2. 获取这些 host 的当前到期时间
  334. hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, hostIds)
  335. if err != nil {
  336. return fmt.Errorf("获取主机当前到期时间失败: %w", err)
  337. }
  338. hostExpiredMap := make(map[int]int64)
  339. for _, h := range hostExpirations {
  340. hostExpiredMap[h.HostId] = h.ExpiredAt
  341. }
  342. var allErrors *multierror.Error
  343. // 3. 比较时间,找出已续费的 host,并恢复服务
  344. for hostId, closedLimit := range closedLimitsMap {
  345. currentHostExpiry, ok := hostExpiredMap[hostId]
  346. if !ok {
  347. continue // host 不存在了,跳过
  348. }
  349. // 如果 host 表的到期时间 > global_limit 表的到期时间,说明已续费
  350. if currentHostExpiry > closedLimit.ExpiredAt {
  351. t.logger.Info("发现已续费并关闭的WAF服务,准备恢复", zap.Int("hostId", hostId))
  352. // 3a. 恢复网站服务
  353. webIds, err := t.GetCdnWebId(ctx, hostId)
  354. if err != nil {
  355. allErrors = multierror.Append(allErrors, fmt.Errorf("恢复hostId %d 时获取webId失败: %w", hostId, err))
  356. continue
  357. }
  358. if err := t.BanServer(ctx, webIds, true); err != nil { // true 表示启用
  359. allErrors = multierror.Append(allErrors, fmt.Errorf("恢复hostId %d 服务失败: %w", hostId, err))
  360. continue
  361. }
  362. // 3b. 更新 global_limit 表的时间和状态
  363. var singleUpdate []struct{hostId int; expiredAt int64}
  364. singleUpdate = append(singleUpdate, struct{hostId int; expiredAt int64}{hostId: hostId, expiredAt: currentHostExpiry})
  365. if err := t.EditGlobalExpired(ctx, singleUpdate, true); err != nil { // true 表示启用
  366. allErrors = multierror.Append(allErrors, fmt.Errorf("更新hostId %d 状态为已恢复失败: %w", hostId, err))
  367. }
  368. }
  369. }
  370. return allErrors.ErrorOrNil()
  371. }
  372. //对于大于7天的药进行数据情侣操作