ソースを参照

refactor(globallimit):重构全球限速功能

-移除 GatewayGroupId 字段
- 更新 GetGatewayipByHostIdAll 方法返回值类型
- 添加 CleanIPByHostId 方法以清理 IP
- 调整全球限速规则创建和编辑逻辑
- 移除与 GatewayGroup 相关的代码
fusu 3 週間 前
コミット
644a2974c0

+ 0 - 1
internal/model/globallimit.go

@@ -10,7 +10,6 @@ type GlobalLimit struct {
 	GroupId         int
 	Uid             int
 	CdnUid          int
-	GatewayGroupId  int
 	Comment         string
 	State         bool `gorm:"column:state" default:"true"`
 	ExpiredAt       int64 `gorm:"column:expired_at"`

+ 6 - 3
internal/repository/gatewayip.go

@@ -14,7 +14,7 @@ type GatewayipRepository interface {
 	EditGatewayip(ctx context.Context, req model.Gatewayip) error
 	DeleteGatewayip(ctx context.Context, req model.Gatewayip) error
 	GetGatewayipByHostId(ctx context.Context, hostId int64) (*model.Gatewayip, error)
-	GetGatewayipByHostIdAll(ctx context.Context, hostId int64) error
+	GetGatewayipByHostIdAll(ctx context.Context, hostId int64) (*model.Gatewayip, error)
 	UpdateGatewayipByHostId(ctx context.Context, req model.Gatewayip) error
 	DeleteGatewayipByHostId(ctx context.Context, hostId int64) error
 	GetIpWhereHostIdNull(ctx context.Context,req v1.GlobalLimitRequireResponse) error
@@ -55,9 +55,9 @@ func (r *gatewayipRepository) GetGatewayipByHostId(ctx context.Context, hostId i
 	return &req, r.DB(ctx).Where("host_id = ?", hostId).First(&req).Error
 }
 
-func (r *gatewayipRepository) GetGatewayipByHostIdAll(ctx context.Context, hostId int64) error {
+func (r *gatewayipRepository) GetGatewayipByHostIdAll(ctx context.Context, hostId int64) (*model.Gatewayip, error) {
 	var req model.Gatewayip
-	return r.DB(ctx).Where("host_id = ?", hostId).Find(&req).Error
+	return &req, r.DB(ctx).Where("host_id = ?", hostId).Find(&req).Error
 }
 
 func (r *gatewayipRepository) UpdateGatewayipByHostId(ctx context.Context, req model.Gatewayip) error {
@@ -73,6 +73,9 @@ func (r *gatewayipRepository) GetIpWhereHostIdNull(ctx context.Context,req v1.Gl
 	if req.IpCount <= 0 {
 		return fmt.Errorf("套餐IP数量错误, 请联系客服")
 	}
+	if req.HostId <= 0 {
+		return fmt.Errorf("主机ID错误, 请联系客服")
+	}
 
 	// 使用事务保证操作的原子性
 	return r.DB(ctx).Transaction(func(tx *gorm.DB) error {

+ 1 - 0
internal/repository/globallimit.go

@@ -50,6 +50,7 @@ func (r *globalLimitRepository) GetGlobalLimit(ctx context.Context, id int64) (*
 	return &globalLimit, nil
 }
 
+
 func (r *globalLimitRepository) AddGlobalLimit(ctx context.Context, req *model.GlobalLimit) error {
 	if err := r.DB(ctx).Create(&req).Error; err != nil {
 		return err

+ 21 - 35
internal/service/globallimit.go

@@ -33,9 +33,7 @@ func NewGlobalLimitService(
 	required RequiredService,
 	parser ParserService,
 	host HostService,
-	gateWayGroup GatewayGroupService,
 	hostRep repository.HostRepository,
-	gateWayGroupRep repository.GatewayGroupRepository,
 	cdnService CdnService,
 	cdnRep repository.CdnRepository,
 	tcpforwardingRep repository.TcpforwardingRepository,
@@ -46,6 +44,7 @@ func NewGlobalLimitService(
 	tcpforwarding TcpforwardingService,
 	udpForWarding UdpForWardingService,
 	webForWarding WebForwardingService,
+	gatewayIpRep repository.GatewayipRepository,
 ) GlobalLimitService {
 	return &globalLimitService{
 		Service:               service,
@@ -56,9 +55,7 @@ func NewGlobalLimitService(
 		required:              required,
 		parser:                parser,
 		host:                  host,
-		gateWayGroup:          gateWayGroup,
 		hostRep:                hostRep,
-		gateWayGroupRep:       gateWayGroupRep,
 		cdnService:            cdnService,
 		cdnRep:                cdnRep,
 		tcpforwardingRep:      tcpforwardingRep,
@@ -69,6 +66,7 @@ func NewGlobalLimitService(
 		tcpforwarding:         tcpforwarding,
 		udpForWarding:         udpForWarding,
 		webForWarding:         webForWarding,
+		gatewayIpRep:             gatewayIpRep,
 	}
 }
 
@@ -81,9 +79,7 @@ type globalLimitService struct {
 	required              RequiredService
 	parser                ParserService
 	host                  HostService
-	gateWayGroup          GatewayGroupService
 	hostRep               repository.HostRepository
-	gateWayGroupRep       repository.GatewayGroupRepository
 	cdnService            CdnService
 	cdnRep                repository.CdnRepository
 	tcpforwardingRep      repository.TcpforwardingRepository
@@ -94,6 +90,7 @@ type globalLimitService struct {
 	tcpforwarding         TcpforwardingService
 	udpForWarding         UdpForWardingService
 	webForWarding         WebForwardingService
+	gatewayIpRep          repository.GatewayipRepository
 }
 
 func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
@@ -169,6 +166,8 @@ func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.Glob
 	res.NodeArea = configCount.NodeArea
 	res.ConfigMaxProtection = configCount.ConfigMaxProtection
 	res.IsBanUdp = configCount.IsBanUdp
+	res.HostId = req.HostId
+
 	domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
 	if err != nil {
 		return v1.GlobalLimitRequireResponse{}, err
@@ -226,18 +225,13 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 	}
 
 	g, gCtx := errgroup.WithContext(ctx)
-	var gatewayGroupId int
 	var userId int64
 	var groupId int64
 	g.Go(func() error {
-		res, e := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(gCtx, require)
+		e := s.gatewayIpRep.GetIpWhereHostIdNull(gCtx, require)
 		if e != nil {
 			return fmt.Errorf("获取网关组失败: %w", e)
 		}
-		if res == 0 {
-			return fmt.Errorf("获取网关组失败")
-		}
-		gatewayGroupId = res
 		return nil
 	})
 
@@ -326,14 +320,6 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 
 
 
-	err = s.gateWayGroupRep.EditGatewayGroup(ctx, &model.GatewayGroup{
-		Id: gatewayGroupId,
-		HostId: req.HostId,
-	})
-	if err != nil {
-		return err
-	}
-
 	expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
 	if err != nil {
 		return err
@@ -353,7 +339,6 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 			Name:           require.GlobalLimitName,
 			RuleId: 		int(ruleId),
 			GroupId:        int(groupId),
-			GatewayGroupId: gatewayGroupId,
 			CdnUid:         int(userId),
 			Comment:        req.Comment,
 			ExpiredAt:      expiredAt,
@@ -374,7 +359,6 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 		Name:           require.GlobalLimitName,
 		RuleId: 		int(ruleId),
 		GroupId:        int(groupId),
-		GatewayGroupId: gatewayGroupId,
 		CdnUid:         int(userId),
 		Comment:        req.Comment,
 		State:          true,
@@ -386,6 +370,7 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 	return nil
 }
 
+
 func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
 	require, err := s.GlobalLimitRequire(ctx, req)
 	if err != nil {
@@ -396,17 +381,21 @@ func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalL
 		return err
 	}
 
-	if data.GatewayGroupId == 0 {
-		gatewayGroupId, e := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(ctx, require)
-		if e != nil {
-			return fmt.Errorf("获取网关组失败: %w", e)
-		}
-		if gatewayGroupId == 0 {
-			return fmt.Errorf("获取网关组失败")
+	// 如果不存在实例,创建
+	gatewayIp, err := s.gatewayIpRep.GetGatewayipByHostIdAll(ctx, int64(req.HostId))
+	if err != nil {
+		return err
+	}
+	if gatewayIp != nil {
+		err = s.gatewayIpRep.GetIpWhereHostIdNull(ctx, require)
+		if err != nil {
+			return fmt.Errorf("获取网关组失败: %w", err)
 		}
-		data.GatewayGroupId = gatewayGroupId
+		return nil
 	}
 
+
+
 	outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
 	if err != nil {
 		return err
@@ -430,7 +419,6 @@ func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalL
 	if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
 		HostId:  req.HostId,
 		Comment: req.Comment,
-		GatewayGroupId: data.GatewayGroupId,
 		ExpiredAt: expiredAt,
 	}); err != nil {
 		return err
@@ -470,6 +458,7 @@ func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.Globa
 		return err
 	}
 
+	// 黑白IP
 	BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId))
 	if err != nil {
 		return err
@@ -530,10 +519,7 @@ func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.Globa
 	if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), false); err != nil {
 		return err
 	}
-	if err := s.gateWayGroupRep.EditGatewayGroup(ctx,&model.GatewayGroup{
-		Id: oldData.GatewayGroupId,
-		HostId: 0,
-	}); err != nil {
+	if err := s.gatewayIpRep.CleanIPByHostId(ctx, []int64{int64(req.HostId)}); err != nil {
 		return err
 	}