|
@@ -6,6 +6,7 @@ import (
|
|
|
v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
|
|
|
"github.com/go-nunu/nunu-layout-advanced/internal/model"
|
|
|
"gorm.io/gorm"
|
|
|
+ "gorm.io/gorm/clause"
|
|
|
)
|
|
|
|
|
|
type GatewayipRepository interface {
|
|
@@ -17,7 +18,7 @@ type GatewayipRepository interface {
|
|
|
GetGatewayipByHostIdAll(ctx context.Context, hostId int64) (*model.Gatewayip, error)
|
|
|
UpdateGatewayipByHostId(ctx context.Context, req model.Gatewayip) error
|
|
|
DeleteGatewayipByHostId(ctx context.Context, hostId int64) error
|
|
|
- GetIpWhereHostIdNull(ctx context.Context,req v1.GlobalLimitRequireResponse) error
|
|
|
+ GetIpWhereHostIdNull(ctx context.Context,req v1.GlobalLimitRequireResponse) ([]string,error)
|
|
|
CleanIPByHostId(ctx context.Context, hostId []int64) error
|
|
|
GetGatewayipOnlyIpByHostIdAll(ctx context.Context, hostId int64) ([]string, error)
|
|
|
}
|
|
@@ -70,67 +71,97 @@ func (r *gatewayipRepository) DeleteGatewayipByHostId(ctx context.Context, hostI
|
|
|
}
|
|
|
|
|
|
|
|
|
-func (r *gatewayipRepository) GetIpWhereHostIdNull(ctx context.Context,req v1.GlobalLimitRequireResponse) error {
|
|
|
+func (r *gatewayipRepository) GetIpWhereHostIdNull(ctx context.Context,req v1.GlobalLimitRequireResponse) ([]string,error) {
|
|
|
if req.IpCount <= 0 {
|
|
|
- return fmt.Errorf("套餐IP数量错误, 请联系客服")
|
|
|
+ return nil, fmt.Errorf("套餐IP数量错误, 请联系客服")
|
|
|
}
|
|
|
if req.HostId <= 0 {
|
|
|
- return fmt.Errorf("主机ID错误, 请联系客服")
|
|
|
+ return nil, fmt.Errorf("主机ID错误, 请联系客服")
|
|
|
}
|
|
|
|
|
|
var count int64
|
|
|
err := r.DB(ctx).Model(&model.Gatewayip{}).Where("host_id = ?", req.HostId).Count(&count).Error
|
|
|
if err != nil {
|
|
|
- return err
|
|
|
+ return nil, err
|
|
|
}
|
|
|
if count >= int64(req.IpCount) {
|
|
|
- return nil
|
|
|
+ return nil, nil // IP数量已足够,无需操作
|
|
|
}
|
|
|
|
|
|
- req.IpCount = int(int64(req.IpCount) - count)
|
|
|
+ neededIpCount := int(int64(req.IpCount) - count)
|
|
|
|
|
|
- // 使用事务保证操作的原子性
|
|
|
- return r.DB(ctx).Transaction(func(tx *gorm.DB) error {
|
|
|
- var idsToAssign []uint // 只需一个切片来接收ID
|
|
|
+ // 这个切片仍然需要是 model.Gatewayip 类型,因为它需要临时持有从数据库查出的完整对象
|
|
|
+ var assignedIPs []model.Gatewayip
|
|
|
|
|
|
- // 步骤 1: 查询所需数量的可用IP ID。使用 Limit 可以提升性能,避免捞出所有可用IP。
|
|
|
+ // 使用事务保证操作的原子性
|
|
|
+ err = r.DB(ctx).Transaction(func(tx *gorm.DB) error {
|
|
|
+ // 步骤 1: 查询并锁定所需数量的可用IP对象
|
|
|
+ // 我们仍然需要完整的对象,因为后续更新需要用到 ID
|
|
|
err := tx.Model(&model.Gatewayip{}).
|
|
|
+ Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
|
Where("operator = ?", req.Operator).
|
|
|
Where("ban_udp = ?", req.IsBanUdp).
|
|
|
Where("ban_overseas = ?", req.IsBanOverseas).
|
|
|
- Where("node_area = ?",req.NodeArea).
|
|
|
- Where("host_id IS NULL OR host_id = ?", 0).
|
|
|
+ Where("node_area = ?", req.NodeArea).
|
|
|
+ Where("host_id IS NULL OR host_id = 0").
|
|
|
Order("id ASC").
|
|
|
- Limit(req.IpCount). // 优化点:直接用Limit限制查询数量
|
|
|
- Pluck("id", &idsToAssign).Error
|
|
|
+ Limit(neededIpCount).
|
|
|
+ Find(&assignedIPs).Error
|
|
|
|
|
|
if err != nil {
|
|
|
- return err // 查询出错,事务回滚
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ // 步骤 2: 检查库存
|
|
|
+ if len(assignedIPs) < neededIpCount {
|
|
|
+ return fmt.Errorf("IP库存不足, 需要 %d 个, 实际可用 %d 个, 请联系客服补充", neededIpCount, len(assignedIPs))
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(assignedIPs) == 0 {
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
- // 步骤 2: 判断实际查到的数量是否足够
|
|
|
- if len(idsToAssign) < req.IpCount {
|
|
|
- return fmt.Errorf("库存不足, 请联系客服补充") // 数量不足,返回特定错误,事务回滚
|
|
|
+ // 步骤 3: 提取ID用于更新
|
|
|
+ var idsToUpdate []int
|
|
|
+ for _, ip := range assignedIPs {
|
|
|
+ idsToUpdate = append(idsToUpdate, ip.Id)
|
|
|
}
|
|
|
|
|
|
- // 步骤 3: 更新这些IP的 host_id
|
|
|
- // 注意:因为上面已经Limit了,所以idsToAssign的长度就是我们要更新的数量
|
|
|
+ // 步骤 4: 更新这些IP的 host_id
|
|
|
updateResult := tx.Model(&model.Gatewayip{}).
|
|
|
- Where("id IN ?", idsToAssign).
|
|
|
+ Where("id IN ?", idsToUpdate).
|
|
|
Update("host_id", req.HostId)
|
|
|
|
|
|
if updateResult.Error != nil {
|
|
|
- return updateResult.Error // 更新失败,事务回滚
|
|
|
+ return updateResult.Error
|
|
|
}
|
|
|
|
|
|
- // (可选) 健壮性检查
|
|
|
- if updateResult.RowsAffected != int64(req.IpCount) {
|
|
|
- return fmt.Errorf("IP分配异常: 期望更新 %d 条记录, 实际更新了 %d 条", req.IpCount, updateResult.RowsAffected)
|
|
|
+ if updateResult.RowsAffected != int64(len(idsToUpdate)) {
|
|
|
+ return fmt.Errorf("IP分配异常: 期望更新 %d 条记录, 实际更新了 %d 条", len(idsToUpdate), updateResult.RowsAffected)
|
|
|
}
|
|
|
|
|
|
- // 返回 nil, GORM 会提交事务
|
|
|
return nil
|
|
|
})
|
|
|
+
|
|
|
+ // 事务执行后,检查是否有错误
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ // 如果事务成功,且分配了IP (assignedIPs不为空)
|
|
|
+ // *** 核心改动点 ***
|
|
|
+ // 创建一个新的字符串切片,用于存放最终要返回的IP地址
|
|
|
+ var ipStrings []string
|
|
|
+ if len(assignedIPs) > 0 {
|
|
|
+ ipStrings = make([]string, 0, len(assignedIPs)) // 预分配容量以提高性能
|
|
|
+ for _, ip := range assignedIPs {
|
|
|
+
|
|
|
+ ipStrings = append(ipStrings, ip.Ip)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 返回IP地址字符串切片和 nil 错误
|
|
|
+ return ipStrings, nil
|
|
|
}
|
|
|
|
|
|
func (r *gatewayipRepository) CleanIPByHostId(ctx context.Context, hostId []int64) error {
|