repository.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. package repository
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/glebarez/sqlite"
  6. "github.com/go-nunu/nunu-layout-advanced/pkg/log"
  7. "github.com/go-nunu/nunu-layout-advanced/pkg/zapgorm2"
  8. "github.com/redis/go-redis/v9"
  9. "github.com/spf13/viper"
  10. "gorm.io/driver/mysql"
  11. "gorm.io/driver/postgres"
  12. "gorm.io/gorm"
  13. gormlogger "gorm.io/gorm/logger"
  14. "time"
  15. )
  16. const ctxTxKey = "TxKey"
  17. type Repository struct {
  18. db *gorm.DB
  19. //rdb *redis.Client
  20. logger *log.Logger
  21. }
  22. func NewRepository(
  23. logger *log.Logger,
  24. db *gorm.DB,
  25. // rdb *redis.Client,
  26. ) *Repository {
  27. return &Repository{
  28. db: db,
  29. //rdb: rdb,
  30. logger: logger,
  31. }
  32. }
  33. type Transaction interface {
  34. Transaction(ctx context.Context, fn func(ctx context.Context) error) error
  35. }
  36. func NewTransaction(r *Repository) Transaction {
  37. return r
  38. }
  39. // DB return tx
  40. // If you need to create a Transaction, you must call DB(ctx) and Transaction(ctx,fn)
  41. func (r *Repository) DB(ctx context.Context) *gorm.DB {
  42. v := ctx.Value(ctxTxKey)
  43. if v != nil {
  44. if tx, ok := v.(*gorm.DB); ok {
  45. return tx
  46. }
  47. }
  48. return r.db.WithContext(ctx)
  49. }
  50. func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
  51. return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
  52. ctx = context.WithValue(ctx, ctxTxKey, tx)
  53. return fn(ctx)
  54. })
  55. }
  56. func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
  57. var (
  58. db *gorm.DB
  59. err error
  60. )
  61. driver := conf.GetString("data.db.user.driver")
  62. dsn := conf.GetString("data.db.user.dsn")
  63. // 读取日志级别配置
  64. logLevelStr := conf.GetString("data.db.user.logLevel")
  65. var logLevel gormlogger.LogLevel
  66. switch logLevelStr {
  67. case "silent":
  68. logLevel = gormlogger.Silent
  69. case "error":
  70. logLevel = gormlogger.Error
  71. case "warn":
  72. logLevel = gormlogger.Warn
  73. case "info":
  74. logLevel = gormlogger.Info
  75. default:
  76. // MySQL 默认只记录警告和错误
  77. if driver == "mysql" {
  78. logLevel = gormlogger.Warn
  79. } else {
  80. logLevel = gormlogger.Info
  81. }
  82. }
  83. logger := zapgorm2.New(l.Logger).LogMode(logLevel)
  84. // GORM doc: https://gorm.io/docs/connecting_to_the_database.html
  85. switch driver {
  86. case "mysql":
  87. db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
  88. Logger: logger,
  89. })
  90. case "postgres":
  91. db, err = gorm.Open(postgres.New(postgres.Config{
  92. DSN: dsn,
  93. PreferSimpleProtocol: true,
  94. }), &gorm.Config{
  95. Logger: logger,
  96. })
  97. case "sqlite":
  98. db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{
  99. Logger: logger,
  100. })
  101. default:
  102. panic("unknown db driver")
  103. }
  104. if err != nil {
  105. panic(err)
  106. }
  107. // Connection Pool config
  108. sqlDB, err := db.DB()
  109. if err != nil {
  110. panic(err)
  111. }
  112. sqlDB.SetMaxIdleConns(10)
  113. sqlDB.SetMaxOpenConns(100)
  114. sqlDB.SetConnMaxLifetime(time.Hour)
  115. return db
  116. }
  117. func NewRedis(conf *viper.Viper) *redis.Client {
  118. rdb := redis.NewClient(&redis.Options{
  119. Addr: conf.GetString("data.redis.addr"),
  120. Password: conf.GetString("data.redis.password"),
  121. DB: conf.GetInt("data.redis.db"),
  122. })
  123. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  124. defer cancel()
  125. _, err := rdb.Ping(ctx).Result()
  126. if err != nil {
  127. panic(fmt.Sprintf("redis error: %s", err.Error()))
  128. }
  129. return rdb
  130. }