gatewayip.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. package waf
  2. import (
  3. "context"
  4. "fmt"
  5. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  6. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  8. "gorm.io/gorm"
  9. "gorm.io/gorm/clause"
  10. )
  11. type GatewayipRepository interface {
  12. GetGatewayip(ctx context.Context, id int64) (*model.Gatewayip, error)
  13. AddGatewayip(ctx context.Context, req model.Gatewayip) error
  14. EditGatewayip(ctx context.Context, req model.Gatewayip) error
  15. DeleteGatewayip(ctx context.Context, req model.Gatewayip) error
  16. GetGatewayipByHostIdFirst(ctx context.Context, hostId int64) (string, error)
  17. GetGatewayipByHostIdAll(ctx context.Context, hostId int64) (*model.Gatewayip, error)
  18. UpdateGatewayipByHostId(ctx context.Context, req model.Gatewayip) error
  19. DeleteGatewayipByHostId(ctx context.Context, hostId int64) error
  20. GetIpWhereHostIdNull(ctx context.Context,req v1.GlobalLimitRequireResponse) ([]string,error)
  21. CleanIPByHostId(ctx context.Context, hostId []int64) error
  22. GetGatewayipOnlyIpByHostIdAll(ctx context.Context, hostId int64) ([]string, error)
  23. }
  24. func NewGatewayipRepository(
  25. repository *repository.Repository,
  26. ) GatewayipRepository {
  27. return &gatewayipRepository{
  28. Repository: repository,
  29. }
  30. }
  31. type gatewayipRepository struct {
  32. *repository.Repository
  33. }
  34. func (r *gatewayipRepository) GetGatewayip(ctx context.Context, id int64) (*model.Gatewayip, error) {
  35. var req model.Gatewayip
  36. return &req, r.DB(ctx).Where("id = ?", id).First(&req).Error
  37. }
  38. func (r *gatewayipRepository) AddGatewayip(ctx context.Context, req model.Gatewayip) error {
  39. return r.DB(ctx).Create(&req).Error
  40. }
  41. func (r *gatewayipRepository) EditGatewayip(ctx context.Context, req model.Gatewayip) error {
  42. return r.DB(ctx).Model(&model.Gatewayip{}).Where("id = ?", req.Id).Updates(req).Error
  43. }
  44. func (r *gatewayipRepository) DeleteGatewayip(ctx context.Context, req model.Gatewayip) error {
  45. return r.DB(ctx).Model(&model.Gatewayip{}).Where("id = ?", req.Id).Delete(req).Error
  46. }
  47. func (r *gatewayipRepository) GetGatewayipByHostIdFirst(ctx context.Context, hostId int64) (string, error) {
  48. var req string
  49. return req, r.DB(ctx).Model(&model.Gatewayip{}).Where("host_id = ?", hostId).Pluck("ip", &req).Error
  50. }
  51. func (r *gatewayipRepository) GetGatewayipByHostIdAll(ctx context.Context, hostId int64) (*model.Gatewayip, error) {
  52. var req model.Gatewayip
  53. return &req, r.DB(ctx).Where("host_id = ?", hostId).Find(&req).Error
  54. }
  55. func (r *gatewayipRepository) UpdateGatewayipByHostId(ctx context.Context, req model.Gatewayip) error {
  56. return r.DB(ctx).Where("host_id = ?", req.HostId).Updates(&req).Error
  57. }
  58. func (r *gatewayipRepository) DeleteGatewayipByHostId(ctx context.Context, hostId int64) error {
  59. return r.DB(ctx).Model(&model.Gatewayip{}).Where("host_id = ?", hostId).Delete(&model.Gatewayip{}).Error
  60. }
  61. func (r *gatewayipRepository) GetIpWhereHostIdNull(ctx context.Context,req v1.GlobalLimitRequireResponse) ([]string,error) {
  62. if req.IpCount <= 0 {
  63. return nil, fmt.Errorf("套餐IP数量错误, 请联系客服")
  64. }
  65. if req.HostId <= 0 {
  66. return nil, fmt.Errorf("主机ID错误, 请联系客服")
  67. }
  68. var count int64
  69. err := r.DB(ctx).Model(&model.Gatewayip{}).Where("host_id = ?", req.HostId).Count(&count).Error
  70. if err != nil {
  71. return nil, err
  72. }
  73. if count >= int64(req.IpCount) {
  74. return nil, nil // IP数量已足够,无需操作
  75. }
  76. neededIpCount := int(int64(req.IpCount) - count)
  77. // 这个切片仍然需要是 model.Gatewayip 类型,因为它需要临时持有从数据库查出的完整对象
  78. var assignedIPs []model.Gatewayip
  79. // 使用事务保证操作的原子性
  80. err = r.DB(ctx).Transaction(func(tx *gorm.DB) error {
  81. // 步骤 1: 查询并锁定所需数量的可用IP对象
  82. // 我们仍然需要完整的对象,因为后续更新需要用到 ID
  83. err := tx.Model(&model.Gatewayip{}).
  84. Clauses(clause.Locking{Strength: "UPDATE"}).
  85. Where("operator = ?", req.Operator).
  86. Where("ban_udp = ?", req.IsBanUdp).
  87. Where("ban_overseas = ?", req.IsBanOverseas).
  88. Where("node_area = ?", req.NodeArea).
  89. Where("host_id IS NULL OR host_id = 0").
  90. Order("id ASC").
  91. Limit(neededIpCount).
  92. Find(&assignedIPs).Error
  93. if err != nil {
  94. return err
  95. }
  96. // 步骤 2: 检查库存
  97. if len(assignedIPs) < neededIpCount {
  98. return fmt.Errorf("IP库存不足, 需要 %d 个, 实际可用 %d 个, 请联系客服补充", neededIpCount, len(assignedIPs))
  99. }
  100. if len(assignedIPs) == 0 {
  101. return nil
  102. }
  103. // 步骤 3: 提取ID用于更新
  104. var idsToUpdate []int
  105. for _, ip := range assignedIPs {
  106. idsToUpdate = append(idsToUpdate, ip.Id)
  107. }
  108. // 步骤 4: 更新这些IP的 host_id
  109. updateResult := tx.Model(&model.Gatewayip{}).
  110. Where("id IN ?", idsToUpdate).
  111. Update("host_id", req.HostId)
  112. if updateResult.Error != nil {
  113. return updateResult.Error
  114. }
  115. if updateResult.RowsAffected != int64(len(idsToUpdate)) {
  116. return fmt.Errorf("IP分配异常: 期望更新 %d 条记录, 实际更新了 %d 条", len(idsToUpdate), updateResult.RowsAffected)
  117. }
  118. return nil
  119. })
  120. // 事务执行后,检查是否有错误
  121. if err != nil {
  122. return nil, err
  123. }
  124. // 如果事务成功,且分配了IP (assignedIPs不为空)
  125. // *** 核心改动点 ***
  126. // 创建一个新的字符串切片,用于存放最终要返回的IP地址
  127. var ipStrings []string
  128. if len(assignedIPs) > 0 {
  129. ipStrings = make([]string, 0, len(assignedIPs)) // 预分配容量以提高性能
  130. for _, ip := range assignedIPs {
  131. ipStrings = append(ipStrings, ip.Ip)
  132. }
  133. }
  134. // 返回IP地址字符串切片和 nil 错误
  135. return ipStrings, nil
  136. }
  137. func (r *gatewayipRepository) CleanIPByHostId(ctx context.Context, hostId []int64) error {
  138. return r.DB(ctx).Model(&model.Gatewayip{}).Where("host_id IN ?", hostId).Update("host_id", 0).Error
  139. }
  140. func (r *gatewayipRepository) GetGatewayipOnlyIpByHostIdAll(ctx context.Context, hostId int64) ([]string, error) {
  141. var req []string
  142. return req, r.DB(ctx).Model(&model.Gatewayip{}).Where("host_id = ?", hostId).Pluck("ip", &req).Error
  143. }