|
- package task
- import (
- "context"
- v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
- "github.com/go-nunu/nunu-layout-advanced/internal/model"
- "github.com/go-nunu/nunu-layout-advanced/internal/repository"
- "github.com/go-nunu/nunu-layout-advanced/internal/service"
- "github.com/hashicorp/go-multierror"
- "sync"
- "time"
- )
- type WafTask interface {
- }
- func NewWafTask (
- webForWardingRep repository.WebForwardingRepository,
- tcpforwardingRep repository.TcpforwardingRepository,
- udpForWardingRep repository.UdpForWardingRepository,
- cdn service.CdnService,
- hostRep repository.HostRepository,
- globalLimitRep repository.GlobalLimitRepository,
- task *Task,
- ) WafTask{
- return &wafTask{
- Task: task,
- webForWardingRep: webForWardingRep,
- tcpforwardingRep: tcpforwardingRep,
- udpForWardingRep: udpForWardingRep,
- cdn: cdn,
- hostRep: hostRep,
- globalLimitRep: globalLimitRep,
- }
- }
- type wafTask struct {
- *Task
- webForWardingRep repository.WebForwardingRepository
- tcpforwardingRep repository.TcpforwardingRepository
- udpForWardingRep repository.UdpForWardingRepository
- cdn service.CdnService
- hostRep repository.HostRepository
- globalLimitRep repository.GlobalLimitRepository
- }
- func (t wafTask) CheckExpiredTask(ctx context.Context) error {
- return nil
- }
- // 获取cdn web id
- func (t wafTask) GetCdnWebId(ctx context.Context,hostId int) ([]int, error) {
- tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, hostId)
- if err != nil {
- return nil, err
- }
- udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, hostId)
- if err != nil {
- return nil, err
- }
- webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, hostId)
- if err != nil {
- return nil, err
- }
- var ids []int
- ids = append(ids, tcpIds...)
- ids = append(ids, udpIds...)
- ids = append(ids, webIds...)
- return ids, nil
- }
- // 启用/禁用 网站
- func (t wafTask) BanServer(ctx context.Context, ids []int, isBan bool) error {
- var wg sync.WaitGroup
- errChan := make(chan error, len(ids))
- // 修正1:为每个 goroutine 增加 WaitGroup 的计数
- wg.Add(len(ids))
- for _, id := range ids {
- go func(id int) {
- // 修正2:确保每个 goroutine 在退出时都调用 Done()
- defer wg.Done()
- err := t.cdn.EditWebIsOn(ctx, int64(id), isBan)
- if err != nil {
- errChan <- err
- // 这里不需要 return,因为 defer wg.Done() 会在函数退出时执行
- }
- }(id)
- }
- // 现在 wg.Wait() 会正确地阻塞,直到所有 goroutine 都调用了 Done()
- wg.Wait()
- // 在所有 goroutine 都结束后,安全地关闭 channel
- close(errChan)
- var result error
- for err := range errChan {
- result = multierror.Append(result, err) // 将多个 error 对象合并成一个单一的 error 对象
- }
- // 修正3:返回收集到的错误,而不是 nil
- return result
- }
- // 获取指定到期时间
- func (t wafTask) GetAlmostExpiring(ctx context.Context,hostIds []int,addTime int64) ([]v1.GetAlmostExpireHostResponse,error) {
- // 3 天
- res, err := t.hostRep.GetAlmostExpired(ctx, hostIds, addTime)
- if err != nil {
- return nil,err
- }
- return res, nil
- }
- // 获取全局到期时间
- func (t wafTask) GetGlobalAlmostExpiring(ctx context.Context,addTime int64) ([]model.GlobalLimit,error) {
- res, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, addTime)
- if err != nil {
- return nil, err
- }
- return res, nil
- }
- // 获取cdn web id
- func (t wafTask) GetGlobalAllHostId(ctx context.Context,addTime int64) (map[int]int64, error) {
- globalData, err := t.GetGlobalAlmostExpiring(ctx,addTime)
- if err != nil {
- return nil, err
- }
- var hostIds []int
- for _, v := range globalData {
- hostIds = append(hostIds, v.HostId)
- }
- globalDataMap := make(map[int]int64, len(globalData))
- planMap := make(map[int]int64, len(globalData))
- for _, v := range globalData {
- globalDataMap[v.HostId] = v.ExpiredAt
- planMap[v.HostId] = int64(v.RuleId)
- }
- hostData,err := t.GetAlmostExpiring(ctx,hostIds,addTime)
- if err != nil {
- return nil, err
- }
- hostDataMap := make(map[int]int64, len(hostData))
- for _, v := range hostData {
- hostDataMap[v.HostId] = v.ExpiredAt
- }
- editMap := make(map[int]int64)
- for k, v := range globalDataMap {
- if hostDataMap[k] != v {
- editMap[k] = v
- }
- }
- planExpireMap := make(map[int]int64)
- for k, v := range planMap {
- if _, ok := editMap[k]; ok {
- planExpireMap[k] = v
- }
- }
- return editMap, nil
- }
- // 修改全局续费
- func (t wafTask) EditGlobalExpired(ctx context.Context,req []struct{
- hostId int
- expiredAt int64
- },state bool) error {
- for _, v := range req {
- err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
- HostId: v.hostId,
- ExpiredAt: v.expiredAt,
- State: state,
- })
- if err != nil {
- return err
- }
- }
- return nil
- }
- // 续费套餐
- func (t wafTask) EnablePlan(ctx context.Context,req []struct{
- planId int
- expiredAt int64
- }) error {
- for _, v := range req {
- err := t.cdn.RenewPlan(ctx, v1.RenewalPlan{
- UserPlanId: int64(v.planId),
- IsFree: true,
- DayTo: time.Unix(v.expiredAt,0).Format("2006-01-02"),
- Period: "monthly",
- CountPeriod: 1,
- PeriodDayTo: time.Unix(v.expiredAt,0).Format("2006-01-02"),
- })
- if err != nil {
- return err
- }
- }
- return nil
- }
- // 续费操作
- func (t wafTask) EditExpired(ctx context.Context,req []struct {
- hostId int
- expiredAt int64
- planId int
- }) error {
- var sendData []struct {
- hostId int
- expiredAt int64
- }
- for _, v := range req {
- sendData = append(sendData, struct {
- hostId int
- expiredAt int64
- }{
- hostId: v.hostId,
- expiredAt: v.expiredAt,
- })
- }
- if err := t.EditGlobalExpired(ctx,sendData,true); err != nil {
- return err
- }
- return nil
- }
|