manager.go 7.7 KB


  1. package sharding
  2. import (
  3. "context"
  4. "fmt"
  5. "regexp"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "github.com/go-nunu/nunu-layout-advanced/pkg/log"
  10. "gorm.io/gorm"
  11. )
  12. // TableModel 支持分表的模型接口
  13. type TableModel interface {
  14. GetBaseTableName() string
  15. GetCreatedAt() time.Time
  16. }
  17. // ThresholdConfig 阈值配置
  18. type ThresholdConfig struct {
  19. Enabled bool
  20. MaxRows int64
  21. CheckInterval time.Duration
  22. }
  23. // ShardingManager 分表管理器
  24. type ShardingManager struct {
  25. strategy ShardingStrategy
  26. logger *log.Logger
  27. thresholdConfig *ThresholdConfig
  28. }
  29. func NewShardingManager(strategy ShardingStrategy, logger *log.Logger) *ShardingManager {
  30. return &ShardingManager{
  31. strategy: strategy,
  32. logger: logger,
  33. }
  34. }
  35. func NewShardingManagerWithThreshold(strategy ShardingStrategy, logger *log.Logger, thresholdConfig *ThresholdConfig) *ShardingManager {
  36. return &ShardingManager{
  37. strategy: strategy,
  38. logger: logger,
  39. thresholdConfig: thresholdConfig,
  40. }
  41. }
  42. // GetWriteTableName 获取写入表名(基于记录的创建时间)
  43. func (sm *ShardingManager) GetWriteTableName(model TableModel) string {
  44. baseTableName := model.GetBaseTableName()
  45. createdAt := model.GetCreatedAt()
  46. if createdAt.IsZero() {
  47. createdAt = time.Now()
  48. }
  49. return sm.strategy.GetTableName(baseTableName, createdAt)
  50. }
  51. // GetCurrentTableName 获取当前表名(用于写入新记录)
  52. func (sm *ShardingManager) GetCurrentTableName(baseTableName string) string {
  53. return sm.strategy.GetCurrentTableName(baseTableName)
  54. }
  55. // GetQueryTableNames 获取查询需要的所有表名
  56. func (sm *ShardingManager) GetQueryTableNames(baseTableName string, start, end *time.Time) []string {
  57. if start == nil || end == nil {
  58. // 如果没有指定时间范围,默认查询最近3个月的表
  59. now := time.Now()
  60. defaultStart := now.AddDate(0, -2, 0) // 前2个月
  61. defaultEnd := now
  62. return sm.strategy.GetTableNamesByRange(baseTableName, defaultStart, defaultEnd)
  63. }
  64. return sm.strategy.GetTableNamesByRange(baseTableName, *start, *end)
  65. }
  66. // EnsureTableExists 确保表存在,不存在则创建
  67. func (sm *ShardingManager) EnsureTableExists(ctx context.Context, db *gorm.DB, tableName string, model interface{}) error {
  68. // 检查表是否存在
  69. if db.Migrator().HasTable(tableName) {
  70. return nil
  71. }
  72. sm.logger.Info(fmt.Sprintf("创建分表: %s", tableName))
  73. // 使用指定的表名创建表
  74. return db.Table(tableName).AutoMigrate(model)
  75. }
  76. // BuildUnionQuery 构建联合查询(用于跨表查询)
  77. func (sm *ShardingManager) BuildUnionQuery(ctx context.Context, db *gorm.DB, tableNames []string, baseQuery func(*gorm.DB) *gorm.DB) *gorm.DB {
  78. if len(tableNames) == 0 {
  79. return db
  80. }
  81. // 过滤存在的表
  82. var existingTables []string
  83. for _, tableName := range tableNames {
  84. if db.Migrator().HasTable(tableName) {
  85. existingTables = append(existingTables, tableName)
  86. }
  87. }
  88. if len(existingTables) == 0 {
  89. return db
  90. }
  91. // 如果只有一个表,直接查询该表
  92. if len(existingTables) == 1 {
  93. return baseQuery(db.Table(existingTables[0]))
  94. }
  95. // 多表联合查询
  96. var subQueries []string
  97. for _, tableName := range existingTables {
  98. subQueries = append(subQueries, fmt.Sprintf("SELECT * FROM %s", tableName))
  99. }
  100. unionSQL := strings.Join(subQueries, " UNION ALL ")
  101. return baseQuery(db.Table(fmt.Sprintf("(%s) as sharded_table", unionSQL)))
  102. }
  103. // GetTableNamesWithExistenceCheck 获取存在的表名列表(只返回分表,不包含原表)
  104. func (sm *ShardingManager) GetTableNamesWithExistenceCheck(db *gorm.DB, baseTableName string, start, end *time.Time) []string {
  105. allTableNames := sm.GetQueryTableNames(baseTableName, start, end)
  106. var existingTables []string
  107. for _, tableName := range allTableNames {
  108. if db.Migrator().HasTable(tableName) {
  109. existingTables = append(existingTables, tableName)
  110. }
  111. }
  112. // 还要检查动态分表(带序号的表)
  113. dynamicTables := sm.findDynamicTables(db, allTableNames)
  114. existingTables = append(existingTables, dynamicTables...)
  115. return existingTables
  116. }
  117. // findDynamicTables 查找动态分表(带序号的表)
  118. func (sm *ShardingManager) findDynamicTables(db *gorm.DB, baseTableNames []string) []string {
  119. var dynamicTables []string
  120. for _, baseTableName := range baseTableNames {
  121. // 查找类似 log_202408_01, log_202408_02 这样的表
  122. pattern := fmt.Sprintf("%s_\\d+", baseTableName)
  123. if tables := sm.findTablesByPattern(db, pattern); len(tables) > 0 {
  124. dynamicTables = append(dynamicTables, tables...)
  125. }
  126. }
  127. return dynamicTables
  128. }
  129. // findTablesByPattern 根据模式查找表
  130. func (sm *ShardingManager) findTablesByPattern(db *gorm.DB, pattern string) []string {
  131. var tables []string
  132. // 获取所有表名
  133. rows, err := db.Raw("SHOW TABLES").Rows()
  134. if err != nil {
  135. sm.logger.Error("获取表列表失败: " + err.Error())
  136. return tables
  137. }
  138. defer rows.Close()
  139. regex, err := regexp.Compile(pattern)
  140. if err != nil {
  141. sm.logger.Error("编译正则表达式失败: " + err.Error())
  142. return tables
  143. }
  144. for rows.Next() {
  145. var tableName string
  146. if err := rows.Scan(&tableName); err != nil {
  147. continue
  148. }
  149. if regex.MatchString(tableName) {
  150. tables = append(tables, tableName)
  151. }
  152. }
  153. return tables
  154. }
  155. // GetOptimalWriteTable 获取最优的写入表(考虑数据量阈值)
  156. func (sm *ShardingManager) GetOptimalWriteTable(ctx context.Context, db *gorm.DB, model TableModel, maxRows int64) (string, error) {
  157. baseTableName := model.GetBaseTableName()
  158. createdAt := model.GetCreatedAt()
  159. if createdAt.IsZero() {
  160. createdAt = time.Now()
  161. }
  162. // 先获取基础表名
  163. baseShardTableName := sm.strategy.GetTableName(baseTableName, createdAt)
  164. // 如果没有启用阈值检查,直接返回基础表名
  165. if sm.thresholdConfig == nil || !sm.thresholdConfig.Enabled {
  166. return baseShardTableName, nil
  167. }
  168. // 使用配置的maxRows,如果没有则使用默认值
  169. if maxRows <= 0 {
  170. maxRows = sm.thresholdConfig.MaxRows
  171. }
  172. // 检查当前表是否已达到阈值
  173. currentTable := baseShardTableName
  174. for {
  175. if !db.Migrator().HasTable(currentTable) {
  176. // 表不存在,可以使用
  177. return currentTable, nil
  178. }
  179. // 检查表的数据量
  180. var count int64
  181. err := db.Table(currentTable).Count(&count).Error
  182. if err != nil {
  183. sm.logger.Error(fmt.Sprintf("检查表 %s 数据量失败: %v", currentTable, err))
  184. return currentTable, nil // 出错时返回当前表
  185. }
  186. if count < maxRows {
  187. // 当前表还有空间
  188. return currentTable, nil
  189. }
  190. // 当前表已满,尝试下一个序号的表
  191. currentTable = sm.getNextSequenceTable(currentTable)
  192. sm.logger.Info(fmt.Sprintf("表 %s 已达到阈值 %d,尝试使用 %s", baseShardTableName, maxRows, currentTable))
  193. }
  194. }
  195. // getNextSequenceTable 获取下一个序号的表名
  196. func (sm *ShardingManager) getNextSequenceTable(currentTableName string) string {
  197. // 检查是否已经有序号
  198. re := regexp.MustCompile(`^(.+)_(\d+)$`)
  199. matches := re.FindStringSubmatch(currentTableName)
  200. if len(matches) == 3 {
  201. // 已有序号,递增
  202. baseName := matches[1]
  203. seq, _ := strconv.Atoi(matches[2])
  204. return fmt.Sprintf("%s_%02d", baseName, seq+1)
  205. } else {
  206. // 没有序号,添加序号01
  207. return fmt.Sprintf("%s_01", currentTableName)
  208. }
  209. }
  210. // CheckAndCreateNewTable 检查是否需要创建新表(基于时间周期)
  211. func (sm *ShardingManager) CheckAndCreateNewTable(ctx context.Context, db *gorm.DB, baseTableName string, modelExample interface{}) error {
  212. currentTime := time.Now()
  213. expectedTableName := sm.strategy.GetTableName(baseTableName, currentTime)
  214. // 检查当前期间的表是否存在
  215. if !db.Migrator().HasTable(expectedTableName) {
  216. sm.logger.Info(fmt.Sprintf("创建新周期分表: %s", expectedTableName))
  217. return sm.EnsureTableExists(ctx, db, expectedTableName, modelExample)
  218. }
  219. return nil
  220. }