瀏覽代碼

refactor(sharding): 优化时间范围查询逻辑并调整默认查询周期

- 修改了 GetQueryTableNames 函数,优化了时间范围查询逻辑
- 调整了默认查询周期,从最近3个月
fusu 22 小時之前
父節點
當前提交
d4c887434b
共有 1 個文件被更改,包括 30 次插入50 次删除
  1. 30 50
      pkg/sharding/manager.go

+ 30 - 50
pkg/sharding/manager.go

@@ -55,11 +55,11 @@ func NewShardingManager(strategy ShardingStrategy, logger *log.Logger, config *T
 func (sm *ShardingManager) GetWriteTableName(model TableModel) string {
 	baseTableName := model.GetBaseTableName()
 	createdAt := model.GetCreatedAt()
-	
+
 	if createdAt.IsZero() {
 		createdAt = time.Now()
 	}
-	
+
 	return sm.strategy.GetTableName(baseTableName, createdAt)
 }
 
@@ -70,36 +70,16 @@ func (sm *ShardingManager) GetCurrentTableName(baseTableName string) string {
 
 // GetQueryTableNames 获取查询需要的所有表名
 func (sm *ShardingManager) GetQueryTableNames(baseTableName string, start, end *time.Time) []string {
-	if start == nil && end == nil {
-		// 如果都为nil,返回所有可能的表名(用于显示所有数据)
-		return sm.getAllPossibleTableNames(baseTableName)
-	} else if start == nil || end == nil {
-		// 如果只有一个为nil,默认查询最近3个月的表
+	if start == nil || end == nil {
+		// 如果没有指定时间范围,默认查询最近3个月的表
 		now := time.Now()
-		defaultStart := now.AddDate(0, -2, 0) // 前2个月
+		defaultStart := now.AddDate(0, -120, 0) // 前2个月
 		defaultEnd := now
 		return sm.strategy.GetTableNamesByRange(baseTableName, defaultStart, defaultEnd)
 	}
 	return sm.strategy.GetTableNamesByRange(baseTableName, *start, *end)
 }
 
-// getAllPossibleTableNames 直接从数据库中查找所有符合模式的表名
-func (sm *ShardingManager) getAllPossibleTableNames(baseTableName string) []string {
-	// 使用数据库查找所有符合模式的表,而不是基于时间推测
-	// 查找基础分表模式: log_202408, log_202407 等
-	basePattern := fmt.Sprintf("%s_\\d{6}", baseTableName) // 例如: log_202408
-	baseTables := sm.findTablesByPattern(nil, basePattern)
-	
-	// 查找序列分表模式: log_202408_01, log_202408_02 等  
-	sequencePattern := fmt.Sprintf("%s_\\d{6}_\\d+", baseTableName) // 例如: log_202408_01
-	sequenceTables := sm.findTablesByPattern(nil, sequencePattern)
-	
-	// 合并结果
-	allTables := append(baseTables, sequenceTables...)
-	
-	return allTables
-}
-
 // EnsureTableExists 确保表存在,不存在则创建
 func (sm *ShardingManager) EnsureTableExists(ctx context.Context, db *gorm.DB, tableName string, model interface{}) error {
 	// 检查表是否存在
@@ -108,7 +88,7 @@ func (sm *ShardingManager) EnsureTableExists(ctx context.Context, db *gorm.DB, t
 	}
 
 	sm.logger.Info(fmt.Sprintf("创建分表: %s", tableName))
-	
+
 	// 使用指定的表名创建表
 	return db.Table(tableName).AutoMigrate(model)
 }
@@ -150,24 +130,24 @@ func (sm *ShardingManager) BuildUnionQuery(ctx context.Context, db *gorm.DB, tab
 func (sm *ShardingManager) GetTableNamesWithExistenceCheck(db *gorm.DB, baseTableName string, start, end *time.Time) []string {
 	allTableNames := sm.GetQueryTableNames(baseTableName, start, end)
 	var existingTables []string
-	
+
 	for _, tableName := range allTableNames {
 		if db.Migrator().HasTable(tableName) {
 			existingTables = append(existingTables, tableName)
 		}
 	}
-	
+
 	// 还要检查动态分表(带序号的表)
 	dynamicTables := sm.findDynamicTables(db, allTableNames)
 	existingTables = append(existingTables, dynamicTables...)
-	
+
 	return existingTables
 }
 
 // findDynamicTables 查找动态分表(带序号的表)
 func (sm *ShardingManager) findDynamicTables(db *gorm.DB, baseTableNames []string) []string {
 	var dynamicTables []string
-	
+
 	for _, baseTableName := range baseTableNames {
 		// 查找类似 log_202408_01, log_202408_02 这样的表
 		pattern := fmt.Sprintf("%s_\\d+", baseTableName)
@@ -175,14 +155,14 @@ func (sm *ShardingManager) findDynamicTables(db *gorm.DB, baseTableNames []strin
 			dynamicTables = append(dynamicTables, tables...)
 		}
 	}
-	
+
 	return dynamicTables
 }
 
 // findTablesByPattern 根据模式查找表
 func (sm *ShardingManager) findTablesByPattern(db *gorm.DB, pattern string) []string {
 	var tables []string
-	
+
 	// 获取所有表名
 	rows, err := db.Raw("SHOW TABLES").Rows()
 	if err != nil {
@@ -190,13 +170,13 @@ func (sm *ShardingManager) findTablesByPattern(db *gorm.DB, pattern string) []st
 		return tables
 	}
 	defer rows.Close()
-	
+
 	regex, err := regexp.Compile(pattern)
 	if err != nil {
 		sm.logger.Error("编译正则表达式失败: " + err.Error())
 		return tables
 	}
-	
+
 	for rows.Next() {
 		var tableName string
 		if err := rows.Scan(&tableName); err != nil {
@@ -206,7 +186,7 @@ func (sm *ShardingManager) findTablesByPattern(db *gorm.DB, pattern string) []st
 			tables = append(tables, tableName)
 		}
 	}
-	
+
 	return tables
 }
 
@@ -214,27 +194,27 @@ func (sm *ShardingManager) findTablesByPattern(db *gorm.DB, pattern string) []st
 func (sm *ShardingManager) GetOptimalWriteTable(ctx context.Context, db *gorm.DB, model TableModel) (string, error) {
 	baseTableName := model.GetBaseTableName()
 	createdAt := model.GetCreatedAt()
-	
+
 	if createdAt.IsZero() {
 		createdAt = time.Now()
 	}
-	
+
 	// 先获取基础表名
 	baseShardTableName := sm.strategy.GetTableName(baseTableName, createdAt)
-	
+
 	// 如果没有启用阈值检查,直接返回基础表名
 	if sm.thresholdConfig == nil || !sm.thresholdConfig.Enabled {
 		return baseShardTableName, nil
 	}
-	
+
 	// 根据表名自动获取阈值
 	maxRows := sm.GetMaxRowsForTable(baseTableName)
-	
+
 	// 如果返回-1,表示该表禁用了阈值检查,直接返回基础表名
 	if maxRows == -1 {
 		return baseShardTableName, nil
 	}
-	
+
 	// 检查当前表是否已达到阈值
 	currentTable := baseShardTableName
 	for {
@@ -242,7 +222,7 @@ func (sm *ShardingManager) GetOptimalWriteTable(ctx context.Context, db *gorm.DB
 			// 表不存在,可以使用
 			return currentTable, nil
 		}
-		
+
 		// 检查表的数据量
 		var count int64
 		err := db.Table(currentTable).Count(&count).Error
@@ -250,12 +230,12 @@ func (sm *ShardingManager) GetOptimalWriteTable(ctx context.Context, db *gorm.DB
 			sm.logger.Error(fmt.Sprintf("检查表 %s 数据量失败: %v", currentTable, err))
 			return currentTable, nil // 出错时返回当前表
 		}
-		
+
 		if count < maxRows {
 			// 当前表还有空间
 			return currentTable, nil
 		}
-		
+
 		// 当前表已满,尝试下一个序号的表
 		currentTable = sm.getNextSequenceTable(currentTable)
 		sm.logger.Info(fmt.Sprintf("表 %s 已达到阈值 %d,尝试使用 %s", baseShardTableName, maxRows, currentTable))
@@ -267,7 +247,7 @@ func (sm *ShardingManager) getNextSequenceTable(currentTableName string) string
 	// 检查是否已经有序号
 	re := regexp.MustCompile(`^(.+)_(\d+)$`)
 	matches := re.FindStringSubmatch(currentTableName)
-	
+
 	if len(matches) == 3 {
 		// 已有序号,递增
 		baseName := matches[1]
@@ -283,13 +263,13 @@ func (sm *ShardingManager) getNextSequenceTable(currentTableName string) string
 func (sm *ShardingManager) CheckAndCreateNewTable(ctx context.Context, db *gorm.DB, baseTableName string, modelExample interface{}) error {
 	currentTime := time.Now()
 	expectedTableName := sm.strategy.GetTableName(baseTableName, currentTime)
-	
+
 	// 检查当前期间的表是否存在
 	if !db.Migrator().HasTable(expectedTableName) {
 		sm.logger.Info(fmt.Sprintf("创建新周期分表: %s", expectedTableName))
 		return sm.EnsureTableExists(ctx, db, expectedTableName, modelExample)
 	}
-	
+
 	return nil
 }
 
@@ -307,18 +287,18 @@ func (sm *ShardingManager) GetMaxRowsForTable(tableName string) int64 {
 			}
 		}
 	}
-	
+
 	// 检查全局配置是否启用
 	if sm.thresholdConfig != nil && !sm.thresholdConfig.Enabled {
 		// 全局禁用阈值检查,返回-1表示无限制
 		return -1
 	}
-	
+
 	// 使用全局默认配置
 	if sm.thresholdConfig != nil && sm.thresholdConfig.MaxRows > 0 {
 		return sm.thresholdConfig.MaxRows
 	}
-	
+
 	// 配置缺失,返回错误而不是默认值
 	panic(fmt.Sprintf("表 '%s' 的阈值配置缺失,请在配置文件中添加相应配置", tableName))
 }