repository.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package repository
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/go-nunu/nunu-layout-advanced/pkg/log"
  6. "github.com/redis/go-redis/v9"
  7. "github.com/spf13/viper"
  8. "gorm.io/driver/mysql"
  9. "gorm.io/gorm"
  10. "moul.io/zapgorm2"
  11. "time"
  12. )
  13. const ctxTxKey = "TxKey"
  14. type Repository struct {
  15. db *gorm.DB
  16. rdb *redis.Client
  17. logger *log.Logger
  18. }
  19. func NewRepository(db *gorm.DB, rdb *redis.Client, logger *log.Logger) *Repository {
  20. return &Repository{
  21. db: db,
  22. rdb: rdb,
  23. logger: logger,
  24. }
  25. }
  26. type Transaction interface {
  27. Transaction(ctx context.Context, fn func(ctx context.Context) error) error
  28. }
  29. func NewTransaction(r *Repository) Transaction {
  30. return r
  31. }
  32. // DB return tx
  33. // If you need to create a Transaction, you must call DB(ctx) and Transaction(ctx,fn)
  34. func (r *Repository) DB(ctx context.Context) *gorm.DB {
  35. v := ctx.Value(ctxTxKey)
  36. if v != nil {
  37. if tx, ok := v.(*gorm.DB); ok {
  38. return tx
  39. }
  40. }
  41. return r.db.WithContext(ctx)
  42. }
  43. func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
  44. return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
  45. ctx = context.WithValue(ctx, ctxTxKey, tx)
  46. return fn(ctx)
  47. })
  48. }
  49. func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
  50. logger := zapgorm2.New(l.Logger)
  51. logger.SetAsDefault()
  52. db, err := gorm.Open(mysql.Open(conf.GetString("data.mysql.user")), &gorm.Config{Logger: logger})
  53. if err != nil {
  54. panic(err)
  55. }
  56. db = db.Debug()
  57. return db
  58. }
  59. func NewRedis(conf *viper.Viper) *redis.Client {
  60. rdb := redis.NewClient(&redis.Options{
  61. Addr: conf.GetString("data.redis.addr"),
  62. Password: conf.GetString("data.redis.password"),
  63. DB: conf.GetInt("data.redis.db"),
  64. })
  65. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  66. defer cancel()
  67. _, err := rdb.Ping(ctx).Result()
  68. if err != nil {
  69. panic(fmt.Sprintf("redis error: %s", err.Error()))
  70. }
  71. return rdb
  72. }