repository.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. package repository
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/go-nunu/nunu-layout-advanced/pkg/log"
  6. "github.com/go-nunu/nunu-layout-advanced/pkg/zapgorm2"
  7. "github.com/redis/go-redis/v9"
  8. "github.com/spf13/viper"
  9. "gorm.io/driver/mysql"
  10. "gorm.io/gorm"
  11. "time"
  12. )
  13. const ctxTxKey = "TxKey"
  14. type Repository struct {
  15. db *gorm.DB
  16. rdb *redis.Client
  17. logger *log.Logger
  18. }
  19. func NewRepository(db *gorm.DB, rdb *redis.Client, logger *log.Logger) *Repository {
  20. return &Repository{
  21. db: db,
  22. rdb: rdb,
  23. logger: logger,
  24. }
  25. }
  26. type Transaction interface {
  27. Transaction(ctx context.Context, fn func(ctx context.Context) error) error
  28. }
  29. func NewTransaction(r *Repository) Transaction {
  30. return r
  31. }
  32. // DB return tx
  33. // If you need to create a Transaction, you must call DB(ctx) and Transaction(ctx,fn)
  34. func (r *Repository) DB(ctx context.Context) *gorm.DB {
  35. v := ctx.Value(ctxTxKey)
  36. if v != nil {
  37. if tx, ok := v.(*gorm.DB); ok {
  38. return tx
  39. }
  40. }
  41. return r.db.WithContext(ctx)
  42. }
  43. func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
  44. return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
  45. ctx = context.WithValue(ctx, ctxTxKey, tx)
  46. return fn(ctx)
  47. })
  48. }
  49. func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
  50. logger := zapgorm2.New(l.Logger)
  51. db, err := gorm.Open(mysql.Open(conf.GetString("data.mysql.user")), &gorm.Config{Logger: logger})
  52. if err != nil {
  53. panic(err)
  54. }
  55. db = db.Debug()
  56. return db
  57. }
  58. func NewRedis(conf *viper.Viper) *redis.Client {
  59. rdb := redis.NewClient(&redis.Options{
  60. Addr: conf.GetString("data.redis.addr"),
  61. Password: conf.GetString("data.redis.password"),
  62. DB: conf.GetInt("data.redis.db"),
  63. })
  64. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  65. defer cancel()
  66. _, err := rdb.Ping(ctx).Result()
  67. if err != nil {
  68. panic(fmt.Sprintf("redis error: %s", err.Error()))
  69. }
  70. return rdb
  71. }