Bläddra i källkod

refactor(gatewayip): 重构 GetIpWhereHostIdNull 方法并添加日志记录功能

- 修改 GetIpWhereHostIdNull 方法签名,返回 []string 和 error
- 增加事务处理,确保操作原子性
- 添加 IP库存检查逻辑
- 优化查询性能,使用 Limit 和 Find替代 Pluck
- 新增日志记录功能,记录分配的 IP 信息
- 更新相关服务和接口以适应新的方法签名
fusu 2 veckor sedan
förälder
incheckning
31197d163e

+ 4 - 4
cmd/server/wire/wire_gen.go

@@ -65,7 +65,9 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	tcpforwardingRepository := repository.NewTcpforwardingRepository(repositoryRepository)
 	udpForWardingRepository := repository.NewUdpForWardingRepository(repositoryRepository)
 	gatewayipRepository := repository.NewGatewayipRepository(repositoryRepository)
-	gatewayipService := service.NewGatewayipService(serviceService, gatewayipRepository, hostService)
+	logRepository := repository.NewLogRepository(repositoryRepository)
+	logService := service.NewLogService(serviceService, logRepository)
+	gatewayipService := service.NewGatewayipService(serviceService, gatewayipRepository, hostService, logService)
 	requestService := service.NewRequestService(serviceService)
 	cdnRepository := repository.NewCdnRepository(repositoryRepository)
 	cdnService := service.NewCdnService(serviceService, viperViper, requestService, cdnRepository)
@@ -83,7 +85,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	udpForWardingHandler := handler.NewUdpForWardingHandler(handlerHandler, udpForWardingService)
 	allowAndDenyIpRepository := repository.NewAllowAndDenyIpRepository(repositoryRepository)
 	allowAndDenyIpService := service.NewAllowAndDenyIpService(serviceService, allowAndDenyIpRepository, wafFormatterService, gatewayipService)
-	globalLimitService := service.NewGlobalLimitService(serviceService, globalLimitRepository, duedateService, crawlerService, viperViper, requiredService, parserService, hostService, hostRepository, cdnService, cdnRepository, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, allowAndDenyIpService, allowAndDenyIpRepository, tcpforwardingService, udpForWardingService, webForwardingService, gatewayipRepository)
+	globalLimitService := service.NewGlobalLimitService(serviceService, globalLimitRepository, duedateService, crawlerService, viperViper, requiredService, parserService, hostService, hostRepository, cdnService, cdnRepository, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, allowAndDenyIpService, allowAndDenyIpRepository, tcpforwardingService, udpForWardingService, webForwardingService, gatewayipRepository, gatewayipService)
 	globalLimitHandler := handler.NewGlobalLimitHandler(handlerHandler, globalLimitService)
 	adminRepository := repository.NewAdminRepository(repositoryRepository)
 	adminService := service.NewAdminService(serviceService, adminRepository)
@@ -98,8 +100,6 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	ccRepository := repository.NewCcRepository(repositoryRepository)
 	ccService := service.NewCcService(serviceService, ccRepository, webForwardingRepository, cdnService)
 	ccHandler := handler.NewCcHandler(handlerHandler, ccService)
