123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267 |
- package sharding
- import (
- "context"
- "fmt"
- "regexp"
- "strconv"
- "strings"
- "time"
- "github.com/go-nunu/nunu-layout-advanced/pkg/log"
- "gorm.io/gorm"
- )
- // TableModel 支持分表的模型接口
- type TableModel interface {
- GetBaseTableName() string
- GetCreatedAt() time.Time
- }
- // ThresholdConfig 阈值配置
- type ThresholdConfig struct {
- Enabled bool
- MaxRows int64
- CheckInterval time.Duration
- }
- // ShardingManager 分表管理器
- type ShardingManager struct {
- strategy ShardingStrategy
- logger *log.Logger
- thresholdConfig *ThresholdConfig
- }
- func NewShardingManager(strategy ShardingStrategy, logger *log.Logger) *ShardingManager {
- return &ShardingManager{
- strategy: strategy,
- logger: logger,
- }
- }
- func NewShardingManagerWithThreshold(strategy ShardingStrategy, logger *log.Logger, thresholdConfig *ThresholdConfig) *ShardingManager {
- return &ShardingManager{
- strategy: strategy,
- logger: logger,
- thresholdConfig: thresholdConfig,
- }
- }
- // GetWriteTableName 获取写入表名(基于记录的创建时间)
- func (sm *ShardingManager) GetWriteTableName(model TableModel) string {
- baseTableName := model.GetBaseTableName()
- createdAt := model.GetCreatedAt()
-
- if createdAt.IsZero() {
- createdAt = time.Now()
- }
-
- return sm.strategy.GetTableName(baseTableName, createdAt)
- }
- // GetCurrentTableName 获取当前表名(用于写入新记录)
- func (sm *ShardingManager) GetCurrentTableName(baseTableName string) string {
- return sm.strategy.GetCurrentTableName(baseTableName)
- }
- // GetQueryTableNames 获取查询需要的所有表名
- func (sm *ShardingManager) GetQueryTableNames(baseTableName string, start, end *time.Time) []string {
- if start == nil || end == nil {
- // 如果没有指定时间范围,默认查询最近3个月的表
- now := time.Now()
- defaultStart := now.AddDate(0, -2, 0) // 前2个月
- defaultEnd := now
- return sm.strategy.GetTableNamesByRange(baseTableName, defaultStart, defaultEnd)
- }
- return sm.strategy.GetTableNamesByRange(baseTableName, *start, *end)
- }
- // EnsureTableExists 确保表存在,不存在则创建
- func (sm *ShardingManager) EnsureTableExists(ctx context.Context, db *gorm.DB, tableName string, model interface{}) error {
- // 检查表是否存在
- if db.Migrator().HasTable(tableName) {
- return nil
- }
- sm.logger.Info(fmt.Sprintf("创建分表: %s", tableName))
-
- // 使用指定的表名创建表
- return db.Table(tableName).AutoMigrate(model)
- }
- // BuildUnionQuery 构建联合查询(用于跨表查询)
- func (sm *ShardingManager) BuildUnionQuery(ctx context.Context, db *gorm.DB, tableNames []string, baseQuery func(*gorm.DB) *gorm.DB) *gorm.DB {
- if len(tableNames) == 0 {
- return db
- }
- // 过滤存在的表
- var existingTables []string
- for _, tableName := range tableNames {
- if db.Migrator().HasTable(tableName) {
- existingTables = append(existingTables, tableName)
- }
- }
- if len(existingTables) == 0 {
- return db
- }
- // 如果只有一个表,直接查询该表
- if len(existingTables) == 1 {
- return baseQuery(db.Table(existingTables[0]))
- }
- // 多表联合查询
- var subQueries []string
- for _, tableName := range existingTables {
- subQueries = append(subQueries, fmt.Sprintf("SELECT * FROM %s", tableName))
- }
- unionSQL := strings.Join(subQueries, " UNION ALL ")
- return baseQuery(db.Table(fmt.Sprintf("(%s) as sharded_table", unionSQL)))
- }
- // GetTableNamesWithExistenceCheck 获取存在的表名列表(只返回分表,不包含原表)
- func (sm *ShardingManager) GetTableNamesWithExistenceCheck(db *gorm.DB, baseTableName string, start, end *time.Time) []string {
- allTableNames := sm.GetQueryTableNames(baseTableName, start, end)
- var existingTables []string
-
- for _, tableName := range allTableNames {
- if db.Migrator().HasTable(tableName) {
- existingTables = append(existingTables, tableName)
- }
- }
-
- // 还要检查动态分表(带序号的表)
- dynamicTables := sm.findDynamicTables(db, allTableNames)
- existingTables = append(existingTables, dynamicTables...)
-
- return existingTables
- }
- // findDynamicTables 查找动态分表(带序号的表)
- func (sm *ShardingManager) findDynamicTables(db *gorm.DB, baseTableNames []string) []string {
- var dynamicTables []string
-
- for _, baseTableName := range baseTableNames {
- // 查找类似 log_202408_01, log_202408_02 这样的表
- pattern := fmt.Sprintf("%s_\\d+", baseTableName)
- if tables := sm.findTablesByPattern(db, pattern); len(tables) > 0 {
- dynamicTables = append(dynamicTables, tables...)
- }
- }
-
- return dynamicTables
- }
- // findTablesByPattern 根据模式查找表
- func (sm *ShardingManager) findTablesByPattern(db *gorm.DB, pattern string) []string {
- var tables []string
-
- // 获取所有表名
- rows, err := db.Raw("SHOW TABLES").Rows()
- if err != nil {
- sm.logger.Error("获取表列表失败: " + err.Error())
- return tables
- }
- defer rows.Close()
-
- regex, err := regexp.Compile(pattern)
- if err != nil {
- sm.logger.Error("编译正则表达式失败: " + err.Error())
- return tables
- }
-
- for rows.Next() {
- var tableName string
- if err := rows.Scan(&tableName); err != nil {
- continue
- }
- if regex.MatchString(tableName) {
- tables = append(tables, tableName)
- }
- }
-
- return tables
- }
- // GetOptimalWriteTable 获取最优的写入表(考虑数据量阈值)
- func (sm *ShardingManager) GetOptimalWriteTable(ctx context.Context, db *gorm.DB, model TableModel, maxRows int64) (string, error) {
- baseTableName := model.GetBaseTableName()
- createdAt := model.GetCreatedAt()
-
- if createdAt.IsZero() {
- createdAt = time.Now()
- }
-
- // 先获取基础表名
- baseShardTableName := sm.strategy.GetTableName(baseTableName, createdAt)
-
- // 如果没有启用阈值检查,直接返回基础表名
- if sm.thresholdConfig == nil || !sm.thresholdConfig.Enabled {
- return baseShardTableName, nil
- }
-
- // 使用配置的maxRows,如果没有则使用默认值
- if maxRows <= 0 {
- maxRows = sm.thresholdConfig.MaxRows
- }
-
- // 检查当前表是否已达到阈值
- currentTable := baseShardTableName
- for {
- if !db.Migrator().HasTable(currentTable) {
- // 表不存在,可以使用
- return currentTable, nil
- }
-
- // 检查表的数据量
- var count int64
- err := db.Table(currentTable).Count(&count).Error
- if err != nil {
- sm.logger.Error(fmt.Sprintf("检查表 %s 数据量失败: %v", currentTable, err))
- return currentTable, nil // 出错时返回当前表
- }
-
- if count < maxRows {
- // 当前表还有空间
- return currentTable, nil
- }
-
- // 当前表已满,尝试下一个序号的表
- currentTable = sm.getNextSequenceTable(currentTable)
- sm.logger.Info(fmt.Sprintf("表 %s 已达到阈值 %d,尝试使用 %s", baseShardTableName, maxRows, currentTable))
- }
- }
- // getNextSequenceTable 获取下一个序号的表名
- func (sm *ShardingManager) getNextSequenceTable(currentTableName string) string {
- // 检查是否已经有序号
- re := regexp.MustCompile(`^(.+)_(\d+)$`)
- matches := re.FindStringSubmatch(currentTableName)
-
- if len(matches) == 3 {
- // 已有序号,递增
- baseName := matches[1]
- seq, _ := strconv.Atoi(matches[2])
- return fmt.Sprintf("%s_%02d", baseName, seq+1)
- } else {
- // 没有序号,添加序号01
- return fmt.Sprintf("%s_01", currentTableName)
- }
- }
- // CheckAndCreateNewTable 检查是否需要创建新表(基于时间周期)
- func (sm *ShardingManager) CheckAndCreateNewTable(ctx context.Context, db *gorm.DB, baseTableName string, modelExample interface{}) error {
- currentTime := time.Now()
- expectedTableName := sm.strategy.GetTableName(baseTableName, currentTime)
-
- // 检查当前期间的表是否存在
- if !db.Migrator().HasTable(expectedTableName) {
- sm.logger.Info(fmt.Sprintf("创建新周期分表: %s", expectedTableName))
- return sm.EnsureTableExists(ctx, db, expectedTableName, modelExample)
- }
-
- return nil
- }
|