Parcourir la source

feat(internal): 增加网关组相关功能并优化全局限制逻辑

- 新增网关组模型、仓库和服务层功能,包括添加、编辑、删除和查询网关组
- 更新全局限制逻辑,增加运营商和IP数量相关配置
- 修改主机配置,添加运营商和IP数量字段
fusu il y a 1 mois
Parent
commit
482ef49c1a

+ 2 - 0
api/v1/globalLimit.go

@@ -18,6 +18,8 @@ type GlobalLimitRequireResponse struct {
 	GlobalLimitName string
 	Bps             string
 	MaxBytesMonth   string
+	IpCount   int
+	Operator int
 }
 
 type GeneralLimitRequireRequest struct {

+ 2 - 0
api/v1/host.go

@@ -28,4 +28,6 @@ type GlobalLimitConfigResponse struct {
 	Bps           string `default:"0"`
 	PortCount     int
 	DomainCount   int
+	Operator      int
+	IpCount       int
 }

+ 10 - 3
internal/model/gatewaygroup.go

@@ -1,11 +1,18 @@
 package model
 
-import "gorm.io/gorm"
+import "time"
 
 type GatewayGroup struct {
-	gorm.Model
+	Id          int `gorm:"primary"`
+	HostId      int `gorm:"null"`
+	RuleId      int `gorm:"not null"`
+	Name        string `gorm:"null"`
+	Operator    int `gorm:"not null"`
+	Comment     string `gorm:"null"`
+	CreatedAt   time.Time
+	UpdatedAt   time.Time
 }
 
 func (m *GatewayGroup) TableName() string {
-	return "gateway_group"
+	return "shd_waf_gateway_group"
 }

+ 68 - 0
internal/repository/gatewaygroup.go

@@ -2,11 +2,20 @@ package repository
 
 import (
 	"context"
+	"errors"
+	"fmt"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
+	"gorm.io/gorm"
 )
 
 type GatewayGroupRepository interface {
 	GetGatewayGroup(ctx context.Context, id int64) (*model.GatewayGroup, error)
+	AddGatewayGroup(ctx context.Context, req *model.GatewayGroup) error
+	EditGatewayGroup(ctx context.Context, req *model.GatewayGroup) error
+	DeleteGatewayGroup(ctx context.Context, req *model.GatewayGroup) error
+	GetGatewayGroupWhereHostIdNull(ctx context.Context,operator int, count int) (int, error)
+	GetGatewayGroupByHostId(ctx context.Context, hostId int64) (*[]model.GatewayGroup, error)
+	GetGatewayGroupAllIds(ctx context.Context) ([]int, error)
 }
 
 func NewGatewayGroupRepository(
@@ -26,3 +35,62 @@ func (r *gatewayGroupRepository) GetGatewayGroup(ctx context.Context, id int64)
 
 	return &gatewayGroup, nil
 }
+
+func (r *gatewayGroupRepository) AddGatewayGroup(ctx context.Context, req *model.GatewayGroup) error {
+	if err := r.DB(ctx).Create(req).Error; err != nil {
+		return err
+	}
+	return nil
+}
+
+func (r *gatewayGroupRepository) EditGatewayGroup(ctx context.Context, req *model.GatewayGroup) error {
+	if err := r.DB(ctx).Model(&model.GatewayGroup{}).Where("rule_id = ?", req.RuleId).Updates(req).Error; err != nil {
+		return err
+	}
+	return nil
+}
+
+func (r *gatewayGroupRepository) DeleteGatewayGroup(ctx context.Context, req *model.GatewayGroup) error {
+	if err := r.DB(ctx).Model(&model.GatewayGroup{}).Where("id = ?", req.Id).Delete(req).Error; err != nil {
+		return err
+	}
+	return nil
+}
+
+func (r *gatewayGroupRepository) GetGatewayGroupWhereHostIdNull(ctx context.Context,operator int, count int) (int, error) {
+	var id int
+	subQuery := r.DB(ctx).Model(&model.GateWayGroupIp{}).
+		Select("gateway_group_id").
+		Group("gateway_group_id").
+		Having("COUNT(*) = ?", count)
+
+	err := r.DB(ctx).Model(&model.GatewayGroup{}).
+		Where("operator = ?", operator).
+		Where("id IN (?)", subQuery).
+		Where("host_id = ?", 0).
+		Select("rule_id").First(&id).Error
+	if err != nil {
+		if errors.Is(err, gorm.ErrRecordNotFound){
+			return 0, fmt.Errorf("库存不足,请联系客服补充网关组库存")
+		}
+		return 0, err
+	}
+
+	return id, nil
+}
+
+func (r *gatewayGroupRepository) GetGatewayGroupByHostId(ctx context.Context, hostId int64) (*[]model.GatewayGroup, error) {
+	res := []model.GatewayGroup{}
+	if err := r.DB(ctx).Where("host_id = ?", hostId).Find(&res).Error; err != nil {
+		return nil, err
+	}
+	return &res, nil
+}
+
+func (r *gatewayGroupRepository) GetGatewayGroupAllIds(ctx context.Context) ([]int, error) {
+	var res []int
+	if err := r.DB(ctx).Model(&model.GatewayGroup{}).Pluck("host_id", &res).Error; err != nil {
+		return nil, err
+	}
+	return res, nil
+}

+ 1 - 1
internal/repository/host.go

@@ -64,4 +64,4 @@ func (r *hostRepository) GetDomainById(ctx context.Context, id int) (string, err
 		return "", err
 	}
 	return res, nil
-}
+}