-	logRepository := repository.NewLogRepository(repositoryRepository)
-	logService := service.NewLogService(serviceService, logRepository)
 	httpServer := server.NewHTTPServer(logger, viperViper, jwtJWT, syncedEnforcer, limiterLimiter, handlerFunc, userHandler, gameShieldHandler, gameShieldBackendHandler, webForwardingHandler, tcpforwardingHandler, udpForWardingHandler, globalLimitHandler, adminHandler, gatewayGroupHandler, gateWayGroupIpHandler, allowAndDenyIpHandler, ccHandler, logService)
 	appApp := newApp(httpServer)
 	return appApp, func() {

+ 2 - 0
cmd/task/wire/wire.go

@@ -43,6 +43,7 @@ var repositorySet = wire.NewSet(
 	repository.NewExpiredRepository,
 	repository.NewProxyRepository,
 	repository.NewGatewayipRepository,
+	repository.NewLogRepository,
 )
 
 var taskSet = wire.NewSet(
@@ -86,6 +87,7 @@ var serviceSet = wire.NewSet(
 	service.NewSslCertService,
 	service.NewWebsocketService,
 	service.NewGatewayipService,
+	service.NewLogService,
 )
 
 // build App

+ 5 - 3
cmd/task/wire/wire_gen.go

@@ -63,7 +63,9 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	globalLimitRepository := repository.NewGlobalLimitRepository(repositoryRepository)
 	expiredRepository := repository.NewExpiredRepository(repositoryRepository)
 	gatewayipRepository := repository.NewGatewayipRepository(repositoryRepository)
-	gatewayipService := service.NewGatewayipService(serviceService, gatewayipRepository, hostService)
+	logRepository := repository.NewLogRepository(repositoryRepository)
+	logService := service.NewLogService(serviceService, logRepository)
+	gatewayipService := service.NewGatewayipService(serviceService, gatewayipRepository, hostService, logService)
 	wafFormatterService := service.NewWafFormatterService(serviceService, globalLimitRepository, hostRepository, requiredService, parserService, tcpforwardingRepository, udpForWardingRepository, webForwardingRepository, rabbitMQ, hostService, gatewayipRepository, gatewayipService, cdnService)
 	proxyRepository := repository.NewProxyRepository(repositoryRepository)
 	proxyService := service.NewProxyService(serviceService, proxyRepository, cdnService)
@@ -87,7 +89,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 // wire.go:
 
-var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewMongoClient, repository.NewCasbinEnforcer, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldPublicIpRepository, repository.NewHostRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldSdkIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, repository.NewCdnRepository, repository.NewExpiredRepository, repository.NewProxyRepository, repository.NewGatewayipRepository)
+var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewMongoClient, repository.NewCasbinEnforcer, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldPublicIpRepository, repository.NewHostRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldSdkIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, repository.NewCdnRepository, repository.NewExpiredRepository, repository.NewProxyRepository, repository.NewGatewayipRepository, repository.NewLogRepository)
 
 var taskSet = wire.NewSet(task.NewTask, task.NewUserTask, task.NewGameShieldTask, task.NewWafTask)
 
@@ -95,7 +97,7 @@ var jobSet = wire.NewSet(job.NewJob, job.NewUserJob, job.NewWhitelistJob)
 
 var serverSet = wire.NewSet(server.NewTaskServer, server.NewJobServer)
 
