Просмотр исходного кода

refactor(waf): 重构全球限制功能并移除冗余代码

- 移除了 GlobalLimitExpired 结构中的 RuleId 字段- 删除了多个与 TCP、UDP 和 Web 限制相关的未使用结构和函数
- 简化了全球限制相关的逻辑,移除了与 CDN 套餐绑定和续费相关的代码
- 新增 GetNodeClusterId 方法以获取节点集群 ID
- 更新了 WafFormatterService 以使用新的节点集群 ID 逻辑
fusu 2 недель назад
Родитель
Сommit
aba976ce31

+ 0 - 1
api/v1/globalLimit.go

@@ -48,7 +48,6 @@ type AccessRuleRules struct {
 
 type GlobalLimitExpired struct {
 	HostId int `json:"hostId" form:"hostId" gorm:"column:host_id"`
-	RuleId int `json:"ruleId" form:"ruleId" gorm:"column:rule_id"`
 	Comment string `json:"comment" form:"comment" gorm:"column:comment"`
 }
 

+ 0 - 21
api/v1/tcpLimit.go

@@ -1,21 +0,0 @@
-package v1
-
-type TcpLimitRequest struct {
-	ConnCount    int    `json:"conn_count" form:"conn_count" default:"0"`
-	ConnDuration string `json:"conn_duration" form:"conn_duration"  default:"0s"`
-	MaxConnCount int    `json:"max_conn_count" form:"max_conn_count" default:"0"`
-}
-
-type TcpLimitDeleteRequest struct {
-	WafTcpLimitId int `json:"waf_tcp_limit_id" form:"waf_tcp_limit_id"`
-}
-
-type TcpLimitSendRequest struct {
-	WafTcpLimitId int    `json:"waf_tcp_limit_id" form:"waf_tcp_limit_id"`
-	Tag           string `json:"tag" form:"tag" binding:"required"`
-	ConnCount     int    `json:"conn_count" form:"conn_count" default:"0"`
-	ConnDuration  string `json:"conn_duration" form:"conn_duration"  default:"0s"`
-	MaxConnCount  int    `json:"max_conn_count" form:"max_conn_count" default:"0"`
-	RuleId        int    `json:"waf_common_limit_id" form:"waf_common_limit_id"`
-	Comment       string `form:"comment" json:"comment"`
-}

+ 0 - 21
api/v1/udpLimit.go

@@ -1,21 +0,0 @@
-package v1
-
-type UdpLimitRequest struct {
-	QosPacketCount    int    `form:"qos_packet_count" json:"qos_packet_count"`
-	QosPacketDuration string `form:"qos_packet_duration" json:"qos_packet_duration" default:"0s"`
-	MaxConnCount      int    `form:"max_conn_count" json:"max_conn_count"`
-}
-
-type UdpLimitDeleteRequest struct {
-	WafUdpLimitId int `json:"waf_udp_limit_id" form:"waf_udp_limit_id"`
-}
-
-type UdpLimitSendRequest struct {
-	WafUdpLimitId     int    `json:"waf_udp_limit_id" form:"waf_udp_limit_id"`
-	Tag               string `json:"tag" form:"tag" binding:"required"`
-	QosPacketCount    int    `form:"qos_packet_count" json:"qos_packet_count" default:"0"`
-	QosPacketDuration string `form:"qos_packet_duration" json:"qos_packet_duration" default:"0s"`
-	MaxConnCount      int    `form:"max_conn_count" json:"max_conn_count" default:"0"`
-	RuleId            int    `json:"waf_common_limit_id" form:"waf_common_limit_id"`
-	Comment           string `form:"comment" json:"comment"`
-}

+ 0 - 19
api/v1/webLimit.go

@@ -1,19 +0,0 @@
-package v1
-
-type WebLimitRequest struct {
-	QpsCount    int    `json:"qps_count" form:"qps_count" default:"0"`
-	QpsDuration string `json:"qps_duration" form:"qps_duration"  default:"0s"`
-}
-
-type WebLimitDeleteRequest struct {
-	WafWebLimitId int `json:"waf_web_limit_id" form:"waf_web_limit_id"`
-}
-
-type WebLimitSendRequest struct {
-	WafWebLimitId int    `json:"waf_web_limit_id" form:"waf_web_limit_id"`
-	Tag           string `json:"tag" form:"tag" binding:"required"`
-	QpsCount      int    `json:"qps_count" form:"qps_count" default:"0"`
-	QpsDuration   string `json:"qps_duration" form:"qps_duration"  default:"0s"`
-	RuleId        int    `json:"waf_common_limit_id" form:"waf_common_limit_id"`
-	Comment       string `form:"comment" json:"comment"`
-}

+ 1 - 1
cmd/server/wire/wire_gen.go

@@ -80,7 +80,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	requestService := service.NewRequestService(serviceService)
 	cdnRepository := flexCdn.NewCdnRepository(repositoryRepository)
 	cdnService := flexCdn2.NewCdnService(serviceService, viperViper, requestService, cdnRepository)
-	wafFormatterService := waf2.NewWafFormatterService(serviceService, globalLimitRepository, hostRepository, requiredService, parserService, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, rabbitMQ, hostService, gatewayipRepository, gatewayipService, cdnService)
+	wafFormatterService := waf2.NewWafFormatterService(serviceService, globalLimitRepository, hostRepository, requiredService, parserService, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, rabbitMQ, hostService, gatewayipRepository, gatewayipService, cdnService, cdnRepository)
 	aoDunService := service.NewAoDunService(serviceService, viperViper)
 	proxyRepository := flexCdn.NewProxyRepository(repositoryRepository)
 	proxyService := flexCdn2.NewProxyService(serviceService, proxyRepository, cdnService)

+ 1 - 1
cmd/task/wire/wire_gen.go

@@ -72,7 +72,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	logRepository := repository.NewLogRepository(repositoryRepository)
 	logService := service.NewLogService(serviceService, logRepository)
 	gatewayipService := waf2.NewGatewayipService(serviceService, gatewayipRepository, hostService, logService)
-	wafFormatterService := waf2.NewWafFormatterService(serviceService, globalLimitRepository, hostRepository, requiredService, parserService, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, rabbitMQ, hostService, gatewayipRepository, gatewayipService, cdnService)
+	wafFormatterService := waf2.NewWafFormatterService(serviceService, globalLimitRepository, hostRepository, requiredService, parserService, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, rabbitMQ, hostService, gatewayipRepository, gatewayipService, cdnService, cdnRepository)
 	proxyRepository := flexCdn.NewProxyRepository(repositoryRepository)
 	proxyService := flexCdn2.NewProxyService(serviceService, proxyRepository, cdnService)
 	tcpforwardingService := waf2.NewTcpforwardingService(serviceService, tcpforwardingRepository, parserService, requiredService, crawlerService, globalLimitRepository, hostRepository, wafFormatterService, cdnService, proxyService)

+ 0 - 1
internal/model/globallimit.go

@@ -6,7 +6,6 @@ type GlobalLimit struct {
 	Id              int `gorm:"primary"`
 	HostId          int
 	Name            string
-	RuleId          int
 	GroupId         int
 	Uid             int
 	CdnUid          int

+ 10 - 1
internal/repository/api/flexCdn/cdn.go

@@ -13,6 +13,7 @@ type CdnRepository interface {
 	PutToken(ctx context.Context, token string) error
 	GetToken(ctx context.Context) (string, error)
 	GetUserId(ctx context.Context, username string) (int64, error)
+	GetNodeClusterId(ctx context.Context, name string) (int64, error)
 }
 
 func NewCdnRepository(
@@ -71,4 +72,12 @@ func (r *cdnRepository) GetUserId(ctx context.Context, username string) (int64,
 	}
 	return id, nil
 
-}
+}
+
+func (r *cdnRepository) GetNodeClusterId(ctx context.Context, name string) (int64, error) {
+	var id int64
+	return id, r.DBWithName(ctx,"cdn").Table("cloud_node_clusters").
+		Where("name = ?", name).
+		Select("id").
+		Find(&id).Error
+}

+ 11 - 77
internal/service/api/waf/globallimit.go

@@ -13,7 +13,6 @@ import (
 	"github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn"
 	"github.com/mozillazg/go-pinyin"
 	"github.com/spf13/viper"
-	"go.uber.org/zap"
 	"golang.org/x/sync/errgroup"
 	"gorm.io/gorm"
 	"strconv"
@@ -274,57 +273,19 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 	}
 
 
-	outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
-	if err != nil {
-		return err
-	}
+
 
 	// 获取套餐ID
-	maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G")
-	if maxProtection == "" {
-		return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection)
-	}
-	maxProtectionInt, err := strconv.Atoi(maxProtection)
-	if err != nil {
-		return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection)
-	}
-	var planId int64
-	maxProtectionNum := 1
-	if maxProtectionInt >= 2000 {
-		maxProtectionNum = maxProtectionInt / 1000
-	}
-	NodeAreaName := fmt.Sprintf("%s-%dT",require.NodeArea, maxProtectionNum)
-	planId, err = s.globalLimitRepository.GetNodeArea(ctx, NodeAreaName)
-	if err != nil {
-		if errors.Is(err, gorm.ErrRecordNotFound) {
-			planId = 0
-		}else {
-			return err
-		}
-	}
-	if planId == 0 {
-		// 安全冗余套餐
-		planId = 6
-		s.Logger.Warn("获取套餐Id失败",  zap.String("节点区域", NodeAreaName), zap.String("防御阈值", require.ConfigMaxProtection),zap.Int64("套餐Id", int64(req.Uid)),zap.Int64("魔方套餐Id", int64(req.HostId)))
-	}
+	//maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G")
+	//if maxProtection == "" {
+	//	return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection)
+	//}
+	//maxProtectionInt, err := strconv.Atoi(maxProtection)
+	//if err != nil {
+	//	return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection)
+	//}
 
 
-	ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
-		UserId:    userId,
-		PlanId:    planId,
-		DayTo:     outputTimeStr,
-		Name:      require.GlobalLimitName,
-		IsFree:    true,
-		Period:    "monthly",
-		CountPeriod: 1,
-		PeriodDayTo: outputTimeStr,
-	})
-	if err != nil {
-		return err
-	}
-	if ruleId == 0 {
-		return fmt.Errorf("分配套餐失败")
-	}
 
 
 
