Browse Source

refactor(service): 重构 GatewayipService 接口并新增方法- 新增 GetGatewayipOnlyIpByHostIdAll 和 GetGatewayipByHostIdFirst 方法
- 修改 AddIpWhereHostIdNull 方法,增加对主机 ID 和用户 ID 的处理
- 更新其他服务中的 GatewayipService 调用,使用新的方法
- 优化代码结构,提高可维护性和可测试性

fusu 2 weeks ago
parent
commit
136292e091

+ 4 - 3
cmd/server/wire/wire_gen.go

@@ -65,23 +65,24 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	tcpforwardingRepository := repository.NewTcpforwardingRepository(repositoryRepository)
 	udpForWardingRepository := repository.NewUdpForWardingRepository(repositoryRepository)
 	gatewayipRepository := repository.NewGatewayipRepository(repositoryRepository)
+	gatewayipService := service.NewGatewayipService(serviceService, gatewayipRepository, hostService)
 	requestService := service.NewRequestService(serviceService)
 	cdnRepository := repository.NewCdnRepository(repositoryRepository)
 	cdnService := service.NewCdnService(serviceService, viperViper, requestService, cdnRepository)
-	wafFormatterService := service.NewWafFormatterService(serviceService, globalLimitRepository, hostRepository, requiredService, parserService, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, rabbitMQ, hostService, gatewayipRepository, cdnService)
+	wafFormatterService := service.NewWafFormatterService(serviceService, globalLimitRepository, hostRepository, requiredService, parserService, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, rabbitMQ, hostService, gatewayipRepository, gatewayipService, cdnService)
 	aoDunService := service.NewAoDunService(serviceService, viperViper)
 	proxyRepository := repository.NewProxyRepository(repositoryRepository)
 	proxyService := service.NewProxyService(serviceService, proxyRepository, cdnService)
 	sslCertService := service.NewSslCertService(serviceService, webForwardingRepository, cdnService)
 	websocketService := service.NewWebsocketService(serviceService, cdnService, webForwardingRepository)
-	webForwardingService := service.NewWebForwardingService(serviceService, requiredService, webForwardingRepository, crawlerService, parserService, wafFormatterService, aoDunService, rabbitMQ, gatewayipRepository, globalLimitRepository, cdnService, proxyService, sslCertService, websocketService)
+	webForwardingService := service.NewWebForwardingService(serviceService, requiredService, webForwardingRepository, crawlerService, parserService, wafFormatterService, aoDunService, rabbitMQ, gatewayipService, globalLimitRepository, cdnService, proxyService, sslCertService, websocketService)
 	webForwardingHandler := handler.NewWebForwardingHandler(handlerHandler, webForwardingService)
 	tcpforwardingService := service.NewTcpforwardingService(serviceService, tcpforwardingRepository, parserService, requiredService, crawlerService, globalLimitRepository, hostRepository, wafFormatterService, cdnService, proxyService)
 	tcpforwardingHandler := handler.NewTcpforwardingHandler(handlerHandler, tcpforwardingService)
 	udpForWardingService := service.NewUdpForWardingService(serviceService, udpForWardingRepository, requiredService, parserService, crawlerService, globalLimitRepository, hostRepository, wafFormatterService, cdnService, proxyService)
 	udpForWardingHandler := handler.NewUdpForWardingHandler(handlerHandler, udpForWardingService)
 	allowAndDenyIpRepository := repository.NewAllowAndDenyIpRepository(repositoryRepository)
-	allowAndDenyIpService := service.NewAllowAndDenyIpService(serviceService, allowAndDenyIpRepository, wafFormatterService, gatewayipRepository)
+	allowAndDenyIpService := service.NewAllowAndDenyIpService(serviceService, allowAndDenyIpRepository, wafFormatterService, gatewayipService)
 	globalLimitService := service.NewGlobalLimitService(serviceService, globalLimitRepository, duedateService, crawlerService, viperViper, requiredService, parserService, hostService, hostRepository, cdnService, cdnRepository, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, allowAndDenyIpService, allowAndDenyIpRepository, tcpforwardingService, udpForWardingService, webForwardingService, gatewayipRepository)
 	globalLimitHandler := handler.NewGlobalLimitHandler(handlerHandler, globalLimitService)
 	adminRepository := repository.NewAdminRepository(repositoryRepository)

+ 6 - 6
internal/service/allowanddenyip.go

@@ -19,14 +19,14 @@ func NewAllowAndDenyIpService(
     service *Service,
     allowAndDenyIpRepository repository.AllowAndDenyIpRepository,
 	wafformatter WafFormatterService,
-	gatewayIpRep repository.GatewayipRepository,
+	gatewayIp GatewayipService,
 
 ) AllowAndDenyIpService {
 	return &allowAndDenyIpService{
 		Service:        service,
 		allowAndDenyIpRepository: allowAndDenyIpRepository,
 		wafformatter : wafformatter,
-		gatewayIpRep : gatewayIpRep,
+		gatewayIp : gatewayIp,
 	}
 }
 
