浏览代码

feat(task): 添加 WAF 任务相关功能

- 新增 GetAlmostExpired 方法获取即将到期的主机信息
- 实现 CheckExpiredTask 方法检查过期任务
- 添加 GetCdnWebId 方法获取 CDN Web ID
- 实现 BanServer 方法禁用网站
- 优化错误处理和并发执行逻辑
fusu 4 周之前
父节点
当前提交
20c75af7aa
共有 3 个文件被更改,包括 114 次插入1 次删除
  1. 5 0
      api/v1/host.go
  2. 16 0
      internal/repository/host.go
  3. 93 1
      internal/task/waf.go

+ 5 - 0
api/v1/host.go

@@ -35,3 +35,8 @@ type GlobalLimitConfigResponse struct {
 	IsBanUdp      int
 	IsBanOverseas int
 }
+
+type GetAlmostExpireHostResponse struct {
+	HostId   int
+	ExpiredAt int64
+}

+ 16 - 0
internal/repository/host.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
+	"time"
 )
 
 type HostRepository interface {
@@ -14,6 +15,8 @@ type HostRepository interface {
 	GetDomainById(ctx context.Context, id int) (string, error)
 	// 获取到期时间
 	GetExpireTime(ctx context.Context, uid int64, hostId int64) (string, error)
+	// 获取指定到期时间
+	GetAlmostExpired(ctx context.Context, hostId []int,addTime int64) ([]v1.GetAlmostExpireHostResponse, error)
 }
 
 func NewHostRepository(
@@ -82,4 +85,17 @@ func (r *hostRepository) GetExpireTime(ctx context.Context, uid int64, hostId in
 	}
 
 	return nextDueDate, nil
+}
+
+// 获取指定到期时间
+func (r *hostRepository) GetAlmostExpired(ctx context.Context, hostId []int,addTime int64) ([]v1.GetAlmostExpireHostResponse, error) {
+	var res []v1.GetAlmostExpireHostResponse
+	expiredTime := time.Now().Unix() + addTime
+	if err := r.DB(ctx).Table("shd_host").
+		Where("id IN ?", hostId).
+		Where("nextduedate < ?", expiredTime).
+		Find(&res).Error; err != nil {
+		return nil, err
+	}
+	return res, nil
 }

+ 93 - 1
internal/task/waf.go

@@ -1,22 +1,114 @@
 package task
 
-import "context"
+import (
+	"context"
+	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
+	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
+	"github.com/go-nunu/nunu-layout-advanced/internal/service"
+	"github.com/hashicorp/go-multierror"
+	"sync"
+)
 
 type WafTask interface {
 }
 
 func NewWafTask (
+	webForWardingRep repository.WebForwardingRepository,
+	tcpforwardingRep repository.TcpforwardingRepository,
+	udpForWardingRep repository.UdpForWardingRepository,
+	cdn service.CdnService,
+	hostRep repository.HostRepository,
 	task *Task,
 	) WafTask{
 	return &wafTask{
 		Task: task,
+		webForWardingRep: webForWardingRep,
+		tcpforwardingRep: tcpforwardingRep,
+		udpForWardingRep: udpForWardingRep,
+		cdn: cdn,
+		hostRep: hostRep,
 	}
 }
 type wafTask struct {
 	*Task
+	webForWardingRep repository.WebForwardingRepository
+	tcpforwardingRep repository.TcpforwardingRepository
+	udpForWardingRep repository.UdpForWardingRepository
+	cdn service.CdnService
+	hostRep repository.HostRepository
 }
 
 func (t wafTask) CheckExpiredTask(ctx context.Context) error {
 	return nil
 
+}
+
+// 获取cdn web id
+func (t wafTask) GetCdnWebId(ctx context.Context,hostId int) ([]int, error) {
+	tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, hostId)
+	if err != nil {
+		return nil, err
+	}
+	udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, hostId)
+	if err != nil {
+		return nil, err
+	}
+	webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, hostId)
+	if err != nil {
+		return nil, err
+	}
+	var ids []int
+	ids = append(ids, tcpIds...)
+	ids = append(ids, udpIds...)
+	ids = append(ids, webIds...)
+	return ids, nil
+}
+
+// 禁用网站
+func (t wafTask) BanServer(ctx context.Context, ids []int, isBan bool) error {
+	var wg sync.WaitGroup
+	errChan := make(chan error, len(ids))
+
+	// 修正1:为每个 goroutine 增加 WaitGroup 的计数
+	wg.Add(len(ids))
+
+	for _, id := range ids {
+		go func(id int) {
+			// 修正2:确保每个 goroutine 在退出时都调用 Done()
+			defer wg.Done()
+
+			err := t.cdn.EditWebIsOn(ctx, int64(id), isBan)
+			if err != nil {
+				errChan <- err
+				// 这里不需要 return,因为 defer wg.Done() 会在函数退出时执行
+			}
+		}(id)
+	}
+
+	// 现在 wg.Wait() 会正确地阻塞,直到所有 goroutine 都调用了 Done()
+	wg.Wait()
+
+	// 在所有 goroutine 都结束后,安全地关闭 channel
+	close(errChan)
+
+	var result error
+	for err := range errChan {
+		result = multierror.Append(result, err)  // 将多个 error 对象合并成一个单一的 error 对象
+	}
+
+	// 修正3:返回收集到的错误,而不是 nil
+	return result
+}
+
+
+
+// 获取到期时间
+func (t wafTask) GetAlmostExpiring(ctx context.Context,hostIds []int) ([]v1.GetAlmostExpireHostResponse,error) {
+	// 3 天
+	res, err := t.hostRep.GetAlmostExpired(ctx, hostIds, 259200)
+	if err != nil {
+		return nil,err
+	}
+
+	return res, nil
 }