@@ -345,7 +306,6 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 			HostId:         req.HostId,
 			Uid:            req.Uid,
 			Name:           require.GlobalLimitName,
-			RuleId: 		int(ruleId),
 			GroupId:        int(groupId),
 			CdnUid:         int(userId),
 			Comment:        req.Comment,
@@ -365,7 +325,6 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 		HostId:         req.HostId,
 		Uid:            req.Uid,
 		Name:           require.GlobalLimitName,
-		RuleId: 		int(ruleId),
 		GroupId:        int(groupId),
 		CdnUid:         int(userId),
 		Comment:        req.Comment,
@@ -384,10 +343,7 @@ func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalL
 	if err != nil {
 		return err
 	}
-	data, err :=  s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
-	if err != nil {
-		return err
-	}
+
 
 	// 如果不存在实例,创建
 	gatewayIp, err := s.gatewayIpRep.GetGatewayipByHostIdAll(ctx, int64(req.HostId))
@@ -403,21 +359,6 @@ func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalL
 
 
 
-	outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
-	if err != nil {
-		return err
-	}
-	err = s.cdnService.RenewPlan(ctx, v1.RenewalPlan{
-		UserPlanId: int64(data.RuleId),
-		DayTo:      outputTimeStr,
-		Period:     "monthly",
-		CountPeriod: 1,
-		IsFree:     true,
-		PeriodDayTo: outputTimeStr,
-	})
-	if err != nil {
-		return err
-	}
 
 	expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
 	if err != nil {
@@ -509,14 +450,7 @@ func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.Globa
 		return nil
 	})
 
