|
@@ -1,22 +1,114 @@
|
|
package task
|
|
package task
|
|
|
|
|
|
-import "context"
|
|
|
|
|
|
+import (
|
|
|
|
+ "context"
|
|
|
|
+ v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
|
|
|
|
+ "github.com/go-nunu/nunu-layout-advanced/internal/repository"
|
|
|
|
+ "github.com/go-nunu/nunu-layout-advanced/internal/service"
|
|
|
|
+ "github.com/hashicorp/go-multierror"
|
|
|
|
+ "sync"
|
|
|
|
+)
|
|
|
|
|
|
type WafTask interface {
|
|
type WafTask interface {
|
|
}
|
|
}
|
|
|
|
|
|
func NewWafTask (
|
|
func NewWafTask (
|
|
|
|
+ webForWardingRep repository.WebForwardingRepository,
|
|
|
|
+ tcpforwardingRep repository.TcpforwardingRepository,
|
|
|
|
+ udpForWardingRep repository.UdpForWardingRepository,
|
|
|
|
+ cdn service.CdnService,
|
|
|
|
+ hostRep repository.HostRepository,
|
|
task *Task,
|
|
task *Task,
|
|
) WafTask{
|
|
) WafTask{
|
|
return &wafTask{
|
|
return &wafTask{
|
|
Task: task,
|
|
Task: task,
|
|
|
|
+ webForWardingRep: webForWardingRep,
|
|
|
|
+ tcpforwardingRep: tcpforwardingRep,
|
|
|
|
+ udpForWardingRep: udpForWardingRep,
|
|
|
|
+ cdn: cdn,
|
|
|
|
+ hostRep: hostRep,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
type wafTask struct {
|
|
type wafTask struct {
|
|
*Task
|
|
*Task
|
|
|
|
+ webForWardingRep repository.WebForwardingRepository
|
|
|
|
+ tcpforwardingRep repository.TcpforwardingRepository
|
|
|
|
+ udpForWardingRep repository.UdpForWardingRepository
|
|
|
|
+ cdn service.CdnService
|
|
|
|
+ hostRep repository.HostRepository
|
|
}
|
|
}
|
|
|
|
|
|
func (t wafTask) CheckExpiredTask(ctx context.Context) error {
|
|
func (t wafTask) CheckExpiredTask(ctx context.Context) error {
|
|
return nil
|
|
return nil
|
|
|
|
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 获取cdn web id
|
|
|
|
+func (t wafTask) GetCdnWebId(ctx context.Context,hostId int) ([]int, error) {
|
|
|
|
+ tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, hostId)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, err
|
|
|
|
+ }
|
|
|
|
+ udpIds, err := t.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx, hostId)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, err
|
|
|
|
+ }
|
|
|
|
+ webIds, err := t.webForWardingRep.GetWebForwardingWafWebAllIds(ctx, hostId)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, err
|
|
|
|
+ }
|
|
|
|
+ var ids []int
|
|
|
|
+ ids = append(ids, tcpIds...)
|
|
|
|
+ ids = append(ids, udpIds...)
|
|
|
|
+ ids = append(ids, webIds...)
|
|
|
|
+ return ids, nil
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 禁用网站
|
|
|
|
+func (t wafTask) BanServer(ctx context.Context, ids []int, isBan bool) error {
|
|
|
|
+ var wg sync.WaitGroup
|
|
|
|
+ errChan := make(chan error, len(ids))
|
|
|
|
+
|
|
|
|
+ // 修正1:为每个 goroutine 增加 WaitGroup 的计数
|
|
|
|
+ wg.Add(len(ids))
|
|
|
|
+
|
|
|
|
+ for _, id := range ids {
|
|
|
|
+ go func(id int) {
|
|
|
|
+ // 修正2:确保每个 goroutine 在退出时都调用 Done()
|
|
|
|
+ defer wg.Done()
|
|
|
|
+
|
|
|
|
+ err := t.cdn.EditWebIsOn(ctx, int64(id), isBan)
|
|
|
|
+ if err != nil {
|
|
|
|
+ errChan <- err
|
|
|
|
+ // 这里不需要 return,因为 defer wg.Done() 会在函数退出时执行
|
|
|
|
+ }
|
|
|
|
+ }(id)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // 现在 wg.Wait() 会正确地阻塞,直到所有 goroutine 都调用了 Done()
|
|
|
|
+ wg.Wait()
|
|
|
|
+
|
|
|
|
+ // 在所有 goroutine 都结束后,安全地关闭 channel
|
|
|
|
+ close(errChan)
|
|
|
|
+
|
|
|
|
+ var result error
|
|
|
|
+ for err := range errChan {
|
|
|
|
+ result = multierror.Append(result, err) // 将多个 error 对象合并成一个单一的 error 对象
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // 修正3:返回收集到的错误,而不是 nil
|
|
|
|
+ return result
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+// 获取到期时间
|
|
|
|
+func (t wafTask) GetAlmostExpiring(ctx context.Context,hostIds []int) ([]v1.GetAlmostExpireHostResponse,error) {
|
|
|
|
+ // 3 天
|
|
|
|
+ res, err := t.hostRep.GetAlmostExpired(ctx, hostIds, 259200)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil,err
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return res, nil
|
|
}
|
|
}
|