repository.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. package repository
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/glebarez/sqlite"
  6. "github.com/go-nunu/nunu-layout-advanced/pkg/log"
  7. "github.com/go-nunu/nunu-layout-advanced/pkg/mongo"
  8. "github.com/go-nunu/nunu-layout-advanced/pkg/zapgorm2"
  9. "github.com/redis/go-redis/v9"
  10. "github.com/spf13/viper"
  11. "gorm.io/driver/mysql"
  12. "gorm.io/driver/postgres"
  13. "gorm.io/gorm"
  14. gormlogger "gorm.io/gorm/logger"
  15. "time"
  16. )
  17. const ctxTxKey = "TxKey"
  18. type Repository struct {
  19. db *gorm.DB
  20. //rdb *redis.Client
  21. mongodb *mongo.MongoDB
  22. logger *log.Logger
  23. }
  24. func NewRepository(
  25. logger *log.Logger,
  26. db *gorm.DB,
  27. // rdb *redis.Client,
  28. mongodb *mongo.MongoDB,
  29. ) *Repository {
  30. return &Repository{
  31. db: db,
  32. //rdb: rdb,
  33. mongodb: mongodb,
  34. logger: logger,
  35. }
  36. }
  37. type Transaction interface {
  38. Transaction(ctx context.Context, fn func(ctx context.Context) error) error
  39. }
  40. func NewTransaction(r *Repository) Transaction {
  41. return r
  42. }
  43. // DB return tx
  44. // If you need to create a Transaction, you must call DB(ctx) and Transaction(ctx,fn)
  45. func (r *Repository) DB(ctx context.Context) *gorm.DB {
  46. v := ctx.Value(ctxTxKey)
  47. if v != nil {
  48. if tx, ok := v.(*gorm.DB); ok {
  49. return tx
  50. }
  51. }
  52. return r.db.WithContext(ctx)
  53. }
  54. func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
  55. return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
  56. ctx = context.WithValue(ctx, ctxTxKey, tx)
  57. return fn(ctx)
  58. })
  59. }
  60. func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
  61. var (
  62. db *gorm.DB
  63. err error
  64. )
  65. driver := conf.GetString("data.db.user.driver")
  66. dsn := conf.GetString("data.db.user.dsn")
  67. // 读取日志级别配置
  68. logLevelStr := conf.GetString("data.db.user.logLevel")
  69. var logLevel gormlogger.LogLevel
  70. switch logLevelStr {
  71. case "silent":
  72. logLevel = gormlogger.Silent
  73. case "error":
  74. logLevel = gormlogger.Error
  75. case "warn":
  76. logLevel = gormlogger.Warn
  77. case "info":
  78. logLevel = gormlogger.Info
  79. default:
  80. // MySQL 默认只记录警告和错误
  81. if driver == "mysql" {
  82. logLevel = gormlogger.Warn
  83. } else {
  84. logLevel = gormlogger.Info
  85. }
  86. }
  87. logger := zapgorm2.New(l.Logger).LogMode(logLevel)
  88. // GORM doc: https://gorm.io/docs/connecting_to_the_database.html
  89. switch driver {
  90. case "mysql":
  91. db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
  92. Logger: logger,
  93. })
  94. case "postgres":
  95. db, err = gorm.Open(postgres.New(postgres.Config{
  96. DSN: dsn,
  97. PreferSimpleProtocol: true,
  98. }), &gorm.Config{
  99. Logger: logger,
  100. })
  101. case "sqlite":
  102. db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{
  103. Logger: logger,
  104. })
  105. default:
  106. panic("unknown db driver")
  107. }
  108. if err != nil {
  109. panic(err)
  110. }
  111. // Connection Pool config
  112. sqlDB, err := db.DB()
  113. if err != nil {
  114. panic(err)
  115. }
  116. sqlDB.SetMaxIdleConns(10)
  117. sqlDB.SetMaxOpenConns(100)
  118. sqlDB.SetConnMaxLifetime(time.Hour)
  119. return db
  120. }
  121. func NewRedis(conf *viper.Viper) *redis.Client {
  122. rdb := redis.NewClient(&redis.Options{
  123. Addr: conf.GetString("data.redis.addr"),
  124. Password: conf.GetString("data.redis.password"),
  125. DB: conf.GetInt("data.redis.db"),
  126. })
  127. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  128. defer cancel()
  129. _, err := rdb.Ping(ctx).Result()
  130. if err != nil {
  131. panic(fmt.Sprintf("redis error: %s", err.Error()))
  132. }
  133. return rdb
  134. }
  135. func NewMongoDB(conf *viper.Viper) *mongo.MongoDB {
  136. config := &mongo.Config{
  137. URI: conf.GetString("data.mongodb.uri"),
  138. Database: conf.GetString("data.mongodb.database"),
  139. Timeout: conf.GetDuration("data.mongodb.timeout"),
  140. MaxPoolSize: conf.GetUint64("data.mongodb.max_pool_size"),
  141. }
  142. mongoDB, err := mongo.New(config)
  143. if err != nil {
  144. panic(fmt.Sprintf("mongodb error: %s", err.Error()))
  145. }
  146. return mongoDB
  147. }