package task import ( "context" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" "github.com/go-nunu/nunu-layout-advanced/internal/service" "github.com/hashicorp/go-multierror" "sync" "time" ) type WafTask interface { } func NewWafTask ( webForWardingRep repository.WebForwardingRepository, tcpforwardingRep repository.TcpforwardingRepository, udpForWardingRep repository.UdpForWardingRepository, cdn service.CdnService, hostRep repository.HostRepository, globalLimitRep repository.GlobalLimitRepository, task *Task, ) WafTask{ return &wafTask{ Task: task, webForWardingRep: webForWardingRep, tcpforwardingRep: tcpforwardingRep, udpForWardingRep: udpForWardingRep, cdn: cdn, hostRep: hostRep, globalLimitRep: globalLimitRep, } } type wafTask struct { *Task webForWardingRep repository.WebForwardingRepository tcpforwardingRep repository.TcpforwardingRepository udpForWardingRep repository.UdpForWardingRepository cdn service.CdnService hostRep repository.HostRepository globalLimitRep repository.GlobalLimitRepository } func (t wafTask) CheckExpiredTask(ctx context.Context) error { return nil } // 获取cdn web id func (t wafTask) GetCdnWebId(ctx context.Context,hostId int) ([]int, error) { tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, hostId) if err != nil { return nil, err } udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, hostId) if err != nil { return nil, err } webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, hostId) if err != nil { return nil, err } var ids []int ids = append(ids, tcpIds...) ids = append(ids, udpIds...) ids = append(ids, webIds...) return ids, nil } // 启用/禁用 网站 func (t wafTask) BanServer(ctx context.Context, ids []int, isBan bool) error { var wg sync.WaitGroup errChan := make(chan error, len(ids)) // 修正1:为每个 goroutine 增加 WaitGroup 的计数 wg.Add(len(ids)) for _, id := range ids { go func(id int) { // 修正2:确保每个 goroutine 在退出时都调用 Done() defer wg.Done() err := t.cdn.EditWebIsOn(ctx, int64(id), isBan) if err != nil { errChan <- err // 这里不需要 return,因为 defer wg.Done() 会在函数退出时执行 } }(id) } // 现在 wg.Wait() 会正确地阻塞,直到所有 goroutine 都调用了 Done() wg.Wait() // 在所有 goroutine 都结束后,安全地关闭 channel close(errChan) var result error for err := range errChan { result = multierror.Append(result, err) // 将多个 error 对象合并成一个单一的 error 对象 } // 修正3:返回收集到的错误,而不是 nil return result } // 获取指定到期时间 func (t wafTask) GetAlmostExpiring(ctx context.Context,hostIds []int,addTime int64) ([]v1.GetAlmostExpireHostResponse,error) { // 3 天 res, err := t.hostRep.GetAlmostExpired(ctx, hostIds, addTime) if err != nil { return nil,err } return res, nil } // 获取全局到期时间 func (t wafTask) GetGlobalAlmostExpiring(ctx context.Context,addTime int64) ([]model.GlobalLimit,error) { res, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, addTime) if err != nil { return nil, err } return res, nil } // 获取cdn web id func (t wafTask) GetGlobalAllHostId(ctx context.Context,addTime int64) (map[int]int64, error) { globalData, err := t.GetGlobalAlmostExpiring(ctx,addTime) if err != nil { return nil, err } var hostIds []int for _, v := range globalData { hostIds = append(hostIds, v.HostId) } globalDataMap := make(map[int]int64, len(globalData)) planMap := make(map[int]int64, len(globalData)) for _, v := range globalData { globalDataMap[v.HostId] = v.ExpiredAt planMap[v.HostId] = int64(v.RuleId) } hostData,err := t.GetAlmostExpiring(ctx,hostIds,addTime) if err != nil { return nil, err } hostDataMap := make(map[int]int64, len(hostData)) for _, v := range hostData { hostDataMap[v.HostId] = v.ExpiredAt } editMap := make(map[int]int64) for k, v := range globalDataMap { if hostDataMap[k] != v { editMap[k] = v } } planExpireMap := make(map[int]int64) for k, v := range planMap { if _, ok := editMap[k]; ok { planExpireMap[k] = v } } return editMap, nil } // 修改全局续费 func (t wafTask) EditGlobalExpired(ctx context.Context,req []struct{ hostId int expiredAt int64 },state bool) error { for _, v := range req { err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{ HostId: v.hostId, ExpiredAt: v.expiredAt, State: state, }) if err != nil { return err } } return nil } // 续费套餐 func (t wafTask) EnablePlan(ctx context.Context,req []struct{ planId int expiredAt int64 }) error { for _, v := range req { err := t.cdn.RenewPlan(ctx, v1.RenewalPlan{ UserPlanId: int64(v.planId), IsFree: true, DayTo: time.Unix(v.expiredAt,0).Format("2006-01-02"), Period: "monthly", CountPeriod: 1, PeriodDayTo: time.Unix(v.expiredAt,0).Format("2006-01-02"), }) if err != nil { return err } } return nil } // 续费操作 func (t wafTask) EditExpired(ctx context.Context,req []struct { hostId int expiredAt int64 planId int }) error { var sendData []struct { hostId int expiredAt int64 } for _, v := range req { sendData = append(sendData, struct { hostId int expiredAt int64 }{ hostId: v.hostId, expiredAt: v.expiredAt, }) } if err := t.EditGlobalExpired(ctx,sendData,true); err != nil { return err } return nil }