package repository import ( "context" "fmt" "github.com/glebarez/sqlite" "github.com/go-nunu/nunu-layout-advanced/pkg/log" "github.com/go-nunu/nunu-layout-advanced/pkg/zapgorm2" "github.com/qiniu/qmgo" "github.com/redis/go-redis/v9" "github.com/spf13/viper" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/gorm" gormlogger "gorm.io/gorm/logger" "time" ) const ctxTxKey = "TxKey" type Repository struct { db *gorm.DB // 主数据库连接 //dbSecond *gorm.DB // 第二个数据库连接 //rdb *redis.Client mongoClient *qmgo.Client mongoDB *qmgo.Database logger *log.Logger } func NewRepository( logger *log.Logger, db *gorm.DB, //dbSecond *gorm.DB, // rdb *redis.Client, mongoClient *qmgo.Client, mongoDB *qmgo.Database, ) *Repository { return &Repository{ db: db, //dbSecond: dbSecond, //rdb: rdb, mongoClient: mongoClient, mongoDB: mongoDB, logger: logger, } } type Transaction interface { Transaction(ctx context.Context, fn func(ctx context.Context) error) error } func NewTransaction(r *Repository) Transaction { return r } // DB return tx // If you need to create a Transaction, you must call DB(ctx) and Transaction(ctx,fn) func (r *Repository) DB(ctx context.Context) *gorm.DB { v := ctx.Value(ctxTxKey) if v != nil { if tx, ok := v.(*gorm.DB); ok { return tx } } return r.db.WithContext(ctx) } // DBSecond returns the second database connection // Note: Transactions are currently only supported on the primary database //func (r *Repository) DBSecond(ctx context.Context) *gorm.DB { // return r.dbSecond.WithContext(ctx) //} func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { ctx = context.WithValue(ctx, ctxTxKey, tx) return fn(ctx) }) } func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB { var ( db *gorm.DB err error ) driver := conf.GetString("data.db.user.driver") dsn := conf.GetString("data.db.user.dsn") // 读取日志级别配置 logLevelStr := conf.GetString("data.db.user.logLevel") var logLevel gormlogger.LogLevel switch logLevelStr { case "silent": logLevel = gormlogger.Silent case "error": logLevel = gormlogger.Error case "warn": logLevel = gormlogger.Warn case "info": logLevel = gormlogger.Info default: // MySQL 默认只记录警告和错误 if driver == "mysql" { logLevel = gormlogger.Warn } else { logLevel = gormlogger.Info } } logger := zapgorm2.New(l.Logger).LogMode(logLevel) // GORM doc: https://gorm.io/docs/connecting_to_the_database.html switch driver { case "mysql": db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger, }) case "postgres": db, err = gorm.Open(postgres.New(postgres.Config{ DSN: dsn, PreferSimpleProtocol: true, }), &gorm.Config{ Logger: logger, }) case "sqlite": db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{ Logger: logger, }) default: panic("unknown db driver") } if err != nil { panic(err) } // Connection Pool config sqlDB, err := db.DB() if err != nil { panic(err) } sqlDB.SetMaxIdleConns(10) sqlDB.SetMaxOpenConns(100) sqlDB.SetConnMaxLifetime(time.Hour) return db } // NewDBSecond 初始化第二个数据库连接 func NewDBSecond(conf *viper.Viper, l *log.Logger) *gorm.DB { var ( db *gorm.DB err error ) // 从second配置项读取第二个数据库配置 driver := conf.GetString("data.db.second.driver") dsn := conf.GetString("data.db.second.dsn") // 如果第二个数据库没有配置,返回nil if dsn == "" { l.Warn("第二个数据库配置不存在或DSN为空") return nil } // 读取日志级别配置 logLevelStr := conf.GetString("data.db.second.logLevel") var logLevel gormlogger.LogLevel switch logLevelStr { case "silent": logLevel = gormlogger.Silent case "error": logLevel = gormlogger.Error case "warn": logLevel = gormlogger.Warn case "info": logLevel = gormlogger.Info default: // MySQL 默认只记录警告和错误 if driver == "mysql" { logLevel = gormlogger.Warn } else { logLevel = gormlogger.Info } } logger := zapgorm2.New(l.Logger).LogMode(logLevel) // 初始化第二个数据库连接 switch driver { case "mysql": db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger, }) case "postgres": db, err = gorm.Open(postgres.New(postgres.Config{ DSN: dsn, PreferSimpleProtocol: true, }), &gorm.Config{ Logger: logger, }) case "sqlite": db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{ Logger: logger, }) default: panic("unknown db driver for second database") } if err != nil { panic("连接第二个数据库失败: " + err.Error()) } // 配置连接池 sqlDB, err := db.DB() if err != nil { panic(err) } sqlDB.SetMaxIdleConns(10) sqlDB.SetMaxOpenConns(100) sqlDB.SetConnMaxLifetime(time.Hour) return db } func NewRedis(conf *viper.Viper) *redis.Client { rdb := redis.NewClient(&redis.Options{ Addr: conf.GetString("data.redis.addr"), Password: conf.GetString("data.redis.password"), DB: conf.GetInt("data.redis.db"), }) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, err := rdb.Ping(ctx).Result() if err != nil { panic(fmt.Sprintf("redis error: %s", err.Error())) } return rdb } func NewMongoClient(conf *viper.Viper) *qmgo.Client { timeout := conf.GetDuration("data.mongodb.timeout") if timeout == 0 { timeout = 10 * time.Second } maxPoolSize := conf.GetUint64("data.mongodb.max_pool_size") ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() // 创建连接配置 clientOpts := &qmgo.Config{ Uri: conf.GetString("data.mongodb.uri"), MaxPoolSize: &maxPoolSize, } // 连接到MongoDB client, err := qmgo.NewClient(ctx, clientOpts) if err != nil { panic(fmt.Sprintf("连接MongoDB失败: %s", err.Error())) } return client } func NewMongoDB(client *qmgo.Client, conf *viper.Viper) *qmgo.Database { databaseName := conf.GetString("data.mongodb.database") if databaseName == "" { panic("MongoDB数据库名不能为空") } return client.Database(databaseName) }