package sharding import ( "context" "fmt" "regexp" "strconv" "strings" "time" "github.com/go-nunu/nunu-layout-advanced/pkg/log" "gorm.io/gorm" ) // TableModel 支持分表的模型接口 type TableModel interface { GetBaseTableName() string GetCreatedAt() time.Time } // ThresholdConfig 阈值配置 type ThresholdConfig struct { Enabled bool MaxRows int64 CheckInterval time.Duration } // ShardingManager 分表管理器 type ShardingManager struct { strategy ShardingStrategy logger *log.Logger thresholdConfig *ThresholdConfig } func NewShardingManager(strategy ShardingStrategy, logger *log.Logger) *ShardingManager { return &ShardingManager{ strategy: strategy, logger: logger, } } func NewShardingManagerWithThreshold(strategy ShardingStrategy, logger *log.Logger, thresholdConfig *ThresholdConfig) *ShardingManager { return &ShardingManager{ strategy: strategy, logger: logger, thresholdConfig: thresholdConfig, } } // GetWriteTableName 获取写入表名(基于记录的创建时间) func (sm *ShardingManager) GetWriteTableName(model TableModel) string { baseTableName := model.GetBaseTableName() createdAt := model.GetCreatedAt() if createdAt.IsZero() { createdAt = time.Now() } return sm.strategy.GetTableName(baseTableName, createdAt) } // GetCurrentTableName 获取当前表名(用于写入新记录) func (sm *ShardingManager) GetCurrentTableName(baseTableName string) string { return sm.strategy.GetCurrentTableName(baseTableName) } // GetQueryTableNames 获取查询需要的所有表名 func (sm *ShardingManager) GetQueryTableNames(baseTableName string, start, end *time.Time) []string { if start == nil || end == nil { // 如果没有指定时间范围,默认查询最近3个月的表 now := time.Now() defaultStart := now.AddDate(0, -2, 0) // 前2个月 defaultEnd := now return sm.strategy.GetTableNamesByRange(baseTableName, defaultStart, defaultEnd) } return sm.strategy.GetTableNamesByRange(baseTableName, *start, *end) } // EnsureTableExists 确保表存在,不存在则创建 func (sm *ShardingManager) EnsureTableExists(ctx context.Context, db *gorm.DB, tableName string, model interface{}) error { // 检查表是否存在 if db.Migrator().HasTable(tableName) { return nil } sm.logger.Info(fmt.Sprintf("创建分表: %s", tableName)) // 使用指定的表名创建表 return db.Table(tableName).AutoMigrate(model) } // BuildUnionQuery 构建联合查询(用于跨表查询) func (sm *ShardingManager) BuildUnionQuery(ctx context.Context, db *gorm.DB, tableNames []string, baseQuery func(*gorm.DB) *gorm.DB) *gorm.DB { if len(tableNames) == 0 { return db } // 过滤存在的表 var existingTables []string for _, tableName := range tableNames { if db.Migrator().HasTable(tableName) { existingTables = append(existingTables, tableName) } } if len(existingTables) == 0 { return db } // 如果只有一个表,直接查询该表 if len(existingTables) == 1 { return baseQuery(db.Table(existingTables[0])) } // 多表联合查询 var subQueries []string for _, tableName := range existingTables { subQueries = append(subQueries, fmt.Sprintf("SELECT * FROM %s", tableName)) } unionSQL := strings.Join(subQueries, " UNION ALL ") return baseQuery(db.Table(fmt.Sprintf("(%s) as sharded_table", unionSQL))) } // GetTableNamesWithExistenceCheck 获取存在的表名列表(只返回分表,不包含原表) func (sm *ShardingManager) GetTableNamesWithExistenceCheck(db *gorm.DB, baseTableName string, start, end *time.Time) []string { allTableNames := sm.GetQueryTableNames(baseTableName, start, end) var existingTables []string for _, tableName := range allTableNames { if db.Migrator().HasTable(tableName) { existingTables = append(existingTables, tableName) } } // 还要检查动态分表(带序号的表) dynamicTables := sm.findDynamicTables(db, allTableNames) existingTables = append(existingTables, dynamicTables...) return existingTables } // findDynamicTables 查找动态分表(带序号的表) func (sm *ShardingManager) findDynamicTables(db *gorm.DB, baseTableNames []string) []string { var dynamicTables []string for _, baseTableName := range baseTableNames { // 查找类似 log_202408_01, log_202408_02 这样的表 pattern := fmt.Sprintf("%s_\\d+", baseTableName) if tables := sm.findTablesByPattern(db, pattern); len(tables) > 0 { dynamicTables = append(dynamicTables, tables...) } } return dynamicTables } // findTablesByPattern 根据模式查找表 func (sm *ShardingManager) findTablesByPattern(db *gorm.DB, pattern string) []string { var tables []string // 获取所有表名 rows, err := db.Raw("SHOW TABLES").Rows() if err != nil { sm.logger.Error("获取表列表失败: " + err.Error()) return tables } defer rows.Close() regex, err := regexp.Compile(pattern) if err != nil { sm.logger.Error("编译正则表达式失败: " + err.Error()) return tables } for rows.Next() { var tableName string if err := rows.Scan(&tableName); err != nil { continue } if regex.MatchString(tableName) { tables = append(tables, tableName) } } return tables } // GetOptimalWriteTable 获取最优的写入表(考虑数据量阈值) func (sm *ShardingManager) GetOptimalWriteTable(ctx context.Context, db *gorm.DB, model TableModel, maxRows int64) (string, error) { baseTableName := model.GetBaseTableName() createdAt := model.GetCreatedAt() if createdAt.IsZero() { createdAt = time.Now() } // 先获取基础表名 baseShardTableName := sm.strategy.GetTableName(baseTableName, createdAt) // 如果没有启用阈值检查,直接返回基础表名 if sm.thresholdConfig == nil || !sm.thresholdConfig.Enabled { return baseShardTableName, nil } // 使用配置的maxRows,如果没有则使用默认值 if maxRows <= 0 { maxRows = sm.thresholdConfig.MaxRows } // 检查当前表是否已达到阈值 currentTable := baseShardTableName for { if !db.Migrator().HasTable(currentTable) { // 表不存在,可以使用 return currentTable, nil } // 检查表的数据量 var count int64 err := db.Table(currentTable).Count(&count).Error if err != nil { sm.logger.Error(fmt.Sprintf("检查表 %s 数据量失败: %v", currentTable, err)) return currentTable, nil // 出错时返回当前表 } if count < maxRows { // 当前表还有空间 return currentTable, nil } // 当前表已满,尝试下一个序号的表 currentTable = sm.getNextSequenceTable(currentTable) sm.logger.Info(fmt.Sprintf("表 %s 已达到阈值 %d,尝试使用 %s", baseShardTableName, maxRows, currentTable)) } } // getNextSequenceTable 获取下一个序号的表名 func (sm *ShardingManager) getNextSequenceTable(currentTableName string) string { // 检查是否已经有序号 re := regexp.MustCompile(`^(.+)_(\d+)$`) matches := re.FindStringSubmatch(currentTableName) if len(matches) == 3 { // 已有序号,递增 baseName := matches[1] seq, _ := strconv.Atoi(matches[2]) return fmt.Sprintf("%s_%02d", baseName, seq+1) } else { // 没有序号,添加序号01 return fmt.Sprintf("%s_01", currentTableName) } } // CheckAndCreateNewTable 检查是否需要创建新表(基于时间周期) func (sm *ShardingManager) CheckAndCreateNewTable(ctx context.Context, db *gorm.DB, baseTableName string, modelExample interface{}) error { currentTime := time.Now() expectedTableName := sm.strategy.GetTableName(baseTableName, currentTime) // 检查当前期间的表是否存在 if !db.Migrator().HasTable(expectedTableName) { sm.logger.Info(fmt.Sprintf("创建新周期分表: %s", expectedTableName)) return sm.EnsureTableExists(ctx, db, expectedTableName, modelExample) } return nil }