123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457 |
- package task
- import (
- "context"
- "fmt"
- v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
- "github.com/go-nunu/nunu-layout-advanced/internal/model"
- "github.com/go-nunu/nunu-layout-advanced/internal/repository"
- "github.com/go-nunu/nunu-layout-advanced/internal/service"
- "github.com/hashicorp/go-multierror"
- "go.uber.org/zap"
- "sync"
- "time"
- )
- type WafTask interface {
- //获取到期时间小于3天的同步时间
- SynchronizationTime(ctx context.Context) error
- StopPlan(ctx context.Context) error
- RecoverStopPlan(ctx context.Context) error
- }
- func NewWafTask (
- webForWardingRep repository.WebForwardingRepository,
- tcpforwardingRep repository.TcpforwardingRepository,
- udpForWardingRep repository.UdpForWardingRepository,
- cdn service.CdnService,
- hostRep repository.HostRepository,
- globalLimitRep repository.GlobalLimitRepository,
- expiredRep repository.ExpiredRepository,
- task *Task,
- ) WafTask{
- return &wafTask{
- Task: task,
- webForWardingRep: webForWardingRep,
- tcpforwardingRep: tcpforwardingRep,
- udpForWardingRep: udpForWardingRep,
- cdn: cdn,
- hostRep: hostRep,
- globalLimitRep: globalLimitRep,
- expiredRep: expiredRep,
- }
- }
- type wafTask struct {
- *Task
- webForWardingRep repository.WebForwardingRepository
- tcpforwardingRep repository.TcpforwardingRepository
- udpForWardingRep repository.UdpForWardingRepository
- cdn service.CdnService
- hostRep repository.HostRepository
- globalLimitRep repository.GlobalLimitRepository
- expiredRep repository.ExpiredRepository
- }
- const (
- // 1天后秒数
- OneDaysInSeconds = 1 * 24 * 60 * 60
- // 7天前秒数
- SevenDaysInSeconds = 7 * 24 * 60 * 60 * -1
- )
- // 获取cdn web id
- func (t wafTask) GetCdnWebId(ctx context.Context,hostId []int) ([]int, error) {
- tcpIds, err := t.tcpforwardingRep.GetTcpAll(ctx, hostId)
- if err != nil {
- return nil, err
- }
- udpIds, err := t.udpForWardingRep.GetUdpAll(ctx, hostId)
- if err != nil {
- return nil, err
- }
- webIds, err := t.webForWardingRep.GetWebAll(ctx, hostId)
- if err != nil {
- return nil, err
- }
- var ids []int
- ids = append(ids, tcpIds...)
- ids = append(ids, udpIds...)
- ids = append(ids, webIds...)
- return ids, nil
- }
- // 启用/禁用 网站
- func (t wafTask) BanServer(ctx context.Context, ids []int, isBan bool) error {
- var wg sync.WaitGroup
- errChan := make(chan error, len(ids))
- // 修正1:为每个 goroutine 增加 WaitGroup 的计数
- wg.Add(len(ids))
- for _, id := range ids {
- go func(id int) {
- // 修正2:确保每个 goroutine 在退出时都调用 Done()
- defer wg.Done()
- err := t.cdn.EditWebIsOn(ctx, int64(id), isBan)
- if err != nil {
- errChan <- err
- // 这里不需要 return,因为 defer wg.Done() 会在函数退出时执行
- }
- }(id)
- }
- // 现在 wg.Wait() 会正确地阻塞,直到所有 goroutine 都调用了 Done()
- wg.Wait()
- // 在所有 goroutine 都结束后,安全地关闭 channel
- close(errChan)
- var result error
- for err := range errChan {
- result = multierror.Append(result, err) // 将多个 error 对象合并成一个单一的 error 对象
- }
- // 修正3:返回收集到的错误,而不是 nil
- return result
- }
- // 获取指定到期时间
- func (t wafTask) GetAlmostExpiring(ctx context.Context,hostIds []int,addTime int64) ([]v1.GetAlmostExpireHostResponse,error) {
- // 3 天
- res, err := t.hostRep.GetAlmostExpired(ctx, hostIds, addTime)
- if err != nil {
- return nil,err
- }
- return res, nil
- }
- // 获取waf全局到期时间
- func (t wafTask) GetGlobalAlmostExpiring(ctx context.Context,addTime int64) ([]model.GlobalLimit,error) {
- res, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, addTime)
- if err != nil {
- return nil, err
- }
- return res, nil
- }
- // 修改全局续费
- func (t wafTask) EditGlobalExpired(ctx context.Context, req []struct{
- hostId int
- expiredAt int64
- }, state bool) error {
- var result *multierror.Error // 使用 multierror
- for _, v := range req {
- err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
- HostId: v.hostId,
- ExpiredAt: v.expiredAt,
- State: state,
- })
- if err != nil {
- // 收集错误,而不是直接返回
- result = multierror.Append(result, err)
- }
- }
- // 返回所有收集到的错误
- return result.ErrorOrNil()
- }
- // 续费套餐
- func (t wafTask) EnablePlan(ctx context.Context, req []struct{
- planId int
- expiredAt int64
- }) error {
- var result *multierror.Error
- for _, v := range req {
- err := t.cdn.RenewPlan(ctx, v1.RenewalPlan{
- UserPlanId: int64(v.planId),
- IsFree: true,
- DayTo: time.Unix(v.expiredAt, 0).Format("2006-01-02"),
- Period: "monthly",
- CountPeriod: 1,
- PeriodDayTo: time.Unix(v.expiredAt, 0).Format("2006-01-02"),
- })
- if err != nil {
- result = multierror.Append(result, err)
- }
- }
- return result.ErrorOrNil()
- }
- // 续费操作
- type RenewalRequest struct {
- HostId int
- PlanId int
- ExpiredAt int64
- }
- // 续费操作
- func (t wafTask) EditExpired(ctx context.Context, reqs []RenewalRequest) error {
- // 如果请求为空,直接返回
- if len(reqs) == 0 {
- return nil
- }
- // 1. 准备用于更新 GlobalLimit 的数据
- var globalLimitUpdates []struct {
- hostId int
- expiredAt int64
- }
- for _, req := range reqs {
- globalLimitUpdates = append(globalLimitUpdates, struct {
- hostId int
- expiredAt int64
- }{req.HostId, req.ExpiredAt})
- }
- // 2. 准备用于续费套餐的数据
- var planRenewals []struct {
- planId int
- expiredAt int64
- }
- for _, req := range reqs {
- planRenewals = append(planRenewals, struct {
- planId int
- expiredAt int64
- }{req.PlanId, req.ExpiredAt})
- }
- var result *multierror.Error
- // 3. 执行更新,并收集错误
- if err := t.EditGlobalExpired(ctx, globalLimitUpdates, true); err != nil {
- result = multierror.Append(result, err)
- }
- if err := t.EnablePlan(ctx, planRenewals); err != nil {
- result = multierror.Append(result, err)
- }
- return result.ErrorOrNil()
- }
- // findMismatchedExpirations 检查 WAF 和 Host 的到期时间差异,并返回需要同步的请求。
- func (t *wafTask) findMismatchedExpirations(ctx context.Context, wafLimits []model.GlobalLimit) ([]RenewalRequest, error) {
- if len(wafLimits) == 0 {
- return nil, nil
- }
- // 2. 将 WAF 数据组织成 Map
- wafExpiredMap := make(map[int]int64, len(wafLimits))
- wafPlanMap := make(map[int]int, len(wafLimits))
- var hostIds []int
- for _, limit := range wafLimits {
- hostIds = append(hostIds, limit.HostId)
- wafExpiredMap[limit.HostId] = limit.ExpiredAt
- wafPlanMap[limit.HostId] = limit.RuleId
- }
- // 3. 获取对应 Host 的到期时间
- hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, hostIds)
- if err != nil {
- return nil, fmt.Errorf("获取主机到期时间失败: %w", err)
- }
- hostExpiredMap := make(map[int]int64, len(hostExpirations))
- for _, h := range hostExpirations {
- hostExpiredMap[h.HostId] = h.ExpiredAt
- }
- // 4. 找出时间不一致的记录
- var renewalRequests []RenewalRequest
- for hostId, wafExpiredTime := range wafExpiredMap {
- hostTime, ok := hostExpiredMap[hostId]
- // 如果 Host 时间与 WAF 时间不一致,则需要同步
- if !ok || hostTime != wafExpiredTime {
- planId, planOk := wafPlanMap[hostId]
- if !planOk {
- t.logger.Warn("数据不一致:在waf_limits中找不到hostId对应的套餐ID", zap.Int("hostId", hostId))
- continue
- }
- renewalRequests = append(renewalRequests, RenewalRequest{
- HostId: hostId,
- ExpiredAt: hostTime, // 以 host 表的时间为准
- PlanId: planId,
- })
- }
- }
- return renewalRequests, nil
- }
- //获取同步到期时间小于1天的套餐
- func (t *wafTask) SynchronizationTime(ctx context.Context) error {
- // 1. 获取 WAF 全局配置中即将到期(小于3天)的数据
- wafLimits, err := t.GetGlobalAlmostExpiring(ctx, OneDaysInSeconds)
- if err != nil {
- return fmt.Errorf("获取全局到期配置失败: %w", err)
- }
- // 2. 找出需要同步的数据
- renewalRequests, err := t.findMismatchedExpirations(ctx, wafLimits)
- if err != nil {
- return err // 错误已在辅助函数中包装
- }
- // 3. 如果有需要同步的数据,执行续费操作
- if len(renewalRequests) > 0 {
- t.logger.Info("发现记录需要同步到期时间。", zap.Int("数量", len(renewalRequests)))
- return t.EditExpired(ctx, renewalRequests)
- }
- return nil
- }
- // 获取到期的进行关闭套餐操作
- func (t *wafTask) StopPlan(ctx context.Context) error {
- // 1. 获取 WAF 全局配置中已经到期的数据
- // 使用 time.Now().Unix() 表示获取所有 expired_at <= 当前时间的记录
- wafLimits, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, 0)
- if err != nil {
- return fmt.Errorf("获取全局到期配置失败: %w", err)
- }
- if len(wafLimits) == 0 {
- return nil // 没有到期的,任务完成
- }
- // 2. (可选,但推荐)先同步任何时间不一致的数据,确保状态准确
- renewalRequests, err := t.findMismatchedExpirations(ctx, wafLimits)
- if err != nil {
- t.logger.Error("在关闭套餐前,同步时间失败", zap.Error(err))
- // 根据业务决定是否要继续,这里我们选择继续,但记录错误
- }
- if len(renewalRequests) > 0 {
- t.logger.Info("关闭套餐前,发现并同步不一致的时间记录", zap.Int("数量", len(renewalRequests)))
- if err := t.EditExpired(ctx, renewalRequests); err != nil {
- t.logger.Error("同步不一致的时间记录失败", zap.Error(err))
- }
- }
- // 3. 筛选出尚未被关闭的套餐
- var plansToClose []model.GlobalLimit
- for _, limit := range wafLimits {
- isClosed, err := t.expiredRep.IsPlanClosed(ctx, int64(limit.HostId))
- if err != nil {
- t.logger.Error("检查Redis中套餐关闭状态失败", zap.Int("hostId", limit.HostId), zap.Error(err))
- continue // 跳过这个,处理下一个
- }
- if !isClosed {
- plansToClose = append(plansToClose, limit)
- }
- }
- if len(plansToClose) == 0 {
- t.logger.Info("没有新的到期套餐需要关闭")
- return nil
- }
- // 4. 对筛选出的套餐执行关闭操作
- t.logger.Info("开始关闭新的到期WAF服务", zap.Int("数量", len(plansToClose)))
- var allErrors *multierror.Error
- var webIds []int
- for _, limit := range plansToClose {
- webIds = append(webIds, limit.HostId)
- }
- if err := t.BanServer(ctx, webIds, false); err != nil {
- allErrors = multierror.Append(allErrors, fmt.Errorf("关闭hostId %v 的服务失败: %w", webIds, err))
- } else {
- // 服务关闭成功后,将这些套餐信息添加到 Redis
- var expiredInfos []repository.ExpiredInfo
- for _, limit := range plansToClose {
- expiredInfos = append(expiredInfos, repository.ExpiredInfo{
- HostID: int64(limit.HostId),
- Expiry: time.Unix(limit.ExpiredAt, 0),
- })
- }
- if len(expiredInfos) > 0 {
- if err := t.expiredRep.AddClosePlans(ctx, expiredInfos...); err != nil {
- allErrors = multierror.Append(allErrors, fmt.Errorf("添加已关闭套餐信息到Redis失败: %w", err))
- }
- }
- }
- return allErrors.ErrorOrNil()
- }
- //对于到期7天内续费的产品需要进行恢复操作
- func (t *wafTask) RecoverStopPlan(ctx context.Context) error {
- // 1. 获取所有已过期(expired_at < now)但状态仍为 true 的 WAF 记录
- // StopPlan 任务会禁用这些服务,但不会改变它们的 state
- wafLimits, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, SevenDaysInSeconds) // addTime=0 表示获取所有当前时间之前到期的
- if err != nil {
- return fmt.Errorf("获取过期WAF配置失败: %w", err)
- }
- if len(wafLimits) == 0 {
- t.logger.Info("没有已过期且需要检查恢复的服务")
- return nil
- }
- // 2. 检查这些记录对应的 host 是否已续费
- // findMismatchedExpirations 会比较 waf.expired_at 和 host.nextduedate
- renewalRequests, err := t.findMismatchedExpirations(ctx, wafLimits)
- if err != nil {
- return fmt.Errorf("检查续费状态失败: %w", err)
- }
- if len(renewalRequests) == 0 {
- t.logger.Info("在已过期的服务中,没有发现已续费且需要恢复的服务")
- return nil
- }
- // 3. 对已续费的服务执行恢复操作
- t.logger.Info("发现已续费、需要恢复的WAF服务", zap.Int("数量", len(renewalRequests)))
- var allErrors *multierror.Error
- var webIds []int
- for _, req := range renewalRequests {
- webIds = append(webIds, req.HostId)
- }
- if err := t.BanServer(ctx, webIds, true); err != nil {
- allErrors = multierror.Append(allErrors, fmt.Errorf("恢复hostId %v: 启用服务失败: %w", webIds, err))
- } else {
- // 服务恢复成功后,从 Redis 中移除这些套餐的关闭记录
- planIds := make([]int64, len(webIds))
- for i, id := range webIds {
- planIds[i] = int64(id)
- }
- if err := t.expiredRep.RemoveClosePlanIds(ctx, planIds...); err != nil {
- allErrors = multierror.Append(allErrors, fmt.Errorf("从Redis移除已恢复的套餐失败: %w", err))
- }
- }
- if len(renewalRequests) > 0 {
- // 统一执行续费和数据库更新操作
- if err := t.EditExpired(ctx, renewalRequests); err != nil {
- allErrors = multierror.Append(allErrors, fmt.Errorf("批量更新已恢复服务的数据库状态失败: %w", err))
- }
- }
- return allErrors.ErrorOrNil()
- }
- //对于大于7天的药进行数据情侣操作
|