package waf import ( "context" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" "gorm.io/gorm" "gorm.io/gorm/clause" ) type GatewayipRepository interface { GetGatewayip(ctx context.Context, id int64) (*model.Gatewayip, error) AddGatewayip(ctx context.Context, req model.Gatewayip) error EditGatewayip(ctx context.Context, req model.Gatewayip) error DeleteGatewayip(ctx context.Context, req model.Gatewayip) error GetGatewayipByHostIdFirst(ctx context.Context, hostId int64) (string, error) GetGatewayipByHostIdAll(ctx context.Context, hostId int64) (*model.Gatewayip, error) UpdateGatewayipByHostId(ctx context.Context, req model.Gatewayip) error DeleteGatewayipByHostId(ctx context.Context, hostId int64) error GetIpWhereHostIdNull(ctx context.Context,req v1.GlobalLimitRequireResponse) ([]string,error) CleanIPByHostId(ctx context.Context, hostId []int64) error GetGatewayipOnlyIpByHostIdAll(ctx context.Context, hostId int64) ([]string, error) } func NewGatewayipRepository( repository *repository.Repository, ) GatewayipRepository { return &gatewayipRepository{ Repository: repository, } } type gatewayipRepository struct { *repository.Repository } func (r *gatewayipRepository) GetGatewayip(ctx context.Context, id int64) (*model.Gatewayip, error) { var req model.Gatewayip return &req, r.DB(ctx).Where("id = ?", id).First(&req).Error } func (r *gatewayipRepository) AddGatewayip(ctx context.Context, req model.Gatewayip) error { return r.DB(ctx).Create(&req).Error } func (r *gatewayipRepository) EditGatewayip(ctx context.Context, req model.Gatewayip) error { return r.DB(ctx).Model(&model.Gatewayip{}).Where("id = ?", req.Id).Updates(req).Error } func (r *gatewayipRepository) DeleteGatewayip(ctx context.Context, req model.Gatewayip) error { return r.DB(ctx).Model(&model.Gatewayip{}).Where("id = ?", req.Id).Delete(req).Error } func (r *gatewayipRepository) GetGatewayipByHostIdFirst(ctx context.Context, hostId int64) (string, error) { var req string return req, r.DB(ctx).Model(&model.Gatewayip{}).Where("host_id = ?", hostId).Pluck("ip", &req).Error } func (r *gatewayipRepository) GetGatewayipByHostIdAll(ctx context.Context, hostId int64) (*model.Gatewayip, error) { var req model.Gatewayip return &req, r.DB(ctx).Where("host_id = ?", hostId).Find(&req).Error } func (r *gatewayipRepository) UpdateGatewayipByHostId(ctx context.Context, req model.Gatewayip) error { return r.DB(ctx).Where("host_id = ?", req.HostId).Updates(&req).Error } func (r *gatewayipRepository) DeleteGatewayipByHostId(ctx context.Context, hostId int64) error { return r.DB(ctx).Model(&model.Gatewayip{}).Where("host_id = ?", hostId).Delete(&model.Gatewayip{}).Error } func (r *gatewayipRepository) GetIpWhereHostIdNull(ctx context.Context,req v1.GlobalLimitRequireResponse) ([]string,error) { if req.IpCount <= 0 { return nil, fmt.Errorf("套餐IP数量错误, 请联系客服") } if req.HostId <= 0 { return nil, fmt.Errorf("主机ID错误, 请联系客服") } var count int64 err := r.DB(ctx).Model(&model.Gatewayip{}).Where("host_id = ?", req.HostId).Count(&count).Error if err != nil { return nil, err } if count >= int64(req.IpCount) { return nil, nil // IP数量已足够,无需操作 } neededIpCount := int(int64(req.IpCount) - count) // 这个切片仍然需要是 model.Gatewayip 类型,因为它需要临时持有从数据库查出的完整对象 var assignedIPs []model.Gatewayip // 使用事务保证操作的原子性 err = r.DB(ctx).Transaction(func(tx *gorm.DB) error { // 步骤 1: 查询并锁定所需数量的可用IP对象 // 我们仍然需要完整的对象,因为后续更新需要用到 ID err := tx.Model(&model.Gatewayip{}). Clauses(clause.Locking{Strength: "UPDATE"}). Where("operator = ?", req.Operator). Where("ban_udp = ?", req.IsBanUdp). Where("ban_overseas = ?", req.IsBanOverseas). Where("node_area = ?", req.NodeArea). Where("host_id IS NULL OR host_id = 0"). Order("id ASC"). Limit(neededIpCount). Find(&assignedIPs).Error if err != nil { return err } // 步骤 2: 检查库存 if len(assignedIPs) < neededIpCount { return fmt.Errorf("IP库存不足, 需要 %d 个, 实际可用 %d 个, 请联系客服补充", neededIpCount, len(assignedIPs)) } if len(assignedIPs) == 0 { return nil } // 步骤 3: 提取ID用于更新 var idsToUpdate []int for _, ip := range assignedIPs { idsToUpdate = append(idsToUpdate, ip.Id) } // 步骤 4: 更新这些IP的 host_id updateResult := tx.Model(&model.Gatewayip{}). Where("id IN ?", idsToUpdate). Update("host_id", req.HostId) if updateResult.Error != nil { return updateResult.Error } if updateResult.RowsAffected != int64(len(idsToUpdate)) { return fmt.Errorf("IP分配异常: 期望更新 %d 条记录, 实际更新了 %d 条", len(idsToUpdate), updateResult.RowsAffected) } return nil }) // 事务执行后,检查是否有错误 if err != nil { return nil, err } // 如果事务成功,且分配了IP (assignedIPs不为空) // *** 核心改动点 *** // 创建一个新的字符串切片,用于存放最终要返回的IP地址 var ipStrings []string if len(assignedIPs) > 0 { ipStrings = make([]string, 0, len(assignedIPs)) // 预分配容量以提高性能 for _, ip := range assignedIPs { ipStrings = append(ipStrings, ip.Ip) } } // 返回IP地址字符串切片和 nil 错误 return ipStrings, nil } func (r *gatewayipRepository) CleanIPByHostId(ctx context.Context, hostId []int64) error { return r.DB(ctx).Model(&model.Gatewayip{}).Where("host_id IN ?", hostId).Update("host_id", 0).Error } func (r *gatewayipRepository) GetGatewayipOnlyIpByHostIdAll(ctx context.Context, hostId int64) ([]string, error) { var req []string return req, r.DB(ctx).Model(&model.Gatewayip{}).Where("host_id = ?", hostId).Pluck("ip", &req).Error }