manager.go 8.2 KB


  1. package sharding
  2. import (
  3. "context"
  4. "fmt"
  5. "regexp"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "github.com/go-nunu/nunu-layout-advanced/pkg/log"
  10. "gorm.io/gorm"
  11. )
  12. // TableModel 支持分表的模型接口
  13. type TableModel interface {
  14. GetBaseTableName() string
  15. GetCreatedAt() time.Time
  16. }
  17. // ThresholdConfig 阈值配置
  18. type ThresholdConfig struct {
  19. Enabled bool
  20. MaxRows int64
  21. CheckInterval time.Duration
  22. // 不同表的阈值配置
  23. TableThresholds map[string]int64
  24. }
  25. // ShardingManager 分表管理器
  26. type ShardingManager struct {
  27. strategy ShardingStrategy
  28. logger *log.Logger
  29. thresholdConfig *ThresholdConfig
  30. }
  31. func NewShardingManager(strategy ShardingStrategy, logger *log.Logger) *ShardingManager {
  32. return &ShardingManager{
  33. strategy: strategy,
  34. logger: logger,
  35. }
  36. }
  37. func NewShardingManagerWithThreshold(strategy ShardingStrategy, logger *log.Logger, thresholdConfig *ThresholdConfig) *ShardingManager {
  38. return &ShardingManager{
  39. strategy: strategy,
  40. logger: logger,
  41. thresholdConfig: thresholdConfig,
  42. }
  43. }
  44. // GetWriteTableName 获取写入表名(基于记录的创建时间)
  45. func (sm *ShardingManager) GetWriteTableName(model TableModel) string {
  46. baseTableName := model.GetBaseTableName()
  47. createdAt := model.GetCreatedAt()
  48. if createdAt.IsZero() {
  49. createdAt = time.Now()
  50. }
  51. return sm.strategy.GetTableName(baseTableName, createdAt)
  52. }
  53. // GetCurrentTableName 获取当前表名(用于写入新记录)
  54. func (sm *ShardingManager) GetCurrentTableName(baseTableName string) string {
  55. return sm.strategy.GetCurrentTableName(baseTableName)
  56. }
  57. // GetQueryTableNames 获取查询需要的所有表名
  58. func (sm *ShardingManager) GetQueryTableNames(baseTableName string, start, end *time.Time) []string {
  59. if start == nil || end == nil {
  60. // 如果没有指定时间范围,默认查询最近3个月的表
  61. now := time.Now()
  62. defaultStart := now.AddDate(0, -2, 0) // 前2个月
  63. defaultEnd := now
  64. return sm.strategy.GetTableNamesByRange(baseTableName, defaultStart, defaultEnd)
  65. }
  66. return sm.strategy.GetTableNamesByRange(baseTableName, *start, *end)
  67. }
  68. // EnsureTableExists 确保表存在,不存在则创建
  69. func (sm *ShardingManager) EnsureTableExists(ctx context.Context, db *gorm.DB, tableName string, model interface{}) error {
  70. // 检查表是否存在
  71. if db.Migrator().HasTable(tableName) {
  72. return nil
  73. }
  74. sm.logger.Info(fmt.Sprintf("创建分表: %s", tableName))
  75. // 使用指定的表名创建表
  76. return db.Table(tableName).AutoMigrate(model)
  77. }
  78. // BuildUnionQuery 构建联合查询(用于跨表查询)
  79. func (sm *ShardingManager) BuildUnionQuery(ctx context.Context, db *gorm.DB, tableNames []string, baseQuery func(*gorm.DB) *gorm.DB) *gorm.DB {
  80. if len(tableNames) == 0 {
  81. return db
  82. }
  83. // 过滤存在的表
  84. var existingTables []string
  85. for _, tableName := range tableNames {
  86. if db.Migrator().HasTable(tableName) {
  87. existingTables = append(existingTables, tableName)
  88. }
  89. }
  90. if len(existingTables) == 0 {
  91. return db
  92. }
  93. // 如果只有一个表,直接查询该表
  94. if len(existingTables) == 1 {
  95. return baseQuery(db.Table(existingTables[0]))
  96. }
  97. // 多表联合查询
  98. var subQueries []string
  99. for _, tableName := range existingTables {
  100. subQueries = append(subQueries, fmt.Sprintf("SELECT * FROM %s", tableName))
  101. }
  102. unionSQL := strings.Join(subQueries, " UNION ALL ")
  103. return baseQuery(db.Table(fmt.Sprintf("(%s) as sharded_table", unionSQL)))
  104. }
  105. // GetTableNamesWithExistenceCheck 获取存在的表名列表(只返回分表,不包含原表)
  106. func (sm *ShardingManager) GetTableNamesWithExistenceCheck(db *gorm.DB, baseTableName string, start, end *time.Time) []string {
  107. allTableNames := sm.GetQueryTableNames(baseTableName, start, end)
  108. var existingTables []string
  109. for _, tableName := range allTableNames {
  110. if db.Migrator().HasTable(tableName) {
  111. existingTables = append(existingTables, tableName)
  112. }
  113. }
  114. // 还要检查动态分表(带序号的表)
  115. dynamicTables := sm.findDynamicTables(db, allTableNames)
  116. existingTables = append(existingTables, dynamicTables...)
  117. return existingTables
  118. }
  119. // findDynamicTables 查找动态分表(带序号的表)
  120. func (sm *ShardingManager) findDynamicTables(db *gorm.DB, baseTableNames []string) []string {
  121. var dynamicTables []string
  122. for _, baseTableName := range baseTableNames {
  123. // 查找类似 log_202408_01, log_202408_02 这样的表
  124. pattern := fmt.Sprintf("%s_\\d+", baseTableName)
  125. if tables := sm.findTablesByPattern(db, pattern); len(tables) > 0 {
  126. dynamicTables = append(dynamicTables, tables...)
  127. }
  128. }
  129. return dynamicTables
  130. }
  131. // findTablesByPattern 根据模式查找表
  132. func (sm *ShardingManager) findTablesByPattern(db *gorm.DB, pattern string) []string {
  133. var tables []string
  134. // 获取所有表名
  135. rows, err := db.Raw("SHOW TABLES").Rows()
  136. if err != nil {
  137. sm.logger.Error("获取表列表失败: " + err.Error())
  138. return tables
  139. }
  140. defer rows.Close()
  141. regex, err := regexp.Compile(pattern)
  142. if err != nil {
  143. sm.logger.Error("编译正则表达式失败: " + err.Error())
  144. return tables
  145. }
  146. for rows.Next() {
  147. var tableName string
  148. if err := rows.Scan(&tableName); err != nil {
  149. continue
  150. }
  151. if regex.MatchString(tableName) {
  152. tables = append(tables, tableName)
  153. }
  154. }
  155. return tables
  156. }
  157. // GetOptimalWriteTable 获取最优的写入表(考虑数据量阈值)
  158. func (sm *ShardingManager) GetOptimalWriteTable(ctx context.Context, db *gorm.DB, model TableModel, maxRows int64) (string, error) {
  159. baseTableName := model.GetBaseTableName()
  160. createdAt := model.GetCreatedAt()
  161. if createdAt.IsZero() {
  162. createdAt = time.Now()
  163. }
  164. // 先获取基础表名
  165. baseShardTableName := sm.strategy.GetTableName(baseTableName, createdAt)
  166. // 如果没有启用阈值检查,直接返回基础表名
  167. if sm.thresholdConfig == nil || !sm.thresholdConfig.Enabled {
  168. return baseShardTableName, nil
  169. }
  170. // 使用配置的maxRows,如果没有则使用默认值
  171. if maxRows <= 0 {
  172. maxRows = sm.thresholdConfig.MaxRows
  173. }
  174. // 检查当前表是否已达到阈值
  175. currentTable := baseShardTableName
  176. for {
  177. if !db.Migrator().HasTable(currentTable) {
  178. // 表不存在,可以使用
  179. return currentTable, nil
  180. }
  181. // 检查表的数据量
  182. var count int64
  183. err := db.Table(currentTable).Count(&count).Error
  184. if err != nil {
  185. sm.logger.Error(fmt.Sprintf("检查表 %s 数据量失败: %v", currentTable, err))
  186. return currentTable, nil // 出错时返回当前表
  187. }
  188. if count < maxRows {
  189. // 当前表还有空间
  190. return currentTable, nil
  191. }
  192. // 当前表已满,尝试下一个序号的表
  193. currentTable = sm.getNextSequenceTable(currentTable)
  194. sm.logger.Info(fmt.Sprintf("表 %s 已达到阈值 %d,尝试使用 %s", baseShardTableName, maxRows, currentTable))
  195. }
  196. }
  197. // getNextSequenceTable 获取下一个序号的表名
  198. func (sm *ShardingManager) getNextSequenceTable(currentTableName string) string {
  199. // 检查是否已经有序号
  200. re := regexp.MustCompile(`^(.+)_(\d+)$`)
  201. matches := re.FindStringSubmatch(currentTableName)
  202. if len(matches) == 3 {
  203. // 已有序号,递增
  204. baseName := matches[1]
  205. seq, _ := strconv.Atoi(matches[2])
  206. return fmt.Sprintf("%s_%02d", baseName, seq+1)
  207. } else {
  208. // 没有序号,添加序号01
  209. return fmt.Sprintf("%s_01", currentTableName)
  210. }
  211. }
  212. // CheckAndCreateNewTable 检查是否需要创建新表(基于时间周期)
  213. func (sm *ShardingManager) CheckAndCreateNewTable(ctx context.Context, db *gorm.DB, baseTableName string, modelExample interface{}) error {
  214. currentTime := time.Now()
  215. expectedTableName := sm.strategy.GetTableName(baseTableName, currentTime)
  216. // 检查当前期间的表是否存在
  217. if !db.Migrator().HasTable(expectedTableName) {
  218. sm.logger.Info(fmt.Sprintf("创建新周期分表: %s", expectedTableName))
  219. return sm.EnsureTableExists(ctx, db, expectedTableName, modelExample)
  220. }
  221. return nil
  222. }
  223. // GetMaxRowsForTable 获取指定表的最大行数配置
  224. func (sm *ShardingManager) GetMaxRowsForTable(tableName string) int64 {
  225. // 优先使用表级配置
  226. if sm.thresholdConfig != nil && sm.thresholdConfig.TableThresholds != nil {
  227. if maxRows, exists := sm.thresholdConfig.TableThresholds[tableName]; exists {
  228. return maxRows
  229. }
  230. }
  231. // 使用默认配置
  232. if sm.thresholdConfig != nil {
  233. return sm.thresholdConfig.MaxRows
  234. }
  235. // 最终默认值
  236. return 3000000
  237. }