Selaa lähdekoodia

feat(sharding): 重构分表逻辑并支持配置文件

-重构了 ThresholdConfig 结构,支持更灵活的表配置
- 新增 TableConfig 结构,用于单表的阈值配置
- 修改了 GetOptimalWriteTable 方法,自动根据表名获取阈值- 增加了对配置文件中分表配置的支持
- 优化了日志和错误处理
fusu 17 tuntia sitten
vanhempi
sitoutus
8e49b01de4

+ 1 - 1
cmd/server/wire/wire_gen.go

@@ -49,7 +49,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	qmgoClient := repository.NewMongoClient(viperViper)
 	database := repository.NewMongoDB(qmgoClient, viperViper)
 	rabbitMQ, cleanup := repository.NewRabbitMQ(viperViper, logger)
-	shardingManager := repository.NewShardingManager(logger)
+	shardingManager := repository.NewShardingManager(logger, viperViper)
 	repositoryRepository := repository.NewRepository(logger, db, client, qmgoClient, database, rabbitMQ, syncedEnforcer, shardingManager)
 	transaction := repository.NewTransaction(repositoryRepository)
 	sidSid := sid.NewSid()

+ 4 - 6
internal/repository/admin/waflog.go

@@ -272,8 +272,8 @@ func (r *wafLogRepository) AddWafLog(ctx context.Context, log *model.WafLog) err
 	}
 
 	
-	// 获取最优的写入表(考虑数据量阈值)
-	tableName, err := r.Manager.GetOptimalWriteTable(ctx, r.DBWithName(ctx, "admin"), log, r.Manager.GetMaxRowsForTable("waf_log"))
+	// 获取最优的写入表(自动根据表名获取阈值)
+	tableName, err := r.Manager.GetOptimalWriteTable(ctx, r.DBWithName(ctx, "admin"), log)
 	if err != nil {
 		return fmt.Errorf("获取写入表失败: %v", err)
 	}
@@ -295,8 +295,6 @@ func (r *wafLogRepository) BatchAddWafLog(ctx context.Context, logs []*model.Waf
 		return nil
 	}
 
-	maxRows := r.Manager.GetMaxRowsForTable("waf_log")
-	
 	// 按表名分组
 	tableGroups := make(map[string][]*model.WafLog)
 	
@@ -306,8 +304,8 @@ func (r *wafLogRepository) BatchAddWafLog(ctx context.Context, logs []*model.Waf
 			log.CreatedAt = time.Now()
 		}
 		
-		// 获取最优的写入表(考虑数据量阈值)
-		tableName, err := r.Manager.GetOptimalWriteTable(ctx, r.DBWithName(ctx, "admin"), log, maxRows)
+		// 获取最优的写入表(自动根据表名获取阈值)
+		tableName, err := r.Manager.GetOptimalWriteTable(ctx, r.DBWithName(ctx, "admin"), log)
 		if err != nil {
 			return fmt.Errorf("获取写入表失败: %v", err)
 		}

+ 2 - 2
internal/repository/log.go

@@ -94,8 +94,8 @@ func (r *logRepository) AddLog(ctx context.Context, log *model.Log) error {
 	}
 
 	
-	// 获取最优的写入表(考虑数据量阈值)
-	tableName, err := r.Manager.GetOptimalWriteTable(ctx, r.DBWithName(ctx, "admin"), log, r.Manager.GetMaxRowsForTable("log"))
+	// 获取最优的写入表(自动根据表名获取阈值)
+	tableName, err := r.Manager.GetOptimalWriteTable(ctx, r.DBWithName(ctx, "admin"), log)
 	if err != nil {
 		return fmt.Errorf("获取写入表失败: %v", err)
 	}

+ 8 - 10
internal/repository/repository.go

@@ -414,18 +414,16 @@ m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act
 }
 
 // NewShardingManager creates a ShardingManager with threshold support for dependency injection
