Browse Source

refactor(repository): 重构分表逻辑并升级依赖

- 移除各 repository 内的 getShardingManager 方法,使用依赖注入的 Manager
- 更新 go.mod 和 go.sum 文件,升级 google/wire到 v0.6.0,添加 subcommands 等新依赖
- 修改 wire 配置,增加 NewShardingManager 的注入
fusu 23 hours ago
parent
commit
12627f8506

+ 1 - 0
cmd/server/wire/wire.go

@@ -42,6 +42,7 @@ var repositorySet = wire.NewSet(
 	repository.NewRabbitMQ,
 	repository.NewRepository,
 	repository.NewTransaction,
+	repository.NewShardingManager,
 	adminRep.NewAdminRepository,
 	adminRep.NewUserRepository,
 	repository.NewGameShieldRepository,

+ 3 - 2
cmd/server/wire/wire_gen.go

@@ -49,7 +49,8 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	qmgoClient := repository.NewMongoClient(viperViper)
 	database := repository.NewMongoDB(qmgoClient, viperViper)
 	rabbitMQ, cleanup := repository.NewRabbitMQ(viperViper, logger)
-	repositoryRepository := repository.NewRepository(logger, db, client, qmgoClient, database, rabbitMQ, syncedEnforcer)
+	shardingManager := repository.NewShardingManager(logger)
+	repositoryRepository := repository.NewRepository(logger, db, client, qmgoClient, database, rabbitMQ, syncedEnforcer, shardingManager)
 	transaction := repository.NewTransaction(repositoryRepository)
 	sidSid := sid.NewSid()
 	serviceService := service.NewService(transaction, logger, sidSid, jwtJWT)
@@ -143,7 +144,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 // wire.go:
 
-var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewCasbinEnforcer, repository.NewMongoClient, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, admin.NewAdminRepository, admin.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldPublicIpRepository, waf.NewWebForwardingRepository, waf.NewTcpforwardingRepository, waf.NewUdpForWardingRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldSdkIpRepository, repository.NewHostRepository, waf.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, flexCdn.NewCdnRepository, waf.NewAllowAndDenyIpRepository, flexCdn.NewProxyRepository, flexCdn.NewCcRepository, repository.NewExpiredRepository, repository.NewLogRepository, waf.NewGatewayipRepository, admin.NewGatewayIpAdminRepository, flexCdn.NewCcIpListRepository, admin.NewLogRepository, admin.NewWafLogRepository, admin.NewWafManageRepository)
+var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewCasbinEnforcer, repository.NewMongoClient, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewShardingManager, admin.NewAdminRepository, admin.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldPublicIpRepository, waf.NewWebForwardingRepository, waf.NewTcpforwardingRepository, waf.NewUdpForWardingRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldSdkIpRepository, repository.NewHostRepository, waf.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, flexCdn.NewCdnRepository, waf.NewAllowAndDenyIpRepository, flexCdn.NewProxyRepository, flexCdn.NewCcRepository, repository.NewExpiredRepository, repository.NewLogRepository, waf.NewGatewayipRepository, admin.NewGatewayIpAdminRepository, flexCdn.NewCcIpListRepository, admin.NewLogRepository, admin.NewWafLogRepository, admin.NewWafManageRepository)
 
 var serviceSet = wire.NewSet(service.NewService, admin2.NewUserService, admin2.NewGatewayIpAdminService, admin2.NewAdminService, gameShield.NewGameShieldService, service.NewAoDunService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewCrawlerService, web.NewWebForwardingService, web.NewAidedWebService, tcp.NewAidedTcpService, tcp.NewTcpforwardingService, udp.NewAidedUdpService, udp.NewUdpForWardingService, service.NewGameShieldUserIpService, gameShield.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewHostService, globallimit.NewGlobalLimitService, service.NewGatewayGroupService, common.NewWafFormatterService, service.NewGateWayGroupIpService, service.NewRequestService, flexCdn2.NewCdnService, common.NewAllowAndDenyIpService, flexCdn2.NewProxyService, flexCdn2.NewSslCertService, flexCdn2.NewWebsocketService, waf2.NewCcService, service.NewLogService, common.NewGatewayipService, waf2.NewCcIpListService, waf2.NewCdnLogService, waf2.NewBuildAudunService, waf2.NewZzybgpService, waf2.NewWaflogService, admin2.NewLogService, admin2.NewWafLogService, admin2.NewWafLogDataCleanService, admin2.NewWafManageService, admin2.NewWafOperationsService, service.NewShardingService)
 

