repository.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. package repository
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/glebarez/sqlite"
  6. "github.com/go-nunu/nunu-layout-advanced/pkg/log"
  7. "github.com/go-nunu/nunu-layout-advanced/pkg/zapgorm2"
  8. "github.com/qiniu/qmgo"
  9. "github.com/redis/go-redis/v9"
  10. "github.com/spf13/viper"
  11. "gorm.io/driver/mysql"
  12. "gorm.io/driver/postgres"
  13. "gorm.io/gorm"
  14. gormlogger "gorm.io/gorm/logger"
  15. "time"
  16. )
  17. const ctxTxKey = "TxKey"
  18. type Repository struct {
  19. db *gorm.DB
  20. //rdb *redis.Client
  21. mongoClient *qmgo.Client
  22. mongoDB *qmgo.Database
  23. logger *log.Logger
  24. }
  25. func NewRepository(
  26. logger *log.Logger,
  27. db *gorm.DB,
  28. // rdb *redis.Client,
  29. mongoClient *qmgo.Client,
  30. mongoDB *qmgo.Database,
  31. ) *Repository {
  32. return &Repository{
  33. db: db,
  34. //rdb: rdb,
  35. mongoClient: mongoClient,
  36. mongoDB: mongoDB,
  37. logger: logger,
  38. }
  39. }
  40. type Transaction interface {
  41. Transaction(ctx context.Context, fn func(ctx context.Context) error) error
  42. }
  43. func NewTransaction(r *Repository) Transaction {
  44. return r
  45. }
  46. // DB return tx
  47. // If you need to create a Transaction, you must call DB(ctx) and Transaction(ctx,fn)
  48. func (r *Repository) DB(ctx context.Context) *gorm.DB {
  49. v := ctx.Value(ctxTxKey)
  50. if v != nil {
  51. if tx, ok := v.(*gorm.DB); ok {
  52. return tx
  53. }
  54. }
  55. return r.db.WithContext(ctx)
  56. }
  57. func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
  58. return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
  59. ctx = context.WithValue(ctx, ctxTxKey, tx)
  60. return fn(ctx)
  61. })
  62. }
  63. func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
  64. var (
  65. db *gorm.DB
  66. err error
  67. )
  68. driver := conf.GetString("data.db.user.driver")
  69. dsn := conf.GetString("data.db.user.dsn")
  70. // 读取日志级别配置
  71. logLevelStr := conf.GetString("data.db.user.logLevel")
  72. var logLevel gormlogger.LogLevel
  73. switch logLevelStr {
  74. case "silent":
  75. logLevel = gormlogger.Silent
  76. case "error":
  77. logLevel = gormlogger.Error
  78. case "warn":
  79. logLevel = gormlogger.Warn
  80. case "info":
  81. logLevel = gormlogger.Info
  82. default:
  83. // MySQL 默认只记录警告和错误
  84. if driver == "mysql" {
  85. logLevel = gormlogger.Warn
  86. } else {
  87. logLevel = gormlogger.Info
  88. }
  89. }
  90. logger := zapgorm2.New(l.Logger).LogMode(logLevel)
  91. // GORM doc: https://gorm.io/docs/connecting_to_the_database.html
  92. switch driver {
  93. case "mysql":
  94. db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
  95. Logger: logger,
  96. })
  97. case "postgres":
  98. db, err = gorm.Open(postgres.New(postgres.Config{
  99. DSN: dsn,
  100. PreferSimpleProtocol: true,
  101. }), &gorm.Config{
  102. Logger: logger,
  103. })
  104. case "sqlite":
  105. db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{
  106. Logger: logger,
  107. })
  108. default:
  109. panic("unknown db driver")
  110. }
  111. if err != nil {
  112. panic(err)
  113. }
  114. // Connection Pool config
  115. sqlDB, err := db.DB()
  116. if err != nil {
  117. panic(err)
  118. }
  119. sqlDB.SetMaxIdleConns(10)
  120. sqlDB.SetMaxOpenConns(100)
  121. sqlDB.SetConnMaxLifetime(time.Hour)
  122. return db
  123. }
  124. func NewRedis(conf *viper.Viper) *redis.Client {
  125. rdb := redis.NewClient(&redis.Options{
  126. Addr: conf.GetString("data.redis.addr"),
  127. Password: conf.GetString("data.redis.password"),
  128. DB: conf.GetInt("data.redis.db"),
  129. })
  130. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  131. defer cancel()
  132. _, err := rdb.Ping(ctx).Result()
  133. if err != nil {
  134. panic(fmt.Sprintf("redis error: %s", err.Error()))
  135. }
  136. return rdb
  137. }
  138. func NewMongoClient(conf *viper.Viper) *qmgo.Client {
  139. timeout := conf.GetDuration("data.mongodb.timeout")
  140. if timeout == 0 {
  141. timeout = 10 * time.Second
  142. }
  143. maxPoolSize := conf.GetUint64("data.mongodb.max_pool_size")
  144. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  145. defer cancel()
  146. // 创建连接配置
  147. clientOpts := &qmgo.Config{
  148. Uri: conf.GetString("data.mongodb.uri"),
  149. MaxPoolSize: &maxPoolSize,
  150. }
  151. // 连接到MongoDB
  152. client, err := qmgo.NewClient(ctx, clientOpts)
  153. if err != nil {
  154. panic(fmt.Sprintf("连接MongoDB失败: %s", err.Error()))
  155. }
  156. return client
  157. }
  158. func NewMongoDB(client *qmgo.Client, conf *viper.Viper) *qmgo.Database {
  159. databaseName := conf.GetString("data.mongodb.database")
  160. if databaseName == "" {
  161. panic("MongoDB数据库名不能为空")
  162. }
  163. return client.Database(databaseName)
  164. }