Эх сурвалжийг харах

refactor(webforwarding): 优化域名白名单添加逻辑

- 新增 GetIp 方法获取网关组对应的 IP
- 将获取 IP 的逻辑移至 GetIp 方法中,减少代码重复
- 优化域名白名单添加流程,提高代码可读性和维护性
fusu 1 сар өмнө
parent
commit
e502ac764a

+ 2 - 2
cmd/server/wire/wire_gen.go

@@ -65,7 +65,8 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	wafFormatterService := service.NewWafFormatterService(serviceService, globalLimitRepository, hostRepository, requiredService, parserService, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, hostService)
 	aoDunService := service.NewAoDunService(serviceService, viperViper)
 	gateWayGroupIpRepository := repository.NewGateWayGroupIpRepository(repositoryRepository)
-	webForwardingService := service.NewWebForwardingService(serviceService, requiredService, webForwardingRepository, crawlerService, parserService, wafFormatterService, aoDunService, rabbitMQ, gateWayGroupIpRepository)
+	gatewayGroupRepository := repository.NewGatewayGroupRepository(repositoryRepository)
+	webForwardingService := service.NewWebForwardingService(serviceService, requiredService, webForwardingRepository, crawlerService, parserService, wafFormatterService, aoDunService, rabbitMQ, gateWayGroupIpRepository, gatewayGroupRepository)
 	webForwardingHandler := handler.NewWebForwardingHandler(handlerHandler, webForwardingService)
 	webLimitRepository := repository.NewWebLimitRepository(repositoryRepository)
 	webLimitService := service.NewWebLimitService(serviceService, webLimitRepository, requiredService, parserService, crawlerService, hostService)
@@ -80,7 +81,6 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	udpLimitRepository := repository.NewUdpLimitRepository(repositoryRepository)
 	udpLimitService := service.NewUdpLimitService(serviceService, udpLimitRepository, requiredService, crawlerService, parserService, hostService)
 	udpLimitHandler := handler.NewUdpLimitHandler(handlerHandler, udpLimitService)
-	gatewayGroupRepository := repository.NewGatewayGroupRepository(repositoryRepository)
 	gatewayGroupService := service.NewGatewayGroupService(serviceService, gatewayGroupRepository, requiredService, parserService)
 	globalLimitService := service.NewGlobalLimitService(serviceService, globalLimitRepository, duedateService, crawlerService, viperViper, requiredService, parserService, hostService, tcpLimitService, udpLimitService, webLimitService, gatewayGroupService, hostRepository, gatewayGroupRepository)
 	globalLimitHandler := handler.NewGlobalLimitHandler(handlerHandler, globalLimitService)

+ 9 - 0
internal/repository/gatewaygroup.go

@@ -15,6 +15,7 @@ type GatewayGroupRepository interface {
 	DeleteGatewayGroup(ctx context.Context, req *model.GatewayGroup) error
 	GetGatewayGroupWhereHostIdNull(ctx context.Context,operator int, count int) (int, error)
 	GetGatewayGroupByHostId(ctx context.Context, hostId int64) (*[]model.GatewayGroup, error)
+	GetGatewayGroupByRuleId(ctx context.Context, ruleId int64) (*model.GatewayGroup, error)
 }
 
 func NewGatewayGroupRepository(
@@ -86,3 +87,11 @@ func (r *gatewayGroupRepository) GetGatewayGroupByHostId(ctx context.Context, ho
 	return &res, nil
 }
 
+func (r *gatewayGroupRepository) GetGatewayGroupByRuleId(ctx context.Context, ruleId int64) (*model.GatewayGroup, error) {
+	res := model.GatewayGroup{}
+	if err := r.DB(ctx).Where("rule_id = ?", ruleId).Find(&res).Error; err != nil {
+		return nil, err
+	}
+	return &res, nil
+
+}

+ 1 - 1
internal/service/globallimit.go

@@ -231,7 +231,7 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 		return err
 	}
 	err = s.gateWayGroupRep.EditGatewayGroup(ctx, &model.GatewayGroup{
-		Id: gatewayGroupId,
+		RuleId: gatewayGroupId,
 		HostId: req.HostId,
 	})
 	return nil

+ 22 - 5
internal/service/webforwarding.go

@@ -34,6 +34,7 @@ func NewWebForwardingService(
 	aoDun AoDunService,
 	mq *rabbitmq.RabbitMQ,
 	gatewayGroupIpRep repository.GateWayGroupIpRepository,
+	gatewayGroupRep repository.GatewayGroupRepository,
 ) WebForwardingService {
 	return &webForwardingService{
 		Service:                 service,
@@ -45,6 +46,7 @@ func NewWebForwardingService(
 		aoDun:                   aoDun,
 		mq:                      mq,
 		gatewayGroupIpRep:        gatewayGroupIpRep,
+		gatewayGroupRep:           gatewayGroupRep,
 	}
 }
 
@@ -58,6 +60,7 @@ type webForwardingService struct {
 	aoDun                   AoDunService
 	mq                      *rabbitmq.RabbitMQ
 	gatewayGroupIpRep        repository.GateWayGroupIpRepository
+	gatewayGroupRep           repository.GatewayGroupRepository
 }
 
 func (s *webForwardingService) require(ctx context.Context,req v1.GlobalRequire) (v1.GlobalRequire, error) {
@@ -297,6 +300,18 @@ func (s *webForwardingService) prepareWafData(ctx context.Context, req *v1.WebFo
 	return require, formData, nil
 }
 
+func (s *webForwardingService) GetIp(ctx context.Context, gatewayGroupId int) (string, error) {
+	WafGatewayGroupRuleId, err := s.gatewayGroupRep.GetGatewayGroupByRuleId(ctx, int64(gatewayGroupId))
+	if err != nil {
+		return "", err
+	}
+	ip, err := s.gatewayGroupIpRep.GetGateWayGroupFirstIpByGatewayGroupId(ctx, WafGatewayGroupRuleId.Id)
+	if err != nil {
+		return "", err
+	}
+	return ip, nil
+}
+
 func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) error {
 	require, formData, err := s.prepareWafData(ctx, req)
 	if err != nil {
@@ -311,15 +326,16 @@ func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.Web
 		return err
 	}
 	if req.WebForwardingData.Domain != "" {
-		// 异步任务:将域名添加到白名单
-		doMain, err := s.wafformatter.ConvertToWildcardDomain(ctx, req.WebForwardingData.Domain)
+		ip, err := s.GetIp(ctx, require.WafGatewayGroupId)
 		if err != nil {
 			return err
 		}
-		ip, err := s.gatewayGroupIpRep.GetGateWayGroupFirstIpByGatewayGroupId(ctx, require.WafGatewayGroupId)
+		// 异步任务:将域名添加到白名单
+		doMain, err := s.wafformatter.ConvertToWildcardDomain(ctx, req.WebForwardingData.Domain)
 		if err != nil {
 			return err
 		}
+
 		go s.publishDomainWhitelistTask(doMain,ip, "add")
 	}
 
@@ -368,8 +384,9 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 		return err
 	}
 	if webData.Domain != req.WebForwardingData.Domain {
-		Ip, err := s.gatewayGroupIpRep.GetGateWayGroupFirstIpByGatewayGroupId(ctx, webData.WafGatewayGroupId)
-			if err != nil {
+		Ip, err := s.GetIp(ctx, webData.WafGatewayGroupId)
+		if err != nil {
+			return err
 		}
 	// 异步任务:将域名添加到白名单
 		doMain, err := s.wafformatter.ConvertToWildcardDomain(ctx, req.WebForwardingData.Domain)