+ 1 - 0
cmd/task/wire/wire.go

@@ -37,6 +37,7 @@ var repositorySet = wire.NewSet(
 	repository.NewRabbitMQ,
 	repository.NewRepository,
 	repository.NewTransaction,
+	repository.NewShardingManager,
 	admin.NewUserRepository,
 	repository.NewGameShieldRepository,
 	repository.NewGameShieldBackendRepository,

+ 3 - 2
cmd/task/wire/wire_gen.go

@@ -40,7 +40,8 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	database := repository.NewMongoDB(qmgoClient, viperViper)
 	rabbitMQ, cleanup := repository.NewRabbitMQ(viperViper, logger)
 	syncedEnforcer := repository.NewCasbinEnforcer(viperViper, logger, db)
-	repositoryRepository := repository.NewRepository(logger, db, client, qmgoClient, database, rabbitMQ, syncedEnforcer)
+	shardingManager := repository.NewShardingManager(logger)
+	repositoryRepository := repository.NewRepository(logger, db, client, qmgoClient, database, rabbitMQ, syncedEnforcer, shardingManager)
 	transaction := repository.NewTransaction(repositoryRepository)
 	sidSid := sid.NewSid()
 	taskTask := task.NewTask(transaction, logger, sidSid)
@@ -115,7 +116,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 // wire.go:
 
-var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewMongoClient, repository.NewCasbinEnforcer, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, admin.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldPublicIpRepository, repository.NewHostRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldSdkIpRepository, waf.NewWebForwardingRepository, waf.NewTcpforwardingRepository, waf.NewUdpForWardingRepository, waf.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, flexCdn.NewCdnRepository, repository.NewExpiredRepository, flexCdn.NewProxyRepository, waf.NewGatewayipRepository, repository.NewLogRepository, flexCdn.NewCcRepository, flexCdn.NewCcIpListRepository, admin.NewWafLogRepository)
+var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewMongoClient, repository.NewCasbinEnforcer, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewShardingManager, admin.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldPublicIpRepository, repository.NewHostRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldSdkIpRepository, waf.NewWebForwardingRepository, waf.NewTcpforwardingRepository, waf.NewUdpForWardingRepository, waf.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, flexCdn.NewCdnRepository, repository.NewExpiredRepository, flexCdn.NewProxyRepository, waf.NewGatewayipRepository, repository.NewLogRepository, flexCdn.NewCcRepository, flexCdn.NewCcIpListRepository, admin.NewWafLogRepository)
 
 var taskSet = wire.NewSet(task.NewTask, task.NewUserTask, task.NewGameShieldTask, task.NewWafTask)
 

+ 3 - 1
go.mod