-func NewShardingManager(logger *log.Logger) *sharding.ShardingManager {
+func NewShardingManager(logger *log.Logger, conf *viper.Viper) *sharding.ShardingManager {
 	strategy := sharding.NewMonthlyShardingStrategy()
 	
-	// 配置阈值参数 - 统一管理所有表的阈值
-	thresholdConfig := &sharding.ThresholdConfig{
-		Enabled: true,
-		MaxRows: 3000000, // 默认300万条
-		TableThresholds: map[string]int64{
-			"log":     3000000, // log表300万条
-			"waf_log": 5000000, // waf_log表500万条
-		},
+	// 从配置文件读取分表配置
+	var thresholdConfig sharding.ThresholdConfig
+	if err := conf.UnmarshalKey("data.sharding.threshold", &thresholdConfig); err != nil {
+		logger.Error("分表阈值配置读取失败,请检查配置文件: " + err.Error())
+		// 不提供默认配置,强制用户配置正确的配置文件
+		panic("分表配置不存在,请在配置文件中添加 data.sharding.threshold 配置")
 	}
 	
-	return sharding.NewShardingManagerWithThreshold(strategy, logger, thresholdConfig)
+	return sharding.NewShardingManager(strategy, logger, &thresholdConfig)
 }

+ 9 - 1
internal/service/sharding.go

@@ -37,7 +37,15 @@ func NewShardingService(
 ) ShardingService {
 	// 根据配置创建分表策略
 	strategy := createShardingStrategy(config)
-	shardingMgr := sharding.NewShardingManager(strategy, logger)
+	
+	// 从配置文件读取阈值配置
+	var thresholdConfig sharding.ThresholdConfig
+	if err := config.UnmarshalKey("data.sharding.threshold", &thresholdConfig); err != nil {
+		logger.Error("分表阈值配置读取失败: " + err.Error())
+		panic("分表配置不存在,请在配置文件中添加 data.sharding.threshold 配置")
+	}
+	
+	shardingMgr := sharding.NewShardingManager(strategy, logger, &thresholdConfig)
 
 	return &shardingService{
 		Service:     service,

+ 43 - 26
pkg/sharding/manager.go

@@ -20,11 +20,17 @@ type TableModel interface {
 
 // ThresholdConfig 阈值配置
 type ThresholdConfig struct {
-	Enabled       bool
-	MaxRows       int64
-	CheckInterval time.Duration
-	// 不同表的阈值配置
-	TableThresholds map[string]int64
+	Enabled       bool              `mapstructure:"enabled"`
+	MaxRows       int64             `mapstructure:"max_rows"`
+	CheckInterval time.Duration     `mapstructure:"check_interval"`
+	Tables        []TableConfig     `mapstructure:"tables"`
+}
+
+// TableConfig 单表配置
+type TableConfig struct {
+	Name    string `mapstructure:"name"`
+	Enabled bool   `mapstructure:"enabled"`
+	MaxRows int64  `mapstructure:"max_rows"`
 }
 
 // ShardingManager 分表管理器
@@ -34,21 +40,17 @@ type ShardingManager struct {
 	thresholdConfig *ThresholdConfig
 }
 
-func NewShardingManager(strategy ShardingStrategy, logger *log.Logger) *ShardingManager {
-	return &ShardingManager{
-		strategy: strategy,
-		logger:   logger,
-	}
-}
 
-func NewShardingManagerWithThreshold(strategy ShardingStrategy, logger *log.Logger, thresholdConfig *ThresholdConfig) *ShardingManager {
+// NewShardingManager 从配置创建ShardingManager
+func NewShardingManager(strategy ShardingStrategy, logger *log.Logger, config *ThresholdConfig) *ShardingManager {
 	return &ShardingManager{
 		strategy:        strategy,
 		logger:          logger,
-		thresholdConfig: thresholdConfig,
+		thresholdConfig: config,
 	}
 }
 
+
 // GetWriteTableName 获取写入表名(基于记录的创建时间)
 func (sm *ShardingManager) GetWriteTableName(model TableModel) string {
 	baseTableName := model.GetBaseTableName()
@@ -188,8 +190,8 @@ func (sm *ShardingManager) findTablesByPattern(db *gorm.DB, pattern string) []st
 	return tables
 }
 
-// GetOptimalWriteTable 获取最优的写入表(考虑数据量阈值)
-func (sm *ShardingManager) GetOptimalWriteTable(ctx context.Context, db *gorm.DB, model TableModel, maxRows int64) (string, error) {
+// GetOptimalWriteTable 获取最优的写入表(根据model自动获取阈值)
+func (sm *ShardingManager) GetOptimalWriteTable(ctx context.Context, db *gorm.DB, model TableModel) (string, error) {
 	baseTableName := model.GetBaseTableName()
 	createdAt := model.GetCreatedAt()
 	
@@ -205,9 +207,12 @@ func (sm *ShardingManager) GetOptimalWriteTable(ctx context.Context, db *gorm.DB
 		return baseShardTableName, nil
 	}
 	
-	// 使用配置的maxRows,如果没有则使用默认值
-	if maxRows <= 0 {
-		maxRows = sm.thresholdConfig.MaxRows
+	// 根据表名自动获取阈值
+	maxRows := sm.GetMaxRowsForTable(baseTableName)
+	
+	// 如果返回-1,表示该表禁用了阈值检查,直接返回基础表名
+	if maxRows == -1 {
+		return baseShardTableName, nil
 	}
 	
 	// 检查当前表是否已达到阈值
@@ -270,18 +275,30 @@ func (sm *ShardingManager) CheckAndCreateNewTable(ctx context.Context, db *gorm.
 
 // GetMaxRowsForTable 获取指定表的最大行数配置
 func (sm *ShardingManager) GetMaxRowsForTable(tableName string) int64 {
-	// 优先使用表级配置
-	if sm.thresholdConfig != nil && sm.thresholdConfig.TableThresholds != nil {
-		if maxRows, exists := sm.thresholdConfig.TableThresholds[tableName]; exists {
-			return maxRows
+	// 检查表级配置
+	if sm.thresholdConfig != nil && sm.thresholdConfig.Tables != nil {
+		for _, tableConfig := range sm.thresholdConfig.Tables {
+			if tableConfig.Name == tableName {
+				if !tableConfig.Enabled {
+					// 表级别禁用分表,返回-1表示无限制
+					return -1
+				}
+				return tableConfig.MaxRows
+			}
 		}
 	}
 	
-	// 使用默认配置
-	if sm.thresholdConfig != nil {
+	// 检查全局配置是否启用
+	if sm.thresholdConfig != nil && !sm.thresholdConfig.Enabled {
+		// 全局禁用阈值检查,返回-1表示无限制
+		return -1
+	}
+	
+	// 使用全局默认配置
+	if sm.thresholdConfig != nil && sm.thresholdConfig.MaxRows > 0 {
 		return sm.thresholdConfig.MaxRows
 	}
 	
-	// 最终默认值
-	return 3000000
+	// 配置缺失,返回错误而不是默认值
+	panic(fmt.Sprintf("表 '%s' 的阈值配置缺失,请在配置文件中添加相应配置", tableName))
 }