Sfoglia il codice sorgente

feat(waf): 添加带宽限制功能

- 新增 Bandwidth 和 BandwidthResponse 结构体用于处理带宽限制请求和响应
- 在 AoDunService 接口中添加 AddBandwidthLimit 和 DelBandwidthLimit 方法
- 实现 AddBandwidthLimit 和 DelBandwidthLimit 方法,用于添加和删除带宽限制规则
- 在 GlobalLimitService 中集成带宽限制功能,通过 BuildAudunService 进行操作
- 优化任务处理逻辑,使用 goroutine 并发添加带宽限制
- 更新 wire 配
fusu 1 settimana fa
parent
commit
ef2b4b931c

+ 16 - 0
api/v1/aodun.go

@@ -52,4 +52,20 @@ type IpGetData struct {
 	MongoID           string `json:"_id"`     // "_id" 不符合 Go 命名规范,换个名字
 	ID                int    `json:"id"`      // 这是我们最终需要的目标字段
 	LongIntStartIP    int64  `json:"long_int_start_ip"`
+}
+
+type Bandwidth struct {
+	Action        string `json:"action"`
+	ClientIPType  string `json:"client_ip_type"`
+	Direction     string `json:"direction"`
+	Name          string `json:"name"`
+	Protocol      int64  `json:"protocol"`
+	ServerIPStart string `json:"server_ip_start"`
+	ServerIPType  string `json:"server_ip_type"`
+	SpeedlimitOut int64  `json:"speedlimit_out"`
+}
+
+type BandwidthResponse struct {
+	Msg         string `json:"msg"`
+	Err         int `json:"err"`
 }

+ 1 - 0
cmd/server/wire/wire.go

@@ -98,6 +98,7 @@ var serviceSet = wire.NewSet(
 	waf.NewGatewayipService,
 	waf.NewCcIpListService,
 	waf.NewCdnLogService,
+	waf.NewBuildAudunService,
 )
 
 var handlerSet = wire.NewSet(

+ 3 - 2
cmd/server/wire/wire_gen.go

@@ -98,7 +98,8 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	udpForWardingHandler := waf3.NewUdpForWardingHandler(handlerHandler, udpForWardingService)
 	allowAndDenyIpRepository := waf.NewAllowAndDenyIpRepository(repositoryRepository)
 	allowAndDenyIpService := waf2.NewAllowAndDenyIpService(serviceService, allowAndDenyIpRepository, wafFormatterService, gatewayipService)
-	globalLimitService := waf2.NewGlobalLimitService(serviceService, globalLimitRepository, duedateService, crawlerService, viperViper, requiredService, parserService, hostService, hostRepository, cdnService, cdnRepository, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, allowAndDenyIpService, allowAndDenyIpRepository, tcpforwardingService, udpForWardingService, webForwardingService, gatewayipRepository, gatewayipService)
+	buildAudunService := waf2.NewBuildAudunService(serviceService, aoDunService)
+	globalLimitService := waf2.NewGlobalLimitService(serviceService, globalLimitRepository, duedateService, crawlerService, viperViper, requiredService, parserService, hostService, hostRepository, cdnService, cdnRepository, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, allowAndDenyIpService, allowAndDenyIpRepository, tcpforwardingService, udpForWardingService, webForwardingService, gatewayipRepository, gatewayipService, buildAudunService)
 	globalLimitHandler := waf3.NewGlobalLimitHandler(handlerHandler, globalLimitService)
 	adminRepository := admin.NewAdminRepository(repositoryRepository)
 	adminService := admin2.NewAdminService(serviceService, adminRepository)
@@ -122,7 +123,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewCasbinEnforcer, repository.NewMongoClient, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, admin.NewAdminRepository, admin.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldPublicIpRepository, waf.NewWebForwardingRepository, waf.NewTcpforwardingRepository, waf.NewUdpForWardingRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldSdkIpRepository, repository.NewHostRepository, waf.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, flexCdn.NewCdnRepository, waf.NewAllowAndDenyIpRepository, flexCdn.NewProxyRepository, flexCdn.NewCcRepository, repository.NewExpiredRepository, repository.NewLogRepository, waf.NewGatewayipRepository, admin.NewGatewayIpAdminRepository, flexCdn.NewCcIpListRepository)
 
