package repository import ( "context" "errors" "fmt" "github.com/go-nunu/nunu-layout-advanced/internal/model" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "time" ) type TcpforwardingRepository interface { GetTcpforwarding(ctx context.Context, id int64) (*model.Tcpforwarding, error) AddTcpforwarding(ctx context.Context, req *model.Tcpforwarding) (int, error) EditTcpforwarding(ctx context.Context, req *model.Tcpforwarding) error DeleteTcpforwarding(ctx context.Context, id int64) error GetTcpforwardingWafTcpIdById(ctx context.Context, id int) (int, error) GetTcpForwardingPortCountByHostId(ctx context.Context, hostId int) (int64, error) AddTcpforwardingIps(ctx context.Context,req model.TcpForwardingRule) (primitive.ObjectID, error) EditTcpforwardingIps(ctx context.Context, req model.TcpForwardingRule) error GetTcpForwardingByID(ctx context.Context, tcpId int) (*model.TcpForwardingRule, error) } func NewTcpforwardingRepository( repository *Repository, ) TcpforwardingRepository { return &tcpforwardingRepository{ Repository: repository, } } type tcpforwardingRepository struct { *Repository } func (r *tcpforwardingRepository) GetTcpforwarding(ctx context.Context, id int64) (*model.Tcpforwarding, error) { var tcpforwarding model.Tcpforwarding if err := r.db.Where("id = ?", id).First(&tcpforwarding).Error; err != nil { return nil, err } return &tcpforwarding, nil } func (r *tcpforwardingRepository) AddTcpforwarding(ctx context.Context, req *model.Tcpforwarding) (int, error) { if err := r.db.Create(req).Error; err != nil { return 0, err } return req.Id, nil } func (r *tcpforwardingRepository) EditTcpforwarding(ctx context.Context, req *model.Tcpforwarding) error { if err := r.db.Updates(req).Error; err != nil { return err } return nil } func (r *tcpforwardingRepository) DeleteTcpforwarding(ctx context.Context, id int64) error { if err := r.db.Where("id = ?", id).Delete(&model.Tcpforwarding{}).Error; err != nil { return err } return nil } func (r *tcpforwardingRepository) GetTcpforwardingWafTcpIdById(ctx context.Context, id int) (int, error) { var WafTcpId int if err := r.db.Model(&model.Tcpforwarding{}).Where("id = ?", id).Select("waf_tcp_id").Find(&WafTcpId).Error; err != nil { return 0, err } return WafTcpId, nil } func (r *tcpforwardingRepository) GetTcpForwardingPortCountByHostId(ctx context.Context, hostId int) (int64, error) { var count int64 if err := r.db.Model(&model.Tcpforwarding{}).Where("host_id = ?", hostId).Count(&count).Error; err != nil { return 0, err } return count, nil } //mongodb 插入 func (r *tcpforwardingRepository) AddTcpforwardingIps(ctx context.Context,req model.TcpForwardingRule) (primitive.ObjectID, error) { collection := r.mongoDB.Collection("tcp_forwarding_rules") req.CreatedAt = time.Now() result, err := collection.InsertOne(ctx, req) if err != nil { return primitive.NilObjectID, fmt.Errorf("插入MongoDB失败: %w", err) } // 返回插入文档的ID return result.InsertedID.(primitive.ObjectID), nil } func (r *tcpforwardingRepository) EditTcpforwardingIps(ctx context.Context, req model.TcpForwardingRule) error { collection := r.mongoDB.Collection("tcp_forwarding_rules") updateData := bson.M{} if req.Uid != 0 { updateData["uid"] = req.Uid } if req.HostId != 0 { updateData["host_id"] = req.HostId } if req.TcpId != 0 { updateData["tcp_id"] = req.TcpId } if req.AccessRule != "" { updateData["access_rule"] = req.AccessRule } if len(req.BackendList) > 0 { updateData["backend_list"] = req.BackendList } if len(req.AllowIpList) > 0 { updateData["allow_ip_list"] = req.AllowIpList } if len(req.DenyIpList) > 0 { updateData["deny_ip_list"] = req.DenyIpList } // 始终更新更新时间 updateData["updated_at"] = time.Now() // 如果没有任何字段需要更新,则直接返回 if len(updateData) == 0 { return nil } // 执行更新 update := bson.M{"$set": updateData} err := collection.UpdateOne(ctx, bson.M{"tcp_id": req.TcpId}, update) if err != nil { if errors.Is(err, mongo.ErrNoDocuments) { return fmt.Errorf("记录不存在") } return fmt.Errorf("更新MongoDB文档失败: %w", err) } return nil } func (r *tcpforwardingRepository) GetTcpForwardingByID(ctx context.Context, tcpId int) (*model.TcpForwardingRule, error) { collection := r.mongoDB.Collection("tcp_forwarding_rules") var res model.TcpForwardingRule err := collection.Find(ctx, bson.M{"tcp_id": tcpId}).One(&res) if err != nil { if errors.Is(err, mongo.ErrNoDocuments) { return nil, fmt.Errorf("记录不存在") } return nil, fmt.Errorf("查询MongoDB失败: %w", err) } return &res, nil }