-	// 删除套餐
-	g.Go(func() error {
-		e := s.cdnService.DelUserPlan(gCtx, int64(oldData.RuleId))
-		if e != nil {
-			return fmt.Errorf("删除套餐失败: %w", e)
-		}
-		return nil
-	})
+
 	// 删除网站分组
 	g.Go(func() error {
 		e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId))

+ 2 - 2
internal/service/api/waf/tcpforwarding.go

@@ -159,6 +159,7 @@ func (s *tcpforwardingService) prepareWafData(ctx context.Context, req *v1.TcpFo
 		return RequireResponse{}, v1.WebsiteSend{}, err
 	}
 
+
 	formData := v1.WebsiteSend{
 		UserId:         int64(require.CdnUid),
 		Type:           "tcpProxy",
@@ -166,8 +167,7 @@ func (s *tcpforwardingService) prepareWafData(ctx context.Context, req *v1.TcpFo
 		Description:    req.TcpForwardingData.Comment,
 		TcpJSON:        byteData,
 		ServerGroupIds: []int64{int64(require.GroupId)},
-		UserPlanId: int64(require.RuleId),
-		NodeClusterId:  1,
+		NodeClusterId:  2,
 	}
 	return require, formData, nil
 }

+ 3 - 2
internal/service/api/waf/udpforwarding.go