-var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, service.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewHostService, service.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewGameShieldUserIpService, service.NewWafFormatterService, service.NewCdnService, service.NewRequestService, service.NewTcpforwardingService, service.NewUdpForWardingService, service.NewWebForwardingService, service.NewProxyService, service.NewSslCertService, service.NewWebsocketService, service.NewGatewayipService)
+var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, service.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewHostService, service.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewGameShieldUserIpService, service.NewWafFormatterService, service.NewCdnService, service.NewRequestService, service.NewTcpforwardingService, service.NewUdpForWardingService, service.NewWebForwardingService, service.NewProxyService, service.NewSslCertService, service.NewWebsocketService, service.NewGatewayipService, service.NewLogService)
 
 // build App
 func newApp(task2 *server.TaskServer,

+ 58 - 27
internal/repository/gatewayip.go

@@ -6,6 +6,7 @@ import (
 	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"gorm.io/gorm"
+	"gorm.io/gorm/clause"
 )
 
 type GatewayipRepository interface {
@@ -17,7 +18,7 @@ type GatewayipRepository interface {
 	GetGatewayipByHostIdAll(ctx context.Context, hostId int64) (*model.Gatewayip, error)
 	UpdateGatewayipByHostId(ctx context.Context, req model.Gatewayip) error
 	DeleteGatewayipByHostId(ctx context.Context, hostId int64) error
-	GetIpWhereHostIdNull(ctx context.Context,req v1.GlobalLimitRequireResponse) error
+	GetIpWhereHostIdNull(ctx context.Context,req v1.GlobalLimitRequireResponse) ([]string,error)
 	CleanIPByHostId(ctx context.Context, hostId []int64) error
 	GetGatewayipOnlyIpByHostIdAll(ctx context.Context, hostId int64) ([]string, error)
 }
@@ -70,67 +71,97 @@ func (r *gatewayipRepository) DeleteGatewayipByHostId(ctx context.Context, hostI
 }
 
 
-func (r *gatewayipRepository) GetIpWhereHostIdNull(ctx context.Context,req v1.GlobalLimitRequireResponse) error {
+func (r *gatewayipRepository) GetIpWhereHostIdNull(ctx context.Context,req v1.GlobalLimitRequireResponse) ([]string,error) {
 	if req.IpCount <= 0 {
-		return fmt.Errorf("套餐IP数量错误, 请联系客服")
+		return nil, fmt.Errorf("套餐IP数量错误, 请联系客服")
 	}
 	if req.HostId <= 0 {
-		return fmt.Errorf("主机ID错误, 请联系客服")
+		return nil, fmt.Errorf("主机ID错误, 请联系客服")
 	}
 
 	var count int64
 	err := r.DB(ctx).Model(&model.Gatewayip{}).Where("host_id = ?", req.HostId).Count(&count).Error
 	if err != nil {
-		return err
+		return nil, err
 	}
 	if count >= int64(req.IpCount) {
-		return nil
+		return nil, nil // IP数量已足够,无需操作
 	}
 
-	req.IpCount = int(int64(req.IpCount) - count)
+	neededIpCount := int(int64(req.IpCount) - count)
 
-	// 使用事务保证操作的原子性
-	return r.DB(ctx).Transaction(func(tx *gorm.DB) error {
-		var idsToAssign []uint // 只需一个切片来接收ID
+	// 这个切片仍然需要是 model.Gatewayip 类型,因为它需要临时持有从数据库查出的完整对象
+	var assignedIPs []model.Gatewayip
 
-		// 步骤 1: 查询所需数量的可用IP ID。使用 Limit 可以提升性能,避免捞出所有可用IP。
+	// 使用事务保证操作的原子性
+	err = r.DB(ctx).Transaction(func(tx *gorm.DB) error {
+		// 步骤 1: 查询并锁定所需数量的可用IP对象
+		// 我们仍然需要完整的对象,因为后续更新需要用到 ID
 		err := tx.Model(&model.Gatewayip{}).
+			Clauses(clause.Locking{Strength: "UPDATE"}).
 			Where("operator = ?", req.Operator).
 			Where("ban_udp = ?", req.IsBanUdp).
 			Where("ban_overseas = ?", req.IsBanOverseas).
-			Where("node_area = ?",req.NodeArea).
-			Where("host_id IS NULL OR host_id = ?", 0).
+			Where("node_area = ?", req.NodeArea).
+			Where("host_id IS NULL OR host_id = 0").
 			Order("id ASC").
-			Limit(req.IpCount). // 优化点:直接用Limit限制查询数量
-			Pluck("id", &idsToAssign).Error
+			Limit(neededIpCount).
+			Find(&assignedIPs).Error
 
 		if err != nil {
-			return err // 查询出错,事务回滚
+			return err
+		}
+
+		// 步骤 2: 检查库存
+		if len(assignedIPs) < neededIpCount {
+			return fmt.Errorf("IP库存不足, 需要 %d 个, 实际可用 %d 个, 请联系客服补充", neededIpCount, len(assignedIPs))
+		}
+
+		if len(assignedIPs) == 0 {
+			return nil
 		}
 
-		// 步骤 2: 判断实际查到的数量是否足够
-		if len(idsToAssign) < req.IpCount {
-			return fmt.Errorf("库存不足, 请联系客服补充") // 数量不足,返回特定错误,事务回滚
+		// 步骤 3: 提取ID用于更新
+		var idsToUpdate []int
+		for _, ip := range assignedIPs {
+			idsToUpdate = append(idsToUpdate, ip.Id)
 		}
 
-		// 步骤 3: 更新这些IP的 host_id
-		// 注意:因为上面已经Limit了,所以idsToAssign的长度就是我们要更新的数量
+		// 步骤 4: 更新这些IP的 host_id
 		updateResult := tx.Model(&model.Gatewayip{}).
-			Where("id IN ?", idsToAssign).
+			Where("id IN ?", idsToUpdate).
 			Update("host_id", req.HostId)
 
 		if updateResult.Error != nil {
-			return updateResult.Error // 更新失败,事务回滚
+			return updateResult.Error
 		}
 
-		// (可选) 健壮性检查
-		if updateResult.RowsAffected != int64(req.IpCount) {
-			return fmt.Errorf("IP分配异常: 期望更新 %d 条记录, 实际更新了 %d 条", req.IpCount, updateResult.RowsAffected)
+		if updateResult.RowsAffected != int64(len(idsToUpdate)) {
+			return fmt.Errorf("IP分配异常: 期望更新 %d 条记录, 实际更新了 %d 条", len(idsToUpdate), updateResult.RowsAffected)
 		}
 
-		// 返回 nil, GORM 会提交事务
 		return nil
 	})
+
+	// 事务执行后,检查是否有错误
+	if err != nil {
+		return nil, err
+	}
+
+	// 如果事务成功,且分配了IP (assignedIPs不为空)
+	// *** 核心改动点 ***
+	// 创建一个新的字符串切片,用于存放最终要返回的IP地址
+	var ipStrings []string
+	if len(assignedIPs) > 0 {
+		ipStrings = make([]string, 0, len(assignedIPs)) // 预分配容量以提高性能
+		for _, ip := range assignedIPs {
+
+			ipStrings = append(ipStrings, ip.Ip)
+		}
+	}
+
+	// 返回IP地址字符串切片和 nil 错误
+	return ipStrings, nil
 }
 
 func (r *gatewayipRepository) CleanIPByHostId(ctx context.Context, hostId []int64) error {

+ 22 - 2
internal/service/gatewayip.go

@@ -2,6 +2,7 @@ package service
 
 import (
 	"context"
+	"encoding/json"
 	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
@@ -11,16 +12,19 @@ type GatewayipService interface {
 	GetGatewayip(ctx context.Context, id int64) (*model.Gatewayip, error)
 	GetGatewayipOnlyIpByHostIdAll(ctx context.Context, hostId int64,uid int64) ([]string, error)
 	GetGatewayipByHostIdFirst(ctx context.Context, hostId int64,uid int64) (string, error)
+	AddIpWhereHostIdNull(ctx context.Context, hostId int64,uid int64) error
 }
 func NewGatewayipService(
     service *Service,
     gatewayipRepository repository.GatewayipRepository,
 	host HostService,
+	log LogService,
 ) GatewayipService {
 	return &gatewayipService{
 		Service:        service,
 		gatewayipRepository: gatewayipRepository,
 		host : host,
+		log : log,
 	}
 }
 
@@ -28,6 +32,7 @@ type gatewayipService struct {
 	*Service
 	gatewayipRepository repository.GatewayipRepository
 	host HostService
+	log LogService
 }
 
 func (s *gatewayipService) GetGatewayip(ctx context.Context, id int64) (*model.Gatewayip, error) {
@@ -40,7 +45,7 @@ func (s *gatewayipService) AddIpWhereHostIdNull(ctx context.Context, hostId int6
 		return  err
 	}
 
-	if err := s.gatewayipRepository.GetIpWhereHostIdNull(ctx, v1.GlobalLimitRequireResponse{
+	 ips, err := s.gatewayipRepository.GetIpWhereHostIdNull(ctx, v1.GlobalLimitRequireResponse{
 		HostId:              int(hostId),
 		Bps:                 config.Bps,
 		MaxBytesMonth:       config.MaxBytesMonth,
@@ -49,10 +54,25 @@ func (s *gatewayipService) AddIpWhereHostIdNull(ctx context.Context, hostId int6
 		NodeArea:            config.NodeArea,
 		ConfigMaxProtection: config.ConfigMaxProtection,
 		IsBanUdp:            config.IsBanUdp,
-	}); err != nil {
+	});
+	 if err != nil {
 		return  err
 	}
 
+	ipsJson, err := json.Marshal(ips)
+	if err != nil {
+		return err
+	}
+
+	if err = s.log.AddLog(ctx, &model.Log{
+		Uid: uid,
+		Api: "AddIpWhereHostIdNull",
+		Message: "分配网关组IP",
+		ExtraData: ipsJson,
+	}); err != nil {
+		return err
+	}
+
 	return nil
 }
 

+ 5 - 2
internal/service/globallimit.go

@@ -46,6 +46,7 @@ func NewGlobalLimitService(
 	udpForWarding UdpForWardingService,
 	webForWarding WebForwardingService,
 	gatewayIpRep repository.GatewayipRepository,
+	gatywayIp GatewayipService,
 ) GlobalLimitService {
 	return &globalLimitService{
 		Service:               service,
@@ -68,6 +69,7 @@ func NewGlobalLimitService(
 		udpForWarding:         udpForWarding,
 		webForWarding:         webForWarding,
 		gatewayIpRep:             gatewayIpRep,
+		gatewayIp: 				gatywayIp,
 	}
 }
 
@@ -92,6 +94,7 @@ type globalLimitService struct {
 	udpForWarding         UdpForWardingService
 	webForWarding         WebForwardingService
 	gatewayIpRep          repository.GatewayipRepository
+	gatewayIp 			GatewayipService
 }
 
 func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
@@ -229,7 +232,7 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 	var userId int64
 	var groupId int64
 	g.Go(func() error {
-		e := s.gatewayIpRep.GetIpWhereHostIdNull(gCtx, require)
+		e := s.gatewayIp.AddIpWhereHostIdNull(gCtx, int64(req.HostId),int64(req.Uid))
 		if e != nil {
 			return fmt.Errorf("获取网关组失败: %w", e)
 		}
@@ -388,7 +391,7 @@ func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalL
 		return err
 	}
 	if gatewayIp != nil {
-		err = s.gatewayIpRep.GetIpWhereHostIdNull(ctx, require)
+		err = s.gatewayIp.AddIpWhereHostIdNull(ctx, int64(req.HostId), int64(req.Uid))
 		if err != nil {
 			return fmt.Errorf("获取网关组失败: %w", err)
 		}