repository.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package repository
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/glebarez/sqlite"
  6. "github.com/go-nunu/nunu-layout-advanced/pkg/log"
  7. "github.com/go-nunu/nunu-layout-advanced/pkg/zapgorm2"
  8. "github.com/redis/go-redis/v9"
  9. "github.com/spf13/viper"
  10. "gorm.io/driver/mysql"
  11. "gorm.io/driver/postgres"
  12. "gorm.io/gorm"
  13. "time"
  14. )
  15. const ctxTxKey = "TxKey"
  16. type Repository struct {
  17. db *gorm.DB
  18. //rdb *redis.Client
  19. logger *log.Logger
  20. }
  21. func NewRepository(
  22. logger *log.Logger,
  23. db *gorm.DB,
  24. // rdb *redis.Client,
  25. ) *Repository {
  26. return &Repository{
  27. db: db,
  28. //rdb: rdb,
  29. logger: logger,
  30. }
  31. }
  32. type Transaction interface {
  33. Transaction(ctx context.Context, fn func(ctx context.Context) error) error
  34. }
  35. func NewTransaction(r *Repository) Transaction {
  36. return r
  37. }
  38. // DB return tx
  39. // If you need to create a Transaction, you must call DB(ctx) and Transaction(ctx,fn)
  40. func (r *Repository) DB(ctx context.Context) *gorm.DB {
  41. v := ctx.Value(ctxTxKey)
  42. if v != nil {
  43. if tx, ok := v.(*gorm.DB); ok {
  44. return tx
  45. }
  46. }
  47. return r.db.WithContext(ctx)
  48. }
  49. func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
  50. return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
  51. ctx = context.WithValue(ctx, ctxTxKey, tx)
  52. return fn(ctx)
  53. })
  54. }
  55. func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
  56. var (
  57. db *gorm.DB
  58. err error
  59. )
  60. logger := zapgorm2.New(l.Logger)
  61. driver := conf.GetString("data.db.user.driver")
  62. dsn := conf.GetString("data.db.user.dsn")
  63. // GORM doc: https://gorm.io/docs/connecting_to_the_database.html
  64. switch driver {
  65. case "mysql":
  66. db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
  67. Logger: logger,
  68. })
  69. case "postgres":
  70. db, err = gorm.Open(postgres.New(postgres.Config{
  71. DSN: dsn,
  72. PreferSimpleProtocol: true, // disables implicit prepared statement usage
  73. }), &gorm.Config{})
  74. case "sqlite":
  75. db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
  76. default:
  77. panic("unknown db driver")
  78. }
  79. if err != nil {
  80. panic(err)
  81. }
  82. db = db.Debug()
  83. // Connection Pool config
  84. sqlDB, err := db.DB()
  85. if err != nil {
  86. panic(err)
  87. }
  88. sqlDB.SetMaxIdleConns(10)
  89. sqlDB.SetMaxOpenConns(100)
  90. sqlDB.SetConnMaxLifetime(time.Hour)
  91. return db
  92. }
  93. func NewRedis(conf *viper.Viper) *redis.Client {
  94. rdb := redis.NewClient(&redis.Options{
  95. Addr: conf.GetString("data.redis.addr"),
  96. Password: conf.GetString("data.redis.password"),
  97. DB: conf.GetInt("data.redis.db"),
  98. })
  99. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  100. defer cancel()
  101. _, err := rdb.Ping(ctx).Result()
  102. if err != nil {
  103. panic(fmt.Sprintf("redis error: %s", err.Error()))
  104. }
  105. return rdb
  106. }