@@ -152,6 +152,8 @@ func (s *udpForWardingService) prepareWafData(ctx context.Context, req *v1.UdpFo
 		return RequireResponse{}, v1.WebsiteSend{}, err
 	}
 
+
+
 	formData := v1.WebsiteSend{
 		UserId:         int64(require.CdnUid),
 		Type:           "udpProxy",
@@ -159,8 +161,7 @@ func (s *udpForWardingService) prepareWafData(ctx context.Context, req *v1.UdpFo
 		Description:    req.UdpForwardingData.Comment,
 		UdpJSON:        byteData,
 		ServerGroupIds: []int64{int64(require.GroupId)},
-		UserPlanId: int64(require.RuleId),
-		NodeClusterId:  1,
+		NodeClusterId:  2,
 	}
 	return require, formData, nil
 }

+ 24 - 0
internal/service/api/waf/wafformatter.go

@@ -7,6 +7,7 @@ import (
 	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
+	flexCdnRep "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/flexCdn"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn"
@@ -42,6 +43,8 @@ type WafFormatterService interface {
 	ConvertToPunycodeIfIDN(ctx context.Context, domain string) (isIDN bool, punycodeDomain string, err error)
 	// 验证端口重复
 	VerifyPort(ctx context.Context,protocol string, id int64, port string,hostId int64,domain string) error
+	// 获取节点集群id
+	GetNodeClusterId(ctx context.Context,hostId int64) (int64, error)
 }
 
 func NewWafFormatterService(
@@ -58,6 +61,7 @@ func NewWafFormatterService(
 	gatewayIpRep waf.GatewayipRepository,
 	gatewayIp GatewayipService,
 	cdn flexCdn.CdnService,
+	cdnRep flexCdnRep.CdnRepository,
 ) WafFormatterService {
 	return &wafFormatterService{
 		Service:           service,
@@ -73,6 +77,7 @@ func NewWafFormatterService(
 		gatewayIpRep: gatewayIpRep,
 		cdn:               cdn,
 		gatewayIp : 		gatewayIp,
+		cdnRep: cdnRep,
 	}
 }
 
@@ -90,6 +95,7 @@ type wafFormatterService struct {
 	gatewayIpRep waf.GatewayipRepository
 	cdn          flexCdn.CdnService
 	gatewayIp    GatewayipService
+	cdnRep flexCdnRep.CdnRepository
 }
 
 type RequireResponse struct {
@@ -624,4 +630,22 @@ func (s *wafFormatterService) verifyUDPPort(ctx context.Context, hostId int64, p
 	}
 
 	return nil
+}
+
+// 获取节点集群id
+func (s *wafFormatterService) GetNodeClusterId(ctx context.Context,hostId int64) (int64, error) {
+	config, err := s.host.GetGlobalLimitConfig(ctx, int(hostId))
+	if err != nil {
+		return 0, err
+	}
+
+	nodeClusterId, err := s.cdnRep.GetNodeClusterId(ctx, config.NodeArea)
+	if err != nil {
+		return 0, err
+	}
+	if nodeClusterId == 0 {
+		return 0, fmt.Errorf("节点集群获取失败")
+	}
+
+	return nodeClusterId, nil
 }

+ 3 - 3
internal/service/api/waf/webforwarding.go

@@ -66,7 +66,6 @@ const (
 	isHttps              = 1
 	protocolHttps        = "https"
 	protocolHttp         = "http"
-	defaultNodeClusterId = 1
 )
 
 type webForwardingService struct {
@@ -241,6 +240,8 @@ func (s *webForwardingService) prepareWafData(ctx context.Context, req *v1.WebFo
 		}
 	}
 
+
+
 	// 3. 组装最终的 WAF 表单数据
 	formData := v1.Website{
 		UserId:          int64(require.CdnUid),
@@ -249,8 +250,7 @@ func (s *webForwardingService) prepareWafData(ctx context.Context, req *v1.WebFo
 		ServerNamesJSON: serverJson,
 		Description:     req.WebForwardingData.Comment,
 		ServerGroupIds:  []int64{int64(require.GroupId)},
-		UserPlanId:      int64(require.RuleId),
-		NodeClusterId:   defaultNodeClusterId,
+		NodeClusterId:   2,
 	}
 
 

+ 0 - 44
internal/service/host.go

@@ -16,9 +16,6 @@ type HostService interface {
 	GetHost(ctx context.Context, id int64) (*model.Host, error)
 	GetGameShieldConfig(ctx context.Context, hostId int) (v1.GameShieldHostBackendConfigResponse, error)
 	GetGlobalLimitConfig(ctx context.Context, hostId int) (v1.GlobalLimitConfigResponse, error)
-	GetTcpLimitConfig(ctx context.Context, hostId int) (v1.TcpLimitRequest, error)
-	GetUdpLimitConfig(ctx context.Context, hostId int) (v1.UdpLimitRequest, error)
-	GetWebLimitConfig(ctx context.Context, hostId int) (v1.WebLimitRequest, error)
 	// 检查是否过期 到期false 未到期true
 	CheckExpired(ctx context.Context, uid int64, hostId int64) (bool, error)
 }
@@ -282,47 +279,6 @@ func (s *hostService) GetGlobalLimitConfig(ctx context.Context, hostId int) (v1.
 }
 
 // GetTcpLimitConfig 修正返回类型,并使用新的辅助函数
-func (s *hostService) GetTcpLimitConfig(ctx context.Context, hostId int) (v1.TcpLimitRequest, error) {
-	_, err := s.getHostConfigsMap(ctx, hostId)
-	if err != nil {
-		return v1.TcpLimitRequest{}, err
-	}
-	data := v1.TcpLimitRequest{
-		ConnCount:    0,
-		ConnDuration: "0s",
-		MaxConnCount: 0,
-	}
-	return data, nil // 返回结构体
-}
-
-// GetUdpLimitConfig
-func (s *hostService) GetUdpLimitConfig(ctx context.Context, hostId int) (v1.UdpLimitRequest, error) {
-	_, err := s.getHostConfigsMap(ctx, hostId)
-	if err != nil {
-		return v1.UdpLimitRequest{}, err
-	}
-	data := v1.UdpLimitRequest{
-		QosPacketCount:    0,
-		QosPacketDuration: "0s",
-		MaxConnCount:      0,
-	}
-	return data, nil
-}
-
-// GetWebLimitConfig 修正返回类型,并使用新的辅助函数
-func (s *hostService) GetWebLimitConfig(ctx context.Context, hostId int) (v1.WebLimitRequest, error) {
-	_, err := s.getHostConfigsMap(ctx, hostId)
-	if err != nil {
-		return v1.WebLimitRequest{}, err
-	}
-	data := v1.WebLimitRequest{
-		QpsCount:    0,
-		QpsDuration: "0s",
-	}
-
-
-	return data, nil
-}
 
 // 检查是否过期 到期false 未到期true
 func (s *hostService) CheckExpired(ctx context.Context, uid int64, hostId int64) (bool, error) {

+ 1 - 13
internal/task/waf.go

@@ -175,18 +175,6 @@ func (t *wafTask) executeRenewalActions(ctx context.Context, reqs []RenewalReque
 				allErrors = multierror.Append(allErrors, err)
 				return // 如果DB更新失败,不继续调用CDN API
 			}
-			// 调用CDN API续费
-			cdnErr := t.cdn.RenewPlan(ctx, v1.RenewalPlan{
-				UserPlanId:  int64(r.PlanId),
-				IsFree:      true,
-				DayTo:       time.Unix(r.ExpiredAt, 0).Format("2006-01-02"),
-				Period:      "monthly",
-				CountPeriod: 1,
-				PeriodDayTo: time.Unix(r.ExpiredAt, 0).Format("2006-01-02"),
-			})
-			if cdnErr != nil {
-				allErrors = multierror.Append(allErrors, cdnErr)
-			}
 		}(req)
 	}
 
@@ -209,7 +197,7 @@ func (t *wafTask) findPlansNeedingSync(ctx context.Context, wafLimits []model.Gl
 	for _, limit := range wafLimits {
 		hostIds = append(hostIds, limit.HostId)
 		wafExpiredMap[limit.HostId] = limit.ExpiredAt
-		wafPlanMap[limit.HostId] = limit.RuleId
+		//wafPlanMap[limit.HostId] = limit.RuleId
 	}
 
 	hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, hostIds)