-var serviceSet = wire.NewSet(service.NewService, admin2.NewUserService, admin2.NewGatewayIpAdminService, admin2.NewAdminService, gameShield.NewGameShieldService, service.NewAoDunService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewCrawlerService, waf2.NewWebForwardingService, waf2.NewTcpforwardingService, waf2.NewUdpForWardingService, service.NewGameShieldUserIpService, gameShield.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewHostService, waf2.NewGlobalLimitService, service.NewGatewayGroupService, waf2.NewWafFormatterService, service.NewGateWayGroupIpService, service.NewRequestService, flexCdn2.NewCdnService, waf2.NewAllowAndDenyIpService, flexCdn2.NewProxyService, flexCdn2.NewSslCertService, flexCdn2.NewWebsocketService, waf2.NewCcService, service.NewLogService, waf2.NewGatewayipService, waf2.NewCcIpListService, waf2.NewCdnLogService)
+var serviceSet = wire.NewSet(service.NewService, admin2.NewUserService, admin2.NewGatewayIpAdminService, admin2.NewAdminService, gameShield.NewGameShieldService, service.NewAoDunService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewCrawlerService, waf2.NewWebForwardingService, waf2.NewTcpforwardingService, waf2.NewUdpForWardingService, service.NewGameShieldUserIpService, gameShield.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewHostService, waf2.NewGlobalLimitService, service.NewGatewayGroupService, waf2.NewWafFormatterService, service.NewGateWayGroupIpService, service.NewRequestService, flexCdn2.NewCdnService, waf2.NewAllowAndDenyIpService, flexCdn2.NewProxyService, flexCdn2.NewSslCertService, flexCdn2.NewWebsocketService, waf2.NewCcService, service.NewLogService, waf2.NewGatewayipService, waf2.NewCcIpListService, waf2.NewCdnLogService, waf2.NewBuildAudunService)
 
 var handlerSet = wire.NewSet(handler.NewHandler, admin3.NewUserHandler, admin3.NewAdminHandler, admin3.NewGatewayIpAdminHandler, handler.NewGameShieldHandler, handler.NewGameShieldPublicIpHandler, waf3.NewWebForwardingHandler, waf3.NewTcpforwardingHandler, waf3.NewUdpForWardingHandler, handler.NewGameShieldUserIpHandler, handler.NewGameShieldBackendHandler, handler.NewGameShieldSdkIpHandler, handler.NewHostHandler, waf3.NewGlobalLimitHandler, handler.NewGatewayGroupHandler, handler.NewGateWayGroupIpHandler, waf3.NewAllowAndDenyIpHandler, waf3.NewCcHandler, waf3.NewGatewayipHandler, waf3.NewCcIpListHandler, waf3.NewCdnLogHandler)
 

+ 1 - 0
cmd/task/wire/wire.go

@@ -98,6 +98,7 @@ var serviceSet = wire.NewSet(
 	waf.NewGatewayipService,
 	service.NewLogService,
 	waf.NewCcIpListService,
+	waf.NewBuildAudunService,
 )
 
 // build App

+ 1 - 1
cmd/task/wire/wire_gen.go

@@ -107,7 +107,7 @@ var jobSet = wire.NewSet(job.NewJob, job.NewUserJob, job.NewWhitelistJob)
 
 var serverSet = wire.NewSet(server.NewTaskServer, server.NewJobServer)
 
