tcpforwarding.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. package repository
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  7. "go.mongodb.org/mongo-driver/bson"
  8. "go.mongodb.org/mongo-driver/bson/primitive"
  9. "go.mongodb.org/mongo-driver/mongo"
  10. "time"
  11. )
  12. type TcpforwardingRepository interface {
  13. GetTcpforwarding(ctx context.Context, id int64) (*model.Tcpforwarding, error)
  14. AddTcpforwarding(ctx context.Context, req *model.Tcpforwarding) (int, error)
  15. EditTcpforwarding(ctx context.Context, req *model.Tcpforwarding) error
  16. DeleteTcpforwarding(ctx context.Context, id int64) error
  17. GetTcpforwardingWafTcpIdById(ctx context.Context, id int) (int, error)
  18. GetTcpForwardingPortCountByHostId(ctx context.Context, hostId int) (int64, error)
  19. AddTcpforwardingIps(ctx context.Context,req model.TcpForwardingRule) (primitive.ObjectID, error)
  20. EditTcpforwardingIps(ctx context.Context, req model.TcpForwardingRule) error
  21. GetTcpForwardingByID(ctx context.Context, tcpId int) (*model.TcpForwardingRule, error)
  22. }
  23. func NewTcpforwardingRepository(
  24. repository *Repository,
  25. ) TcpforwardingRepository {
  26. return &tcpforwardingRepository{
  27. Repository: repository,
  28. }
  29. }
  30. type tcpforwardingRepository struct {
  31. *Repository
  32. }
  33. func (r *tcpforwardingRepository) GetTcpforwarding(ctx context.Context, id int64) (*model.Tcpforwarding, error) {
  34. var tcpforwarding model.Tcpforwarding
  35. if err := r.db.Where("id = ?", id).First(&tcpforwarding).Error; err != nil {
  36. return nil, err
  37. }
  38. return &tcpforwarding, nil
  39. }
  40. func (r *tcpforwardingRepository) AddTcpforwarding(ctx context.Context, req *model.Tcpforwarding) (int, error) {
  41. if err := r.db.Create(req).Error; err != nil {
  42. return 0, err
  43. }
  44. return req.Id, nil
  45. }
  46. func (r *tcpforwardingRepository) EditTcpforwarding(ctx context.Context, req *model.Tcpforwarding) error {
  47. if err := r.db.Updates(req).Error; err != nil {
  48. return err
  49. }
  50. return nil
  51. }
  52. func (r *tcpforwardingRepository) DeleteTcpforwarding(ctx context.Context, id int64) error {
  53. if err := r.db.Where("id = ?", id).Delete(&model.Tcpforwarding{}).Error; err != nil {
  54. return err
  55. }
  56. return nil
  57. }
  58. func (r *tcpforwardingRepository) GetTcpforwardingWafTcpIdById(ctx context.Context, id int) (int, error) {
  59. var WafTcpId int
  60. if err := r.db.Model(&model.Tcpforwarding{}).Where("id = ?", id).Select("waf_tcp_id").Find(&WafTcpId).Error; err != nil {
  61. return 0, err
  62. }
  63. return WafTcpId, nil
  64. }
  65. func (r *tcpforwardingRepository) GetTcpForwardingPortCountByHostId(ctx context.Context, hostId int) (int64, error) {
  66. var count int64
  67. if err := r.db.Model(&model.Tcpforwarding{}).Where("host_id = ?", hostId).Count(&count).Error; err != nil {
  68. return 0, err
  69. }
  70. return count, nil
  71. }
  72. //mongodb 插入
  73. func (r *tcpforwardingRepository) AddTcpforwardingIps(ctx context.Context,req model.TcpForwardingRule) (primitive.ObjectID, error) {
  74. collection := r.mongoDB.Collection("tcp_forwarding_rules")
  75. req.CreatedAt = time.Now()
  76. result, err := collection.InsertOne(ctx, req)
  77. if err != nil {
  78. return primitive.NilObjectID, fmt.Errorf("插入MongoDB失败: %w", err)
  79. }
  80. // 返回插入文档的ID
  81. return result.InsertedID.(primitive.ObjectID), nil
  82. }
  83. func (r *tcpforwardingRepository) EditTcpforwardingIps(ctx context.Context, req model.TcpForwardingRule) error {
  84. collection := r.mongoDB.Collection("tcp_forwarding_rules")
  85. updateData := bson.M{}
  86. if req.Uid != 0 {
  87. updateData["uid"] = req.Uid
  88. }
  89. if req.HostId != 0 {
  90. updateData["host_id"] = req.HostId
  91. }
  92. if req.TcpId != 0 {
  93. updateData["tcp_id"] = req.TcpId
  94. }
  95. if req.AccessRule != "" {
  96. updateData["access_rule"] = req.AccessRule
  97. }
  98. if len(req.BackendList) > 0 {
  99. updateData["backend_list"] = req.BackendList
  100. }
  101. if len(req.AllowIpList) > 0 {
  102. updateData["allow_ip_list"] = req.AllowIpList
  103. }
  104. if len(req.DenyIpList) > 0 {
  105. updateData["deny_ip_list"] = req.DenyIpList
  106. }
  107. // 始终更新更新时间
  108. updateData["updated_at"] = time.Now()
  109. // 如果没有任何字段需要更新,则直接返回
  110. if len(updateData) == 0 {
  111. return nil
  112. }
  113. // 执行更新
  114. update := bson.M{"$set": updateData}
  115. err := collection.UpdateOne(ctx, bson.M{"tcp_id": req.TcpId}, update)
  116. if err != nil {
  117. if errors.Is(err, mongo.ErrNoDocuments) {
  118. return fmt.Errorf("记录不存在")
  119. }
  120. return fmt.Errorf("更新MongoDB文档失败: %w", err)
  121. }
  122. return nil
  123. }
  124. func (r *tcpforwardingRepository) GetTcpForwardingByID(ctx context.Context, tcpId int) (*model.TcpForwardingRule, error) {
  125. collection := r.mongoDB.Collection("tcp_forwarding_rules")
  126. var res model.TcpForwardingRule
  127. err := collection.Find(ctx, bson.M{"tcp_id": tcpId}).One(&res)
  128. if err != nil {
  129. if errors.Is(err, mongo.ErrNoDocuments) {
  130. return nil, fmt.Errorf("记录不存在")
  131. }
  132. return nil, fmt.Errorf("查询MongoDB失败: %w", err)
  133. }
  134. return &res, nil
  135. }