@@ -21,7 +21,7 @@ require (
 	github.com/golang-jwt/jwt/v5 v5.2.2
 	github.com/golang/mock v1.6.0
 	github.com/google/uuid v1.6.0
-	github.com/google/wire v0.5.0
+	github.com/google/wire v0.6.0
 	github.com/hashicorp/go-multierror v1.0.0
 	github.com/jinzhu/copier v0.4.0
 	github.com/mcuadros/go-defaults v1.2.0
@@ -85,6 +85,7 @@ require (
 	github.com/golang/protobuf v1.5.4 // indirect
 	github.com/golang/snappy v1.0.0 // indirect
 	github.com/google/go-querystring v1.1.0 // indirect
+	github.com/google/subcommands v1.2.0 // indirect
 	github.com/gorilla/websocket v1.4.2 // indirect
 	github.com/hashicorp/errwrap v1.0.0 // indirect
 	github.com/hashicorp/hcl v1.0.0 // indirect
@@ -147,6 +148,7 @@ require (
 	go.uber.org/multierr v1.11.0 // indirect
 	golang.org/x/arch v0.3.0 // indirect
 	golang.org/x/exp v0.0.0-20221208152030-732eee02a75a // indirect
+	golang.org/x/mod v0.25.0 // indirect
 	golang.org/x/sys v0.34.0 // indirect
 	golang.org/x/text v0.27.0 // indirect
 	golang.org/x/tools v0.34.0 // indirect

+ 10 - 0
go.sum

@@ -281,6 +281,8 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbu
 github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
 github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
 github.com/google/subcommands v1.0.1/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
+github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
+github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
 github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -288,6 +290,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
 github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 github.com/google/wire v0.5.0 h1:I7ELFeVBr3yfPIcc8+MWvrjk+3VjbcSzoXm3JVa+jD8=
 github.com/google/wire v0.5.0/go.mod h1:ngWDr9Qvq3yZA10YrxfyGELY/AFWGVpy9c1LTRi1EoU=
+github.com/google/wire v0.6.0 h1:HBkoIh4BdSxoyo9PveV8giw7ZsaBOvzWKfcg/6MrVwI=
+github.com/google/wire v0.6.0/go.mod h1:F4QhpQ9EDIdJ1Mbop/NZBRB+5yrR6qg3BnctaoUk6NA=
 github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
 github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
 github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
@@ -625,6 +629,7 @@ golang.org/x/crypto v0.0.0-20220511200225-c6db032c6c88/go.mod h1:IxCIyHEi3zRg3s0
 golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
 golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
 golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
+golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
 golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
 golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
 golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
@@ -674,6 +679,7 @@ golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
 golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
 golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
 golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
+golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
 golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
 golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
 golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
@@ -724,6 +730,7 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
 golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
 golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
 golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
+golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
 golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
 golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
 golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
@@ -819,6 +826,7 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
@@ -832,6 +840,7 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX
 golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
 golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
 golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
+golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
 golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
 golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
 golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
@@ -918,6 +927,7 @@ golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
 golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
 golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
 golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
+golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps=
 golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
 golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
 golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=

+ 4 - 11
internal/repository/admin/log.go

@@ -11,7 +11,6 @@ import (
 	admin "github.com/go-nunu/nunu-layout-advanced/api/v1/admin"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
-	"github.com/go-nunu/nunu-layout-advanced/pkg/sharding"
 	"gorm.io/gorm"
 )
 
@@ -35,8 +34,8 @@ type logRepository struct {
 func (r *logRepository) GetLog(ctx context.Context, id int64) (*model.Log, error) {
 	var res model.Log
 	
-	// 获取分表管理器
-	shardingMgr := r.getShardingManager()
+	// 使用依赖注入的分表管理器
+	shardingMgr := r.Manager
 	
 	// 获取存在的分表
 	existingTables := shardingMgr.GetTableNamesWithExistenceCheck(r.DBWithName(ctx, "admin"), "log", nil, nil)
@@ -54,8 +53,8 @@ func (r *logRepository) GetLog(ctx context.Context, id int64) (*model.Log, error
 }
 
 func (r *logRepository) GetLogList(ctx context.Context, req admin.SearchLogParams) (*v1.PaginatedResponse[model.Log], error) {
-	// 获取分表管理器
-	shardingMgr := r.getShardingManager()
+	// 使用依赖注入的分表管理器
+	shardingMgr := r.Manager
 	
 	// 解析时间范围(如果有的话)
 	var startTime, endTime *time.Time
@@ -251,9 +250,3 @@ func (r *logRepository) applyFilters(query *gorm.DB, req admin.SearchLogParams)
 	return query
 }
 
-// getShardingManager 获取分表管理器
-func (r *logRepository) getShardingManager() *sharding.ShardingManager {
-	// 使用月度分表策略
-	strategy := sharding.NewMonthlyShardingStrategy()
-	return sharding.NewShardingManager(strategy, r.Logger)
-}

+ 10 - 39
internal/repository/admin/waflog.go

@@ -11,7 +11,6 @@ import (
 	adminApi "github.com/go-nunu/nunu-layout-advanced/api/v1/admin"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
-	"github.com/go-nunu/nunu-layout-advanced/pkg/sharding"
 	"gorm.io/gorm"
 )
 
@@ -82,12 +81,9 @@ func (r *wafLogRepository) buildExportQuery(ctx context.Context, req adminApi.Ex
 
 func (r *wafLogRepository) GetWafLog(ctx context.Context, id int64) (*model.WafLog, error) {
 	var res model.WafLog
-	
-	// 获取分表管理器
-	shardingMgr := r.getShardingManager()
-	
+
 	// 获取存在的分表
-	existingTables := shardingMgr.GetTableNamesWithExistenceCheck(r.DBWithName(ctx, "admin"), "waf_log", nil, nil)
+	existingTables := r.Manager.GetTableNamesWithExistenceCheck(r.DBWithName(ctx, "admin"), "waf_log", nil, nil)
 	
 	// 在各个分表中查找
 	for _, tableName := range existingTables {
@@ -102,8 +98,7 @@ func (r *wafLogRepository) GetWafLog(ctx context.Context, id int64) (*model.WafL
 }
 
 func (r *wafLogRepository) GetWafLogList(ctx context.Context, req adminApi.SearchWafLogParams) (*v1.PaginatedResponse[model.WafLog], error) {
-	// 获取分表管理器
-	shardingMgr := r.getShardingManager()
+
 	
 	// 解析时间范围(如果有的话)
 	var startTime, endTime *time.Time
@@ -111,7 +106,7 @@ func (r *wafLogRepository) GetWafLogList(ctx context.Context, req adminApi.Searc
 	// 暂时查询最近3个月的数据
 	
 	// 获取需要查询的表
-	existingTables := shardingMgr.GetTableNamesWithExistenceCheck(r.DBWithName(ctx, "admin"), "waf_log", startTime, endTime)
+	existingTables := r.Manager.GetTableNamesWithExistenceCheck(r.DBWithName(ctx, "admin"), "waf_log", startTime, endTime)
 	
 	if len(existingTables) == 0 {
 		// 没有分表,返回空结果
@@ -275,12 +270,10 @@ func (r *wafLogRepository) AddWafLog(ctx context.Context, log *model.WafLog) err
 	if log.CreatedAt.IsZero() {
 		log.CreatedAt = time.Now()
 	}
-	
-	// 获取分表管理器
-	shardingMgr := r.getShardingManagerWithThreshold()
+
 	
 	// 获取最优的写入表(考虑数据量阈值)
-	tableName, err := shardingMgr.GetOptimalWriteTable(ctx, r.DBWithName(ctx, "admin"), log, r.getMaxRowsForTable("waf_log"))
+	tableName, err := r.Manager.GetOptimalWriteTable(ctx, r.DBWithName(ctx, "admin"), log, r.getMaxRowsForTable("waf_log"))
 	if err != nil {
 		return fmt.Errorf("获取写入表失败: %v", err)
 	}
@@ -288,7 +281,7 @@ func (r *wafLogRepository) AddWafLog(ctx context.Context, log *model.WafLog) err
 	log.SetTableName(tableName)
 	
 	// 确保表存在
-	err = shardingMgr.EnsureTableExists(ctx, r.DBWithName(ctx, "admin"), tableName, &model.WafLog{})
+	err = r.Manager.EnsureTableExists(ctx, r.DBWithName(ctx, "admin"), tableName, &model.WafLog{})
 	if err != nil {
 		return err
 	}
@@ -301,9 +294,7 @@ func (r *wafLogRepository) BatchAddWafLog(ctx context.Context, logs []*model.Waf
 	if len(logs) == 0 {
 		return nil
 	}
-	
-	// 获取带阈值的分表管理器
-	shardingMgr := r.getShardingManagerWithThreshold()
+
 	maxRows := r.getMaxRowsForTable("waf_log")
 	
 	// 按表名分组
@@ -316,7 +307,7 @@ func (r *wafLogRepository) BatchAddWafLog(ctx context.Context, logs []*model.Waf
 		}
 		
 		// 获取最优的写入表(考虑数据量阈值)
-		tableName, err := shardingMgr.GetOptimalWriteTable(ctx, r.DBWithName(ctx, "admin"), log, maxRows)
+		tableName, err := r.Manager.GetOptimalWriteTable(ctx, r.DBWithName(ctx, "admin"), log, maxRows)
 		if err != nil {
 			return fmt.Errorf("获取写入表失败: %v", err)
 		}
@@ -330,7 +321,7 @@ func (r *wafLogRepository) BatchAddWafLog(ctx context.Context, logs []*model.Waf
 	// 为每个表批量插入
 	for tableName, tableLogs := range tableGroups {
 		// 确保表存在
-		err := shardingMgr.EnsureTableExists(ctx, r.DBWithName(ctx, "admin"), tableName, &model.WafLog{})
+		err := r.Manager.EnsureTableExists(ctx, r.DBWithName(ctx, "admin"), tableName, &model.WafLog{})
 		if err != nil {
 			return err
 		}
@@ -399,26 +390,6 @@ func (r *wafLogRepository) GetWafLogExportCount(ctx context.Context, req adminAp
 	return int(count), nil
 }
 
-// getShardingManager 获取分表管理器
-func (r *wafLogRepository) getShardingManager() *sharding.ShardingManager {
-	// 使用月度分表策略
-	strategy := sharding.NewMonthlyShardingStrategy()
-	return sharding.NewShardingManager(strategy, r.Logger)
-}
-
-// getShardingManagerWithThreshold 获取带阈值配置的分表管理器
-func (r *wafLogRepository) getShardingManagerWithThreshold() *sharding.ShardingManager {
-	strategy := sharding.NewMonthlyShardingStrategy()
-	
-	// 阈值配置(这里可以从配置文件读取,暂时硬编码)
-	thresholdConfig := &sharding.ThresholdConfig{
-		Enabled:       true,
-		MaxRows:       5000000, // waf_log表默认500万条
-		CheckInterval: time.Hour,
-	}
-	
-	return sharding.NewShardingManagerWithThreshold(strategy, r.Logger, thresholdConfig)
-}
 
 // getMaxRowsForTable 获取指定表的最大行数配置
 func (r *wafLogRepository) getMaxRowsForTable(tableName string) int64 {

+ 10 - 38
internal/repository/log.go

@@ -6,7 +6,6 @@ import (
 	"time"
 
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
-	"github.com/go-nunu/nunu-layout-advanced/pkg/sharding"
 )
 
 type LogRepository interface {
@@ -30,12 +29,10 @@ type logRepository struct {
 
 func (r *logRepository) GetLog(ctx context.Context, id int64) (*model.Log, error) {
 	var log model.Log
-	
-	// 获取分表管理器
-	shardingMgr := r.getShardingManager()
+
 	
 	// 获取可能的表名(查询最近3个月)
-	existingTables := shardingMgr.GetTableNamesWithExistenceCheck(r.DBWithName(ctx, "admin"), "log", nil, nil)
+	existingTables := r.Manager.GetTableNamesWithExistenceCheck(r.DBWithName(ctx, "admin"), "log", nil, nil)
 	
 	// 在各个分表中查找
 	for _, tableName := range existingTables {
@@ -52,11 +49,10 @@ func (r *logRepository) GetLog(ctx context.Context, id int64) (*model.Log, error
 func (r *logRepository) GetLogsByTimeRange(ctx context.Context, start, end *time.Time) ([]*model.Log, error) {
 	var logs []*model.Log
 	
-	// 获取分表管理器
-	shardingMgr := r.getShardingManager()
+
 	
 	// 检查存在的表
-	existingTables := shardingMgr.GetTableNamesWithExistenceCheck(r.DBWithName(ctx, "admin"), "log", start, end)
+	existingTables := r.Manager.GetTableNamesWithExistenceCheck(r.DBWithName(ctx, "admin"), "log", start, end)
 	
 	if len(existingTables) == 0 {
 		return logs, nil // 没有分表,返回空结果
@@ -96,12 +92,10 @@ func (r *logRepository) AddLog(ctx context.Context, log *model.Log) error {
 	if log.CreatedAt.IsZero() {
 		log.CreatedAt = time.Now()
 	}
-	
-	// 获取分表管理器
-	shardingMgr := r.getShardingManagerWithThreshold()
+
 	
 	// 获取最优的写入表(考虑数据量阈值)
-	tableName, err := shardingMgr.GetOptimalWriteTable(ctx, r.DBWithName(ctx, "admin"), log, r.getMaxRowsForTable("log"))
+	tableName, err := r.Manager.GetOptimalWriteTable(ctx, r.DBWithName(ctx, "admin"), log, r.getMaxRowsForTable("log"))
 	if err != nil {
 		return fmt.Errorf("获取写入表失败: %v", err)
 	}
@@ -109,7 +103,7 @@ func (r *logRepository) AddLog(ctx context.Context, log *model.Log) error {
 	log.SetTableName(tableName)
 	
 	// 确保表存在
-	err = shardingMgr.EnsureTableExists(ctx, r.DBWithName(ctx, "admin"), tableName, &model.Log{})
+	err = r.Manager.EnsureTableExists(ctx, r.DBWithName(ctx, "admin"), tableName, &model.Log{})
 	if err != nil {
 		return err
 	}
@@ -123,16 +117,14 @@ func (r *logRepository) EditLog(ctx context.Context, log *model.Log) error {
 	if log.TableName() != "log" {
 		return r.DBWithName(ctx, "admin").Table(log.TableName()).Updates(log).Error
 	}
-	
-	// 获取分表管理器
-	shardingMgr := r.getShardingManager()
+
 	
 	// 确定表名
-	tableName := shardingMgr.GetWriteTableName(log)
+	tableName := r.Manager.GetWriteTableName(log)
 	log.SetTableName(tableName)
 	
 	// 确保表存在
-	err := shardingMgr.EnsureTableExists(ctx, r.DBWithName(ctx, "admin"), tableName, &model.Log{})
+	err := r.Manager.EnsureTableExists(ctx, r.DBWithName(ctx, "admin"), tableName, &model.Log{})
 	if err != nil {
 		return err
 	}
@@ -140,26 +132,6 @@ func (r *logRepository) EditLog(ctx context.Context, log *model.Log) error {
 	return r.DBWithName(ctx, "admin").Table(tableName).Updates(log).Error
 }
 
-// getShardingManager 获取分表管理器
-func (r *logRepository) getShardingManager() *sharding.ShardingManager {
-	// 使用月度分表策略
-	strategy := sharding.NewMonthlyShardingStrategy()
-	return sharding.NewShardingManager(strategy, r.Logger)
-}
-
-// getShardingManagerWithThreshold 获取带阈值配置的分表管理器
-func (r *logRepository) getShardingManagerWithThreshold() *sharding.ShardingManager {
-	strategy := sharding.NewMonthlyShardingStrategy()
-	
-	// 阈值配置(这里可以从配置文件读取,暂时硬编码)
-	thresholdConfig := &sharding.ThresholdConfig{
-		Enabled:       true,
-		MaxRows:       3000000, // log表默认300万条
-		CheckInterval: time.Hour,
-	}
-	
-	return sharding.NewShardingManagerWithThreshold(strategy, r.Logger, thresholdConfig)
-}
 
 // getMaxRowsForTable 获取指定表的最大行数配置
 func (r *logRepository) getMaxRowsForTable(tableName string) int64 {

+ 10 - 0
internal/repository/repository.go

@@ -9,6 +9,7 @@ import (
 	"github.com/glebarez/sqlite"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/log"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/rabbitmq"
+	"github.com/go-nunu/nunu-layout-advanced/pkg/sharding"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/zapgorm2"
 	"github.com/qiniu/qmgo"
 	"github.com/redis/go-redis/v9"
@@ -31,6 +32,7 @@ type Repository struct {
 	mq     *rabbitmq.RabbitMQ
 	Logger *log.Logger
 	E      *casbin.SyncedEnforcer
+	Manager *sharding.ShardingManager
 }
 
 func NewRepository(
@@ -41,6 +43,7 @@ func NewRepository(
 	mongoDB *qmgo.Database,
 	mq *rabbitmq.RabbitMQ,
 	e *casbin.SyncedEnforcer,
+	manager *sharding.ShardingManager,
 ) *Repository {
 	return &Repository{
 		Db:          db,
@@ -50,6 +53,7 @@ func NewRepository(
 		mq:          mq,
 		Logger:      logger,
 		E:           e,
+		Manager:     manager,
 	}
 }
 
@@ -407,4 +411,10 @@ m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act
 	e.EnableAutoSave(true)
 
 	return e
+}
+
+// new creates a ShardingManager for dependency injection
+func NewShardingManager(logger *log.Logger) *sharding.ShardingManager {
+	strategy := sharding.NewMonthlyShardingStrategy()
+	return sharding.NewShardingManager(strategy, logger)
 }