udpforwarding.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package repository
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  7. "go.mongodb.org/mongo-driver/bson"
  8. "go.mongodb.org/mongo-driver/bson/primitive"
  9. "go.mongodb.org/mongo-driver/mongo"
  10. "time"
  11. )
  12. type UdpForWardingRepository interface {
  13. GetUdpForWarding(ctx context.Context, id int64) (*model.UdpForWarding, error)
  14. AddUdpForwarding(ctx context.Context, req *model.UdpForWarding) (int, error)
  15. EditUdpForwarding(ctx context.Context, req *model.UdpForWarding) error
  16. DeleteUdpForwarding(ctx context.Context, id int64) error
  17. GetUdpForwardingWafUdpIdById(ctx context.Context, id int) (int, error)
  18. GetUdpForwardingPortCountByHostId(ctx context.Context, hostId int) (int64, error)
  19. AddUdpForwardingIps(ctx context.Context, req model.UdpForwardingRule) (primitive.ObjectID, error)
  20. EditUdpForwardingIps(ctx context.Context, req model.UdpForwardingRule) error
  21. GetTcpForwardingByID(ctx context.Context, udpId int) (*model.UdpForwardingRule, error)
  22. }
  23. func NewUdpForWardingRepository(
  24. repository *Repository,
  25. ) UdpForWardingRepository {
  26. return &udpForWardingRepository{
  27. Repository: repository,
  28. }
  29. }
  30. type udpForWardingRepository struct {
  31. *Repository
  32. }
  33. func (r *udpForWardingRepository) GetUdpForWarding(ctx context.Context, id int64) (*model.UdpForWarding, error) {
  34. var udpForWarding model.UdpForWarding
  35. if err := r.db.Where("id = ?", id).First(&udpForWarding).Error; err != nil {
  36. return nil, err
  37. }
  38. return &udpForWarding, nil
  39. }
  40. func (r *udpForWardingRepository) AddUdpForwarding(ctx context.Context, req *model.UdpForWarding) (int, error) {
  41. if err := r.db.Create(req).Error; err != nil {
  42. return 0, err
  43. }
  44. return req.Id, nil
  45. }
  46. func (r *udpForWardingRepository) EditUdpForwarding(ctx context.Context, req *model.UdpForWarding) error {
  47. if err := r.db.Updates(req).Error; err != nil {
  48. return err
  49. }
  50. return nil
  51. }
  52. func (r *udpForWardingRepository) DeleteUdpForwarding(ctx context.Context, id int64) error {
  53. if err := r.db.Where("id = ?", id).Delete(&model.UdpForWarding{}).Error; err != nil {
  54. return err
  55. }
  56. return nil
  57. }
  58. func (r *udpForWardingRepository) GetUdpForwardingWafUdpIdById(ctx context.Context, id int) (int, error) {
  59. var WafUdpId int
  60. if err := r.db.Model(&model.UdpForWarding{}).Where("id = ?", id).Select("waf_udp_id").Find(&WafUdpId).Error; err != nil {
  61. return 0, err
  62. }
  63. return WafUdpId, nil
  64. }
  65. func (r *udpForWardingRepository) GetUdpForwardingPortCountByHostId(ctx context.Context, hostId int) (int64, error) {
  66. var count int64
  67. if err := r.db.Model(&model.UdpForWarding{}).Where("host_id = ?", hostId).Count(&count).Error; err != nil {
  68. return 0, err
  69. }
  70. return count, nil
  71. }
  72. // mongodb 插入
  73. func (r *udpForWardingRepository) AddUdpForwardingIps(ctx context.Context, req model.UdpForwardingRule) (primitive.ObjectID, error) {
  74. collection := r.mongoDB.Collection("udp_forwarding_rules")
  75. req.CreatedAt = time.Now()
  76. result, err := collection.InsertOne(ctx, req)
  77. if err != nil {
  78. return primitive.NilObjectID, fmt.Errorf("插入MongoDB失败: %w", err)
  79. }
  80. // 返回插入文档的ID
  81. return result.InsertedID.(primitive.ObjectID), nil
  82. }
  83. func (r *udpForWardingRepository) EditUdpForwardingIps(ctx context.Context, req model.UdpForwardingRule) error {
  84. collection := r.mongoDB.Collection("udp_forwarding_rules")
  85. updateData := bson.M{}
  86. if req.Uid != 0 {
  87. updateData["uid"] = req.Uid
  88. }
  89. if req.HostId != 0 {
  90. updateData["host_id"] = req.HostId
  91. }
  92. if req.UdpId != 0 {
  93. updateData["udp_id"] = req.UdpId
  94. }
  95. if req.AccessRule != "" {
  96. updateData["access_rule"] = req.AccessRule
  97. }
  98. if len(req.BackendList) > 0 {
  99. updateData["backend_list"] = req.BackendList
  100. }
  101. if len(req.AllowIpList) > 0 {
  102. updateData["allow_ip_list"] = req.AllowIpList
  103. }
  104. if len(req.DenyIpList) > 0 {
  105. updateData["deny_ip_list"] = req.DenyIpList
  106. }
  107. // 始终更新更新时间
  108. updateData["updated_at"] = time.Now()
  109. // 如果没有任何字段需要更新,则直接返回
  110. if len(updateData) == 0 {
  111. return nil
  112. }
  113. // 执行更新
  114. update := bson.M{"$set": updateData}
  115. err := collection.UpdateOne(ctx, bson.M{"udp_id": req.UdpId}, update)
  116. if err != nil {
  117. return fmt.Errorf("更新MongoDB文档失败: %w", err)
  118. }
  119. return nil
  120. }
  121. func (r *udpForWardingRepository) GetTcpForwardingByID(ctx context.Context, udpId int) (*model.UdpForwardingRule, error) {
  122. collection := r.mongoDB.Collection("udp_forwarding_rules")
  123. var result model.UdpForwardingRule
  124. err := collection.Find(ctx, bson.M{"udp_id": udpId}).One(&result)
  125. if err != nil {
  126. if errors.Is(err, mongo.ErrNoDocuments) {
  127. return nil, nil
  128. }
  129. return nil, fmt.Errorf("从MongoDB中获取文档失败: %w", err)
  130. }
  131. return &result, nil
  132. }