+ 27 - 0
internal/service/gatewaygroup.go

@@ -12,6 +12,9 @@ import (
 type GatewayGroupService interface {
 	GetGatewayGroup(ctx context.Context, id int64) (*model.GatewayGroup, error)
 	AddGatewayGroup(ctx context.Context, req v1.AddGateWayGroupRequest) (int, error)
+	EditGatewayGroup(ctx context.Context, group model.GatewayGroup) error
+	DeleteGatewayGroup(ctx context.Context, group model.GatewayGroup) error
+	GetGatewayGroupByHostId(ctx context.Context, hostId int) ([]model.GatewayGroup, error)
 }
 func NewGatewayGroupService(
     service *Service,
@@ -63,4 +66,28 @@ func (s *gatewayGroupService) AddGatewayGroup(ctx context.Context, req v1.AddGat
 		return 0,  err
 	}
 	return gateWayGroupId, nil
+}
+
+func (s *gatewayGroupService) GetGatewayGroupByHostId(ctx context.Context, hostId int) ([]model.GatewayGroup, error) {
+	res, err := s.gatewayGroupRepository.GetGatewayGroupByHostId(ctx, int64(hostId))
+	if err != nil {
+		return nil, err
+	}
+	return *res, nil
+}
+
+func (s *gatewayGroupService) EditGatewayGroup(ctx context.Context, ip model.GatewayGroup) error {
+	if err := s.gatewayGroupRepository.EditGatewayGroup(ctx, &ip); err != nil {
+		return err
+	}
+	return nil
+
+}
+
+func (s *gatewayGroupService) DeleteGatewayGroup(ctx context.Context, ip model.GatewayGroup) error {
+	if err := s.gatewayGroupRepository.DeleteGatewayGroup(ctx, &ip); err != nil {
+		return err
+	}
+	return nil
+
 }

+ 15 - 1
internal/service/globallimit.go

@@ -33,6 +33,7 @@ func NewGlobalLimitService(
 	webLimit WebLimitService,
 	gateWayGroup GatewayGroupService,
 	hostRep repository.HostRepository,
+	gateWayGroupRep repository.GatewayGroupRepository,
 ) GlobalLimitService {
 	return &globalLimitService{
 		Service:               service,
@@ -48,6 +49,7 @@ func NewGlobalLimitService(
 		webLimit:              webLimit,
 		gateWayGroup:          gateWayGroup,
 		hostRep:                hostRep,
+		gateWayGroupRep:       gateWayGroupRep,
 	}
 }
 
@@ -65,6 +67,7 @@ type globalLimitService struct {
 	webLimit              WebLimitService
 	gateWayGroup          GatewayGroupService
 	hostRep               repository.HostRepository
+	gateWayGroupRep       repository.GatewayGroupRepository
 }
 
 func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
@@ -86,6 +89,8 @@ func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.Glob
 
 	res.Bps = configCount.Bps
 	res.MaxBytesMonth = configCount.MaxBytesMonth
+	res.Operator = configCount.Operator
+	res.IpCount = configCount.IpCount
 	domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
 	if err != nil {
 		return v1.GlobalLimitRequireResponse{}, err
@@ -103,6 +108,10 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 	if err != nil {
 		return err
 	}
+	gatewayGroupId, err := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(ctx, require.Operator, require.IpCount)
+	if err != nil {
+		return err
+	}
 	formData := map[string]interface{}{
 		"tag":             require.GlobalLimitName,
 		"bps":             require.Bps,
@@ -193,6 +202,7 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 	if err := g.Wait(); err != nil {
 		return err
 	}
+
 	err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
 		HostId:          req.HostId,
 		RuleId:          cast.ToInt(ruleId),
@@ -201,11 +211,15 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 		TcpLimitRuleId:  tcpLimitRuleId,
 		UdpLimitRuleId:  udpLimitRuleId,
 		WebLimitRuleId:  webLimitRuleId,
-		GatewayGroupId:  5,// TODO: 临时写死
+		GatewayGroupId:  gatewayGroupId,
 	})
 	if err != nil {
 		return err
 	}
+	err = s.gateWayGroupRep.EditGatewayGroup(ctx, &model.GatewayGroup{
+		RuleId: gatewayGroupId,
+		HostId: req.HostId,
+	})
 	return nil
 }
 

+ 19 - 0
internal/service/host.go

@@ -40,6 +40,8 @@ const (
 	ConfigMaxBytesMonth        = "高防防护能力"
 	ConfigPortCount 		   = "防御端口数量"
 	ConfigDomainCount          = "防御域名(需要备案)"
+	ConfigOperator             = "高防线路"
+	ConfigIpCount              = "高防节点IP"
 )
 
 // unitSuffixMap 存储需要去除的单位后缀
@@ -47,6 +49,7 @@ var unitSuffixMap = map[string]string{
 	ConfigOnlineDevices: "个",
 	ConfigRuleEntries:   "个",
 	ConfigMaxBandwidth:  "条",
+	ConfigIpCount:      "个",
 }
 
 type hostService struct {
@@ -205,6 +208,8 @@ func (s *hostService) GetGlobalLimitConfig(ctx context.Context, hostId int) (v1.
 	data := v1.GlobalLimitConfigResponse{
 		MaxBytesMonth: "0",
 		Bps:           "0",
+		IpCount:       0,
+		Operator:      0,
 	}
 	if val, ok := configsMap[ConfigBps]; ok {
 		data.Bps = val
@@ -224,6 +229,20 @@ func (s *hostService) GetGlobalLimitConfig(ctx context.Context, hostId int) (v1.
 			return data, err
 		}
 	}
+	if val, ok := configsMap[ConfigOperator]; ok {
+		if val == "电信" {
+			data.IpCount = 1
+		}
+		if val == "BGP" {
+			data.IpCount = 2
+		}
+	}
+	if val, ok := configsMap[ConfigIpCount]; ok {
+		data.Operator, err = cast.ToIntE(val)
+		if err != nil {
+			return data, err
+		}
+	}
 	return data, nil
 }