-var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, gameShield.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewHostService, gameShield.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewGameShieldUserIpService, waf2.NewWafFormatterService, flexCdn2.NewCdnService, service.NewRequestService, waf2.NewTcpforwardingService, waf2.NewUdpForWardingService, waf2.NewWebForwardingService, flexCdn2.NewProxyService, flexCdn2.NewSslCertService, flexCdn2.NewWebsocketService, waf2.NewCcService, waf2.NewGatewayipService, service.NewLogService, waf2.NewCcIpListService)
+var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, gameShield.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewHostService, gameShield.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewGameShieldUserIpService, waf2.NewWafFormatterService, flexCdn2.NewCdnService, service.NewRequestService, waf2.NewTcpforwardingService, waf2.NewUdpForWardingService, waf2.NewWebForwardingService, flexCdn2.NewProxyService, flexCdn2.NewSslCertService, flexCdn2.NewWebsocketService, waf2.NewCcService, waf2.NewGatewayipService, service.NewLogService, waf2.NewCcIpListService, waf2.NewBuildAudunService)
 
 // build App
 func newApp(task2 *server.TaskServer,

+ 41 - 0
internal/service/aodun.go

@@ -23,6 +23,8 @@ type AoDunService interface {
 	AddWhiteStaticList(ctx context.Context, isSmall bool, req []v1.IpInfo, color string) error
 	DelWhiteStaticList(ctx context.Context, isSmall bool, id string, color string) error
 	GetWhiteStaticList(ctx context.Context, isSmall bool, ip string,serverIp string, color string) (int, error)
+	AddBandwidthLimit(ctx context.Context, req v1.Bandwidth) error
+	DelBandwidthLimit(ctx context.Context, req v1.Bandwidth) error
 }
 
 // aoDunService 是 AoDunService 接口的实现
@@ -326,3 +328,42 @@ func (s *aoDunService) DomainWhiteList(ctx context.Context, domain, ip, apiType
 
 	return nil
 }
+
+// AddBandwidthLimit 添加带宽限制
+func (s *aoDunService) AddBandwidthLimit(ctx context.Context, req v1.Bandwidth) error {
+	var res v1.BandwidthResponse
+	formData := map[string]interface{}{
+		"server_ip_type": req.ServerIPType,
+		"server_ip_start": req.ServerIPStart,
+		"name": req.Name,
+		"speedlimit_out": req.SpeedlimitOut,
+		"client_ip_type": req.ClientIPType,
+		"action": req.Action,
+		"direction": req.Direction,
+		"protocol": req.Protocol,
+	}
+	err := s.sendAuthenticatedRequest(ctx, true, "v1.0/firewall/add_filter_rule", formData, &res)
+	if err != nil {
+		return err
+	}
+	if res.Err != 0 {
+		return fmt.Errorf("API 错误: code %d, msg '%s'", res.Err, res.Msg)
+	}
+	return nil
+}
+
+// DelBandwidthLimit 删除带宽限制
+func (s *aoDunService) DelBandwidthLimit(ctx context.Context, req v1.Bandwidth) error {
+	var res v1.BandwidthResponse
+	formData := map[string]interface{}{
+		"name": req.Name,
+	}
+	err := s.sendAuthenticatedRequest(ctx, true, "v1.0/firewall/delete_filter_rule", formData, &res)
+	if err != nil {
+		return err
+	}
+	if res.Err != 0 {
+		return fmt.Errorf("API 错误: code %d, msg '%s'", res.Err, res.Msg)
+	}
+	return nil
+}

+ 38 - 9
internal/service/api/waf/globallimit.go

@@ -11,12 +11,14 @@ import (
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn"
+	"github.com/hashicorp/go-multierror"
 	"github.com/mozillazg/go-pinyin"
 	"github.com/spf13/viper"
 	"golang.org/x/sync/errgroup"
 	"gorm.io/gorm"
 	"strconv"
 	"strings"
+	"sync"
 	"time"
 )
 
@@ -50,6 +52,7 @@ func NewGlobalLimitService(
 	webForWarding WebForwardingService,
 	gatewayIpRep waf.GatewayipRepository,
 	gatywayIp GatewayipService,
+	bulidAudun BuildAudunService,
 ) GlobalLimitService {
 	return &globalLimitService{
 		Service:               service,
@@ -73,6 +76,7 @@ func NewGlobalLimitService(
 		webForWarding:         webForWarding,
 		gatewayIpRep:             gatewayIpRep,
 		gatewayIp: 				gatywayIp,
+		bulidAudun: bulidAudun,
 	}
 }
 
@@ -98,6 +102,7 @@ type globalLimitService struct {
 	webForWarding         WebForwardingService
 	gatewayIpRep          waf.GatewayipRepository
 	gatewayIp             GatewayipService
+	bulidAudun             BuildAudunService
 }
 
 func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
@@ -275,15 +280,39 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 
 
 
-	// 获取套餐ID
-	//maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G")
-	//if maxProtection == "" {
-	//	return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection)
-	//}
-	//maxProtectionInt, err := strconv.Atoi(maxProtection)
-	//if err != nil {
-	//	return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection)
-	//}
+	// 添加带宽限制
+	ip, err := s.gatewayIpRep.GetGatewayipOnlyIpByHostIdAll(ctx, int64(req.HostId))
+	if err != nil {
+		return err
+	}
+	var wg sync.WaitGroup
+	wg.Add(len(ip))
+	var errChan = make(chan error, len(ip))
+	if ip != nil {
+		for _, v := range ip {
+			go func(v string) {
+				defer wg.Done()
+				err := s.bulidAudun.AddBandwidth(ctx, v1.Bandwidth{
+					Name: require.Bps,
+					ServerIPStart: v,
+				})
+				if err != nil {
+					errChan <- err
+				}
+			}(v)
+		}
+
+		wg.Wait()
+		close(errChan)
+		var allErrors error
+		for err := range errChan {
+			allErrors = multierror.Append(allErrors, err)
+		}
+		if allErrors != nil {
+			return allErrors
+		}
+
+	}
 
 
 

+ 3 - 1
internal/task/waf.go

@@ -164,14 +164,16 @@ func (t *wafTask) executeRenewalActions(ctx context.Context, reqs []RenewalReque
 	var allErrors *multierror.Error
 	var wg sync.WaitGroup
 	wg.Add(len(reqs))
-
+	var mu sync.Mutex
 	for _, req := range reqs {
 		go func(r RenewalRequest) {
 			defer wg.Done()
 			// 更新数据库状态
 			err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{HostId: r.HostId, ExpiredAt: r.ExpiredAt, State: true})
 			if err != nil {
+				mu.Lock() // 在修改前加锁
 				allErrors = multierror.Append(allErrors, err)
+				mu.Unlock() // 修改后解锁
 				return // 如果DB更新失败,不继续调用CDN API
 			}
 		}(req)