Explorar el Código

feat(tcpforwarding): 重构 TCP 转发功能并添加 MongoDB 支持

- 移除原有的 mongo 包,改用 qmgo 包连接 MongoDB
- 新增 TcpForwardingRule 模型用于 MongoDB 存储
- 实现 MongoDB 的插入和更新操作
- 优化 TCP 转发数据结构和处理逻辑
- 更新 wire 配置,添加 MongoDB 客户端初始化
fusu hace 2 meses
padre
commit
3e3c2cef4a

+ 24 - 4
api/v1/tcpForwarding.go

@@ -1,10 +1,9 @@
 package v1
 
-type TcpForwardingData struct {
-	Id                int    `form:"id" json:"id"`
+type TcpForwardingDataSend struct {
 	WafTcpId          int    `form:"waf_tcp_id" json:"waf_tcp_id"`
 	Tag               string `form:"tag" json:"tag" binding:"required"`
-	Port              int    `form:"port" json:"port" binding:"required"`
+	Port              string    `form:"port" json:"port" binding:"required"`
 	WafGatewayGroupId int    `form:"waf_gateway_group_id" json:"waf_gateway_group_id"`
 	WafTcpLimitRuleId int    `form:"waf_tcp_limit_id" json:"waf_tcp_limit_id"`
 	CcCount           int    `form:"cc_count" json:"cc_count" default:"0"`
@@ -20,6 +19,27 @@ type TcpForwardingData struct {
 	Comment           string `form:"comment" json:"comment"`
 }
 
+
+type TcpForwardingDataRequest struct {
+	Id                int    `form:"id" json:"id"`
+	WafTcpId          int    `form:"waf_tcp_id" json:"waf_tcp_id"`
+	Tag               string `form:"tag" json:"tag" binding:"required"`
+	Port              string    `form:"port" json:"port" binding:"required"`
+	WafGatewayGroupId int    `form:"waf_gateway_group_id" json:"waf_gateway_group_id"`
+	WafTcpLimitRuleId int    `form:"waf_tcp_limit_id" json:"waf_tcp_limit_id"`
+	CcCount           int    `form:"cc_count" json:"cc_count" default:"0"`
+	CcDuration        string `form:"cc_duration" json:"cc_duration" default:"0s"`
+	CcBlockCount      int    `form:"cc_block_count" json:"cc_block_count" default:"0"`
+	CcBlockDuration   string `form:"cc_block_duration" json:"cc_block_duration" default:"0s"`
+	BackendProtocol   string `form:"backend_protocol" json:"backend_protocol" default:"tcp"`
+	BackendTimeout    string `form:"backend_timeout" json:"backend_timeout" default:"5s"`
+	BackendList       []string `form:"backend_list" json:"backend_list"`
+	AllowIpList       []string `form:"allow_ip_list" json:"allow_ip_list"`
+	DenyIpList        []string `form:"deny_ip_list" json:"deny_ip_list"`
+	AccessRule        string `form:"access_rule" json:"access_rule"`
+	Comment           string `form:"comment" json:"comment"`
+}
+
 type DeleteTcpForwardingRequest struct {
 	WafTcpId int `form:"waf_tcp_id" json:"waf_tcp_id" binding:"required"`
 }
@@ -28,7 +48,7 @@ type TcpForwardingRequest struct {
 	Id                int `form:"id" json:"id"`
 	HostId            int `form:"host_id" json:"host_id" binding:"required"`
 	Uid               int `form:"uid" json:"uid" binding:"required"`
-	TcpForwardingData TcpForwardingData
+	TcpForwardingData TcpForwardingDataRequest
 }
 type TcpForwardingRequire struct {
 	HostId            int    `form:"host_id" json:"host_id" binding:"required"`

+ 1 - 0
cmd/server/wire/wire.go

@@ -23,6 +23,7 @@ import (
 var repositorySet = wire.NewSet(
 	repository.NewDB,
 	//repository.NewRedis,
+	repository.NewMongoClient,
 	repository.NewMongoDB,
 	repository.NewRepository,
 	repository.NewTransaction,

+ 4 - 3
cmd/server/wire/wire_gen.go

@@ -31,8 +31,9 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	handlerFunc := middleware.NewRateLimitMiddleware(limiterLimiter)
 	handlerHandler := handler.NewHandler(logger)
 	db := repository.NewDB(viperViper, logger)
-	mongoDB := repository.NewMongoDB(viperViper)
-	repositoryRepository := repository.NewRepository(logger, db, mongoDB)
+	client := repository.NewMongoClient(viperViper)
+	database := repository.NewMongoDB(client, viperViper)
+	repositoryRepository := repository.NewRepository(logger, db, client, database)
 	transaction := repository.NewTransaction(repositoryRepository)
 	sidSid := sid.NewSid()
 	serviceService := service.NewService(transaction, logger, sidSid, jwtJWT)
@@ -92,7 +93,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 // wire.go:
 
-var repositorySet = wire.NewSet(repository.NewDB, repository.NewMongoDB, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldPublicIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewGameShieldUserIpRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldSdkIpRepository, repository.NewHostRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository)
+var repositorySet = wire.NewSet(repository.NewDB, repository.NewMongoClient, repository.NewMongoDB, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldPublicIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewGameShieldUserIpRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldSdkIpRepository, repository.NewHostRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository)
 
 var serviceSet = wire.NewSet(service.NewService, service.NewUserService, service.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewWebForwardingService, service.NewTcpforwardingService, service.NewUdpForWardingService, service.NewGameShieldUserIpService, service.NewWebLimitService, service.NewTcpLimitService, service.NewUdpLimitService, service.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewHostService, service.NewGlobalLimitService, service.NewGatewayGroupService, service.NewWafFormatterService)
 

+ 1 - 0
cmd/task/wire/wire.go

@@ -19,6 +19,7 @@ import (
 var repositorySet = wire.NewSet(
 	repository.NewDB,
 	//repository.NewRedis,
+	repository.NewMongoClient,
 	repository.NewMongoDB,
 	repository.NewRepository,
 	repository.NewTransaction,

+ 4 - 3
cmd/task/wire/wire_gen.go

@@ -23,8 +23,9 @@ import (
 
 func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), error) {
 	db := repository.NewDB(viperViper, logger)
-	mongoDB := repository.NewMongoDB(viperViper)
-	repositoryRepository := repository.NewRepository(logger, db, mongoDB)
+	client := repository.NewMongoClient(viperViper)
+	database := repository.NewMongoDB(client, viperViper)
+	repositoryRepository := repository.NewRepository(logger, db, client, database)
 	transaction := repository.NewTransaction(repositoryRepository)
 	sidSid := sid.NewSid()
 	taskTask := task.NewTask(transaction, logger, sidSid)
@@ -57,7 +58,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 // wire.go:
 
-var repositorySet = wire.NewSet(repository.NewDB, repository.NewMongoDB, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldPublicIpRepository, repository.NewHostRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldSdkIpRepository)
+var repositorySet = wire.NewSet(repository.NewDB, repository.NewMongoClient, repository.NewMongoDB, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldPublicIpRepository, repository.NewHostRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldSdkIpRepository)
 
 var taskSet = wire.NewSet(task.NewTask, task.NewUserTask, task.NewGameShieldTask)
 

+ 4 - 4
internal/handler/tcpforwarding.go

@@ -51,12 +51,12 @@ func (h *TcpforwardingHandler) EditTcpForwarding(ctx *gin.Context) {
 		return
 	}
 	defaults.SetDefaults(req)
-	res, err := h.tcpforwardingService.EditTcpForwarding(ctx, req)
+	 err := h.tcpforwardingService.EditTcpForwarding(ctx, req)
 	if err != nil {
 		v1.HandleError(ctx, http.StatusInternalServerError, err, err.Error())
 		return
 	}
-	v1.HandleSuccess(ctx, res)
+	v1.HandleSuccess(ctx, nil)
 }
 
 func (h *TcpforwardingHandler) DeleteTcpForwarding(ctx *gin.Context) {
@@ -66,10 +66,10 @@ func (h *TcpforwardingHandler) DeleteTcpForwarding(ctx *gin.Context) {
 		return
 	}
 	defaults.SetDefaults(req)
-	res, err := h.tcpforwardingService.DeleteTcpForwarding(ctx, req.WafTcpId)
+	 err := h.tcpforwardingService.DeleteTcpForwarding(ctx, req.WafTcpId)
 	if err != nil {
 		v1.HandleError(ctx, http.StatusInternalServerError, err, err.Error())
 		return
 	}
-	v1.HandleSuccess(ctx, res)
+	v1.HandleSuccess(ctx, nil)
 }

+ 26 - 1
internal/model/tcpforwarding.go

@@ -1,9 +1,14 @@
 package model
 
+import (
+	"go.mongodb.org/mongo-driver/bson/primitive"
+	"time"
+)
+
 type Tcpforwarding struct {
 	Id                   int `gorm:"primary"`
 	HostId               int `gorm:"not null"`
-	RuleId               int `gorm:"not null"`
+	WafTcpId            int `gorm:"not null"`
 	Tag                  string `gorm:"null"`
 	Port                 string `gorm:"not null"`
 	WafGatewayGroupId    int `gorm:"null"`
@@ -15,8 +20,28 @@ type Tcpforwarding struct {
 	BackendProtocol      string `gorm:"default:tcp"`
 	BackendTimeout       string `gorm:"null"`
 	Comment              string `gorm:"null"`
+	CreatedAt            time.Time
+	UpdatedAt            time.Time
 }
 
 func (m *Tcpforwarding) TableName() string {
     return "shd_waf_tcp"
 }
+
+
+type TcpForwardingRule struct {
+	ID          primitive.ObjectID `bson:"_id,omitempty"`
+	Uid         int                `bson:"uid" json:"uid"`
+	HostId      int                `bson:"host_id" json:"host_id"`
+	TcpId       int                `bson:"tcp_id" json:"tcp_id"`
+	BackendList []string           `bson:"backend_list" json:"backend_list"`
+	AllowIpList []string           `bson:"allow_ip_list" json:"allow_ip_list"`
+	DenyIpList  []string           `bson:"deny_ip_list" json:"deny_ip_list"`
+	AccessRule  string             `bson:"access_rule" json:"access_rule"`
+	CreatedAt   time.Time          `bson:"created_at" json:"created_at"`
+	UpdatedAt   time.Time          `bson:"updated_at" json:"updated_at"`
+}
+
+func (m *TcpForwardingRule) CollectionName() string {
+	return "tcp_forwarding_rules"
+}

+ 35 - 13
internal/repository/repository.go

@@ -5,8 +5,8 @@ import (
 	"fmt"
 	"github.com/glebarez/sqlite"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/log"
-	"github.com/go-nunu/nunu-layout-advanced/pkg/mongo"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/zapgorm2"
+	"github.com/qiniu/qmgo"
 	"github.com/redis/go-redis/v9"
 	"github.com/spf13/viper"
 	"gorm.io/driver/mysql"
@@ -21,7 +21,8 @@ const ctxTxKey = "TxKey"
 type Repository struct {
 	db *gorm.DB
 	//rdb    *redis.Client
-	mongodb *mongo.MongoDB
+	mongoClient *qmgo.Client
+	mongoDB     *qmgo.Database
 	logger *log.Logger
 }
 
@@ -29,12 +30,14 @@ func NewRepository(
 	logger *log.Logger,
 	db *gorm.DB,
 	// rdb *redis.Client,
-	mongodb *mongo.MongoDB,
+	mongoClient *qmgo.Client,
+	mongoDB *qmgo.Database,
 ) *Repository {
 	return &Repository{
 		db: db,
 		//rdb:    rdb,
-		mongodb: mongodb,
+		mongoClient: mongoClient,
+		mongoDB: mongoDB,
 		logger: logger,
 	}
 }
@@ -154,18 +157,37 @@ func NewRedis(conf *viper.Viper) *redis.Client {
 	return rdb
 }
 
-func NewMongoDB(conf *viper.Viper) *mongo.MongoDB {
-	config := &mongo.Config{
-		URI:         conf.GetString("data.mongodb.uri"),
-		Database:    conf.GetString("data.mongodb.database"),
-		Timeout:     conf.GetDuration("data.mongodb.timeout"),
-		MaxPoolSize: conf.GetUint64("data.mongodb.max_pool_size"),
+func NewMongoClient(conf *viper.Viper) *qmgo.Client {
+	timeout := conf.GetDuration("data.mongodb.timeout")
+	if timeout == 0 {
+		timeout = 10 * time.Second
+	}
+	
+	maxPoolSize := conf.GetUint64("data.mongodb.max_pool_size")
+	
+	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+	defer cancel()
+	
+	// 创建连接配置
+	clientOpts := &qmgo.Config{
+		Uri:         conf.GetString("data.mongodb.uri"),
+		MaxPoolSize: &maxPoolSize,
 	}
 
-	mongoDB, err := mongo.New(config)
+	// 连接到MongoDB
+	client, err := qmgo.NewClient(ctx, clientOpts)
 	if err != nil {
-		panic(fmt.Sprintf("mongodb error: %s", err.Error()))
+		panic(fmt.Sprintf("连接MongoDB失败: %s", err.Error()))
 	}
 
-	return mongoDB
+	return client
+}
+
+func NewMongoDB(client *qmgo.Client, conf *viper.Viper) *qmgo.Database {
+	databaseName := conf.GetString("data.mongodb.database")
+	if databaseName == "" {
+		panic("MongoDB数据库名不能为空")
+	}
+	
+	return client.Database(databaseName)
 }

+ 88 - 6
internal/repository/tcpforwarding.go

@@ -2,14 +2,21 @@ package repository
 
 import (
 	"context"
+	"fmt"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
+	"go.mongodb.org/mongo-driver/bson"
+	"go.mongodb.org/mongo-driver/bson/primitive"
+	"time"
 )
 
 type TcpforwardingRepository interface {
 	GetTcpforwarding(ctx context.Context, id int64) (*model.Tcpforwarding, error)
-	AddTcpforwarding(ctx context.Context, req *model.Tcpforwarding) error
+	AddTcpforwarding(ctx context.Context, req *model.Tcpforwarding) (int, error)
 	EditTcpforwarding(ctx context.Context, req *model.Tcpforwarding) error
 	DeleteTcpforwarding(ctx context.Context, id int64) error
+	GetTcpforwardingWafTcpIdById(ctx context.Context, id int) (int, error)
+	AddTcpforwardingIps(ctx context.Context,req model.TcpForwardingRule) (primitive.ObjectID, error)
+	EditTcpforwardingIps(ctx context.Context, req model.TcpForwardingRule) error
 }
 
 func NewTcpforwardingRepository(
@@ -30,15 +37,15 @@ func (r *tcpforwardingRepository) GetTcpforwarding(ctx context.Context, id int64
 	return &tcpforwarding, nil
 }
 
-func (r *tcpforwardingRepository) AddTcpforwarding(ctx context.Context, req *model.Tcpforwarding) error {
-	if err := r.db.Create(&req).Error; err != nil {
-		return err
+func (r *tcpforwardingRepository) AddTcpforwarding(ctx context.Context, req *model.Tcpforwarding) (int, error) {
+	if err := r.db.Create(req).Error; err != nil {
+		return 0, err
 	}
-	return nil
+	return req.Id, nil
 }
 
 func (r *tcpforwardingRepository) EditTcpforwarding(ctx context.Context, req *model.Tcpforwarding) error {
-	if err := r.db.Updates(&req).Error; err != nil {
+	if err := r.db.Updates(req).Error; err != nil {
 		return err
 	}
 	return nil
@@ -50,3 +57,78 @@ func (r *tcpforwardingRepository) DeleteTcpforwarding(ctx context.Context, id in
 	}
 	return nil
 }
+
+func (r *tcpforwardingRepository) GetTcpforwardingWafTcpIdById(ctx context.Context, id int) (int, error) {
+	var WafTcpId int
+
+	if err := r.db.Model(&model.Tcpforwarding{}).Where("id = ?", id).Select("waf_tcp_id").Find(&WafTcpId).Error; err != nil {
+		return 0, err
+	}
+	return WafTcpId, nil
+
+}
+
+
+//mongodb 插入
+func (r *tcpforwardingRepository) AddTcpforwardingIps(ctx context.Context,req model.TcpForwardingRule) (primitive.ObjectID, error) {
+	collection := r.mongoDB.Collection("tcp_forwarding_rules")
+	req.CreatedAt = time.Now()
+	result, err := collection.InsertOne(ctx, req)
+	if err != nil {
+		return primitive.NilObjectID, fmt.Errorf("插入MongoDB失败: %w", err)
+	}
+
+	// 返回插入文档的ID
+	return result.InsertedID.(primitive.ObjectID), nil
+}
+
+func (r *tcpforwardingRepository) EditTcpforwardingIps(ctx context.Context, req model.TcpForwardingRule) error {
+	collection := r.mongoDB.Collection("tcp_forwarding_rules")
+	updateData := bson.M{}
+
+	if req.Uid != 0 {
+		updateData["uid"] = req.Uid
+	}
+
+	if req.HostId != 0 {
+		updateData["host_id"] = req.HostId
+	}
+
+	if req.TcpId  != 0 {
+		updateData["tcp_id"] = req.TcpId
+	}
+
+	if req.AccessRule != "" {
+		updateData["access_rule"] = req.AccessRule
+	}
+
+	if len(req.BackendList) > 0 {
+		updateData["backend_list"] = req.BackendList
+	}
+
+	if len(req.AllowIpList) > 0 {
+		updateData["allow_ip_list"] = req.AllowIpList
+	}
+
+	if len(req.DenyIpList) > 0 {
+		updateData["deny_ip_list"] = req.DenyIpList
+	}
+
+	// 始终更新更新时间
+	updateData["updated_at"] = time.Now()
+
+	// 如果没有任何字段需要更新,则直接返回
+	if len(updateData) == 0 {
+		return nil
+	}
+
+	// 执行更新
+	update := bson.M{"$set": updateData}
+	 err := collection.UpdateOne(ctx, bson.M{"tcp_id": req.TcpId}, update)
+	if err != nil {
+		return fmt.Errorf("更新MongoDB文档失败: %w", err)
+	}
+
+	return nil
+
+}

+ 88 - 28
internal/service/tcpforwarding.go

@@ -6,13 +6,14 @@ import (
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 	"strconv"
+	"strings"
 )
 
 type TcpforwardingService interface {
 	GetTcpforwarding(ctx context.Context, id int64) (*model.Tcpforwarding, error)
 	AddTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest)  error
-	EditTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest) (string, error)
-	DeleteTcpForwarding(ctx context.Context, wafTcpId int) (string, error)
+	EditTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest)  error
+	DeleteTcpForwarding(ctx context.Context, wafTcpId int) error
 }
 
 func NewTcpforwardingService(
@@ -60,8 +61,9 @@ func (s *tcpforwardingService) require(ctx context.Context,req v1.GlobalRequire)
 	return res, nil
 }
 
-func (s *tcpforwardingService) buildWafFormData(req *v1.TcpForwardingData, require v1.GlobalRequire) map[string]interface{} {
+func (s *tcpforwardingService) buildWafFormData(req *v1.TcpForwardingDataSend, require v1.GlobalRequire) map[string]interface{} {
 	return map[string]interface{}{
+		"waf_tcp_id":           req.WafTcpId,
 		"tag":                  require.Tag,
 		"port":                 req.Port,
 		"waf_gateway_group_id": require.WafGatewayGroupId,
@@ -80,11 +82,11 @@ func (s *tcpforwardingService) buildWafFormData(req *v1.TcpForwardingData, requi
 	}
 }
 
-func (s *tcpforwardingService) buildTcpForwardingModel(req *v1.TcpForwardingData, ruleId int, require v1.GlobalRequire) *model.Tcpforwarding {
+func (s *tcpforwardingService) buildTcpForwardingModel(req *v1.TcpForwardingDataRequest, ruleId int, require v1.GlobalRequire) *model.Tcpforwarding {
 	return &model.Tcpforwarding{
 		HostId:  require.HostId,
-		RuleId: ruleId,
-		Port: strconv.Itoa(req.Port),
+		WafTcpId: ruleId,
+		Port:    req.Port,
 		Tag:     require.Tag,
 		Comment: req.Comment,
 		WafGatewayGroupId: require.WafGatewayGroupId,
@@ -97,16 +99,65 @@ func (s *tcpforwardingService) buildTcpForwardingModel(req *v1.TcpForwardingData
 	}
 }
 
-func (s *tcpforwardingService) AddTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest)  error {
+func (s *tcpforwardingService) buildTcpRuleModel(reqData *v1.TcpForwardingDataRequest, require v1.GlobalRequire, localDbId int) *model.TcpForwardingRule {
+	return &model.TcpForwardingRule{
+		Uid:         require.Uid,
+		HostId:      require.HostId,
+		TcpId:       localDbId, // 关联到本地数据库的主记录 ID
+		BackendList: reqData.BackendList,
+		AllowIpList: reqData.AllowIpList,
+		DenyIpList:  reqData.DenyIpList,
+		AccessRule:  reqData.AccessRule,
+	}
+}
+
+func (s *tcpforwardingService) prepareWafData(ctx context.Context, req *v1.TcpForwardingRequest) (v1.GlobalRequire, map[string]interface{}, error) {
+	// 1. 获取必要的全局信息
 	require, err := s.require(ctx, v1.GlobalRequire{
-		HostId: req.HostId,
-		Uid:    req.Uid,
+		HostId:  req.HostId,
+		Uid:     req.Uid,
 		Comment: req.TcpForwardingData.Comment,
 	})
 	if err != nil {
-		return  err
+		return v1.GlobalRequire{}, nil, err
+	}
+
+	// 2. 将字符串切片拼接成字符串,用于 WAF API
+	backendListStr := strings.Join(req.TcpForwardingData.BackendList, "\n")
+	allowIpListStr := strings.Join(req.TcpForwardingData.AllowIpList, "\n")
+	denyIpListStr := strings.Join(req.TcpForwardingData.DenyIpList, "\n")
+
+	// 3. 创建用于构建 WAF 表单的数据结构
+	formDataBase := v1.TcpForwardingDataSend{
+		Tag:               require.Tag,
+		WafTcpId:          req.TcpForwardingData.WafTcpId,
+		WafGatewayGroupId: require.WafGatewayGroupId,
+		WafTcpLimitRuleId: require.LimitRuleId,
+		Port:              req.TcpForwardingData.Port,
+		CcCount:           req.TcpForwardingData.CcCount,
+		CcDuration:        req.TcpForwardingData.CcDuration,
+		CcBlockCount:      req.TcpForwardingData.CcBlockCount,
+		CcBlockDuration:   req.TcpForwardingData.CcBlockDuration,
+		BackendProtocol:   req.TcpForwardingData.BackendProtocol,
+		BackendTimeout:    req.TcpForwardingData.BackendTimeout,
+		BackendList:       backendListStr,
+		AllowIpList:       allowIpListStr,
+		DenyIpList:        denyIpListStr,
+		AccessRule:        req.TcpForwardingData.AccessRule,
+		Comment:           req.TcpForwardingData.Comment,
+	}
+
+	// 4. 构建 WAF 表单数据映射
+	formData := s.buildWafFormData(&formDataBase, require)
+
+	return require, formData, nil
+}
+
+func (s *tcpforwardingService) AddTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest)  error {
+	require, formData, err := s.prepareWafData(ctx, req)
+	if err != nil {
+		return err
 	}
-	formData := s.buildWafFormData(&req.TcpForwardingData, require)
 	wafTcpId, err := s.wafformatter.sendFormData(ctx, "admin/info/waf_tcp/new", "admin/new/waf_tcp", formData)
 	if err != nil {
 		return err
@@ -114,39 +165,48 @@ func (s *tcpforwardingService) AddTcpForwarding(ctx context.Context, req *v1.Tcp
 
 	tcpModel := s.buildTcpForwardingModel(&req.TcpForwardingData, wafTcpId, require)
 
-	if err = s.tcpforwardingRepository.AddTcpforwarding(ctx, tcpModel); err != nil {
+	id, err := s.tcpforwardingRepository.AddTcpforwarding(ctx, tcpModel)
+	if err != nil {
+		return  err
+	}
+	TcpRuleModel := s.buildTcpRuleModel(&req.TcpForwardingData, require, id)
+	if _, err = s.tcpforwardingRepository.AddTcpforwardingIps(ctx, *TcpRuleModel); err != nil {
 		return err
 	}
 	return  nil
 }
 
-func (s *tcpforwardingService) EditTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest) (string, error) {
-	require, err := s.require(ctx, v1.GlobalRequire{
-		HostId: req.HostId,
-		Uid:    req.Uid,
-		Comment: req.TcpForwardingData.Comment,
-	})
+func (s *tcpforwardingService) EditTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest) error {
+	WafTcpId, err := s.tcpforwardingRepository.GetTcpforwardingWafTcpIdById(ctx, req.Id)
 	if err != nil {
-		return "", err
+		return  err
 	}
-	formData := s.buildWafFormData(&req.TcpForwardingData, require)
-	_, err = s.wafformatter.sendFormData(ctx, "admin/info/waf_tcp/edit?&__goadmin_edit_pk="+strconv.Itoa(req.TcpForwardingData.WafTcpId), "admin/edit/waf_tcp", formData)
+	req.TcpForwardingData.WafTcpId = WafTcpId
+	require, formData, err := s.prepareWafData(ctx, req)
 	if err != nil {
-		return "", err
+		return  err
 	}
 
+	_, err = s.wafformatter.sendFormData(ctx, "admin/info/waf_tcp/edit?&__goadmin_edit_pk="+strconv.Itoa(req.TcpForwardingData.WafTcpId), "admin/edit/waf_tcp", formData)
+	if err != nil {
+		return err
+	}
 	tcpModel := s.buildTcpForwardingModel(&req.TcpForwardingData, req.TcpForwardingData.WafTcpId, require)
 	tcpModel.Id = req.Id
 	if err = s.tcpforwardingRepository.EditTcpforwarding(ctx, tcpModel); err != nil {
-		return "", err
+		return  err
+	}
+	TcpRuleModel := s.buildTcpRuleModel(&req.TcpForwardingData, require, req.Id)
+	if err = s.tcpforwardingRepository.EditTcpforwardingIps(ctx, *TcpRuleModel); err != nil {
+		return err
 	}
-	return "", nil
+	return  nil
 }
 
-func (s *tcpforwardingService) DeleteTcpForwarding(ctx context.Context, wafTcpId int) (string, error) {
-	res, err := s.crawler.DeleteRule(ctx, wafTcpId, "admin/delete/waf_tcp?page=1&__pageSize=10&__sort=waf_tcp_id&__sort_type=desc")
+func (s *tcpforwardingService) DeleteTcpForwarding(ctx context.Context, wafTcpId int)  error {
+	_, err := s.crawler.DeleteRule(ctx, wafTcpId, "admin/delete/waf_tcp?page=1&__pageSize=10&__sort=waf_tcp_id&__sort_type=desc")
 	if err != nil {
-		return "", err
+		return err
 	}
-	return res, nil
+	return  nil
 }

+ 0 - 79
pkg/mongo/mongo.go

@@ -1,79 +0,0 @@
-package mongo
-
-import (
-	"context"
-	"fmt"
-	"time"
-
-	"go.mongodb.org/mongo-driver/mongo"
-	"go.mongodb.org/mongo-driver/mongo/options"
-	"go.mongodb.org/mongo-driver/mongo/readpref"
-)
-
-// Config MongoDB配置
-type Config struct {
-	URI         string        `mapstructure:"uri"`
-	Database    string        `mapstructure:"database"`
-	Timeout     time.Duration `mapstructure:"timeout"`
-	MaxPoolSize uint64        `mapstructure:"max_pool_size"`
-}
-
-// MongoDB连接管理器
-type MongoDB struct {
-	config   *Config
-	client   *mongo.Client
-	database *mongo.Database
-}
-
-// New 创建新的MongoDB客户端
-func New(config *Config) (*MongoDB, error) {
-	ctx, cancel := context.WithTimeout(context.Background(), config.Timeout)
-	defer cancel()
-
-	// 创建MongoDB客户端选项
-	clientOptions := options.Client().
-		ApplyURI(config.URI).
-		SetMaxPoolSize(config.MaxPoolSize)
-
-	// 连接到MongoDB
-	client, err := mongo.Connect(ctx, clientOptions)
-	if err != nil {
-		return nil, fmt.Errorf("连接MongoDB失败: %w", err)
-	}
-
-	// 验证连接
-	if err := client.Ping(ctx, readpref.Primary()); err != nil {
-		return nil, fmt.Errorf("MongoDB连接测试失败: %w", err)
-	}
-
-	// 获取数据库
-	database := client.Database(config.Database)
-
-	return &MongoDB{
-		config:   config,
-		client:   client,
-		database: database,
-	}, nil
-}
-
-// Close 关闭MongoDB连接
-func (m *MongoDB) Close() error {
-	ctx, cancel := context.WithTimeout(context.Background(), m.config.Timeout)
-	defer cancel()
-	return m.client.Disconnect(ctx)
-}
-
-// GetDatabase 获取数据库实例
-func (m *MongoDB) GetDatabase() *mongo.Database {
-	return m.database
-}
-
-// GetCollection 获取集合实例
-func (m *MongoDB) GetCollection(name string) *mongo.Collection {
-	return m.database.Collection(name)
-}
-
-// Client 获取原始MongoDB客户端
-func (m *MongoDB) Client() *mongo.Client {
-	return m.client
-}