Browse Source

refactor(waf): 重构带宽限制功能

- 新增 BuildAudunService 接口和 buildAudunService 实现类
- 添加 Bandwidth 方法统一处理带宽限制的添加和删除
- 使用 goroutine 并发处理多个 IP 地址的带宽限制- 优化错误处理,使用 multierror.Append 汇总所有错误
- 移除冗余代码,提高代码可读性和维护性
fusu 1 week ago
parent
commit
7d6b0e0c42
3 changed files with 138 additions and 36 deletions
  1. 127 0
      internal/service/api/waf/buildaudun.go
  2. 1 36
      internal/service/api/waf/globallimit.go
  3. 10 0
      internal/task/waf.go

+ 127 - 0
internal/service/api/waf/buildaudun.go

@@ -0,0 +1,127 @@
+package waf
+
+import (
+	"context"
+	"fmt"
+	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
+	wafRep "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
+	"github.com/go-nunu/nunu-layout-advanced/internal/service"
+	"github.com/hashicorp/go-multierror"
+	"strconv"
+	"sync"
+)
+
+type BuildAudunService interface {
+	AddBandwidth(ctx context.Context, req v1.Bandwidth) error
+	DelBandwidth(ctx context.Context, req v1.Bandwidth) error
+	Bandwidth(ctx context.Context,hostId int64, action string) error
+}
+func NewBuildAudunService(
+    service *service.Service,
+	audun   service.AoDunService,
+	gatewayIpRep wafRep.GatewayipRepository,
+	host service.HostService,
+) BuildAudunService {
+	return &buildAudunService{
+		Service:        service,
+		audun:   		audun,
+		gatewayIpRep:   gatewayIpRep,
+		host:           host,
+
+	}
+}
+
+type buildAudunService struct {
+	*service.Service
+	audun service.AoDunService
+	gatewayIpRep wafRep.GatewayipRepository
+	host service.HostService
+}
+
+
+func (s *buildAudunService) BuildName(ip string, bandwidth string, apiName string) string {
+	return apiName + ip + "限速" + bandwidth + "M"
+}
+
+func (s *buildAudunService) AddBandwidth(ctx context.Context, req v1.Bandwidth) error {
+	err := s.audun.AddBandwidthLimit(ctx, v1.Bandwidth{
+		Action:        "limit",
+		ClientIPType:  "all",
+		Direction:     "out",
+		Name:          s.BuildName(req.ServerIPStart, strconv.FormatInt(req.SpeedlimitOut, 10), ""),
+		Protocol:      0,
+		ServerIPStart: req.ServerIPStart,
+		ServerIPType:  "single",
+		SpeedlimitOut: req.SpeedlimitOut,
+	})
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func (s *buildAudunService) DelBandwidth(ctx context.Context, req v1.Bandwidth) error {
+	err := s.audun.DelBandwidthLimit(ctx, v1.Bandwidth{
+		Name: s.BuildName(req.ServerIPStart, req.ServerIPStart, "KFW-API-RESTAPI-"),
+	})
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func (s *buildAudunService) Bandwidth(ctx context.Context,hostId int64, action string) error {
+	ips, err := s.gatewayIpRep.GetGatewayipOnlyIpByHostIdAll(ctx, hostId)
+	if err != nil {
+		return err
+	}
+	if len(ips) == 0 {
+		return nil
+	}
+	config, err := s.host.GetGlobalLimitConfig(ctx, int(hostId))
+	if err != nil {
+		return err
+	}
+	bpsInt, err := strconv.Atoi(config.Bps)
+	if err != nil {
+		return err
+	}
+	var errChan = make(chan error, len(ips))
+	var wg sync.WaitGroup
+	var allErrors error
+	wg.Add(len(ips))
+
+	for _, ip := range ips {
+		go func(ip string) {
+			var e error
+			defer wg.Done()
+			switch action {
+			case "add":
+				e = s.AddBandwidth(ctx, v1.Bandwidth{
+					ServerIPStart: ip,
+					SpeedlimitOut: int64(bpsInt),
+				})
+			case "del":
+				e = s.DelBandwidth(ctx,v1.Bandwidth{
+					Name:          ip,
+					SpeedlimitOut: int64(bpsInt),
+				})
+			default:
+				e = fmt.Errorf("未知操作")
+			}
+			if e != nil {
+				errChan <- fmt.Errorf("清除ip %s失败: %w", ip, e)
+			}
+
+		}(ip)
+	}
+	wg.Wait()
+	close(errChan)
+	for err := range errChan {
+		allErrors = multierror.Append(allErrors, err)
+	}
+	if allErrors != nil {
+		return allErrors
+	}
+	return nil
+}

+ 1 - 36
internal/service/api/waf/globallimit.go

@@ -11,14 +11,12 @@ import (
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn"
-	"github.com/hashicorp/go-multierror"
 	"github.com/mozillazg/go-pinyin"
 	"github.com/spf13/viper"
 	"golang.org/x/sync/errgroup"
 	"gorm.io/gorm"
 	"strconv"
 	"strings"
-	"sync"
 	"time"
 )
 
@@ -276,43 +274,10 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 
 
 	// 添加带宽限制
-	ip, err := s.gatewayIpRep.GetGatewayipOnlyIpByHostIdAll(ctx, int64(req.HostId))
+	err = s.bulidAudun.Bandwidth(ctx, int64(req.HostId), "add")
 	if err != nil {
 		return err
 	}
-	bpsInt, err := strconv.Atoi(require.Bps)
-	if err != nil {
-		return err
-	}
-	var wg sync.WaitGroup
-	wg.Add(len(ip))
-	var errChan = make(chan error, len(ip))
-	if ip != nil {
-		for _, v := range ip {
-			go func(v string) {
-				defer wg.Done()
-				err := s.bulidAudun.AddBandwidth(ctx, v1.Bandwidth{
-					Name: require.Bps,
-					ServerIPStart: v,
-					SpeedlimitOut: int64(bpsInt),
-				})
-				if err != nil {
-					errChan <- err
-				}
-			}(v)
-		}
-
-		wg.Wait()
-		close(errChan)
-		var allErrors error
-		for err := range errChan {
-			allErrors = multierror.Append(allErrors, err)
-		}
-		if allErrors != nil {
-			return allErrors
-		}
-
-	}
 
 
 

+ 10 - 0
internal/task/waf.go

@@ -46,6 +46,7 @@ func NewWafTask(
 	tcp waf.TcpforwardingService,
 	udp waf.UdpForWardingService,
 	web waf.WebForwardingService,
+	buildAoDun waf.BuildAudunService,
 ) WafTask {
 	return &wafTask{
 		Task:              task,
@@ -60,6 +61,7 @@ func NewWafTask(
 		tcp:               tcp,
 		udp:               udp,
 		web:               web,
+		buildAoDun:        buildAoDun,
 	}
 }
 
@@ -76,6 +78,7 @@ type wafTask struct {
 	tcp              waf.TcpforwardingService
 	udp waf.UdpForWardingService
 	web waf.WebForwardingService
+	buildAoDun waf.BuildAudunService
 }
 
 const (
@@ -535,6 +538,13 @@ func (t *wafTask) executeSinglePlanCleanup(ctx context.Context, limit model.Glob
 	}
 
 
+	// 清除小防火墙带宽限制
+	if err := t.buildAoDun.Bandwidth(ctx, int64(limit.HostId), "del"); err != nil {
+		allErrors = multierror.Append(allErrors, err)
+	}
+
+
+
 	// 只有在上述所有步骤都没有出错的情况下,才执行最终的数据库更新和Redis标记
 	if allErrors.ErrorOrNil() == nil {
 		err := t.gatewayIpRep.CleanIPByHostId(ctx, []int64{hostId})