package repository import ( "context" "fmt" "github.com/casbin/casbin/v2" "github.com/casbin/casbin/v2/model" gormadapter "github.com/casbin/gorm-adapter/v3" "github.com/glebarez/sqlite" "github.com/go-nunu/nunu-layout-advanced/pkg/log" "github.com/go-nunu/nunu-layout-advanced/pkg/rabbitmq" "github.com/go-nunu/nunu-layout-advanced/pkg/zapgorm2" "github.com/qiniu/qmgo" "github.com/redis/go-redis/v9" "github.com/spf13/viper" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/gorm" gormlogger "gorm.io/gorm/logger" "gorm.io/plugin/dbresolver" "time" ) const ctxTxKey = "TxKey" type Repository struct { Db *gorm.DB Rdb *redis.Client mongoClient *qmgo.Client MongoDB *qmgo.Database mq *rabbitmq.RabbitMQ Logger *log.Logger E *casbin.SyncedEnforcer } func NewRepository( logger *log.Logger, db *gorm.DB, rdb *redis.Client, mongoClient *qmgo.Client, mongoDB *qmgo.Database, mq *rabbitmq.RabbitMQ, e *casbin.SyncedEnforcer, ) *Repository { return &Repository{ Db: db, Rdb: rdb, mongoClient: mongoClient, MongoDB: mongoDB, mq: mq, Logger: logger, E: e, } } type Transaction interface { Transaction(ctx context.Context, fn func(ctx context.Context) error) error // 在特定数据库上执行事务 TransactionWithDB(ctx context.Context, dbName string, fn func(ctx context.Context) error) error } func NewTransaction(r *Repository) Transaction { return r } // DB return tx // If you need to create a Transaction, you must call DB(ctx) and Transaction(ctx,fn) func (r *Repository) DB(ctx context.Context) *gorm.DB { v := ctx.Value(ctxTxKey) if v != nil { if tx, ok := v.(*gorm.DB); ok { return tx } } return r.Db.WithContext(ctx) } // DBWithName 使用特定名称的数据库连接 func (r *Repository) DBWithName(ctx context.Context, dbName string) *gorm.DB { // 先检查上下文中是否已存在事务 v := ctx.Value(ctxTxKey) if v != nil { if tx, ok := v.(*gorm.DB); ok { // 如果事务中已经指定了数据库,则直接返回 return tx } } // 使用指定名称的数据库连接 if dbName != "" { return r.Db.Clauses(dbresolver.Use(dbName)).WithContext(ctx) } return r.Db.WithContext(ctx) } func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error { return r.Db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { ctxWithTx := context.WithValue(ctx, ctxTxKey, tx) return fn(ctxWithTx) }) } // TransactionWithDB 在特定数据库上执行事务 func (r *Repository) TransactionWithDB(ctx context.Context, dbName string, fn func(ctx context.Context) error) error { // 使用特定的数据库连接 db := r.Db if dbName != "" { db = db.Clauses(dbresolver.Use(dbName)) } return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // tx已经是针对特定数据库的事务句柄,无需再次指定数据库 ctxWithTx := context.WithValue(ctx, ctxTxKey, tx) return fn(ctxWithTx) }) } func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB { var ( db *gorm.DB err error ) // 获取主数据库键名 primaryDBKey := conf.GetString("data.primary_db_key") if primaryDBKey == "" { // 默认使用user作为主数据库键名(向后兼容) primaryDBKey = "user" } // 从配置中获取主数据库配置 driver := conf.GetString(fmt.Sprintf("data.db.%s.driver", primaryDBKey)) if driver == "" { panic("主数据库驱动配置不能为空") } dsn := conf.GetString(fmt.Sprintf("data.db.%s.dsn", primaryDBKey)) if dsn == "" { panic("主数据库连接字符串不能为空") } // 读取日志级别配置 logLevelStr := conf.GetString(fmt.Sprintf("data.db.%s.logLevel", primaryDBKey)) var logLevel gormlogger.LogLevel switch logLevelStr { case "silent": logLevel = gormlogger.Silent case "error": logLevel = gormlogger.Error case "warn": logLevel = gormlogger.Warn case "info": logLevel = gormlogger.Info default: // MySQL 默认只记录警告和错误 if driver == "mysql" { logLevel = gormlogger.Warn } else { logLevel = gormlogger.Info } } logger := zapgorm2.New(l.Logger).LogMode(logLevel) // 连接主数据库 switch driver { case "mysql": db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger, }) case "postgres": db, err = gorm.Open(postgres.New(postgres.Config{ DSN: dsn, PreferSimpleProtocol: true, }), &gorm.Config{ Logger: logger, }) case "sqlite": db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{ Logger: logger, }) default: panic("不支持的数据库驱动类型: " + driver) } if err != nil { panic(fmt.Sprintf("连接主数据库失败: %s", err.Error())) } // 创建 dbresolver 实例 resolver := dbresolver.Register(dbresolver.Config{}) // 获取所有配置的数据库列表 databases := conf.GetStringMap("data.db") // 遍历所有数据库配置(跳过主数据库,因为已经连接) for dbKey, _ := range databases { // 跳过主数据库(已经直接连接了) if dbKey == primaryDBKey { continue } // 检查该键是否确实是一个数据库配置对象 dbDriver := conf.GetString(fmt.Sprintf("data.db.%s.driver", dbKey)) dbDSN := conf.GetString(fmt.Sprintf("data.db.%s.dsn", dbKey)) if dbDriver != "" && dbDSN != "" { // 构建数据库连接器 var dialector gorm.Dialector switch dbDriver { case "mysql": dialector = mysql.Open(dbDSN) case "postgres": dialector = postgres.New(postgres.Config{ DSN: dbDSN, PreferSimpleProtocol: true, }) case "sqlite": dialector = sqlite.Open(dbDSN) default: l.Warn(fmt.Sprintf("跳过不支持的数据库驱动类型: %s (dbKey: %s)", dbDriver, dbKey)) continue } // 注册到resolver resolver.Register(dbresolver.Config{ Sources: []gorm.Dialector{dialector}, Replicas: []gorm.Dialector{dialector}, Policy: dbresolver.RandomPolicy{}, }, dbKey) // 使用配置键作为数据库名称 l.Info(fmt.Sprintf("成功配置数据库连接: %s", dbKey)) } } // 设置连接池参数 resolver.SetConnMaxIdleTime(time.Hour). SetConnMaxLifetime(24 * time.Hour). SetMaxIdleConns(10). SetMaxOpenConns(100) // 应用配置好的 dbresolver 到 db err = db.Use(resolver) if err != nil { panic(fmt.Sprintf("应用数据库连接配置失败: %s", err.Error())) } // 主数据库连接池配置 sqlDB, err := db.DB() if err != nil { panic(err) } sqlDB.SetMaxIdleConns(10) sqlDB.SetMaxOpenConns(100) sqlDB.SetConnMaxLifetime(time.Hour) return db } func NewRedis(conf *viper.Viper) *redis.Client { rdb := redis.NewClient(&redis.Options{ Addr: conf.GetString("data.redis.addr"), Password: conf.GetString("data.redis.password"), DB: conf.GetInt("data.redis.db"), }) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, err := rdb.Ping(ctx).Result() if err != nil { panic(fmt.Sprintf("redis error: %s", err.Error())) } return rdb } func NewMongoClient(conf *viper.Viper) *qmgo.Client { timeout := conf.GetDuration("data.mongodb.timeout") if timeout == 0 { timeout = 10 * time.Second } maxPoolSize := conf.GetUint64("data.mongodb.max_pool_size") ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() // 创建连接配置 clientOpts := &qmgo.Config{ Uri: conf.GetString("data.mongodb.uri"), MaxPoolSize: &maxPoolSize, } // 连接到MongoDB client, err := qmgo.NewClient(ctx, clientOpts) if err != nil { panic(fmt.Sprintf("连接MongoDB失败: %s", err.Error())) } return client } func NewMongoDB(client *qmgo.Client, conf *viper.Viper) *qmgo.Database { databaseName := conf.GetString("data.mongodb.database") if databaseName == "" { panic("MongoDB数据库名不能为空") } return client.Database(databaseName) } func NewRabbitMQ(conf *viper.Viper, logger *log.Logger) (*rabbitmq.RabbitMQ, func()) { var cfg rabbitmq.Config if err := conf.UnmarshalKey("rabbitmq", &cfg); err != nil { panic(fmt.Sprintf("unmarshal rabbitmq config error: %s", err.Error())) } mq, err := rabbitmq.New(cfg, logger) if err != nil { panic(fmt.Sprintf("init rabbitmq error: %s", err.Error())) } // Setup task queue if err := mq.SetupAllTaskQueues(); err != nil { panic(fmt.Sprintf("failed to setup rabbitmq task queues: %v", err)) } cleanup := func() { logger.Info("Closing RabbitMQ connection") _ = mq.Close() } return mq, cleanup } func NewCasbinEnforcer(conf *viper.Viper, l *log.Logger, db *gorm.DB) *casbin.SyncedEnforcer { var ( adapter *gormadapter.Adapter err error casbinDb *gorm.DB = db // 默认使用主数据库连接 ) // 创建一个专门给Enforcer使用的、日志级别为Warn的日志记录器,以屏蔽轮询日志。 // 这不会影响数据库连接的全局日志配置。 enforcerLogger := zapgorm2.New(l.Logger).LogMode(gormlogger.Warn) // 扫描配置,查找为Casbin指定的数据库 dbSettings := conf.GetStringMap("data.db") foundSpecialDb := false for dbKey := range dbSettings { casbinFlagPath := fmt.Sprintf("data.db.%s.casbin", dbKey) if conf.GetBool(casbinFlagPath) { l.Info(fmt.Sprintf("检测到Casbin专用数据库配置: '%s'。Enforcer将使用此数据库连接。", dbKey)) // 从全局连接池中获取指定的数据库连接 casbinDb = db.Clauses(dbresolver.Use(dbKey)) foundSpecialDb = true break } } if !foundSpecialDb { l.Warn("未找到Casbin的专用数据库配置 (缺少 'casbin: true' 标志),Enforcer将回退使用主数据库连接。") } // 为Enforcer创建一个带有“安静”日志记录器的GORM会话。 // 这样可以确保只有Enforcer自身的操作是安静的, // 而通过DBWithName()进行的直接数据库操作仍将使用原始的、更详细的日志记录器。 dbForEnforcer := casbinDb.Session(&gorm.Session{Logger: enforcerLogger}) // 使用带有“安静”日志记录器的会话来创建适配器 adapter, err = gormadapter.NewAdapterByDB(dbForEnforcer) if err != nil { panic(fmt.Sprintf("创建Casbin gorm-adapter失败: %s", err)) } m, err := model.NewModelFromString(` [request_definition] r = sub, obj, act [policy_definition] p = sub, obj, act [role_definition] g = _, _ [policy_effect] e = some(where (p.eft == allow)) [matchers] m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act `) if err != nil { panic(fmt.Sprintf("创建Casbin模型失败: %s", err)) } e, err := casbin.NewSyncedEnforcer(m, adapter) if err != nil { panic(fmt.Sprintf("创建Casbin Enforcer失败: %s", err)) } // 每10秒自动加载策略 e.StartAutoLoadPolicy(10 * time.Second) // 自动保存策略 e.EnableAutoSave(true) return e }