repository.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. package repository
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/casbin/casbin/v2"
  6. "github.com/casbin/casbin/v2/model"
  7. gormadapter "github.com/casbin/gorm-adapter/v3"
  8. "github.com/glebarez/sqlite"
  9. "github.com/go-nunu/nunu-layout-advanced/pkg/log"
  10. "github.com/go-nunu/nunu-layout-advanced/pkg/rabbitmq"
  11. "github.com/go-nunu/nunu-layout-advanced/pkg/zapgorm2"
  12. "github.com/qiniu/qmgo"
  13. "github.com/redis/go-redis/v9"
  14. "github.com/spf13/viper"
  15. "gorm.io/driver/mysql"
  16. "gorm.io/driver/postgres"
  17. "gorm.io/gorm"
  18. gormlogger "gorm.io/gorm/logger"
  19. "gorm.io/plugin/dbresolver"
  20. "time"
  21. )
  22. const ctxTxKey = "TxKey"
  23. type Repository struct {
  24. db *gorm.DB
  25. mongoClient *qmgo.Client
  26. mongoDB *qmgo.Database
  27. mq *rabbitmq.RabbitMQ
  28. logger *log.Logger
  29. e *casbin.SyncedEnforcer
  30. }
  31. func NewRepository(
  32. logger *log.Logger,
  33. db *gorm.DB,
  34. mongoClient *qmgo.Client,
  35. mongoDB *qmgo.Database,
  36. mq *rabbitmq.RabbitMQ,
  37. e *casbin.SyncedEnforcer,
  38. ) *Repository {
  39. return &Repository{
  40. db: db,
  41. mongoClient: mongoClient,
  42. mongoDB: mongoDB,
  43. mq: mq,
  44. logger: logger,
  45. e: e,
  46. }
  47. }
  48. type Transaction interface {
  49. Transaction(ctx context.Context, fn func(ctx context.Context) error) error
  50. // 在特定数据库上执行事务
  51. TransactionWithDB(ctx context.Context, dbName string, fn func(ctx context.Context) error) error
  52. }
  53. func NewTransaction(r *Repository) Transaction {
  54. return r
  55. }
  56. // DB return tx
  57. // If you need to create a Transaction, you must call DB(ctx) and Transaction(ctx,fn)
  58. func (r *Repository) DB(ctx context.Context) *gorm.DB {
  59. v := ctx.Value(ctxTxKey)
  60. if v != nil {
  61. if tx, ok := v.(*gorm.DB); ok {
  62. return tx
  63. }
  64. }
  65. return r.db.WithContext(ctx)
  66. }
  67. // DBWithName 使用特定名称的数据库连接
  68. func (r *Repository) DBWithName(ctx context.Context, dbName string) *gorm.DB {
  69. // 先检查上下文中是否已存在事务
  70. v := ctx.Value(ctxTxKey)
  71. if v != nil {
  72. if tx, ok := v.(*gorm.DB); ok {
  73. // 如果事务中已经指定了数据库,则直接返回
  74. return tx
  75. }
  76. }
  77. // 使用指定名称的数据库连接
  78. if dbName != "" {
  79. return r.db.Clauses(dbresolver.Use(dbName)).WithContext(ctx)
  80. }
  81. return r.db.WithContext(ctx)
  82. }
  83. func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
  84. return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
  85. ctxWithTx := context.WithValue(ctx, ctxTxKey, tx)
  86. return fn(ctxWithTx)
  87. })
  88. }
  89. // TransactionWithDB 在特定数据库上执行事务
  90. func (r *Repository) TransactionWithDB(ctx context.Context, dbName string, fn func(ctx context.Context) error) error {
  91. // 使用特定的数据库连接
  92. db := r.db
  93. if dbName != "" {
  94. db = db.Clauses(dbresolver.Use(dbName))
  95. }
  96. return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
  97. // tx已经是针对特定数据库的事务句柄,无需再次指定数据库
  98. ctxWithTx := context.WithValue(ctx, ctxTxKey, tx)
  99. return fn(ctxWithTx)
  100. })
  101. }
  102. func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
  103. var (
  104. db *gorm.DB
  105. err error
  106. )
  107. // 获取主数据库键名
  108. primaryDBKey := conf.GetString("data.primary_db_key")
  109. if primaryDBKey == "" {
  110. // 默认使用user作为主数据库键名(向后兼容)
  111. primaryDBKey = "user"
  112. }
  113. // 从配置中获取主数据库配置
  114. driver := conf.GetString(fmt.Sprintf("data.db.%s.driver", primaryDBKey))
  115. if driver == "" {
  116. panic("主数据库驱动配置不能为空")
  117. }
  118. dsn := conf.GetString(fmt.Sprintf("data.db.%s.dsn", primaryDBKey))
  119. if dsn == "" {
  120. panic("主数据库连接字符串不能为空")
  121. }
  122. // 读取日志级别配置
  123. logLevelStr := conf.GetString(fmt.Sprintf("data.db.%s.logLevel", primaryDBKey))
  124. var logLevel gormlogger.LogLevel
  125. switch logLevelStr {
  126. case "silent":
  127. logLevel = gormlogger.Silent
  128. case "error":
  129. logLevel = gormlogger.Error
  130. case "warn":
  131. logLevel = gormlogger.Warn
  132. case "info":
  133. logLevel = gormlogger.Info
  134. default:
  135. // MySQL 默认只记录警告和错误
  136. if driver == "mysql" {
  137. logLevel = gormlogger.Warn
  138. } else {
  139. logLevel = gormlogger.Info
  140. }
  141. }
  142. logger := zapgorm2.New(l.Logger).LogMode(logLevel)
  143. // 连接主数据库
  144. switch driver {
  145. case "mysql":
  146. db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
  147. Logger: logger,
  148. })
  149. case "postgres":
  150. db, err = gorm.Open(postgres.New(postgres.Config{
  151. DSN: dsn,
  152. PreferSimpleProtocol: true,
  153. }), &gorm.Config{
  154. Logger: logger,
  155. })
  156. case "sqlite":
  157. db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{
  158. Logger: logger,
  159. })
  160. default:
  161. panic("不支持的数据库驱动类型: " + driver)
  162. }
  163. if err != nil {
  164. panic(fmt.Sprintf("连接主数据库失败: %s", err.Error()))
  165. }
  166. // 创建 dbresolver 实例
  167. resolver := dbresolver.Register(dbresolver.Config{})
  168. // 获取所有配置的数据库列表
  169. databases := conf.GetStringMap("data.db")
  170. // 遍历所有数据库配置(跳过主数据库,因为已经连接)
  171. for dbKey, _ := range databases {
  172. // 跳过主数据库(已经直接连接了)
  173. if dbKey == primaryDBKey {
  174. continue
  175. }
  176. // 检查该键是否确实是一个数据库配置对象
  177. dbDriver := conf.GetString(fmt.Sprintf("data.db.%s.driver", dbKey))
  178. dbDSN := conf.GetString(fmt.Sprintf("data.db.%s.dsn", dbKey))
  179. if dbDriver != "" && dbDSN != "" {
  180. // 构建数据库连接器
  181. var dialector gorm.Dialector
  182. switch dbDriver {
  183. case "mysql":
  184. dialector = mysql.Open(dbDSN)
  185. case "postgres":
  186. dialector = postgres.New(postgres.Config{
  187. DSN: dbDSN,
  188. PreferSimpleProtocol: true,
  189. })
  190. case "sqlite":
  191. dialector = sqlite.Open(dbDSN)
  192. default:
  193. l.Warn(fmt.Sprintf("跳过不支持的数据库驱动类型: %s (dbKey: %s)", dbDriver, dbKey))
  194. continue
  195. }
  196. // 注册到resolver
  197. resolver.Register(dbresolver.Config{
  198. Sources: []gorm.Dialector{dialector},
  199. Replicas: []gorm.Dialector{dialector},
  200. Policy: dbresolver.RandomPolicy{},
  201. }, dbKey) // 使用配置键作为数据库名称
  202. l.Info(fmt.Sprintf("成功配置数据库连接: %s", dbKey))
  203. }
  204. }
  205. // 设置连接池参数
  206. resolver.SetConnMaxIdleTime(time.Hour).
  207. SetConnMaxLifetime(24 * time.Hour).
  208. SetMaxIdleConns(10).
  209. SetMaxOpenConns(100)
  210. // 应用配置好的 dbresolver 到 db
  211. err = db.Use(resolver)
  212. if err != nil {
  213. panic(fmt.Sprintf("应用数据库连接配置失败: %s", err.Error()))
  214. }
  215. // 主数据库连接池配置
  216. sqlDB, err := db.DB()
  217. if err != nil {
  218. panic(err)
  219. }
  220. sqlDB.SetMaxIdleConns(10)
  221. sqlDB.SetMaxOpenConns(100)
  222. sqlDB.SetConnMaxLifetime(time.Hour)
  223. return db
  224. }
  225. func NewRedis(conf *viper.Viper) *redis.Client {
  226. rdb := redis.NewClient(&redis.Options{
  227. Addr: conf.GetString("data.redis.addr"),
  228. Password: conf.GetString("data.redis.password"),
  229. DB: conf.GetInt("data.redis.db"),
  230. })
  231. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  232. defer cancel()
  233. _, err := rdb.Ping(ctx).Result()
  234. if err != nil {
  235. panic(fmt.Sprintf("redis error: %s", err.Error()))
  236. }
  237. return rdb
  238. }
  239. func NewMongoClient(conf *viper.Viper) *qmgo.Client {
  240. timeout := conf.GetDuration("data.mongodb.timeout")
  241. if timeout == 0 {
  242. timeout = 10 * time.Second
  243. }
  244. maxPoolSize := conf.GetUint64("data.mongodb.max_pool_size")
  245. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  246. defer cancel()
  247. // 创建连接配置
  248. clientOpts := &qmgo.Config{
  249. Uri: conf.GetString("data.mongodb.uri"),
  250. MaxPoolSize: &maxPoolSize,
  251. }
  252. // 连接到MongoDB
  253. client, err := qmgo.NewClient(ctx, clientOpts)
  254. if err != nil {
  255. panic(fmt.Sprintf("连接MongoDB失败: %s", err.Error()))
  256. }
  257. return client
  258. }
  259. func NewMongoDB(client *qmgo.Client, conf *viper.Viper) *qmgo.Database {
  260. databaseName := conf.GetString("data.mongodb.database")
  261. if databaseName == "" {
  262. panic("MongoDB数据库名不能为空")
  263. }
  264. return client.Database(databaseName)
  265. }
  266. func NewRabbitMQ(conf *viper.Viper, logger *log.Logger) (*rabbitmq.RabbitMQ, func()) {
  267. var cfg rabbitmq.Config
  268. if err := conf.UnmarshalKey("rabbitmq", &cfg); err != nil {
  269. panic(fmt.Sprintf("unmarshal rabbitmq config error: %s", err.Error()))
  270. }
  271. mq, err := rabbitmq.New(cfg, logger)
  272. if err != nil {
  273. panic(fmt.Sprintf("init rabbitmq error: %s", err.Error()))
  274. }
  275. // Setup task queue
  276. if err := mq.SetupAllTaskQueues(); err != nil {
  277. panic(fmt.Sprintf("failed to setup rabbitmq task queues: %v", err))
  278. }
  279. cleanup := func() {
  280. logger.Info("Closing RabbitMQ connection")
  281. _ = mq.Close()
  282. }
  283. return mq, cleanup
  284. }
  285. func NewCasbinEnforcer(conf *viper.Viper, l *log.Logger, db *gorm.DB) *casbin.SyncedEnforcer {
  286. a, _ := gormadapter.NewAdapterByDB(db)
  287. m, err := model.NewModelFromString(`
  288. [request_definition]
  289. r = sub, obj, act
  290. [policy_definition]
  291. p = sub, obj, act
  292. [role_definition]
  293. g = _, _
  294. [policy_effect]
  295. e = some(where (p.eft == allow))
  296. [matchers]
  297. m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act
  298. `)
  299. if err != nil {
  300. panic(err)
  301. }
  302. e, _ := casbin.NewSyncedEnforcer(m, a)
  303. // 每10秒自动加载策略,防止启动多服务进程策略不一致
  304. // 如果不想用轮询DB的方式,你也可以使用Casbin Watchers来同步策略,该方式需要基于Redis、Etcd等存储中间件
  305. // Watchers相关文档:https://casbin.org/zh/docs/watchers
  306. e.StartAutoLoadPolicy(10 * time.Second)
  307. // Enable Logger, decide whether to show it in terminal
  308. //e.EnableLog(true)
  309. // Save the policy back to DB.
  310. e.EnableAutoSave(true)
  311. return e
  312. }