Browse Source

feat(waf): 添加设置防御带宽功能

- 新增 SetDefense 结构体用于设置防御带宽请求
- 在 AoDunService 接口中添加 SetDefense 方法- 实现 SetDefense 方法,通过 HTTP 请求设置防御带宽
- 在 GlobalLimitService 和 WafTask 中集成 SetDefense 功能
- 新增 ZzybgpService接口和实现,用于设置防御带宽
- 更新 wire 配置,注入新的 ZzybgpService 依赖
fusu 1 week ago
parent
commit
81a7ae5a60

+ 5 - 0
api/v1/aodun.go

@@ -68,4 +68,9 @@ type Bandwidth struct {
 type BandwidthResponse struct {
 	Msg         string `json:"msg"`
 	Err         int `json:"err"`
+}
+
+type SetDefense struct {
+	IpAddr string `json:"ip_addr"`
+	Defense int `json:"defense"`
 }

+ 1 - 0
cmd/server/wire/wire.go

@@ -99,6 +99,7 @@ var serviceSet = wire.NewSet(
 	waf.NewCcIpListService,
 	waf.NewCdnLogService,
 	waf.NewBuildAudunService,
+	waf.NewZzybgpService,
 )
 
 var handlerSet = wire.NewSet(

+ 4 - 3
cmd/server/wire/wire_gen.go

@@ -81,7 +81,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	cdnRepository := flexCdn.NewCdnRepository(repositoryRepository)
 	cdnService := flexCdn2.NewCdnService(serviceService, viperViper, requestService, cdnRepository)
 	wafFormatterService := waf2.NewWafFormatterService(serviceService, globalLimitRepository, hostRepository, requiredService, parserService, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, rabbitMQ, hostService, gatewayipRepository, gatewayipService, cdnService, cdnRepository)
-	aoDunService := service.NewAoDunService(serviceService, viperViper)
+	aoDunService := service.NewAoDunService(serviceService, viperViper, requestService)
 	proxyRepository := flexCdn.NewProxyRepository(repositoryRepository)
 	proxyService := flexCdn2.NewProxyService(serviceService, proxyRepository, cdnService)
 	sslCertService := flexCdn2.NewSslCertService(serviceService, webForwardingRepository, cdnService)
@@ -99,7 +99,8 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	allowAndDenyIpRepository := waf.NewAllowAndDenyIpRepository(repositoryRepository)
 	allowAndDenyIpService := waf2.NewAllowAndDenyIpService(serviceService, allowAndDenyIpRepository, wafFormatterService, gatewayipService)
 	buildAudunService := waf2.NewBuildAudunService(serviceService, aoDunService, gatewayipRepository, hostService)
-	globalLimitService := waf2.NewGlobalLimitService(serviceService, globalLimitRepository, duedateService, crawlerService, viperViper, requiredService, parserService, hostService, hostRepository, cdnService, cdnRepository, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, allowAndDenyIpService, allowAndDenyIpRepository, tcpforwardingService, udpForWardingService, webForwardingService, gatewayipRepository, gatewayipService, buildAudunService)
+	zzybgpService := waf2.NewZzybgpService(serviceService, gatewayipRepository, hostService, aoDunService)
+	globalLimitService := waf2.NewGlobalLimitService(serviceService, globalLimitRepository, duedateService, crawlerService, viperViper, requiredService, parserService, hostService, hostRepository, cdnService, cdnRepository, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, allowAndDenyIpService, allowAndDenyIpRepository, tcpforwardingService, udpForWardingService, webForwardingService, gatewayipRepository, gatewayipService, buildAudunService, zzybgpService)
 	globalLimitHandler := waf3.NewGlobalLimitHandler(handlerHandler, globalLimitService)
 	adminRepository := admin.NewAdminRepository(repositoryRepository)
 	adminService := admin2.NewAdminService(serviceService, adminRepository)
@@ -123,7 +124,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewCasbinEnforcer, repository.NewMongoClient, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, admin.NewAdminRepository, admin.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldPublicIpRepository, waf.NewWebForwardingRepository, waf.NewTcpforwardingRepository, waf.NewUdpForWardingRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldSdkIpRepository, repository.NewHostRepository, waf.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, flexCdn.NewCdnRepository, waf.NewAllowAndDenyIpRepository, flexCdn.NewProxyRepository, flexCdn.NewCcRepository, repository.NewExpiredRepository, repository.NewLogRepository, waf.NewGatewayipRepository, admin.NewGatewayIpAdminRepository, flexCdn.NewCcIpListRepository)
 
-var serviceSet = wire.NewSet(service.NewService, admin2.NewUserService, admin2.NewGatewayIpAdminService, admin2.NewAdminService, gameShield.NewGameShieldService, service.NewAoDunService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewCrawlerService, waf2.NewWebForwardingService, waf2.NewTcpforwardingService, waf2.NewUdpForWardingService, service.NewGameShieldUserIpService, gameShield.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewHostService, waf2.NewGlobalLimitService, service.NewGatewayGroupService, waf2.NewWafFormatterService, service.NewGateWayGroupIpService, service.NewRequestService, flexCdn2.NewCdnService, waf2.NewAllowAndDenyIpService, flexCdn2.NewProxyService, flexCdn2.NewSslCertService, flexCdn2.NewWebsocketService, waf2.NewCcService, service.NewLogService, waf2.NewGatewayipService, waf2.NewCcIpListService, waf2.NewCdnLogService, waf2.NewBuildAudunService)
+var serviceSet = wire.NewSet(service.NewService, admin2.NewUserService, admin2.NewGatewayIpAdminService, admin2.NewAdminService, gameShield.NewGameShieldService, service.NewAoDunService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewCrawlerService, waf2.NewWebForwardingService, waf2.NewTcpforwardingService, waf2.NewUdpForWardingService, service.NewGameShieldUserIpService, gameShield.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewHostService, waf2.NewGlobalLimitService, service.NewGatewayGroupService, waf2.NewWafFormatterService, service.NewGateWayGroupIpService, service.NewRequestService, flexCdn2.NewCdnService, waf2.NewAllowAndDenyIpService, flexCdn2.NewProxyService, flexCdn2.NewSslCertService, flexCdn2.NewWebsocketService, waf2.NewCcService, service.NewLogService, waf2.NewGatewayipService, waf2.NewCcIpListService, waf2.NewCdnLogService, waf2.NewBuildAudunService, waf2.NewZzybgpService)
 
 var handlerSet = wire.NewSet(handler.NewHandler, admin3.NewUserHandler, admin3.NewAdminHandler, admin3.NewGatewayIpAdminHandler, handler.NewGameShieldHandler, handler.NewGameShieldPublicIpHandler, waf3.NewWebForwardingHandler, waf3.NewTcpforwardingHandler, waf3.NewUdpForWardingHandler, handler.NewGameShieldUserIpHandler, handler.NewGameShieldBackendHandler, handler.NewGameShieldSdkIpHandler, handler.NewHostHandler, waf3.NewGlobalLimitHandler, handler.NewGatewayGroupHandler, handler.NewGateWayGroupIpHandler, waf3.NewAllowAndDenyIpHandler, waf3.NewCcHandler, waf3.NewGatewayipHandler, waf3.NewCcIpListHandler, waf3.NewCdnLogHandler)
 

+ 1 - 0
cmd/task/wire/wire.go

@@ -99,6 +99,7 @@ var serviceSet = wire.NewSet(
 	service.NewLogService,
 	waf.NewCcIpListService,
 	waf.NewBuildAudunService,
+	waf.NewZzybgpService,
 )
 
 // build App

+ 4 - 3
cmd/task/wire/wire_gen.go

@@ -77,7 +77,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	proxyService := flexCdn2.NewProxyService(serviceService, proxyRepository, cdnService)
 	tcpforwardingService := waf2.NewTcpforwardingService(serviceService, tcpforwardingRepository, parserService, requiredService, crawlerService, globalLimitRepository, hostRepository, wafFormatterService, cdnService, proxyService)
 	udpForWardingService := waf2.NewUdpForWardingService(serviceService, udpForWardingRepository, requiredService, parserService, crawlerService, globalLimitRepository, hostRepository, wafFormatterService, cdnService, proxyService)
-	aoDunService := service.NewAoDunService(serviceService, viperViper)
+	aoDunService := service.NewAoDunService(serviceService, viperViper, requestService)
 	sslCertService := flexCdn2.NewSslCertService(serviceService, webForwardingRepository, cdnService)
 	websocketService := flexCdn2.NewWebsocketService(serviceService, cdnService, webForwardingRepository)
 	ccRepository := flexCdn.NewCcRepository(repositoryRepository)
@@ -86,7 +86,8 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	ccService := waf2.NewCcService(serviceService, ccRepository, webForwardingRepository, cdnService, ccIpListService)
 	webForwardingService := waf2.NewWebForwardingService(serviceService, requiredService, webForwardingRepository, crawlerService, parserService, wafFormatterService, aoDunService, rabbitMQ, gatewayipService, globalLimitRepository, cdnService, proxyService, sslCertService, websocketService, ccService, ccIpListService)
 	buildAudunService := waf2.NewBuildAudunService(serviceService, aoDunService, gatewayipRepository, hostService)
-	wafTask := task.NewWafTask(webForwardingRepository, tcpforwardingRepository, udpForWardingRepository, cdnService, hostRepository, globalLimitRepository, expiredRepository, taskTask, gatewayipRepository, tcpforwardingService, udpForWardingService, webForwardingService, buildAudunService)
+	zzybgpService := waf2.NewZzybgpService(serviceService, gatewayipRepository, hostService, aoDunService)
+	wafTask := task.NewWafTask(webForwardingRepository, tcpforwardingRepository, udpForWardingRepository, cdnService, hostRepository, globalLimitRepository, expiredRepository, taskTask, gatewayipRepository, tcpforwardingService, udpForWardingService, webForwardingService, buildAudunService, zzybgpService)
 	taskServer := server.NewTaskServer(logger, userTask, gameShieldTask, wafTask)
 	jobJob := job.NewJob(transaction, logger, sidSid, rabbitMQ)
 	userJob := job.NewUserJob(jobJob, userRepository)
@@ -108,7 +109,7 @@ var jobSet = wire.NewSet(job.NewJob, job.NewUserJob, job.NewWhitelistJob)
 
 var serverSet = wire.NewSet(server.NewTaskServer, server.NewJobServer)
 
-var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, gameShield.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewHostService, gameShield.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewGameShieldUserIpService, waf2.NewWafFormatterService, flexCdn2.NewCdnService, service.NewRequestService, waf2.NewTcpforwardingService, waf2.NewUdpForWardingService, waf2.NewWebForwardingService, flexCdn2.NewProxyService, flexCdn2.NewSslCertService, flexCdn2.NewWebsocketService, waf2.NewCcService, waf2.NewGatewayipService, service.NewLogService, waf2.NewCcIpListService, waf2.NewBuildAudunService)
+var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, gameShield.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewHostService, gameShield.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewGameShieldUserIpService, waf2.NewWafFormatterService, flexCdn2.NewCdnService, service.NewRequestService, waf2.NewTcpforwardingService, waf2.NewUdpForWardingService, waf2.NewWebForwardingService, flexCdn2.NewProxyService, flexCdn2.NewSslCertService, flexCdn2.NewWebsocketService, waf2.NewCcService, waf2.NewGatewayipService, service.NewLogService, waf2.NewCcIpListService, waf2.NewBuildAudunService, waf2.NewZzybgpService)
 
 // build App
 func newApp(task2 *server.TaskServer,

+ 40 - 1
internal/service/aodun.go

@@ -19,12 +19,20 @@ import (
 
 // AoDunService 定义了与傲盾 API 交互的服务接口
 type AoDunService interface {
+	// 添加域名到白名单
 	DomainWhiteList(ctx context.Context, domain string, ip string, apiType string) error
+	// 添加 IP 到静态白名单
 	AddWhiteStaticList(ctx context.Context, isSmall bool, req []v1.IpInfo, color string) error
+	// 根据 ID 从白名单中删除 IP
 	DelWhiteStaticList(ctx context.Context, isSmall bool, id string, color string) error
+	// 查询白名单 IP
 	GetWhiteStaticList(ctx context.Context, isSmall bool, ip string,serverIp string, color string) (int, error)
+	// 添加带宽限制
 	AddBandwidthLimit(ctx context.Context, req v1.Bandwidth) error
+	// 删除带宽限制
 	DelBandwidthLimit(ctx context.Context, req v1.Bandwidth) error
+	// 设置防御带宽
+	SetDefense(ctx context.Context, req v1.SetDefense) error
 }
 
 // aoDunService 是 AoDunService 接口的实现
@@ -32,6 +40,7 @@ type aoDunService struct {
 	*Service
 	cfg        *aoDunConfig
 	httpClient *http.Client
+	request    RequestService
 }
 
 // aoDunConfig 用于整合来自 viper 的所有配置
@@ -47,7 +56,11 @@ type aoDunConfig struct {
 }
 
 // NewAoDunService 创建一个新的 AoDunService 实例
-func NewAoDunService(service *Service, conf *viper.Viper) AoDunService {
+func NewAoDunService(
+	service *Service,
+	conf *viper.Viper,
+	request RequestService,
+	) AoDunService {
 	cfg := &aoDunConfig{
 		Url:            conf.GetString("aodun.Url"),
 		ClientID:       conf.GetString("aodun.clientID"),
@@ -75,6 +88,7 @@ func NewAoDunService(service *Service, conf *viper.Viper) AoDunService {
 		Service:    service,
 		cfg:        cfg,
 		httpClient: client,
+		request:    request,
 	}
 }
 
@@ -372,4 +386,29 @@ func (s *aoDunService) DelBandwidthLimit(ctx context.Context, req v1.Bandwidth)
 		return fmt.Errorf("API 错误: code %d, msg '%s'", res.Err, res.Msg)
 	}
 	return nil
+}
+
+// 设置防御带宽
+func (s *aoDunService) SetDefense(ctx context.Context, req v1.SetDefense) error {
+	formData := map[string]interface{}{
+		"ip_addr": req.IpAddr,
+		"defense": req.Defense,
+		"username":  s.cfg.DomainUsername,
+		"password":  s.cfg.DomainPassword,
+	}
+	resBody, err := s.request.Request(ctx,formData, "http://zapi.zzybgp.com/api/set_defense", "", "")
+	if err != nil {
+		return err
+	}
+	var res struct {
+		Code int `json:"code"`
+		Msg string `json:"msg"`
+	}
+	if err := json.Unmarshal(resBody, &res); err != nil {
+		return fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
+	}
+	if res.Code != 200 {
+		return fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Msg)
+	}
+	return nil
 }

+ 1 - 1
internal/service/api/waf/buildaudun.go

@@ -110,7 +110,7 @@ func (s *buildAudunService) Bandwidth(ctx context.Context,hostId int64, action s
 				e = fmt.Errorf("未知操作")
 			}
 			if e != nil {
-				errChan <- fmt.Errorf("清除ip %s失败: %w", ip, e)
+				errChan <- fmt.Errorf("设置限速 %s失败: %w", ip, e)
 			}
 
 		}(ip)

