Ver código fonte

feat(tcp/udp forwarding): 添加获取转发列表功能

- 新增 TCP 和 UDP 转发列表获取接口和相关逻辑
- 在数据库中增加相应的查询方法
-优化了数据处理流程,使用协程提高性能
- 为后续增加的第二个数据库连接做准备
fusu 1 mês atrás
pai
commit
84ac9cf549

+ 1 - 0
cmd/server/wire/wire.go

@@ -22,6 +22,7 @@ import (
 
 var repositorySet = wire.NewSet(
 	repository.NewDB,
+	//repository.NewDBSecond,
 	//repository.NewRedis,
 	repository.NewMongoClient,
 	repository.NewMongoDB,

+ 1 - 0
cmd/task/wire/wire.go

@@ -18,6 +18,7 @@ import (
 
 var repositorySet = wire.NewSet(
 	repository.NewDB,
+	repository.NewDBSecond,
 	//repository.NewRedis,
 	repository.NewMongoClient,
 	repository.NewMongoDB,

+ 4 - 0
config/prod.yml

@@ -15,6 +15,10 @@ data:
       driver: mysql
       dsn: 183_136_132_25:xGrNJphcmGcXiajE@tcp(183.136.132.25:3306)/183_136_132_25?charset=utf8mb4&parseTime=True&loc=Local
       logLevel: "warn"
+#    second:
+#      driver: mysql
+#      dsn: second_db_user:password@tcp(second-db-host:3306)/second_db_name?charset=utf8mb4&parseTime=True&loc=Local
+#      logLevel: "warn"
   #    user:
   #      driver: sqlite
   #      dsn: storage/nunu-test.db?_busy_timeout=5000

+ 15 - 0
internal/handler/tcpforwarding.go

@@ -85,3 +85,18 @@ func (h *TcpforwardingHandler) DeleteTcpForwarding(ctx *gin.Context) {
 	}
 	v1.HandleSuccess(ctx, nil)
 }
+
+func (h *TcpforwardingHandler) GetTcpForwardingList(ctx *gin.Context) {
+	req := new(v1.GetForwardingRequest)
+	if err := ctx.ShouldBind(req); err != nil {
+		v1.HandleError(ctx, http.StatusBadRequest, v1.ErrBadRequest, err.Error())
+		return
+	}
+	defaults.SetDefaults(req)
+	res, err := h.tcpforwardingService.GetTcpForwardingAllIpsByHostId(ctx, *req)
+	if err != nil {
+		v1.HandleError(ctx, http.StatusInternalServerError, err, err.Error())
+		return
+	}
+	v1.HandleSuccess(ctx, res)
+}

+ 15 - 0
internal/handler/udpforwarding.go

@@ -85,3 +85,18 @@ func (h *UdpForWardingHandler) DeleteUdpForWarding(ctx *gin.Context) {
 	}
 	v1.HandleSuccess(ctx, nil)
 }
+
+func (h *UdpForWardingHandler) GetUdpForWardingList(ctx *gin.Context) {
+	req := new(v1.GetForwardingRequest)
+	if err := ctx.ShouldBind(req); err != nil {
+		v1.HandleError(ctx, http.StatusBadRequest, v1.ErrBadRequest, err.Error())
+		return
+	}
+	defaults.SetDefaults(req)
+	res, err := h.udpForWardingService.GetUdpForwardingWafUdpAllIps(ctx, *req)
+	if err != nil {
+		v1.HandleError(ctx, http.StatusInternalServerError, err, err.Error())
+		return
+	}
+	v1.HandleSuccess(ctx, res)
+}

+ 88 - 1
internal/repository/repository.go

@@ -19,7 +19,8 @@ import (
 const ctxTxKey = "TxKey"
 
 type Repository struct {
-	db *gorm.DB
+	db *gorm.DB         // 主数据库连接
+	//dbSecond *gorm.DB   // 第二个数据库连接
 	//rdb    *redis.Client
 	mongoClient *qmgo.Client
 	mongoDB     *qmgo.Database
@@ -29,12 +30,14 @@ type Repository struct {
 func NewRepository(
 	logger *log.Logger,
 	db *gorm.DB,
+	//dbSecond *gorm.DB,
 	// rdb *redis.Client,
 	mongoClient *qmgo.Client,
 	mongoDB *qmgo.Database,
 ) *Repository {
 	return &Repository{
 		db: db,
+		//dbSecond: dbSecond,
 		//rdb:    rdb,
 		mongoClient: mongoClient,
 		mongoDB: mongoDB,
@@ -62,6 +65,12 @@ func (r *Repository) DB(ctx context.Context) *gorm.DB {
 	return r.db.WithContext(ctx)
 }
 
+// DBSecond returns the second database connection
+// Note: Transactions are currently only supported on the primary database
+//func (r *Repository) DBSecond(ctx context.Context) *gorm.DB {
+//	return r.dbSecond.WithContext(ctx)
+//}
+
 func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
 	return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 		ctx = context.WithValue(ctx, ctxTxKey, tx)
@@ -139,6 +148,84 @@ func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
 	return db
 }
 
+// NewDBSecond 初始化第二个数据库连接
+func NewDBSecond(conf *viper.Viper, l *log.Logger) *gorm.DB {
+	var (
+		db  *gorm.DB
+		err error
+	)
+
+	// 从second配置项读取第二个数据库配置
+	driver := conf.GetString("data.db.second.driver")
+	dsn := conf.GetString("data.db.second.dsn")
+
+	// 如果第二个数据库没有配置,返回nil
+	if dsn == "" {
+		l.Warn("第二个数据库配置不存在或DSN为空")
+		return nil
+	}
+
+	// 读取日志级别配置
+	logLevelStr := conf.GetString("data.db.second.logLevel")
+	var logLevel gormlogger.LogLevel
+
+	switch logLevelStr {
+	case "silent":
+		logLevel = gormlogger.Silent
+	case "error":
+		logLevel = gormlogger.Error
+	case "warn":
+		logLevel = gormlogger.Warn
+	case "info":
+		logLevel = gormlogger.Info
+	default:
+		// MySQL 默认只记录警告和错误
+		if driver == "mysql" {
+			logLevel = gormlogger.Warn
+		} else {
+			logLevel = gormlogger.Info
+		}
+	}
+
+	logger := zapgorm2.New(l.Logger).LogMode(logLevel)
+
+	// 初始化第二个数据库连接
+	switch driver {
+	case "mysql":
+		db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
+			Logger: logger,
+		})
+	case "postgres":
+		db, err = gorm.Open(postgres.New(postgres.Config{
+			DSN:                  dsn,
+			PreferSimpleProtocol: true,
+		}), &gorm.Config{
+			Logger: logger,
+		})
+	case "sqlite":
+		db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{
+			Logger: logger,
+		})
+	default:
+		panic("unknown db driver for second database")
+	}
+
+	if err != nil {
+		panic("连接第二个数据库失败: " + err.Error())
+	}
+
+	// 配置连接池
+	sqlDB, err := db.DB()
+	if err != nil {
+		panic(err)
+	}
+	sqlDB.SetMaxIdleConns(10)
+	sqlDB.SetMaxOpenConns(100)
+	sqlDB.SetConnMaxLifetime(time.Hour)
+
+	return db
+}
+
 func NewRedis(conf *viper.Viper) *redis.Client {
 	rdb := redis.NewClient(&redis.Options{
 		Addr:     conf.GetString("data.redis.addr"),

+ 9 - 0
internal/repository/tcpforwarding.go

@@ -18,6 +18,7 @@ type TcpforwardingRepository interface {
 	DeleteTcpforwarding(ctx context.Context, id int64) error
 	GetTcpforwardingWafTcpIdById(ctx context.Context, id int) (int, error)
 	GetTcpForwardingPortCountByHostId(ctx context.Context, hostId int) (int64, error)
+	GetTcpForwardingAllIdsByID(ctx context.Context, hostId int) ([]int, error)
 	AddTcpforwardingIps(ctx context.Context,req model.TcpForwardingRule) (primitive.ObjectID, error)
 	EditTcpforwardingIps(ctx context.Context, req model.TcpForwardingRule) error
 	GetTcpForwardingIpsByID(ctx context.Context, tcpId int) (*model.TcpForwardingRule, error)
@@ -83,6 +84,14 @@ func (r *tcpforwardingRepository) GetTcpForwardingPortCountByHostId(ctx context.
 	return count, nil
 }
 
+func (r *tcpforwardingRepository) GetTcpForwardingAllIdsByID(ctx context.Context, hostId int) ([]int, error) {
+	var res []int
+	if err := r.db.WithContext(ctx).Model(&model.Tcpforwarding{}).Where("host_id = ?", hostId).Select("id").Find(&res).Error; err != nil {
+		return nil, err
+	}
+	return res, nil
+}
+
 //mongodb 插入
 func (r *tcpforwardingRepository) AddTcpforwardingIps(ctx context.Context,req model.TcpForwardingRule) (primitive.ObjectID, error) {
 	collection := r.mongoDB.Collection("tcp_forwarding_rules")

+ 11 - 1
internal/repository/udpforwarding.go

@@ -18,6 +18,7 @@ type UdpForWardingRepository interface {
 	DeleteUdpForwarding(ctx context.Context, id int64) error
 	GetUdpForwardingWafUdpIdById(ctx context.Context, id int) (int, error)
 	GetUdpForwardingPortCountByHostId(ctx context.Context, hostId int) (int64, error)
+	GetUdpForwardingWafUdpAllIds(ctx context.Context, udpId int) ([]int, error)
 	AddUdpForwardingIps(ctx context.Context, req model.UdpForwardingRule) (primitive.ObjectID, error)
 	EditUdpForwardingIps(ctx context.Context, req model.UdpForwardingRule) error
 	GetUdpForwardingIpsByID(ctx context.Context, udpId int) (*model.UdpForwardingRule, error)
@@ -83,6 +84,14 @@ func (r *udpForWardingRepository) GetUdpForwardingPortCountByHostId(ctx context.
 	return count, nil
 }
 
+func (r *udpForWardingRepository) GetUdpForwardingWafUdpAllIds(ctx context.Context, udpId int) ([]int, error) {
+	var res []int
+	if err:= r.db.WithContext(ctx).Model(&model.UdpForWarding{}).Where("id = ?", udpId).Select("waf_udp_id").Find(&res).Error; err != nil {
+		return nil, err
+	}
+	return res, nil
+}
+
 
 // mongodb 插入
 func (r *udpForWardingRepository) AddUdpForwardingIps(ctx context.Context, req model.UdpForwardingRule) (primitive.ObjectID, error) {
@@ -176,4 +185,5 @@ func (r *udpForWardingRepository) DeleteUdpForwardingIpsById(ctx context.Context
 	}
 	return nil
 
-}
+}
+

+ 2 - 0
internal/server/http.go

@@ -101,9 +101,11 @@ func NewHTTPServer(
 			noAuthRouter.POST("/webLimit/delete", ipAllowlistMiddleware, weblimitHandler.DeleteWebLimit)
 			noAuthRouter.POST("/tcpForward/add", ipAllowlistMiddleware, tcpForwardingHandler.AddTcpForwarding)
 			noAuthRouter.POST("/tcpForward/get", ipAllowlistMiddleware, tcpForwardingHandler.GetTcpforwarding)
+			noAuthRouter.POST("/tcpForward/getList", ipAllowlistMiddleware, tcpForwardingHandler.GetTcpForwardingList)
 			noAuthRouter.POST("/tcpForward/edit", ipAllowlistMiddleware, tcpForwardingHandler.EditTcpForwarding)
 			noAuthRouter.POST("/tcpForward/delete", ipAllowlistMiddleware, tcpForwardingHandler.DeleteTcpForwarding)
 			noAuthRouter.POST("/udpForward/get", ipAllowlistMiddleware, udpForwardingHandler.GetUdpForWarding)
+			noAuthRouter.POST("/udpForward/getList", ipAllowlistMiddleware, udpForwardingHandler.GetUdpForWardingList)
 			noAuthRouter.POST("/udpForward/add", ipAllowlistMiddleware, udpForwardingHandler.AddUdpForWarding)
 			noAuthRouter.POST("/udpForward/edit", ipAllowlistMiddleware, udpForwardingHandler.EditUdpForWarding)
 			noAuthRouter.POST("/udpForward/delete", ipAllowlistMiddleware, udpForwardingHandler.DeleteUdpForWarding)

+ 78 - 0
internal/service/tcpforwarding.go

@@ -16,6 +16,7 @@ type TcpforwardingService interface {
 	AddTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest)  error
 	EditTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest)  error
 	DeleteTcpForwarding(ctx context.Context, req v1.DeleteTcpForwardingRequest) error
+	GetTcpForwardingAllIpsByHostId(ctx context.Context, req v1.GetForwardingRequest) ([]v1.TcpForwardingDataRequest, error)
 }
 
 func NewTcpforwardingService(
@@ -279,4 +280,81 @@ func (s *tcpforwardingService) DeleteTcpForwarding(ctx context.Context, req v1.D
 		}
 	}
 	return  nil
+}
+
+func (s *tcpforwardingService) GetTcpForwardingAllIpsByHostId(ctx context.Context, req v1.GetForwardingRequest) ([]v1.TcpForwardingDataRequest, error) {
+	type CombinedResult struct {
+		Id          int
+		Forwarding  *model.Tcpforwarding
+		BackendRule *model.TcpForwardingRule
+		Err         error // 如果此ID的处理出错,则携带错误
+	}
+	g,gCtx := errgroup.WithContext(ctx)
+	ids, err := s.tcpforwardingRepository.GetTcpForwardingAllIdsByID(gCtx, req.HostId)
+	if err != nil {
+		return nil, fmt.Errorf("GetTcpForwardingAllIds failed: %w", err)
+	}
+	if len(ids) == 0 {
+		return nil, nil
+	}
+	resChan := make(chan CombinedResult, len(ids))
+	g.Go(func() error {
+		for _, idVal := range ids {
+			currentID := idVal
+			g.Go(func() error {
+				var wf *model.Tcpforwarding
+				var bk *model.TcpForwardingRule
+				var localErr error
+				wf, localErr = s.tcpforwardingRepository.GetTcpforwarding(gCtx, int64(currentID))
+				if localErr != nil {
+					resChan <- CombinedResult{Id: currentID, Err: localErr}
+					return localErr
+				}
+				bk, localErr = s.tcpforwardingRepository.GetTcpForwardingIpsByID(gCtx, currentID)
+				if localErr != nil {
+					resChan <- CombinedResult{Id: currentID, Err: localErr}
+					return localErr
+				}
+				resChan <- CombinedResult{Id: currentID, Forwarding: wf, BackendRule: bk}
+				return nil
+			})
+		}
+		return nil
+	})
+	groupErr := g.Wait()
+	close(resChan)
+	if groupErr != nil {
+		return nil, groupErr
+	}
+	res := make([]v1.TcpForwardingDataRequest, 0, len(ids))
+	for r := range resChan {
+		if r.Err != nil {
+			return nil, fmt.Errorf("received error from goroutine for ID %d: %w", r.Id, r.Err)
+		}
+		if r.Forwarding == nil {
+			return nil,fmt.Errorf("received nil forwarding from goroutine for ID %d", r.Id)
+		}
+
+		dataReq := v1.TcpForwardingDataRequest{
+			Id: r.Forwarding.Id,
+			Port: r.Forwarding.Port,
+			CcCount: r.Forwarding.CcCount,
+			CcDuration: r.Forwarding.CcDuration,
+			CcBlockCount: r.Forwarding.CcBlockCount,
+			CcBlockDuration: r.Forwarding.CcBlockDuration,
+			BackendProtocol: r.Forwarding.BackendProtocol,
+			BackendTimeout: r.Forwarding.BackendTimeout,
+			Comment: r.Forwarding.Comment,
+
+		}
+		if r.BackendRule != nil {
+			dataReq.BackendList = r.BackendRule.BackendList
+			dataReq.AllowIpList = r.BackendRule.AllowIpList
+			dataReq.DenyIpList = r.BackendRule.DenyIpList
+			dataReq.AccessRule = r.BackendRule.AccessRule
+		}
+		res = append(res, dataReq)
+	}
+	return res, nil
+
 }

+ 80 - 0
internal/service/udpforwarding.go

@@ -16,6 +16,7 @@ type UdpForWardingService interface {
 	AddUdpForwarding(ctx context.Context, req *v1.UdpForwardingRequest) error
 	EditUdpForwarding(ctx context.Context, req *v1.UdpForwardingRequest) error
 	DeleteUdpForwarding(ctx context.Context, Ids []int) error
+	GetUdpForwardingWafUdpAllIps(ctx context.Context, req v1.GetForwardingRequest) ([]v1.UdpForwardingDataRequest, error)
 }
 
 func NewUdpForWardingService(
@@ -298,3 +299,82 @@ func (s *udpForWardingService) DeleteUdpForwarding(ctx context.Context, Ids []in
 	}
 	return nil
 }
+
+func (s *udpForWardingService) GetUdpForwardingWafUdpAllIps(ctx context.Context, req v1.GetForwardingRequest) ([]v1.UdpForwardingDataRequest, error) {
+	type CombinedResult struct {
+		Id          int
+		Forwarding  *model.UdpForWarding
+		BackendRule *model.UdpForwardingRule
+		Err         error // 如果此ID的处理出错,则携带错误
+	}
+
+	g,gCtx := errgroup.WithContext(ctx)
+	ids, err := s.udpForWardingRepository.GetUdpForwardingWafUdpAllIds(gCtx, req.HostId)
+	if err != nil {
+		return nil, fmt.Errorf("GetUdpForwardingWafUdpAllIds failed: %w", err)
+	}
+	if len(ids) == 0 {
+		return nil, nil
+	}
+	resChan := make(chan CombinedResult, len(ids))
+
+	for _, idVal := range ids {
+		currentID := idVal
+		g.Go(func() error {
+			var wf *model.UdpForWarding
+			var bk *model.UdpForwardingRule
+			var localErr error
+			wf, localErr = s.udpForWardingRepository.GetUdpForWarding(gCtx, int64(currentID))
+			if localErr != nil {
+				resChan <- CombinedResult{Id: currentID, Err: localErr}
+				return localErr
+			}
+			bk, localErr = s.udpForWardingRepository.GetUdpForwardingIpsByID(gCtx, currentID)
+			if localErr != nil {
+				resChan <- CombinedResult{Id: currentID, Err: localErr}
+				return localErr
+			}
+			resChan <- CombinedResult{Id: currentID, Forwarding: wf, BackendRule: bk}
+			return nil
+		})
+	}
+	groupErr := g.Wait()
+	close(resChan)
+	if groupErr != nil {
+		return nil, groupErr
+	}
+	res := make([]v1.UdpForwardingDataRequest, 0, len(ids))
+	for r := range resChan {
+		if r.Err != nil {
+			return nil, fmt.Errorf("received error from goroutine for ID %d: %w", r.Id, r.Err)
+		}
+		if r.Forwarding == nil  {
+			return nil, fmt.Errorf("received nil forwarding from goroutine for ID %d", r.Id)
+		}
+
+		dataReq := v1.UdpForwardingDataRequest{
+			Id: r.Forwarding.Id,
+			Port: r.Forwarding.Port,
+			CcPacketCount:     r.Forwarding.CcPacketCount,
+			CcPacketDuration:  r.Forwarding.CcPacketDuration,
+			CcPacketBlockCount: r.Forwarding.CcPacketBlockCount,
+			CcPacketBlockDuration: r.Forwarding.CcPacketBlockDuration,
+			CcCount:           r.Forwarding.CcCount,
+			CcDuration:        r.Forwarding.CcDuration,
+			CcBlockCount:      r.Forwarding.CcBlockCount,
+			CcBlockDuration:   r.Forwarding.CcBlockDuration,
+			SessionTimeout:    r.Forwarding.SessionTimeout,
+			Comment:           r.Forwarding.Comment,
+		}
+
+		if r.BackendRule != nil {
+			dataReq.BackendList = r.BackendRule.BackendList
+			dataReq.AllowIpList = r.BackendRule.AllowIpList
+			dataReq.DenyIpList = r.BackendRule.DenyIpList
+			dataReq.AccessRule = r.BackendRule.AccessRule
+		}
+		res = append(res, dataReq)
+	}
+
+	return res, nil
+}

+ 0 - 4
internal/service/webforwarding.go

@@ -459,13 +459,9 @@ func (s *webForwardingService) GetWebForwardingWafWebAllIps(ctx context.Context,
 
 		dataReq := v1.WebForwardingDataRequest{
 			Id:                  res.Forwarding.Id,
-			WafWebId:            res.Forwarding.WafWebId,
-			Tag:                 res.Forwarding.Tag,
 			Port:                res.Forwarding.Port,
 			Domain:              res.Forwarding.Domain,
 			CustomHost:          res.Forwarding.CustomHost,
-			WafWebLimitId:       res.Forwarding.WebLimitRuleId,
-			WafGatewayGroupId:   res.Forwarding.WafGatewayGroupId,
 			CcCount:             res.Forwarding.CcCount,
 			CcDuration:          res.Forwarding.CcDuration,
 			CcBlockCount:        res.Forwarding.CcBlockCount,