|
@@ -55,11 +55,11 @@ func NewShardingManager(strategy ShardingStrategy, logger *log.Logger, config *T
|
|
func (sm *ShardingManager) GetWriteTableName(model TableModel) string {
|
|
func (sm *ShardingManager) GetWriteTableName(model TableModel) string {
|
|
baseTableName := model.GetBaseTableName()
|
|
baseTableName := model.GetBaseTableName()
|
|
createdAt := model.GetCreatedAt()
|
|
createdAt := model.GetCreatedAt()
|
|
-
|
|
|
|
|
|
+
|
|
if createdAt.IsZero() {
|
|
if createdAt.IsZero() {
|
|
createdAt = time.Now()
|
|
createdAt = time.Now()
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
return sm.strategy.GetTableName(baseTableName, createdAt)
|
|
return sm.strategy.GetTableName(baseTableName, createdAt)
|
|
}
|
|
}
|
|
|
|
|
|
@@ -70,36 +70,16 @@ func (sm *ShardingManager) GetCurrentTableName(baseTableName string) string {
|
|
|
|
|
|
// GetQueryTableNames 获取查询需要的所有表名
|
|
// GetQueryTableNames 获取查询需要的所有表名
|
|
func (sm *ShardingManager) GetQueryTableNames(baseTableName string, start, end *time.Time) []string {
|
|
func (sm *ShardingManager) GetQueryTableNames(baseTableName string, start, end *time.Time) []string {
|
|
- if start == nil && end == nil {
|
|
|
|
- // 如果都为nil,返回所有可能的表名(用于显示所有数据)
|
|
|
|
- return sm.getAllPossibleTableNames(baseTableName)
|
|
|
|
- } else if start == nil || end == nil {
|
|
|
|
- // 如果只有一个为nil,默认查询最近3个月的表
|
|
|
|
|
|
+ if start == nil || end == nil {
|
|
|
|
+ // 如果没有指定时间范围,默认查询最近3个月的表
|
|
now := time.Now()
|
|
now := time.Now()
|
|
- defaultStart := now.AddDate(0, -2, 0) // 前2个月
|
|
|
|
|
|
+ defaultStart := now.AddDate(0, -120, 0) // 前2个月
|
|
defaultEnd := now
|
|
defaultEnd := now
|
|
return sm.strategy.GetTableNamesByRange(baseTableName, defaultStart, defaultEnd)
|
|
return sm.strategy.GetTableNamesByRange(baseTableName, defaultStart, defaultEnd)
|
|
}
|
|
}
|
|
return sm.strategy.GetTableNamesByRange(baseTableName, *start, *end)
|
|
return sm.strategy.GetTableNamesByRange(baseTableName, *start, *end)
|
|
}
|
|
}
|
|
|
|
|
|
-// getAllPossibleTableNames 直接从数据库中查找所有符合模式的表名
|
|
|
|
-func (sm *ShardingManager) getAllPossibleTableNames(baseTableName string) []string {
|
|
|
|
- // 使用数据库查找所有符合模式的表,而不是基于时间推测
|
|
|
|
- // 查找基础分表模式: log_202408, log_202407 等
|
|
|
|
- basePattern := fmt.Sprintf("%s_\\d{6}", baseTableName) // 例如: log_202408
|
|
|
|
- baseTables := sm.findTablesByPattern(nil, basePattern)
|
|
|
|
-
|
|
|
|
- // 查找序列分表模式: log_202408_01, log_202408_02 等
|
|
|
|
- sequencePattern := fmt.Sprintf("%s_\\d{6}_\\d+", baseTableName) // 例如: log_202408_01
|
|
|
|
- sequenceTables := sm.findTablesByPattern(nil, sequencePattern)
|
|
|
|
-
|
|
|
|
- // 合并结果
|
|
|
|
- allTables := append(baseTables, sequenceTables...)
|
|
|
|
-
|
|
|
|
- return allTables
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
// EnsureTableExists 确保表存在,不存在则创建
|
|
// EnsureTableExists 确保表存在,不存在则创建
|
|
func (sm *ShardingManager) EnsureTableExists(ctx context.Context, db *gorm.DB, tableName string, model interface{}) error {
|
|
func (sm *ShardingManager) EnsureTableExists(ctx context.Context, db *gorm.DB, tableName string, model interface{}) error {
|
|
// 检查表是否存在
|
|
// 检查表是否存在
|
|
@@ -108,7 +88,7 @@ func (sm *ShardingManager) EnsureTableExists(ctx context.Context, db *gorm.DB, t
|
|
}
|
|
}
|
|
|
|
|
|
sm.logger.Info(fmt.Sprintf("创建分表: %s", tableName))
|
|
sm.logger.Info(fmt.Sprintf("创建分表: %s", tableName))
|
|
-
|
|
|
|
|
|
+
|
|
// 使用指定的表名创建表
|
|
// 使用指定的表名创建表
|
|
return db.Table(tableName).AutoMigrate(model)
|
|
return db.Table(tableName).AutoMigrate(model)
|
|
}
|
|
}
|
|
@@ -150,24 +130,24 @@ func (sm *ShardingManager) BuildUnionQuery(ctx context.Context, db *gorm.DB, tab
|
|
func (sm *ShardingManager) GetTableNamesWithExistenceCheck(db *gorm.DB, baseTableName string, start, end *time.Time) []string {
|
|
func (sm *ShardingManager) GetTableNamesWithExistenceCheck(db *gorm.DB, baseTableName string, start, end *time.Time) []string {
|
|
allTableNames := sm.GetQueryTableNames(baseTableName, start, end)
|
|
allTableNames := sm.GetQueryTableNames(baseTableName, start, end)
|
|
var existingTables []string
|
|
var existingTables []string
|
|
-
|
|
|
|
|
|
+
|
|
for _, tableName := range allTableNames {
|
|
for _, tableName := range allTableNames {
|
|
if db.Migrator().HasTable(tableName) {
|
|
if db.Migrator().HasTable(tableName) {
|
|
existingTables = append(existingTables, tableName)
|
|
existingTables = append(existingTables, tableName)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
// 还要检查动态分表(带序号的表)
|
|
// 还要检查动态分表(带序号的表)
|
|
dynamicTables := sm.findDynamicTables(db, allTableNames)
|
|
dynamicTables := sm.findDynamicTables(db, allTableNames)
|
|
existingTables = append(existingTables, dynamicTables...)
|
|
existingTables = append(existingTables, dynamicTables...)
|
|
-
|
|
|
|
|
|
+
|
|
return existingTables
|
|
return existingTables
|
|
}
|
|
}
|
|
|
|
|
|
// findDynamicTables 查找动态分表(带序号的表)
|
|
// findDynamicTables 查找动态分表(带序号的表)
|
|
func (sm *ShardingManager) findDynamicTables(db *gorm.DB, baseTableNames []string) []string {
|
|
func (sm *ShardingManager) findDynamicTables(db *gorm.DB, baseTableNames []string) []string {
|
|
var dynamicTables []string
|
|
var dynamicTables []string
|
|
-
|
|
|
|
|
|
+
|
|
for _, baseTableName := range baseTableNames {
|
|
for _, baseTableName := range baseTableNames {
|
|
// 查找类似 log_202408_01, log_202408_02 这样的表
|
|
// 查找类似 log_202408_01, log_202408_02 这样的表
|
|
pattern := fmt.Sprintf("%s_\\d+", baseTableName)
|
|
pattern := fmt.Sprintf("%s_\\d+", baseTableName)
|
|
@@ -175,14 +155,14 @@ func (sm *ShardingManager) findDynamicTables(db *gorm.DB, baseTableNames []strin
|
|
dynamicTables = append(dynamicTables, tables...)
|
|
dynamicTables = append(dynamicTables, tables...)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
return dynamicTables
|
|
return dynamicTables
|
|
}
|
|
}
|
|
|
|
|
|
// findTablesByPattern 根据模式查找表
|
|
// findTablesByPattern 根据模式查找表
|
|
func (sm *ShardingManager) findTablesByPattern(db *gorm.DB, pattern string) []string {
|
|
func (sm *ShardingManager) findTablesByPattern(db *gorm.DB, pattern string) []string {
|
|
var tables []string
|
|
var tables []string
|
|
-
|
|
|
|
|
|
+
|
|
// 获取所有表名
|
|
// 获取所有表名
|
|
rows, err := db.Raw("SHOW TABLES").Rows()
|
|
rows, err := db.Raw("SHOW TABLES").Rows()
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -190,13 +170,13 @@ func (sm *ShardingManager) findTablesByPattern(db *gorm.DB, pattern string) []st
|
|
return tables
|
|
return tables
|
|
}
|
|
}
|
|
defer rows.Close()
|
|
defer rows.Close()
|
|
-
|
|
|
|
|
|
+
|
|
regex, err := regexp.Compile(pattern)
|
|
regex, err := regexp.Compile(pattern)
|
|
if err != nil {
|
|
if err != nil {
|
|
sm.logger.Error("编译正则表达式失败: " + err.Error())
|
|
sm.logger.Error("编译正则表达式失败: " + err.Error())
|
|
return tables
|
|
return tables
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
for rows.Next() {
|
|
for rows.Next() {
|
|
var tableName string
|
|
var tableName string
|
|
if err := rows.Scan(&tableName); err != nil {
|
|
if err := rows.Scan(&tableName); err != nil {
|
|
@@ -206,7 +186,7 @@ func (sm *ShardingManager) findTablesByPattern(db *gorm.DB, pattern string) []st
|
|
tables = append(tables, tableName)
|
|
tables = append(tables, tableName)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
return tables
|
|
return tables
|
|
}
|
|
}
|
|
|
|
|
|
@@ -214,27 +194,27 @@ func (sm *ShardingManager) findTablesByPattern(db *gorm.DB, pattern string) []st
|
|
func (sm *ShardingManager) GetOptimalWriteTable(ctx context.Context, db *gorm.DB, model TableModel) (string, error) {
|
|
func (sm *ShardingManager) GetOptimalWriteTable(ctx context.Context, db *gorm.DB, model TableModel) (string, error) {
|
|
baseTableName := model.GetBaseTableName()
|
|
baseTableName := model.GetBaseTableName()
|
|
createdAt := model.GetCreatedAt()
|
|
createdAt := model.GetCreatedAt()
|
|
-
|
|
|
|
|
|
+
|
|
if createdAt.IsZero() {
|
|
if createdAt.IsZero() {
|
|
createdAt = time.Now()
|
|
createdAt = time.Now()
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
// 先获取基础表名
|
|
// 先获取基础表名
|
|
baseShardTableName := sm.strategy.GetTableName(baseTableName, createdAt)
|
|
baseShardTableName := sm.strategy.GetTableName(baseTableName, createdAt)
|
|
-
|
|
|
|
|
|
+
|
|
// 如果没有启用阈值检查,直接返回基础表名
|
|
// 如果没有启用阈值检查,直接返回基础表名
|
|
if sm.thresholdConfig == nil || !sm.thresholdConfig.Enabled {
|
|
if sm.thresholdConfig == nil || !sm.thresholdConfig.Enabled {
|
|
return baseShardTableName, nil
|
|
return baseShardTableName, nil
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
// 根据表名自动获取阈值
|
|
// 根据表名自动获取阈值
|
|
maxRows := sm.GetMaxRowsForTable(baseTableName)
|
|
maxRows := sm.GetMaxRowsForTable(baseTableName)
|
|
-
|
|
|
|
|
|
+
|
|
// 如果返回-1,表示该表禁用了阈值检查,直接返回基础表名
|
|
// 如果返回-1,表示该表禁用了阈值检查,直接返回基础表名
|
|
if maxRows == -1 {
|
|
if maxRows == -1 {
|
|
return baseShardTableName, nil
|
|
return baseShardTableName, nil
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
// 检查当前表是否已达到阈值
|
|
// 检查当前表是否已达到阈值
|
|
currentTable := baseShardTableName
|
|
currentTable := baseShardTableName
|
|
for {
|
|
for {
|
|
@@ -242,7 +222,7 @@ func (sm *ShardingManager) GetOptimalWriteTable(ctx context.Context, db *gorm.DB
|
|
// 表不存在,可以使用
|
|
// 表不存在,可以使用
|
|
return currentTable, nil
|
|
return currentTable, nil
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
// 检查表的数据量
|
|
// 检查表的数据量
|
|
var count int64
|
|
var count int64
|
|
err := db.Table(currentTable).Count(&count).Error
|
|
err := db.Table(currentTable).Count(&count).Error
|
|
@@ -250,12 +230,12 @@ func (sm *ShardingManager) GetOptimalWriteTable(ctx context.Context, db *gorm.DB
|
|
sm.logger.Error(fmt.Sprintf("检查表 %s 数据量失败: %v", currentTable, err))
|
|
sm.logger.Error(fmt.Sprintf("检查表 %s 数据量失败: %v", currentTable, err))
|
|
return currentTable, nil // 出错时返回当前表
|
|
return currentTable, nil // 出错时返回当前表
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
if count < maxRows {
|
|
if count < maxRows {
|
|
// 当前表还有空间
|
|
// 当前表还有空间
|
|
return currentTable, nil
|
|
return currentTable, nil
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
// 当前表已满,尝试下一个序号的表
|
|
// 当前表已满,尝试下一个序号的表
|
|
currentTable = sm.getNextSequenceTable(currentTable)
|
|
currentTable = sm.getNextSequenceTable(currentTable)
|
|
sm.logger.Info(fmt.Sprintf("表 %s 已达到阈值 %d,尝试使用 %s", baseShardTableName, maxRows, currentTable))
|
|
sm.logger.Info(fmt.Sprintf("表 %s 已达到阈值 %d,尝试使用 %s", baseShardTableName, maxRows, currentTable))
|
|
@@ -267,7 +247,7 @@ func (sm *ShardingManager) getNextSequenceTable(currentTableName string) string
|
|
// 检查是否已经有序号
|
|
// 检查是否已经有序号
|
|
re := regexp.MustCompile(`^(.+)_(\d+)$`)
|
|
re := regexp.MustCompile(`^(.+)_(\d+)$`)
|
|
matches := re.FindStringSubmatch(currentTableName)
|
|
matches := re.FindStringSubmatch(currentTableName)
|
|
-
|
|
|
|
|
|
+
|
|
if len(matches) == 3 {
|
|
if len(matches) == 3 {
|
|
// 已有序号,递增
|
|
// 已有序号,递增
|
|
baseName := matches[1]
|
|
baseName := matches[1]
|
|
@@ -283,13 +263,13 @@ func (sm *ShardingManager) getNextSequenceTable(currentTableName string) string
|
|
func (sm *ShardingManager) CheckAndCreateNewTable(ctx context.Context, db *gorm.DB, baseTableName string, modelExample interface{}) error {
|
|
func (sm *ShardingManager) CheckAndCreateNewTable(ctx context.Context, db *gorm.DB, baseTableName string, modelExample interface{}) error {
|
|
currentTime := time.Now()
|
|
currentTime := time.Now()
|
|
expectedTableName := sm.strategy.GetTableName(baseTableName, currentTime)
|
|
expectedTableName := sm.strategy.GetTableName(baseTableName, currentTime)
|
|
-
|
|
|
|
|
|
+
|
|
// 检查当前期间的表是否存在
|
|
// 检查当前期间的表是否存在
|
|
if !db.Migrator().HasTable(expectedTableName) {
|
|
if !db.Migrator().HasTable(expectedTableName) {
|
|
sm.logger.Info(fmt.Sprintf("创建新周期分表: %s", expectedTableName))
|
|
sm.logger.Info(fmt.Sprintf("创建新周期分表: %s", expectedTableName))
|
|
return sm.EnsureTableExists(ctx, db, expectedTableName, modelExample)
|
|
return sm.EnsureTableExists(ctx, db, expectedTableName, modelExample)
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
@@ -307,18 +287,18 @@ func (sm *ShardingManager) GetMaxRowsForTable(tableName string) int64 {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
// 检查全局配置是否启用
|
|
// 检查全局配置是否启用
|
|
if sm.thresholdConfig != nil && !sm.thresholdConfig.Enabled {
|
|
if sm.thresholdConfig != nil && !sm.thresholdConfig.Enabled {
|
|
// 全局禁用阈值检查,返回-1表示无限制
|
|
// 全局禁用阈值检查,返回-1表示无限制
|
|
return -1
|
|
return -1
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
// 使用全局默认配置
|
|
// 使用全局默认配置
|
|
if sm.thresholdConfig != nil && sm.thresholdConfig.MaxRows > 0 {
|
|
if sm.thresholdConfig != nil && sm.thresholdConfig.MaxRows > 0 {
|
|
return sm.thresholdConfig.MaxRows
|
|
return sm.thresholdConfig.MaxRows
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
// 配置缺失,返回错误而不是默认值
|
|
// 配置缺失,返回错误而不是默认值
|
|
panic(fmt.Sprintf("表 '%s' 的阈值配置缺失,请在配置文件中添加相应配置", tableName))
|
|
panic(fmt.Sprintf("表 '%s' 的阈值配置缺失,请在配置文件中添加相应配置", tableName))
|
|
}
|
|
}
|