+ 14 - 1
internal/service/api/waf/globallimit.go

@@ -51,6 +51,7 @@ func NewGlobalLimitService(
 	gatewayIpRep waf.GatewayipRepository,
 	gatywayIp GatewayipService,
 	bulidAudun BuildAudunService,
+	zzyBgp ZzybgpService,
 ) GlobalLimitService {
 	return &globalLimitService{
 		Service:               service,
@@ -75,6 +76,7 @@ func NewGlobalLimitService(
 		gatewayIpRep:             gatewayIpRep,
 		gatewayIp: 				gatywayIp,
 		bulidAudun: bulidAudun,
+		zzyBgp: zzyBgp,
 	}
 }
 
@@ -101,6 +103,7 @@ type globalLimitService struct {
 	gatewayIpRep          waf.GatewayipRepository
 	gatewayIp             GatewayipService
 	bulidAudun             BuildAudunService
+	zzyBgp                ZzybgpService
 }
 
 func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
@@ -271,7 +274,11 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 	}
 
 
-
+	// 添加防护
+	err = s.zzyBgp.SetDefense(ctx, int64(req.HostId), 0)
+	if err != nil {
+		return err
+	}
 
 	// 添加带宽限制
 	err = s.bulidAudun.Bandwidth(ctx, int64(req.HostId), "add")
