package repository import ( "context" "errors" "fmt" "github.com/go-nunu/nunu-layout-advanced/internal/model" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "time" ) type UdpForWardingRepository interface { GetUdpForWarding(ctx context.Context, id int64) (*model.UdpForWarding, error) AddUdpForwarding(ctx context.Context, req *model.UdpForWarding) (int, error) EditUdpForwarding(ctx context.Context, req *model.UdpForWarding) error DeleteUdpForwarding(ctx context.Context, id int64) error GetUdpForwardingWafUdpIdById(ctx context.Context, id int) (int, error) GetUdpForwardingPortCountByHostId(ctx context.Context, hostId int) (int64, error) AddUdpForwardingIps(ctx context.Context, req model.UdpForwardingRule) (primitive.ObjectID, error) EditUdpForwardingIps(ctx context.Context, req model.UdpForwardingRule) error GetUdpForwardingIpsByID(ctx context.Context, udpId int) (*model.UdpForwardingRule, error) DeleteUdpForwardingIpsById(ctx context.Context, udpId int) error } func NewUdpForWardingRepository( repository *Repository, ) UdpForWardingRepository { return &udpForWardingRepository{ Repository: repository, } } type udpForWardingRepository struct { *Repository } func (r *udpForWardingRepository) GetUdpForWarding(ctx context.Context, id int64) (*model.UdpForWarding, error) { var udpForWarding model.UdpForWarding if err := r.db.Where("id = ?", id).First(&udpForWarding).Error; err != nil { return nil, err } return &udpForWarding, nil } func (r *udpForWardingRepository) AddUdpForwarding(ctx context.Context, req *model.UdpForWarding) (int, error) { if err := r.db.Create(req).Error; err != nil { return 0, err } return req.Id, nil } func (r *udpForWardingRepository) EditUdpForwarding(ctx context.Context, req *model.UdpForWarding) error { if err := r.db.Updates(req).Error; err != nil { return err } return nil } func (r *udpForWardingRepository) DeleteUdpForwarding(ctx context.Context, id int64) error { if err := r.db.Where("id = ?", id).Delete(&model.UdpForWarding{}).Error; err != nil { return err } return nil } func (r *udpForWardingRepository) GetUdpForwardingWafUdpIdById(ctx context.Context, id int) (int, error) { var WafUdpId int if err := r.db.Model(&model.UdpForWarding{}).Where("id = ?", id).Select("waf_udp_id").Find(&WafUdpId).Error; err != nil { return 0, err } return WafUdpId, nil } func (r *udpForWardingRepository) GetUdpForwardingPortCountByHostId(ctx context.Context, hostId int) (int64, error) { var count int64 if err := r.db.Model(&model.UdpForWarding{}).Where("host_id = ?", hostId).Count(&count).Error; err != nil { return 0, err } return count, nil } // mongodb 插入 func (r *udpForWardingRepository) AddUdpForwardingIps(ctx context.Context, req model.UdpForwardingRule) (primitive.ObjectID, error) { collection := r.mongoDB.Collection("udp_forwarding_rules") req.CreatedAt = time.Now() result, err := collection.InsertOne(ctx, req) if err != nil { return primitive.NilObjectID, fmt.Errorf("插入MongoDB失败: %w", err) } // 返回插入文档的ID return result.InsertedID.(primitive.ObjectID), nil } func (r *udpForWardingRepository) EditUdpForwardingIps(ctx context.Context, req model.UdpForwardingRule) error { collection := r.mongoDB.Collection("udp_forwarding_rules") updateData := bson.M{} if req.Uid != 0 { updateData["uid"] = req.Uid } if req.HostId != 0 { updateData["host_id"] = req.HostId } if req.UdpId != 0 { updateData["udp_id"] = req.UdpId } if req.AccessRule != "" { updateData["access_rule"] = req.AccessRule } if len(req.BackendList) > 0 { updateData["backend_list"] = req.BackendList } if len(req.AllowIpList) > 0 { updateData["allow_ip_list"] = req.AllowIpList } if len(req.DenyIpList) > 0 { updateData["deny_ip_list"] = req.DenyIpList } // 始终更新更新时间 updateData["updated_at"] = time.Now() // 如果没有任何字段需要更新,则直接返回 if len(updateData) == 0 { return nil } // 执行更新 update := bson.M{"$set": updateData} err := collection.UpdateOne(ctx, bson.M{"udp_id": req.UdpId}, update) if err != nil { return fmt.Errorf("更新MongoDB文档失败: %w", err) } return nil } func (r *udpForWardingRepository) GetUdpForwardingIpsByID(ctx context.Context, udpId int) (*model.UdpForwardingRule, error) { collection := r.mongoDB.Collection("udp_forwarding_rules") var result model.UdpForwardingRule err := collection.Find(ctx, bson.M{"udp_id": udpId}).One(&result) if err != nil { if errors.Is(err, mongo.ErrNoDocuments) { return nil, nil } return nil, fmt.Errorf("从MongoDB中获取文档失败: %w", err) } return &result, nil } func (r *udpForWardingRepository) DeleteUdpForwardingIpsById(ctx context.Context, udpId int) error { collection := r.mongoDB.Collection("udp_forwarding_rules") err := collection.Remove(ctx, bson.M{"udp_id": udpId}) if err != nil { if errors.Is(err, mongo.ErrNoDocuments) { return fmt.Errorf("记录不存在") } return fmt.Errorf("删除MongoDB文档失败: %w", err) } return nil }