repository.go 6.2 KB


  1. package repository
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/glebarez/sqlite"
  6. "github.com/go-nunu/nunu-layout-advanced/pkg/log"
  7. "github.com/go-nunu/nunu-layout-advanced/pkg/zapgorm2"
  8. "github.com/qiniu/qmgo"
  9. "github.com/redis/go-redis/v9"
  10. "github.com/spf13/viper"
  11. "gorm.io/driver/mysql"
  12. "gorm.io/driver/postgres"
  13. "gorm.io/gorm"
  14. gormlogger "gorm.io/gorm/logger"
  15. "time"
  16. )
  17. const ctxTxKey = "TxKey"
  18. type Repository struct {
  19. db *gorm.DB // 主数据库连接
  20. //dbSecond *gorm.DB // 第二个数据库连接
  21. //rdb *redis.Client
  22. mongoClient *qmgo.Client
  23. mongoDB *qmgo.Database
  24. logger *log.Logger
  25. }
  26. func NewRepository(
  27. logger *log.Logger,
  28. db *gorm.DB,
  29. //dbSecond *gorm.DB,
  30. // rdb *redis.Client,
  31. mongoClient *qmgo.Client,
  32. mongoDB *qmgo.Database,
  33. ) *Repository {
  34. return &Repository{
  35. db: db,
  36. //dbSecond: dbSecond,
  37. //rdb: rdb,
  38. mongoClient: mongoClient,
  39. mongoDB: mongoDB,
  40. logger: logger,
  41. }
  42. }
  43. type Transaction interface {
  44. Transaction(ctx context.Context, fn func(ctx context.Context) error) error
  45. }
  46. func NewTransaction(r *Repository) Transaction {
  47. return r
  48. }
  49. // DB return tx
  50. // If you need to create a Transaction, you must call DB(ctx) and Transaction(ctx,fn)
  51. func (r *Repository) DB(ctx context.Context) *gorm.DB {
  52. v := ctx.Value(ctxTxKey)
  53. if v != nil {
  54. if tx, ok := v.(*gorm.DB); ok {
  55. return tx
  56. }
  57. }
  58. return r.db.WithContext(ctx)
  59. }
  60. // DBSecond returns the second database connection
  61. // Note: Transactions are currently only supported on the primary database
  62. //func (r *Repository) DBSecond(ctx context.Context) *gorm.DB {
  63. // return r.dbSecond.WithContext(ctx)
  64. //}
  65. func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
  66. return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
  67. ctx = context.WithValue(ctx, ctxTxKey, tx)
  68. return fn(ctx)
  69. })
  70. }
  71. func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
  72. var (
  73. db *gorm.DB
  74. err error
  75. )
  76. driver := conf.GetString("data.db.user.driver")
  77. dsn := conf.GetString("data.db.user.dsn")
  78. // 读取日志级别配置
  79. logLevelStr := conf.GetString("data.db.user.logLevel")
  80. var logLevel gormlogger.LogLevel
  81. switch logLevelStr {
  82. case "silent":
  83. logLevel = gormlogger.Silent
  84. case "error":
  85. logLevel = gormlogger.Error
  86. case "warn":
  87. logLevel = gormlogger.Warn
  88. case "info":
  89. logLevel = gormlogger.Info
  90. default:
  91. // MySQL 默认只记录警告和错误
  92. if driver == "mysql" {
  93. logLevel = gormlogger.Warn
  94. } else {
  95. logLevel = gormlogger.Info
  96. }
  97. }
  98. logger := zapgorm2.New(l.Logger).LogMode(logLevel)
  99. // GORM doc: https://gorm.io/docs/connecting_to_the_database.html
  100. switch driver {
  101. case "mysql":
  102. db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
  103. Logger: logger,
  104. })
  105. case "postgres":
  106. db, err = gorm.Open(postgres.New(postgres.Config{
  107. DSN: dsn,
  108. PreferSimpleProtocol: true,
  109. }), &gorm.Config{
  110. Logger: logger,
  111. })
  112. case "sqlite":
  113. db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{
  114. Logger: logger,
  115. })
  116. default:
  117. panic("unknown db driver")
  118. }
  119. if err != nil {
  120. panic(err)
  121. }
  122. // Connection Pool config
  123. sqlDB, err := db.DB()
  124. if err != nil {
  125. panic(err)
  126. }
  127. sqlDB.SetMaxIdleConns(10)
  128. sqlDB.SetMaxOpenConns(100)
  129. sqlDB.SetConnMaxLifetime(time.Hour)
  130. return db
  131. }
  132. // NewDBSecond 初始化第二个数据库连接
  133. func NewDBSecond(conf *viper.Viper, l *log.Logger) *gorm.DB {
  134. var (
  135. db *gorm.DB
  136. err error
  137. )
  138. // 从second配置项读取第二个数据库配置
  139. driver := conf.GetString("data.db.second.driver")
  140. dsn := conf.GetString("data.db.second.dsn")
  141. // 如果第二个数据库没有配置,返回nil
  142. if dsn == "" {
  143. l.Warn("第二个数据库配置不存在或DSN为空")
  144. return nil
  145. }
  146. // 读取日志级别配置
  147. logLevelStr := conf.GetString("data.db.second.logLevel")
  148. var logLevel gormlogger.LogLevel
  149. switch logLevelStr {
  150. case "silent":
  151. logLevel = gormlogger.Silent
  152. case "error":
  153. logLevel = gormlogger.Error
  154. case "warn":
  155. logLevel = gormlogger.Warn
  156. case "info":
  157. logLevel = gormlogger.Info
  158. default:
  159. // MySQL 默认只记录警告和错误
  160. if driver == "mysql" {
  161. logLevel = gormlogger.Warn
  162. } else {
  163. logLevel = gormlogger.Info
  164. }
  165. }
  166. logger := zapgorm2.New(l.Logger).LogMode(logLevel)
  167. // 初始化第二个数据库连接
  168. switch driver {
  169. case "mysql":
  170. db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
  171. Logger: logger,
  172. })
  173. case "postgres":
  174. db, err = gorm.Open(postgres.New(postgres.Config{
  175. DSN: dsn,
  176. PreferSimpleProtocol: true,
  177. }), &gorm.Config{
  178. Logger: logger,
  179. })
  180. case "sqlite":
  181. db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{
  182. Logger: logger,
  183. })
  184. default:
  185. panic("unknown db driver for second database")
  186. }
  187. if err != nil {
  188. panic("连接第二个数据库失败: " + err.Error())
  189. }
  190. // 配置连接池
  191. sqlDB, err := db.DB()
  192. if err != nil {
  193. panic(err)
  194. }
  195. sqlDB.SetMaxIdleConns(10)
  196. sqlDB.SetMaxOpenConns(100)
  197. sqlDB.SetConnMaxLifetime(time.Hour)
  198. return db
  199. }
  200. func NewRedis(conf *viper.Viper) *redis.Client {
  201. rdb := redis.NewClient(&redis.Options{
  202. Addr: conf.GetString("data.redis.addr"),
  203. Password: conf.GetString("data.redis.password"),
  204. DB: conf.GetInt("data.redis.db"),
  205. })
  206. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  207. defer cancel()
  208. _, err := rdb.Ping(ctx).Result()
  209. if err != nil {
  210. panic(fmt.Sprintf("redis error: %s", err.Error()))
  211. }
  212. return rdb
  213. }
  214. func NewMongoClient(conf *viper.Viper) *qmgo.Client {
  215. timeout := conf.GetDuration("data.mongodb.timeout")
  216. if timeout == 0 {
  217. timeout = 10 * time.Second
  218. }
  219. maxPoolSize := conf.GetUint64("data.mongodb.max_pool_size")
  220. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  221. defer cancel()
  222. // 创建连接配置
  223. clientOpts := &qmgo.Config{
  224. Uri: conf.GetString("data.mongodb.uri"),
  225. MaxPoolSize: &maxPoolSize,
  226. }
  227. // 连接到MongoDB
  228. client, err := qmgo.NewClient(ctx, clientOpts)
  229. if err != nil {
  230. panic(fmt.Sprintf("连接MongoDB失败: %s", err.Error()))
  231. }
  232. return client
  233. }
  234. func NewMongoDB(client *qmgo.Client, conf *viper.Viper) *qmgo.Database {
  235. databaseName := conf.GetString("data.mongodb.database")
  236. if databaseName == "" {
  237. panic("MongoDB数据库名不能为空")
  238. }
  239. return client.Database(databaseName)
  240. }