@@ -400,6 +407,12 @@ func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.Globa
 		return err
 	}
 
+	// 重置防护
+	err = s.zzyBgp.SetDefense(ctx, int64(req.HostId), 10)
+	if err != nil {
+		return err
+	}
+
 	// 删除带宽限制
 	err = s.bulidAudun.Bandwidth(ctx, int64(req.HostId), "del")
 	if err != nil {

+ 117 - 0
internal/service/api/waf/zzybgp.go

@@ -0,0 +1,117 @@
+package waf
+
+import (
+	"context"
+	"fmt"
+	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
+	wafRep "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
+	"github.com/go-nunu/nunu-layout-advanced/internal/service"
+	"strconv"
+	"strings"
+	"sync"
+)
+
+type ZzybgpService interface {
+	SetDefense(ctx context.Context, hostId int64,defense int) error
+}
+func NewZzybgpService(
+    service *service.Service,
+	gatewayIpRep wafRep.GatewayipRepository,
+	host service.HostService,
+	aoDun service.AoDunService,
+) ZzybgpService {
+	return &zzybgpService{
+		Service:        service,
+		gatewayIpRep:   gatewayIpRep,
+		host: host,
+		aoDun: aoDun,
+	}
+}
+
+type zzybgpService struct {
+	*service.Service
+	gatewayIpRep wafRep.GatewayipRepository
+	host service.HostService
+	aoDun service.AoDunService
+}
+
+func (s *zzybgpService) SetDefense(ctx context.Context, hostId int64, defense int) error {
+
+	ips, err := s.gatewayIpRep.GetGatewayipOnlyIpByHostIdAll(ctx, hostId)
+	if err != nil {
+		return fmt.Errorf("通过hostId获取IP列表失败: %w", err)
+	}
+	if len(ips) == 0 {
+		return nil
+	}
+
+	// 2.【修复BUG】使用一个明确的变量来存储最终生效的防御值
+	effectiveDefense := defense
+
+	// 如果传入的防御值为0,则从全局配置中获取
+	if effectiveDefense == 0 {
+		config, err := s.host.GetGlobalLimitConfig(ctx, int(hostId))
+		if err != nil {
+			return fmt.Errorf("获取全局限制配置失败: %w", err)
+		}
+
+		// 从配置中解析防御值
+		defenseStr := strings.TrimSuffix(config.ConfigMaxProtection, "G")
+		defenseInt, err := strconv.Atoi(defenseStr)
+		if err != nil {
+			// 如果解析失败,返回包含原始值的错误信息,便于排查
+			return fmt.Errorf("解析防御配置 '%s' 失败: %w", config.ConfigMaxProtection, err)
+		}
+
+		// 如果配置的防御值也为0,则无需操作
+		if defenseInt == 0 {
+			return nil
+		}
+		// 将从配置中获取的值赋给 effectiveDefense
+		effectiveDefense = defenseInt
+	}
+
+	var wg sync.WaitGroup
+	// 创建一个足够大的 channel 来收集所有可能发生的错误
+	errChan := make(chan error, len(ips))
+
+	wg.Add(len(ips))
+	for _, ip := range ips {
+
+		go func(ipAddr string) {
+			defer wg.Done()
+
+			e := s.aoDun.SetDefense(ctx, v1.SetDefense{
+				IpAddr:  ipAddr,
+				Defense: effectiveDefense,
+			})
+
+			if e != nil {
+				// 3.【优化】为错误添加IP地址信息,便于定位问题
+				errChan <- fmt.Errorf("IP [%s] 设置防御带宽失败: %w", ipAddr, e)
+			}
+		}(ip)
+	}
+
+	wg.Wait()
+
+	close(errChan)
+
+
+	var allErrors []error
+	for e := range errChan {
+		allErrors = append(allErrors, e)
+	}
+
+
+	if len(allErrors) > 0 {
+		// 将多个错误信息拼接成一个字符串
+		var errStrings []string
+		for _, singleErr := range allErrors {
+			errStrings = append(errStrings, singleErr.Error())
+		}
+		return fmt.Errorf("设置防御时发生多个错误: %s", strings.Join(errStrings, "; "))
+	}
+
+	return nil
+}

+ 12 - 1
internal/task/waf.go

@@ -47,6 +47,7 @@ func NewWafTask(
 	udp waf.UdpForWardingService,
 	web waf.WebForwardingService,
 	buildAoDun waf.BuildAudunService,
+	zzyBgp waf.ZzybgpService,
 ) WafTask {
 	return &wafTask{
 		Task:              task,
@@ -62,6 +63,7 @@ func NewWafTask(
 		udp:               udp,
 		web:               web,
 		buildAoDun:        buildAoDun,
+		zzyBgp :           zzyBgp,
 	}
 }
 
@@ -79,6 +81,7 @@ type wafTask struct {
 	udp waf.UdpForWardingService
 	web waf.WebForwardingService
 	buildAoDun waf.BuildAudunService
+	zzyBgp waf.ZzybgpService
 }
 
 const (
@@ -538,13 +541,21 @@ func (t *wafTask) executeSinglePlanCleanup(ctx context.Context, limit model.Glob
 	}
 
 
+	// 重置防护
+	err = t.zzyBgp.SetDefense(ctx, hostId, 10)
+	if err != nil {
+		return err
+	}
+
 	// 清除小防火墙带宽限制
-	if err := t.buildAoDun.Bandwidth(ctx, int64(limit.HostId), "del"); err != nil {
+	if err := t.buildAoDun.Bandwidth(ctx, hostId, "del"); err != nil {
 		allErrors = multierror.Append(allErrors, err)
 	}
 
 
 
+
+
 	// 只有在上述所有步骤都没有出错的情况下,才执行最终的数据库更新和Redis标记
 	if allErrors.ErrorOrNil() == nil {
 		err := t.gatewayIpRep.CleanIPByHostId(ctx, []int64{hostId})