@@ -34,7 +34,7 @@ type allowAndDenyIpService struct {
 	*Service
 	allowAndDenyIpRepository repository.AllowAndDenyIpRepository
 	wafformatter WafFormatterService
-	gatewayIpRep repository.GatewayipRepository
+	gatewayIp GatewayipService
 }
 
 func (s *allowAndDenyIpService) GetAllowAndDenyIp(ctx context.Context, id int64) (*model.AllowAndDenyIp, error) {
@@ -61,7 +61,7 @@ func (s *allowAndDenyIpService) AddAllowAndDenyIps(ctx context.Context, req v1.A
 	}
 
 
-	gatewayGroupIps, err := s.gatewayIpRep.GetGatewayipOnlyIpByHostIdAll(ctx, int64(req.HostId))
+	gatewayGroupIps, err := s.gatewayIp.GetGatewayipOnlyIpByHostIdAll(ctx, int64(req.HostId), int64(req.Uid))
 	if err != nil {
 		return err
 	}
@@ -94,7 +94,7 @@ func (s *allowAndDenyIpService) EditAllowAndDenyIps(ctx context.Context, req v1.
 	}
 
 
-	gatewayGroupIps, err := s.gatewayIpRep.GetGatewayipOnlyIpByHostIdAll(ctx, int64(req.HostId))
+	gatewayGroupIps, err := s.gatewayIp.GetGatewayipOnlyIpByHostIdAll(ctx, int64(req.HostId), int64(req.Uid))
 	if err != nil {
 		return err
 	}
@@ -133,7 +133,7 @@ func (s *allowAndDenyIpService) EditAllowAndDenyIps(ctx context.Context, req v1.
 func (s *allowAndDenyIpService) DeleteAllowAndDenyIps(ctx context.Context, req v1.DelAllowAndDenyIpRequest) error {
 
 	for _, id := range req.Ids {
-		gatewayGroupIps, err := s.gatewayIpRep.GetGatewayipOnlyIpByHostIdAll(ctx, int64(req.HostId))
+		gatewayGroupIps, err := s.gatewayIp.GetGatewayipOnlyIpByHostIdAll(ctx, int64(req.HostId), int64(req.Uid))
 		if err != nil {
 			return err
 		}

+ 70 - 1
internal/service/gatewayip.go

@@ -1,29 +1,98 @@
 package service
 
 import (
-    "context"
+	"context"
+	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 )
 
 type GatewayipService interface {
 	GetGatewayip(ctx context.Context, id int64) (*model.Gatewayip, error)
+	GetGatewayipOnlyIpByHostIdAll(ctx context.Context, hostId int64,uid int64) ([]string, error)
+	GetGatewayipByHostIdFirst(ctx context.Context, hostId int64,uid int64) (string, error)
 }
 func NewGatewayipService(
     service *Service,
     gatewayipRepository repository.GatewayipRepository,
+	host HostService,
 ) GatewayipService {
 	return &gatewayipService{
 		Service:        service,
 		gatewayipRepository: gatewayipRepository,
+		host : host,
 	}
 }
 
 type gatewayipService struct {
 	*Service
 	gatewayipRepository repository.GatewayipRepository
+	host HostService
 }
 
 func (s *gatewayipService) GetGatewayip(ctx context.Context, id int64) (*model.Gatewayip, error) {
 	return s.gatewayipRepository.GetGatewayip(ctx, id)
 }
+
+func (s *gatewayipService) AddIpWhereHostIdNull(ctx context.Context, hostId int64,uid int64) error {
+	config, err := s.host.GetGlobalLimitConfig(ctx, int(hostId))
+	if err != nil {
+		return  err
+	}
+
+	if err := s.gatewayipRepository.GetIpWhereHostIdNull(ctx, v1.GlobalLimitRequireResponse{
+		HostId:              int(hostId),
+		Bps:                 config.Bps,
+		MaxBytesMonth:       config.MaxBytesMonth,
+		IpCount:             config.IpCount,
+		Operator:            config.Operator,
+		NodeArea:            config.NodeArea,
+		ConfigMaxProtection: config.ConfigMaxProtection,
+		IsBanUdp:            config.IsBanUdp,
+	}); err != nil {
+		return  err
+	}
+
+	return nil
+}
+
+func (s *gatewayipService) GetGatewayipOnlyIpByHostIdAll(ctx context.Context, hostId int64,uid int64) ([]string, error) {
+	gatewayIps, err := s.gatewayipRepository.GetGatewayipOnlyIpByHostIdAll(ctx, hostId)
+	if err != nil {
+		return nil, err
+	}
+
+	if len(gatewayIps) == 0 {
+		err = s.AddIpWhereHostIdNull(ctx, hostId,uid)
+		if err != nil {
+			return nil, err
+		}
+		gatewayIps, err = s.gatewayipRepository.GetGatewayipOnlyIpByHostIdAll(ctx, hostId)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	return gatewayIps, nil
+}
+
+
+func (s *gatewayipService) GetGatewayipByHostIdFirst(ctx context.Context, hostId int64,uid int64) (string, error) {
+	gatewayIps, err := s.gatewayipRepository.GetGatewayipByHostIdFirst(ctx, hostId)
+	if err != nil {
+		return "", err
+	}
+
+	if len(gatewayIps) == 0 {
+		err = s.AddIpWhereHostIdNull(ctx, hostId,uid)
+		if err != nil {
+			return "", err
+		}
+		gatewayIps, err = s.gatewayipRepository.GetGatewayipByHostIdFirst(ctx, hostId)
+		if err != nil {
+			return "", err
+		}
+	}
+
+	return gatewayIps, nil
+}

+ 1 - 0
internal/service/globallimit.go

@@ -22,6 +22,7 @@ type GlobalLimitService interface {
 	AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
 	EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
 	DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
+	GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error)
 }
 
 func NewGlobalLimitService(

+ 4 - 1
internal/service/wafformatter.go

@@ -53,6 +53,7 @@ func NewWafFormatterService(
 	mq *rabbitmq.RabbitMQ,
 	host HostService,
 	gatewayIpRep repository.GatewayipRepository,
+	gatewayIp GatewayipService,
 	cdn CdnService,
 ) WafFormatterService {
 	return &wafFormatterService{
@@ -68,6 +69,7 @@ func NewWafFormatterService(
 		mq:                mq,
 		gatewayIpRep: gatewayIpRep,
 		cdn:               cdn,
+		gatewayIp : 		gatewayIp,
 	}
 }
 
@@ -84,6 +86,7 @@ type wafFormatterService struct {
 	mq                *rabbitmq.RabbitMQ
 	gatewayIpRep      repository.GatewayipRepository
 	cdn               CdnService
+	gatewayIp GatewayipService
 }
 
 type RequireResponse struct {
@@ -109,7 +112,7 @@ func (s *wafFormatterService) Require(ctx context.Context, req v1.GlobalRequire)
 	}
 	res.Tag = strconv.Itoa(req.Uid) + "_" + strconv.Itoa(req.HostId) + "_" + domain + "_" + req.Comment
 
-	res.GatewayIps, err = s.gatewayIpRep.GetGatewayipOnlyIpByHostIdAll(ctx, int64(req.HostId))
+	res.GatewayIps, err = s.gatewayIp.GetGatewayipOnlyIpByHostIdAll(ctx, int64(req.HostId), int64(req.Uid))
 	if err != nil {
 		return RequireResponse{}, err
 	}

+ 6 - 6
internal/service/webforwarding.go

@@ -31,7 +31,7 @@ func NewWebForwardingService(
 	wafformatter WafFormatterService,
 	aoDun AoDunService,
 	mq *rabbitmq.RabbitMQ,
-	gatewayIpRep repository.GatewayipRepository,
+	gatewayIp GatewayipService,
 	globalLimitRep repository.GlobalLimitRepository,
 	cdn CdnService,
 	proxy ProxyService,
@@ -47,7 +47,7 @@ func NewWebForwardingService(
 		wafformatter:            wafformatter,
 		aoDun:                   aoDun,
 		mq:                      mq,
-		gatewayIpRep:            gatewayIpRep,
+		gatewayIp:               gatewayIp,
 		cdn:                     cdn,
 		globalLimitRep:          globalLimitRep,
 		proxy:                   proxy,
@@ -72,7 +72,7 @@ type webForwardingService struct {
 	wafformatter            WafFormatterService
 	aoDun                   AoDunService
 	mq                      *rabbitmq.RabbitMQ
-	gatewayIpRep            repository.GatewayipRepository
+	gatewayIp               GatewayipService
 	cdn                     CdnService
 	globalLimitRep          repository.GlobalLimitRepository
 	proxy                   ProxyService
@@ -507,7 +507,7 @@ func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.Web
 		if len(require.GatewayIps) == 0 {
 			return fmt.Errorf("网关组不存在")
 		}
-		firstIp, err := s.gatewayIpRep.GetGatewayipByHostIdFirst(ctx, int64(require.HostId))
+		firstIp, err := s.gatewayIp.GetGatewayipByHostIdFirst(ctx, int64(require.HostId), int64(require.Uid))
 		if err != nil {
 			return err
 		}
@@ -701,7 +701,7 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 
 	// 异步任务:将域名添加到白名单
 	if webData.Domain != req.WebForwardingData.Domain {
-		firstIp, err := s.gatewayIpRep.GetGatewayipByHostIdFirst(ctx, int64(req.HostId))
+		firstIp, err := s.gatewayIp.GetGatewayipByHostIdFirst(ctx, int64(req.HostId), int64(req.Uid))
 		if err != nil {
 			return err
 		}
@@ -840,7 +840,7 @@ func (s *webForwardingService) DeleteWebForwarding(ctx context.Context, req v1.D
 		}
 
 		// 异步任务:将域名添加到白名单
-		firstIp, err := s.gatewayIpRep.GetGatewayipByHostIdFirst(ctx, int64(oldData.HostId))
+		firstIp, err := s.gatewayIp.GetGatewayipByHostIdFirst(ctx, int64(oldData.HostId), int64(req.Uid))
 		if err != nil {
 			return err
 		}