瀏覽代碼

feat(service): 增加 CDN 用户管理功能

- 新增 CDN 用户管理相关的接口和实现
- 增加用户信息获取、转换时间和续费计划等辅助函数
- 更新全局限制相关服务,支持 CDN 用户和套餐绑定
- 优化配置文件,增加 CDN 数据库配置
fusu 1 月之前
父節點
當前提交
19456a2e70

+ 1 - 0
api/v1/globalLimit.go

@@ -24,6 +24,7 @@ type GlobalLimitSendRequest struct {
 type GlobalLimitRequireResponse struct {
 type GlobalLimitRequireResponse struct {
 	ExpiredAt       string
 	ExpiredAt       string
 	GlobalLimitName string
 	GlobalLimitName string
+	HostName        string
 	Bps             string
 	Bps             string
 	MaxBytesMonth   string
 	MaxBytesMonth   string
 	IpCount   int
 	IpCount   int

+ 2 - 0
cmd/server/wire/wire.go

@@ -45,6 +45,7 @@ var repositorySet = wire.NewSet(
 	repository.NewGlobalLimitRepository,
 	repository.NewGlobalLimitRepository,
 	repository.NewGatewayGroupRepository,
 	repository.NewGatewayGroupRepository,
 	repository.NewGateWayGroupIpRepository,
 	repository.NewGateWayGroupIpRepository,
+	repository.NewCdnRepository,
 
 
 )
 )
 
 
@@ -76,6 +77,7 @@ var serviceSet = wire.NewSet(
 	service.NewWafFormatterService,
 	service.NewWafFormatterService,
 	service.NewGateWayGroupIpService,
 	service.NewGateWayGroupIpService,
 	service.NewRequestService,
 	service.NewRequestService,
+	service.NewCdnService,
 )
 )
 
 
 var handlerSet = wire.NewSet(
 var handlerSet = wire.NewSet(

+ 9 - 6
cmd/server/wire/wire_gen.go

@@ -31,10 +31,11 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	limiterLimiter := limiter.NewLimiter(viperViper)
 	limiterLimiter := limiter.NewLimiter(viperViper)
 	handlerFunc := middleware.NewRateLimitMiddleware(limiterLimiter)
 	handlerFunc := middleware.NewRateLimitMiddleware(limiterLimiter)
 	handlerHandler := handler.NewHandler(logger)
 	handlerHandler := handler.NewHandler(logger)
-	client := repository.NewMongoClient(viperViper)
-	database := repository.NewMongoDB(client, viperViper)
+	client := repository.NewRedis(viperViper)
+	qmgoClient := repository.NewMongoClient(viperViper)
+	database := repository.NewMongoDB(qmgoClient, viperViper)
 	rabbitMQ, cleanup := repository.NewRabbitMQ(viperViper, logger)
 	rabbitMQ, cleanup := repository.NewRabbitMQ(viperViper, logger)
-	repositoryRepository := repository.NewRepository(logger, db, client, database, rabbitMQ, syncedEnforcer)
+	repositoryRepository := repository.NewRepository(logger, db, client, qmgoClient, database, rabbitMQ, syncedEnforcer)
 	transaction := repository.NewTransaction(repositoryRepository)
 	transaction := repository.NewTransaction(repositoryRepository)
 	sidSid := sid.NewSid()
 	sidSid := sid.NewSid()
 	serviceService := service.NewService(transaction, logger, sidSid, jwtJWT)
 	serviceService := service.NewService(transaction, logger, sidSid, jwtJWT)
@@ -84,7 +85,9 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	udpLimitHandler := handler.NewUdpLimitHandler(handlerHandler, udpLimitService)
 	udpLimitHandler := handler.NewUdpLimitHandler(handlerHandler, udpLimitService)
 	requestService := service.NewRequestService(serviceService)
 	requestService := service.NewRequestService(serviceService)
 	gatewayGroupService := service.NewGatewayGroupService(serviceService, gatewayGroupRepository, requiredService, parserService, requestService)
 	gatewayGroupService := service.NewGatewayGroupService(serviceService, gatewayGroupRepository, requiredService, parserService, requestService)
-	globalLimitService := service.NewGlobalLimitService(serviceService, globalLimitRepository, duedateService, crawlerService, viperViper, requiredService, parserService, hostService, tcpLimitService, udpLimitService, webLimitService, gatewayGroupService, hostRepository, gatewayGroupRepository)
+	cdnRepository := repository.NewCdnRepository(repositoryRepository)
+	cdnService := service.NewCdnService(serviceService, viperViper, requestService, cdnRepository)
+	globalLimitService := service.NewGlobalLimitService(serviceService, globalLimitRepository, duedateService, crawlerService, viperViper, requiredService, parserService, hostService, tcpLimitService, udpLimitService, webLimitService, gatewayGroupService, hostRepository, gatewayGroupRepository, cdnService, cdnRepository)
 	globalLimitHandler := handler.NewGlobalLimitHandler(handlerHandler, globalLimitService)
 	globalLimitHandler := handler.NewGlobalLimitHandler(handlerHandler, globalLimitService)
 	adminRepository := repository.NewAdminRepository(repositoryRepository)
 	adminRepository := repository.NewAdminRepository(repositoryRepository)
 	adminService := service.NewAdminService(serviceService, adminRepository)
 	adminService := service.NewAdminService(serviceService, adminRepository)
@@ -101,9 +104,9 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 
 // wire.go:
 // wire.go:
 
 
-var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewCasbinEnforcer, repository.NewMongoClient, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewAdminRepository, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldPublicIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewGameShieldUserIpRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldSdkIpRepository, repository.NewHostRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository)
+var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewCasbinEnforcer, repository.NewMongoClient, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewAdminRepository, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldPublicIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewGameShieldUserIpRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldSdkIpRepository, repository.NewHostRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, repository.NewCdnRepository)
 
 
-var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, service.NewUserService, service.NewAdminService, service.NewGameShieldService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewCrawlerService, service.NewWebForwardingService, service.NewTcpforwardingService, service.NewUdpForWardingService, service.NewGameShieldUserIpService, service.NewWebLimitService, service.NewTcpLimitService, service.NewUdpLimitService, service.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewHostService, service.NewGlobalLimitService, service.NewGatewayGroupService, service.NewWafFormatterService, service.NewGateWayGroupIpService, service.NewRequestService)
+var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, service.NewUserService, service.NewAdminService, service.NewGameShieldService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewCrawlerService, service.NewWebForwardingService, service.NewTcpforwardingService, service.NewUdpForWardingService, service.NewGameShieldUserIpService, service.NewWebLimitService, service.NewTcpLimitService, service.NewUdpLimitService, service.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewHostService, service.NewGlobalLimitService, service.NewGatewayGroupService, service.NewWafFormatterService, service.NewGateWayGroupIpService, service.NewRequestService, service.NewCdnService)
 
 
 var handlerSet = wire.NewSet(handler.NewHandler, handler.NewUserHandler, handler.NewAdminHandler, handler.NewGameShieldHandler, handler.NewGameShieldPublicIpHandler, handler.NewWebForwardingHandler, handler.NewTcpforwardingHandler, handler.NewUdpForWardingHandler, handler.NewGameShieldUserIpHandler, handler.NewWebLimitHandler, handler.NewTcpLimitHandler, handler.NewUdpLimitHandler, handler.NewGameShieldBackendHandler, handler.NewGameShieldSdkIpHandler, handler.NewHostHandler, handler.NewGlobalLimitHandler, handler.NewGatewayGroupHandler, handler.NewGateWayGroupIpHandler)
 var handlerSet = wire.NewSet(handler.NewHandler, handler.NewUserHandler, handler.NewAdminHandler, handler.NewGameShieldHandler, handler.NewGameShieldPublicIpHandler, handler.NewWebForwardingHandler, handler.NewTcpforwardingHandler, handler.NewUdpForWardingHandler, handler.NewGameShieldUserIpHandler, handler.NewWebLimitHandler, handler.NewTcpLimitHandler, handler.NewUdpLimitHandler, handler.NewGameShieldBackendHandler, handler.NewGameShieldSdkIpHandler, handler.NewHostHandler, handler.NewGlobalLimitHandler, handler.NewGatewayGroupHandler, handler.NewGateWayGroupIpHandler)
 
 

+ 1 - 1
cmd/task/wire/wire.go

@@ -19,7 +19,7 @@ import (
 
 
 var repositorySet = wire.NewSet(
 var repositorySet = wire.NewSet(
 	repository.NewDB,
 	repository.NewDB,
-	//repository.NewRedis,
+	repository.NewRedis,
 	repository.NewMongoClient,
 	repository.NewMongoClient,
 	repository.NewCasbinEnforcer,
 	repository.NewCasbinEnforcer,
 	repository.NewMongoDB,
 	repository.NewMongoDB,

+ 5 - 4
cmd/task/wire/wire_gen.go

@@ -24,11 +24,12 @@ import (
 
 
 func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), error) {
 func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), error) {
 	db := repository.NewDB(viperViper, logger)
 	db := repository.NewDB(viperViper, logger)
-	client := repository.NewMongoClient(viperViper)
-	database := repository.NewMongoDB(client, viperViper)
+	client := repository.NewRedis(viperViper)
+	qmgoClient := repository.NewMongoClient(viperViper)
+	database := repository.NewMongoDB(qmgoClient, viperViper)
 	rabbitMQ, cleanup := repository.NewRabbitMQ(viperViper, logger)
 	rabbitMQ, cleanup := repository.NewRabbitMQ(viperViper, logger)
 	syncedEnforcer := repository.NewCasbinEnforcer(viperViper, logger, db)
 	syncedEnforcer := repository.NewCasbinEnforcer(viperViper, logger, db)
-	repositoryRepository := repository.NewRepository(logger, db, client, database, rabbitMQ, syncedEnforcer)
+	repositoryRepository := repository.NewRepository(logger, db, client, qmgoClient, database, rabbitMQ, syncedEnforcer)
 	transaction := repository.NewTransaction(repositoryRepository)
 	transaction := repository.NewTransaction(repositoryRepository)
 	sidSid := sid.NewSid()
 	sidSid := sid.NewSid()
 	taskTask := task.NewTask(transaction, logger, sidSid)
 	taskTask := task.NewTask(transaction, logger, sidSid)
@@ -74,7 +75,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 
 // wire.go:
 // wire.go:
 
 
-var repositorySet = wire.NewSet(repository.NewDB, repository.NewMongoClient, repository.NewCasbinEnforcer, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldPublicIpRepository, repository.NewHostRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldSdkIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository)
+var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewMongoClient, repository.NewCasbinEnforcer, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldPublicIpRepository, repository.NewHostRepository, repository.NewGameShieldUserIpRepository, repository.NewGameShieldSdkIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository)
 
 
 var taskSet = wire.NewSet(task.NewTask, task.NewUserTask, task.NewGameShieldTask)
 var taskSet = wire.NewSet(task.NewTask, task.NewUserTask, task.NewGameShieldTask)
 
 

+ 4 - 0
config/local.yml

@@ -24,6 +24,10 @@ data:
       dsn: admin:GhCbHDRDnMbAkZHw@tcp(110.42.96.15:3306)/admin?charset=utf8mb4&parseTime=True&loc=Local
       dsn: admin:GhCbHDRDnMbAkZHw@tcp(110.42.96.15:3306)/admin?charset=utf8mb4&parseTime=True&loc=Local
       logLevel: "info"
       logLevel: "info"
       casbin: true
       casbin: true
+    cdn:
+      driver: mysql
+      dsn: root:671119d76d73b5c9d4182d71e8e91eaa@tcp(110.42.96.120:3306)/clouds?charset=utf8mb4&parseTime=True&loc=Local
+      logLevel: "info"
   #    user:
   #    user:
   #      driver: postgres
   #      driver: postgres
   #      dsn: host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai
   #      dsn: host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai

+ 4 - 0
config/prod.yml

@@ -20,6 +20,10 @@ data:
       dsn: admin:GhCbHDRDnMbAkZHw@tcp(110.42.96.15:3306)/admin?charset=utf8mb4&parseTime=True&loc=Local
       dsn: admin:GhCbHDRDnMbAkZHw@tcp(110.42.96.15:3306)/admin?charset=utf8mb4&parseTime=True&loc=Local
       logLevel: "warn"
       logLevel: "warn"
       casbin: true
       casbin: true
+    cdn:
+      driver: mysql
+      dsn: root:671119d76d73b5c9d4182d71e8e91eaa@tcp(110.42.96.120:3306)/clouds?charset=utf8mb4&parseTime=True&loc=Local
+      logLevel: "warn"
 #    second:
 #    second:
 #      driver: mysql
 #      driver: mysql
 #      dsn: root:Mgrj9hMF3QQ3atX5hFIo@tcp(115.238.186.121:3306)/0panel?charset=utf8mb4&parseTime=True&loc=Local
 #      dsn: root:Mgrj9hMF3QQ3atX5hFIo@tcp(115.238.186.121:3306)/0panel?charset=utf8mb4&parseTime=True&loc=Local

+ 1 - 0
go.mod

@@ -105,6 +105,7 @@ require (
 	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
 	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
 	github.com/modern-go/reflect2 v1.0.2 // indirect
 	github.com/modern-go/reflect2 v1.0.2 // indirect
 	github.com/montanaflynn/stats v0.7.1 // indirect
 	github.com/montanaflynn/stats v0.7.1 // indirect
+	github.com/mozillazg/go-pinyin v0.20.0 // indirect
 	github.com/pelletier/go-toml v1.9.5 // indirect
 	github.com/pelletier/go-toml v1.9.5 // indirect
 	github.com/pelletier/go-toml/v2 v2.0.9 // indirect
 	github.com/pelletier/go-toml/v2 v2.0.9 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect

+ 2 - 0
go.sum

@@ -421,6 +421,8 @@ github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3P
 github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
 github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
 github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
 github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
 github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
 github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
+github.com/mozillazg/go-pinyin v0.20.0 h1:BtR3DsxpApHfKReaPO1fCqF4pThRwH9uwvXzm+GnMFQ=
+github.com/mozillazg/go-pinyin v0.20.0/go.mod h1:iR4EnMMRXkfpFVV5FMi4FNB6wGq9NV6uDWbUuPhP4Yc=
 github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
 github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
 github.com/onsi/ginkgo v1.10.1 h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo=
 github.com/onsi/ginkgo v1.10.1 h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo=
 github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
 github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=

+ 0 - 1
internal/model/gatewaygroup.go

@@ -5,7 +5,6 @@ import "time"
 type GatewayGroup struct {
 type GatewayGroup struct {
 	Id          int `gorm:"primary" json:"id" form:"id"`
 	Id          int `gorm:"primary" json:"id" form:"id"`
 	HostId      int `gorm:"null" json:"hostId" form:"hostId"`
 	HostId      int `gorm:"null" json:"hostId" form:"hostId"`
-	RuleId      int `gorm:"not null" json:"ruleId" form:"ruleId"`
 	BanUdp       int `gorm:"null" json:"banUdp" form:"banUdp"`
 	BanUdp       int `gorm:"null" json:"banUdp" form:"banUdp"`
 	BanOverseas int `gorm:"null" json:"banOverseas" form:"banOverseas"`
 	BanOverseas int `gorm:"null" json:"banOverseas" form:"banOverseas"`
 	Name        string `gorm:"null" json:"name" form:"name"`
 	Name        string `gorm:"null" json:"name" form:"name"`

+ 2 - 0
internal/model/globallimit.go

@@ -5,7 +5,9 @@ import "time"
 type GlobalLimit struct {
 type GlobalLimit struct {
 	Id              int `gorm:"primary"`
 	Id              int `gorm:"primary"`
 	HostId          int
 	HostId          int
+	Name            string
 	RuleId          int
 	RuleId          int
+	GroupId         int
 	Uid             int
 	Uid             int
 	CdnUid          int
 	CdnUid          int
 	GatewayGroupId  int
 	GatewayGroupId  int

+ 13 - 0
internal/repository/cdn.go

@@ -11,6 +11,7 @@ type CdnRepository interface {
 	GetCdn(ctx context.Context, id int64) (*model.Cdn, error)
 	GetCdn(ctx context.Context, id int64) (*model.Cdn, error)
 	PutToken(ctx context.Context, token string) error
 	PutToken(ctx context.Context, token string) error
 	GetToken(ctx context.Context) (string, error)
 	GetToken(ctx context.Context) (string, error)
+	GetUserId(ctx context.Context, username string) (int64, error)
 }
 }
 
 
 func NewCdnRepository(
 func NewCdnRepository(
@@ -57,4 +58,16 @@ func (r *cdnRepository) GetToken(ctx context.Context) (string, error) {
 	}
 	}
 
 
 	return token, nil
 	return token, nil
+}
+
+func (r *cdnRepository) GetUserId(ctx context.Context, username string) (int64, error)  {
+	var id int64
+	if err := r.DBWithName(ctx,"cdn").Table("cloud_users").
+		Where("username = ?", username).
+		Select("id").
+		Find(&id).Error; err != nil {
+		return 0, err
+	}
+	return id, nil
+
 }
 }

+ 2 - 3
internal/repository/gatewaygroup.go

@@ -49,7 +49,7 @@ func (r *gatewayGroupRepository) AddGatewayGroup(ctx context.Context, req *model
 }
 }
 
 
 func (r *gatewayGroupRepository) EditGatewayGroup(ctx context.Context, req *model.GatewayGroup) error {
 func (r *gatewayGroupRepository) EditGatewayGroup(ctx context.Context, req *model.GatewayGroup) error {
-	if err := r.DB(ctx).Model(&model.GatewayGroup{}).Where("rule_id = ?", req.RuleId).Updates(req).Error; err != nil {
+	if err := r.DB(ctx).Model(&model.GatewayGroup{}).Where("id = ?", req.Id).Updates(req).Error; err != nil {
 		return err
 		return err
 	}
 	}
 	return nil
 	return nil
@@ -73,14 +73,13 @@ func (r *gatewayGroupRepository) GetGatewayGroupWhereHostIdNull(ctx context.Cont
 		Where("operator = ?", operator).
 		Where("operator = ?", operator).
 		Where("id IN (?)", subQuery).
 		Where("id IN (?)", subQuery).
 		Where("host_id = ?", 0).
 		Where("host_id = ?", 0).
-		Select("rule_id").First(&id).Error
+		Select("id").First(&id).Error
 	if err != nil {
 	if err != nil {
 		if errors.Is(err, gorm.ErrRecordNotFound){
 		if errors.Is(err, gorm.ErrRecordNotFound){
 			return 0, fmt.Errorf("库存不足,请联系客服补充网关组库存")
 			return 0, fmt.Errorf("库存不足,请联系客服补充网关组库存")
 		}
 		}
 		return 0, err
 		return 0, err
 	}
 	}
-
 	return id, nil
 	return id, nil
 }
 }
 
 

+ 24 - 1
internal/repository/globallimit.go

@@ -2,8 +2,11 @@ package repository
 
 
 import (
 import (
 	"context"
 	"context"
+	"errors"
+	"fmt"
 	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
+	"gorm.io/gorm"
 	"time"
 	"time"
 )
 )
 
 
@@ -18,6 +21,7 @@ type GlobalLimitRepository interface {
 	GetGlobalLimitAllHostId(ctx context.Context) ([]v1.GlobalLimitExpired, error)
 	GetGlobalLimitAllHostId(ctx context.Context) ([]v1.GlobalLimitExpired, error)
 	GetGlobalLimitFirst(ctx context.Context,uid int64) (*model.GlobalLimit, error)
 	GetGlobalLimitFirst(ctx context.Context,uid int64) (*model.GlobalLimit, error)
 	GetUserInfo(ctx context.Context, uid int64) (v1.UserInfo, error)
 	GetUserInfo(ctx context.Context, uid int64) (v1.UserInfo, error)
+	GetHostName(ctx context.Context,hostId int64) (string, error)
 }
 }
 
 
 func NewGlobalLimitRepository(
 func NewGlobalLimitRepository(
@@ -116,11 +120,30 @@ func (r *globalLimitRepository) GetGlobalLimitFirst(ctx context.Context,uid int6
 
 
 func (r *globalLimitRepository) GetUserInfo(ctx context.Context, uid int64) (v1.UserInfo, error) {
 func (r *globalLimitRepository) GetUserInfo(ctx context.Context, uid int64) (v1.UserInfo, error) {
 	var res v1.UserInfo
 	var res v1.UserInfo
-	if err := r.DB(ctx).Table("shd_user").
+	if err := r.DB(ctx).Table("shd_clients").
 		Where("id = ?", uid).
 		Where("id = ?", uid).
 		Select("username", "email", "phonenumber").
 		Select("username", "email", "phonenumber").
 		Find(&res).Error; err != nil {
 		Find(&res).Error; err != nil {
 		return v1.UserInfo{}, err
 		return v1.UserInfo{}, err
 	}
 	}
 	return res, nil
 	return res, nil
+}
+
+func (r *globalLimitRepository) GetHostName(ctx context.Context,hostId int64) (string, error)  {
+	var projectName string
+	err := r.db.WithContext(ctx).Table("shd_host").
+		Select("shd_products.name").
+		Joins("JOIN shd_products ON shd_host.productid = shd_products.id").
+		Where("shd_host.id = ?", hostId).
+		Find(&projectName).Error
+
+	if err != nil {
+		if errors.Is(err, gorm.ErrRecordNotFound) {
+			return "", fmt.Errorf("未找到 hostId 为 %d 的项目名称", hostId)
+		}
+		return "", fmt.Errorf("查询 host 和 project 名称失败: %w", err)
+	}
+
+	// 如果查询成功,返回项目名称
+	return projectName, nil
 }
 }

+ 152 - 92
internal/service/cdn.go

@@ -16,45 +16,106 @@ type CdnService interface {
 	BindPlan(ctx context.Context, req v1.Plan) (int64, error)
 	BindPlan(ctx context.Context, req v1.Plan) (int64, error)
 	RenewPlan(ctx context.Context, req v1.RenewalPlan) error
 	RenewPlan(ctx context.Context, req v1.RenewalPlan) error
 	CreateWebsite(ctx context.Context, req v1.Website) (int64, error)
 	CreateWebsite(ctx context.Context, req v1.Website) (int64, error)
-	EditProtocol(ctx context.Context, req v1.ProxyJson,action string) error
+	EditProtocol(ctx context.Context, req v1.ProxyJson, action string) error
 	CreateOrigin(ctx context.Context, req v1.Origin) (int64, error)
 	CreateOrigin(ctx context.Context, req v1.Origin) (int64, error)
 	EditOrigin(ctx context.Context, req v1.Origin) error
 	EditOrigin(ctx context.Context, req v1.Origin) error
 }
 }
+
 func NewCdnService(
 func NewCdnService(
-    service *Service,
+	service *Service,
 	conf *viper.Viper,
 	conf *viper.Viper,
 	request RequestService,
 	request RequestService,
 	cdnRepository repository.CdnRepository,
 	cdnRepository repository.CdnRepository,
 ) CdnService {
 ) CdnService {
 	return &cdnService{
 	return &cdnService{
-		Service:        service,
-		Url:             conf.GetString("flexCdn.Url"),
-		AccessKeyID:     conf.GetString("flexCdn.AccessKeyID"),
-		AccessKeySecret: conf.GetString("flexCdn.AccessKeySecret"),
-		request:         request,
-		cdnRepository:   cdnRepository,
+		Service:         service,
+		Url:              conf.GetString("flexCdn.Url"),
+		AccessKeyID:      conf.GetString("flexCdn.AccessKeyID"),
+		AccessKeySecret:  conf.GetString("flexCdn.AccessKeySecret"),
+		request:          request,
+		cdnRepository:    cdnRepository,
+		maxRetryCount:    3, // 可以配置最大重试次数
+		retryDelaySeconds: 2, // 可以配置重试间隔
 	}
 	}
 }
 }
 
 
 type cdnService struct {
 type cdnService struct {
 	*Service
 	*Service
-	Url             string
-	AccessKeyID     string
-	AccessKeySecret string
-	request         RequestService
-	cdnRepository   repository.CdnRepository
+	Url              string
+	AccessKeyID      string
+	AccessKeySecret  string
+	request          RequestService
+	cdnRepository    repository.CdnRepository
+	maxRetryCount    int
+	retryDelaySeconds int
 }
 }
 
 
-func (s *cdnService) SendData(ctx context.Context,  formData map[string]interface{}, apiUrl string,) ([]byte, error)  {
-	token, err := s.Toekn(ctx)
-	if err != nil {
-		return nil, err
+// SendData 是一个通用的请求发送方法,它封装了 token 过期重试的逻辑
+func (s *cdnService) sendDataWithTokenRetry(ctx context.Context, formData map[string]interface{}, apiUrl string) ([]byte, error) {
+	var resBody []byte
+
+	for i := 0; i < s.maxRetryCount; i++ {
+		token, err := s.Token(ctx) // 确保使用最新的 token
+		if err != nil {
+			return nil, fmt.Errorf("获取或刷新 token 失败: %w", err)
+		}
+
+		resBody, err = s.request.Request(ctx, formData, apiUrl, "X-Cloud-Access-Token", token)
+		if err != nil {
+			// 检查错误是否是由于 token 无效引起的
+			if s.isTokenInvalidError(resBody, err) { // 判断是否是 token 无效错误
+				_, getTokenErr := s.GetToken(ctx)
+				if getTokenErr != nil {
+					return nil, fmt.Errorf("刷新 token 失败: %w", getTokenErr)
+				}
+				continue // 继续下一次循环,使用新的 token
+			}
+			return nil, fmt.Errorf("请求失败: %w", err)
+		}
+
+		// 成功获取到响应,处理响应体
+		var generalResponse v1.GeneralResponse[any]
+		if err := json.Unmarshal(resBody, &generalResponse); err != nil {
+			return nil, fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
+		}
+
+		// 检查 API 返回的 code 和 message
+		if generalResponse.Code == 400 && generalResponse.Message == "invalid access token" {
+			fmt.Printf("尝试 %d/%d:API 返回无效 token 错误,准备刷新并重试...\n", i+1, s.maxRetryCount)
+			_, getTokenErr := s.GetToken(ctx)
+			if getTokenErr != nil {
+				return nil, fmt.Errorf("刷新 token 失败: %w", getTokenErr)
+			}
+			continue // 继续下一次循环,使用新的 token
+		}
+
+		// 成功处理,返回结果
+		return resBody, nil
 	}
 	}
-	resBody, err := s.request.Request(ctx, formData, apiUrl, "X-Cloud-Access-Token", token)
+
+	// 如果循环结束仍未成功,则返回最终错误
+	return nil, fmt.Errorf("达到最大重试次数后请求仍然失败")
+}
+
+// isTokenInvalidError 是一个辅助函数,用于判断错误是否是由于 token 无效引起的。
+// 你需要根据你的 request.Request 实现来具体实现这个函数。
+// 例如,你可以检查 resBody 是否包含特定的错误信息。
+func (s *cdnService) isTokenInvalidError(resBody []byte, err error) bool {
+	// 示例:如果请求本身就返回了非 200 的错误,并且响应体中有特定信息
 	if err != nil {
 	if err != nil {
-		return nil, err
+		// 尝试从 resBody 中解析出错误信息,判断是否是 token 无效
+		var generalResponse v1.GeneralResponse[any]
+		if parseErr := json.Unmarshal(resBody, &generalResponse); parseErr == nil {
+			if generalResponse.Code == 400 && generalResponse.Message == "invalid access token" {
+				return true
+			}
+		}
+		// 或者检查 err 本身是否有相关的错误信息
+		// if strings.Contains(err.Error(), "invalid access token") {
+		// 	return true
+		// }
 	}
 	}
-	return resBody, nil
+	return false
 }
 }
 
 
 func (s *cdnService) GetToken(ctx context.Context) (string, error) {
 func (s *cdnService) GetToken(ctx context.Context) (string, error) {
@@ -64,6 +125,7 @@ func (s *cdnService) GetToken(ctx context.Context) (string, error) {
 		"accessKey":   s.AccessKeySecret,
 		"accessKey":   s.AccessKeySecret,
 	}
 	}
 	apiUrl := s.Url + "APIAccessTokenService/getAPIAccessToken"
 	apiUrl := s.Url + "APIAccessTokenService/getAPIAccessToken"
+
 	resBody, err := s.request.Request(ctx, formData, apiUrl, "X-Cloud-Access-Token", "")
 	resBody, err := s.request.Request(ctx, formData, apiUrl, "X-Cloud-Access-Token", "")
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
@@ -85,7 +147,7 @@ func (s *cdnService) GetToken(ctx context.Context) (string, error) {
 	return res.Data.Token, nil
 	return res.Data.Token, nil
 }
 }
 
 
-func (s *cdnService) Toekn(ctx context.Context) (string, error)  {
+func (s *cdnService) Token(ctx context.Context) (string, error) {
 	token, err := s.cdnRepository.GetToken(ctx)
 	token, err := s.cdnRepository.GetToken(ctx)
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
@@ -104,7 +166,7 @@ func (s *cdnService) AddUser(ctx context.Context, req v1.User) (int64, error) {
 	formData := map[string]interface{}{
 	formData := map[string]interface{}{
 		"id":       req.ID,
 		"id":       req.ID,
 		"username": req.Username,
 		"username": req.Username,
-		"password": "a7fKiKujgAzzsJ6",
+		"password": "a7fKiKujgAzzsJ6", // 这个密码应该被妥善管理,而不是硬编码
 		"fullname": req.Fullname,
 		"fullname": req.Fullname,
 		"mobile":   req.Mobile,
 		"mobile":   req.Mobile,
 		"tel":      req.Tel,
 		"tel":      req.Tel,
@@ -114,7 +176,7 @@ func (s *cdnService) AddUser(ctx context.Context, req v1.User) (int64, error) {
 		"nodeClusterId": 1,
 		"nodeClusterId": 1,
 	}
 	}
 	apiUrl := s.Url + "UserService/createUser"
 	apiUrl := s.Url + "UserService/createUser"
-	resBody, err := s.SendData(ctx, formData, apiUrl)
+	resBody, err := s.sendDataWithTokenRetry(ctx, formData, apiUrl)
 	if err != nil {
 	if err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
@@ -123,7 +185,7 @@ func (s *cdnService) AddUser(ctx context.Context, req v1.User) (int64, error) {
 	}
 	}
 	var res v1.GeneralResponse[DataStr]
 	var res v1.GeneralResponse[DataStr]
 	if err := json.Unmarshal(resBody, &res); err != nil {
 	if err := json.Unmarshal(resBody, &res); err != nil {
-		return  0, fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
+		return 0, fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
 	}
 	}
 
 
 	if res.Code != 200 {
 	if res.Code != 200 {
@@ -141,7 +203,7 @@ func (s *cdnService) CreateGroup(ctx context.Context, req v1.Group) (int64, erro
 		"name": req.Name,
 		"name": req.Name,
 	}
 	}
 	apiUrl := s.Url + "ServerGroupService/createServerGroup"
 	apiUrl := s.Url + "ServerGroupService/createServerGroup"
-	resBody, err := s.SendData(ctx, formData, apiUrl)
+	resBody, err := s.sendDataWithTokenRetry(ctx, formData, apiUrl) // 使用封装后的方法
 	if err != nil {
 	if err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
@@ -150,7 +212,7 @@ func (s *cdnService) CreateGroup(ctx context.Context, req v1.Group) (int64, erro
 	}
 	}
 	var res v1.GeneralResponse[DataStr]
 	var res v1.GeneralResponse[DataStr]
 	if err := json.Unmarshal(resBody, &res); err != nil {
 	if err := json.Unmarshal(resBody, &res); err != nil {
-		return  0, fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
+		return 0, fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
 	}
 	}
 	if res.Code != 200 {
 	if res.Code != 200 {
 		return 0, fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
 		return 0, fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
@@ -161,21 +223,20 @@ func (s *cdnService) CreateGroup(ctx context.Context, req v1.Group) (int64, erro
 	return res.Data.ServerGroupId, nil
 	return res.Data.ServerGroupId, nil
 }
 }
 
 
-
 //分配套餐
 //分配套餐
 func (s *cdnService) BindPlan(ctx context.Context, req v1.Plan) (int64, error) {
 func (s *cdnService) BindPlan(ctx context.Context, req v1.Plan) (int64, error) {
 	formData := map[string]interface{}{
 	formData := map[string]interface{}{
-		"userId": req.UserId,
-		"planId": req.PlanId,
-		"dayTo":  req.DayTo,
-		"period": req.Period,
-		"countPeriod": req.CountPeriod,
-		"name":   req.Name,
-		"isFree": req.IsFree,
-		"periodDayTo": req.PeriodDayTo,
+		"userId":        req.UserId,
+		"planId":        req.PlanId,
+		"dayTo":         req.DayTo,
+		"period":        req.Period,
+		"countPeriod":   req.CountPeriod,
+		"name":          req.Name,
+		"isFree":        req.IsFree,
+		"periodDayTo":   req.PeriodDayTo,
 	}
 	}
 	apiUrl := s.Url + "UserPlanService/buyUserPlan"
 	apiUrl := s.Url + "UserPlanService/buyUserPlan"
-	resBody, err := s.SendData(ctx, formData, apiUrl)
+	resBody, err := s.sendDataWithTokenRetry(ctx, formData, apiUrl) // 使用封装后的方法
 	if err != nil {
 	if err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
@@ -184,7 +245,7 @@ func (s *cdnService) BindPlan(ctx context.Context, req v1.Plan) (int64, error) {
 	}
 	}
 	var res v1.GeneralResponse[DataStr]
 	var res v1.GeneralResponse[DataStr]
 	if err := json.Unmarshal(resBody, &res); err != nil {
 	if err := json.Unmarshal(resBody, &res); err != nil {
-		return  0, fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
+		return 0, fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
 	}
 	}
 	if res.Code != 200 {
 	if res.Code != 200 {
 		return 0, fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
 		return 0, fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
@@ -198,17 +259,17 @@ func (s *cdnService) BindPlan(ctx context.Context, req v1.Plan) (int64, error) {
 //续费套餐
 //续费套餐
 func (s *cdnService) RenewPlan(ctx context.Context, req v1.RenewalPlan) error {
 func (s *cdnService) RenewPlan(ctx context.Context, req v1.RenewalPlan) error {
 	formData := map[string]interface{}{
 	formData := map[string]interface{}{
-		"userPlanId": req.UserPlanId,
-		"dayTo":  req.DayTo,
-		"period": req.Period,
-		"countPeriod": req.CountPeriod,
-		"isFree": req.IsFree,
-		"periodDayTo": req.PeriodDayTo,
+		"userPlanId":    req.UserPlanId,
+		"dayTo":         req.DayTo,
+		"period":        req.Period,
+		"countPeriod":   req.CountPeriod,
+		"isFree":        req.IsFree,
+		"periodDayTo":   req.PeriodDayTo,
 	}
 	}
 	apiUrl := s.Url + "UserPlanService/renewUserPlan"
 	apiUrl := s.Url + "UserPlanService/renewUserPlan"
-	resBody, err := s.SendData(ctx, formData, apiUrl)
+	resBody, err := s.sendDataWithTokenRetry(ctx, formData, apiUrl) // 使用封装后的方法
 	if err != nil {
 	if err != nil {
-		return  err
+		return err
 	}
 	}
 	var res v1.GeneralResponse[any]
 	var res v1.GeneralResponse[any]
 	if err := json.Unmarshal(resBody, &res); err != nil {
 	if err := json.Unmarshal(resBody, &res); err != nil {
@@ -220,28 +281,27 @@ func (s *cdnService) RenewPlan(ctx context.Context, req v1.RenewalPlan) error {
 	return nil
 	return nil
 }
 }
 
 
-
 //创建网站
 //创建网站
 func (s *cdnService) CreateWebsite(ctx context.Context, req v1.Website) (int64, error) {
 func (s *cdnService) CreateWebsite(ctx context.Context, req v1.Website) (int64, error) {
 	formData := map[string]interface{}{
 	formData := map[string]interface{}{
-		"userId": req.UserId,
-		"type": req.Type,
-		"name": req.Name,
-		"description": req.Description,
-		"serverNamesJSON": req.ServerNamesJSON,
-		"httpJSON": req.HttpJSON,
-		"httpsJSON": req.HttpsJSON,
-		"tcpJSON": req.TcpJSON,
-		"tlsJSON": req.TlsJSON,
-		"udpJSON": req.UdpJSON,
-		"webId": req.WebId,
-		"reverseProxyJSON": req.ReverseProxyJSON,
-		"serverGroupIds": req.ServerGroupIds,
-		"userPlanId": req.UserPlanId,
-		"nodeClusterId": req.NodeClusterId,
+		"userId":             req.UserId,
+		"type":               req.Type,
+		"name":               req.Name,
+		"description":        req.Description,
+		"serverNamesJSON":    req.ServerNamesJSON,
+		"httpJSON":           req.HttpJSON,
+		"httpsJSON":          req.HttpsJSON,
+		"tcpJSON":            req.TcpJSON,
+		"tlsJSON":            req.TlsJSON,
+		"udpJSON":            req.UdpJSON,
+		"webId":              req.WebId,
+		"reverseProxyJSON":   req.ReverseProxyJSON,
+		"serverGroupIds":     req.ServerGroupIds,
+		"userPlanId":         req.UserPlanId,
+		"nodeClusterId":      req.NodeClusterId,
 	}
 	}
 	apiUrl := s.Url + "ServerService/createServer"
 	apiUrl := s.Url + "ServerService/createServer"
-	resBody, err := s.SendData(ctx, formData, apiUrl)
+	resBody, err := s.sendDataWithTokenRetry(ctx, formData, apiUrl) // 使用封装后的方法
 	if err != nil {
 	if err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
@@ -250,7 +310,7 @@ func (s *cdnService) CreateWebsite(ctx context.Context, req v1.Website) (int64,
 	}
 	}
 	var res v1.GeneralResponse[DataStr]
 	var res v1.GeneralResponse[DataStr]
 	if err := json.Unmarshal(resBody, &res); err != nil {
 	if err := json.Unmarshal(resBody, &res); err != nil {
-		return  0, fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
+		return 0, fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
 	}
 	}
 	if res.Code != 200 {
 	if res.Code != 200 {
 		return 0, fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
 		return 0, fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
@@ -261,7 +321,7 @@ func (s *cdnService) CreateWebsite(ctx context.Context, req v1.Website) (int64,
 	return res.Data.WebsiteId, nil
 	return res.Data.WebsiteId, nil
 }
 }
 
 
-func (s *cdnService) EditProtocol(ctx context.Context, req v1.ProxyJson,action string)  error  {
+func (s *cdnService) EditProtocol(ctx context.Context, req v1.ProxyJson, action string) error {
 	formData := map[string]interface{}{
 	formData := map[string]interface{}{
 		"serverId": req.ServerId,
 		"serverId": req.ServerId,
 	}
 	}
@@ -285,7 +345,7 @@ func (s *cdnService) EditProtocol(ctx context.Context, req v1.ProxyJson,action s
 	default:
 	default:
 		return fmt.Errorf("不支持的协议类型")
 		return fmt.Errorf("不支持的协议类型")
 	}
 	}
-	resBody, err := s.SendData(ctx, formData, apiUrl)
+	resBody, err := s.sendDataWithTokenRetry(ctx, formData, apiUrl) // 使用封装后的方法
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -299,23 +359,23 @@ func (s *cdnService) EditProtocol(ctx context.Context, req v1.ProxyJson,action s
 	return nil
 	return nil
 }
 }
 
 
-func (s *cdnService) CreateOrigin(ctx context.Context, req v1.Origin) (int64, error)  {
+func (s *cdnService) CreateOrigin(ctx context.Context, req v1.Origin) (int64, error) {
 	formData := map[string]interface{}{
 	formData := map[string]interface{}{
-		"name": req.Name,
-		"addr": req.Addr,
-		"ossJSON": req.OssJSON,
-		"description": req.Description,
-		"weight": req.Weight,
-		"isOn": req.IsOn,
-		"domains": req.Domains,
-		"certRefJSON": req.CertRefJSON,
-		"host": req.Host,
-		"followPort": req.FollowPort,
-		"http2Enabled": req.Http2Enabled,
+		"name":                    req.Name,
+		"addr":                    req.Addr,
+		"ossJSON":                 req.OssJSON,
+		"description":             req.Description,
+		"weight":                  req.Weight,
+		"isOn":                    req.IsOn,
+		"domains":                 req.Domains,
+		"certRefJSON":             req.CertRefJSON,
+		"host":                    req.Host,
+		"followPort":              req.FollowPort,
+		"http2Enabled":            req.Http2Enabled,
 		"tlsSecurityVerifyMode": req.TlsSecurityVerifyMode,
 		"tlsSecurityVerifyMode": req.TlsSecurityVerifyMode,
 	}
 	}
 	apiUrl := s.Url + "OriginService/createOrigin"
 	apiUrl := s.Url + "OriginService/createOrigin"
-	resBody, err := s.SendData(ctx, formData, apiUrl)
+	resBody, err := s.sendDataWithTokenRetry(ctx, formData, apiUrl) // 使用封装后的方法
 	if err != nil {
 	if err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
@@ -324,7 +384,7 @@ func (s *cdnService) CreateOrigin(ctx context.Context, req v1.Origin) (int64, er
 	}
 	}
 	var res v1.GeneralResponse[DataStr]
 	var res v1.GeneralResponse[DataStr]
 	if err := json.Unmarshal(resBody, &res); err != nil {
 	if err := json.Unmarshal(resBody, &res); err != nil {
-		return  0, fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
+		return 0, fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
 	}
 	}
 	if res.Code != 200 {
 	if res.Code != 200 {
 		return 0, fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
 		return 0, fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
@@ -337,22 +397,22 @@ func (s *cdnService) CreateOrigin(ctx context.Context, req v1.Origin) (int64, er
 
 
 func (s *cdnService) EditOrigin(ctx context.Context, req v1.Origin) error {
 func (s *cdnService) EditOrigin(ctx context.Context, req v1.Origin) error {
 	formData := map[string]interface{}{
 	formData := map[string]interface{}{
-		"originId": req.OriginId,
-		"name": req.Name,
-		"addr": req.Addr,
-		"ossJSON": req.OssJSON,
-		"description": req.Description,
-		"weight": req.Weight,
-		"isOn": req.IsOn,
-		"domains": req.Domains,
-		"certRefJSON": req.CertRefJSON,
-		"host": req.Host,
-		"followPort": req.FollowPort,
-		"http2Enabled": req.Http2Enabled,
+		"originId":                req.OriginId,
+		"name":                    req.Name,
+		"addr":                    req.Addr,
+		"ossJSON":                 req.OssJSON,
+		"description":             req.Description,
+		"weight":                  req.Weight,
+		"isOn":                    req.IsOn,
+		"domains":                 req.Domains,
+		"certRefJSON":             req.CertRefJSON,
+		"host":                    req.Host,
+		"followPort":              req.FollowPort,
+		"http2Enabled":            req.Http2Enabled,
 		"tlsSecurityVerifyMode": req.TlsSecurityVerifyMode,
 		"tlsSecurityVerifyMode": req.TlsSecurityVerifyMode,
 	}
 	}
 	apiUrl := s.Url + "OriginService/updateOrigin"
 	apiUrl := s.Url + "OriginService/updateOrigin"
-	resBody, err := s.SendData(ctx, formData, apiUrl)
+	resBody, err := s.sendDataWithTokenRetry(ctx, formData, apiUrl) // 使用封装后的方法
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}

+ 4 - 1
internal/service/duedate.go

@@ -30,9 +30,12 @@ type duedateService struct {
 
 
 func (service *duedateService) NextDueDate(ctx context.Context, uid int, productID int) (string, error) {
 func (service *duedateService) NextDueDate(ctx context.Context, uid int, productID int) (string, error) {
 	timeStr, err := service.gameShieldRepository.GetGameShieldNextduedate(ctx, int64(uid), productID)
 	timeStr, err := service.gameShieldRepository.GetGameShieldNextduedate(ctx, int64(uid), productID)
-	if timeStr == "0" || timeStr == "" {
+	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
+	if timeStr == "0" || timeStr == "" {
+		return "", fmt.Errorf("产品不存在")
+	}
 
 
 	// 将字符串转为 int64 时间戳
 	// 将字符串转为 int64 时间戳
 	unixTime, err := strconv.ParseInt(timeStr, 10, 64)
 	unixTime, err := strconv.ParseInt(timeStr, 10, 64)

+ 0 - 3
internal/service/gatewaygroup.go

@@ -90,7 +90,6 @@ func (s *gatewayGroupService) EditGatewayGroup(ctx context.Context, req v1.AddGa
 		Name: req.Name,
 		Name: req.Name,
 		Comment: req.Comment,
 		Comment: req.Comment,
 		HostId: req.HostId,
 		HostId: req.HostId,
-		RuleId: req.RuleId,
 		BanUdp: req.BanUdp,
 		BanUdp: req.BanUdp,
 		BanOverseas: req.BanOverseas,
 		BanOverseas: req.BanOverseas,
 		Operator: req.Operator,
 		Operator: req.Operator,
@@ -124,7 +123,6 @@ func (s *gatewayGroupService) AddGatewayGroupAdmin(ctx context.Context,req v1.Ad
 		Name: req.Name,
 		Name: req.Name,
 		Comment: req.Comment,
 		Comment: req.Comment,
 		HostId: req.HostId,
 		HostId: req.HostId,
-		RuleId: req.RuleId,
 		BanUdp: req.BanUdp,
 		BanUdp: req.BanUdp,
 		BanOverseas: req.BanOverseas,
 		BanOverseas: req.BanOverseas,
 		Operator: req.Operator,
 		Operator: req.Operator,
@@ -141,7 +139,6 @@ func (s *gatewayGroupService) EditGatewayGroupAdmin(ctx context.Context, req v1.
 		Name: req.Name,
 		Name: req.Name,
 		Comment: req.Comment,
 		Comment: req.Comment,
 		HostId: req.HostId,
 		HostId: req.HostId,
-		RuleId: req.RuleId,
 		BanUdp: req.BanUdp,
 		BanUdp: req.BanUdp,
 		BanOverseas: req.BanOverseas,
 		BanOverseas: req.BanOverseas,
 		Operator: req.Operator,
 		Operator: req.Operator,

+ 113 - 26
internal/service/globallimit.go

@@ -2,13 +2,17 @@ package service
 
 
 import (
 import (
 	"context"
 	"context"
+	"errors"
 	"fmt"
 	"fmt"
 	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
+	"github.com/mozillazg/go-pinyin"
 	"github.com/spf13/viper"
 	"github.com/spf13/viper"
 	"golang.org/x/sync/errgroup"
 	"golang.org/x/sync/errgroup"
+	"gorm.io/gorm"
 	"strconv"
 	"strconv"
+	"strings"
 	"time"
 	"time"
 )
 )
 
 
@@ -36,6 +40,7 @@ func NewGlobalLimitService(
 	hostRep repository.HostRepository,
 	hostRep repository.HostRepository,
 	gateWayGroupRep repository.GatewayGroupRepository,
 	gateWayGroupRep repository.GatewayGroupRepository,
 	cdnService CdnService,
 	cdnService CdnService,
+	cdnRep repository.CdnRepository,
 ) GlobalLimitService {
 ) GlobalLimitService {
 	return &globalLimitService{
 	return &globalLimitService{
 		Service:               service,
 		Service:               service,
@@ -53,6 +58,7 @@ func NewGlobalLimitService(
 		hostRep:                hostRep,
 		hostRep:                hostRep,
 		gateWayGroupRep:       gateWayGroupRep,
 		gateWayGroupRep:       gateWayGroupRep,
 		cdnService:            cdnService,
 		cdnService:            cdnService,
+		cdnRep:                cdnRep,
 	}
 	}
 }
 }
 
 
@@ -72,12 +78,15 @@ type globalLimitService struct {
 	hostRep               repository.HostRepository
 	hostRep               repository.HostRepository
 	gateWayGroupRep       repository.GatewayGroupRepository
 	gateWayGroupRep       repository.GatewayGroupRepository
 	cdnService            CdnService
 	cdnService            CdnService
+	cdnRep                repository.CdnRepository
 }
 }
 
 
 func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
 func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
 	data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
 	data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
 	if err != nil {
 	if err != nil {
-		return 0, err
+		if !errors.Is(err, gorm.ErrRecordNotFound) {
+			return 0, err
+		}
 	}
 	}
 	if data != nil && data.CdnUid != 0 {
 	if data != nil && data.CdnUid != 0 {
 		return int64(data.CdnUid), nil
 		return int64(data.CdnUid), nil
@@ -86,8 +95,22 @@ func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64
 	if err != nil {
 	if err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
+	// 中文转拼音
+	a := pinyin.NewArgs()
+	a.Style = pinyin.Normal
+	pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
+	userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
+	// 查询用户是否存在
+	UserId,err := s.cdnRep.GetUserId(ctx, userName)
+	if err != nil {
+		return 0, err
+	}
+	if UserId != 0 {
+		return UserId, nil
+	}
+	// 注册用户
 	userId, err := s.cdnService.AddUser(ctx, v1.User{
 	userId, err := s.cdnService.AddUser(ctx, v1.User{
-		Username: userInfo.Username,
+		Username: userName,
 		Email:    userInfo.Email,
 		Email:    userInfo.Email,
 		Fullname: userInfo.Username,
 		Fullname: userInfo.Username,
 		Mobile:   userInfo.PhoneNumber,
 		Mobile:   userInfo.PhoneNumber,
@@ -132,7 +155,15 @@ func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.Glob
 	if err != nil {
 	if err != nil {
 		return v1.GlobalLimitRequireResponse{}, err
 		return v1.GlobalLimitRequireResponse{}, err
 	}
 	}
-	res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + strconv.Itoa(req.HostId) + "_" + domain
+	userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
+	if err != nil {
+		return v1.GlobalLimitRequireResponse{}, err
+	}
+	res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
+	res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
+	if err != nil {
+		return v1.GlobalLimitRequireResponse{}, err
+	}
 	return res, nil
 	return res, nil
 }
 }
 
 
@@ -140,6 +171,29 @@ func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*mod
 	return s.globalLimitRepository.GetGlobalLimit(ctx, id)
 	return s.globalLimitRepository.GetGlobalLimit(ctx, id)
 }
 }
 
 
+func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error)  {
+	// 2. 将字符串解析成 time.Time 对象
+	// time.Parse 会根据你提供的布局来理解输入的字符串
+	t, err := time.Parse("2006-01-02 15:04:05", req)
+	if err != nil {
+		// 如果输入的字符串格式和布局不匹配,这里会报错
+		return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
+	}
+	// 3. 定义新的输出格式 "YYYY-MM-DD"
+	outputLayout := "2006-01-02"
+	// 4. 将 time.Time 对象格式化为新的字符串
+	outputTimeStr := t.Format(outputLayout)
+	return outputTimeStr, nil
+}
+
+func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error)  {
+	t, err := time.Parse("2006-01-02 15:04:05", req)
+	if err != nil {
+		return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
+	}
+	expiredAt := t.Unix()
+	return expiredAt, nil
+}
 func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
 func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
 	isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
 	isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
 	if err != nil {
 	if err != nil {
@@ -160,7 +214,7 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 	g.Go(func() error {
 	g.Go(func() error {
 		res, e := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(gCtx, require.Operator, require.IpCount)
 		res, e := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(gCtx, require.Operator, require.IpCount)
 		if e != nil {
 		if e != nil {
-			return e
+			return fmt.Errorf("获取网关组失败: %w", e)
 		}
 		}
 		if res == 0 {
 		if res == 0 {
 			return fmt.Errorf("获取网关组失败")
 			return fmt.Errorf("获取网关组失败")
@@ -174,7 +228,7 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 	g.Go(func() error {
 	g.Go(func() error {
 		res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
 		res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
 		if e != nil {
 		if e != nil {
-			return e
+			return fmt.Errorf("获取cdn用户失败: %w", e)
 		}
 		}
 		if res == 0 {
 		if res == 0 {
 			return fmt.Errorf("获取cdn用户失败")
 			return fmt.Errorf("获取cdn用户失败")
@@ -186,7 +240,7 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 	g.Go(func() error {
 	g.Go(func() error {
 		res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
 		res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
 		if e != nil {
 		if e != nil {
-			return e
+			return fmt.Errorf("创建规则分组失败: %w", e)
 		}
 		}
 		if res == 0 {
 		if res == 0 {
 			return fmt.Errorf("创建规则分组失败")
 			return fmt.Errorf("创建规则分组失败")
@@ -199,18 +253,53 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 		return err
 		return err
 	}
 	}
 
 
+
+	outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
+	if err != nil {
+		return err
+	}
+
 	ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
 	ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
 		UserId:    userId,
 		UserId:    userId,
+		PlanId: 	4,
+		DayTo:     outputTimeStr,
+		Name:      require.GlobalLimitName,
+		IsFree:    true,
+		Period:    "monthly",
+		CountPeriod: 1,
+		PeriodDayTo: outputTimeStr,
+	})
+	if err != nil {
+		return err
+	}
+	if ruleId == 0 {
+		return fmt.Errorf("分配套餐失败")
+	}
+
 
 
+
+	err = s.gateWayGroupRep.EditGatewayGroup(ctx, &model.GatewayGroup{
+		Id: gatewayGroupId,
+		HostId: req.HostId,
 	})
 	})
+	if err != nil {
+		return err
+	}
 
 
+	expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
+	if err != nil {
+		return err
+	}
 	err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
 	err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
-		HostId:    req.HostId,
-		Uid:       req.Uid,
-		RuleId:    ,
-		CdnUid:    int(userId),
-		Comment:   req.Comment,
-		ExpiredAt: require.ExpiredAt,
+		HostId:         req.HostId,
+		Uid:            req.Uid,
+		Name:           require.GlobalLimitName,
+		RuleId: 		int(ruleId),
+		GroupId:        int(groupId),
+		GatewayGroupId: gatewayGroupId,
+		CdnUid:         int(userId),
+		Comment:        req.Comment,
+		ExpiredAt:      expiredAt,
 	})
 	})
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -223,32 +312,30 @@ func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalL
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	formData := map[string]interface{}{
-		"tag":             require.GlobalLimitName,
-		"bps":             require.Bps,
-		"max_bytes_month": require.MaxBytesMonth,
-		"expired_at":      require.ExpiredAt,
-	}
 	data, err :=  s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
 	data, err :=  s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	respBody, err := s.required.SendForm(ctx, "admin/info/waf_common_limit/edit?&__goadmin_edit_pk="+strconv.Itoa(data.RuleId), "admin/edit/waf_common_limit", formData)
+	outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	res, err := s.parser.ParseAlert(string(respBody))
+	err = s.cdnService.RenewPlan(ctx, v1.RenewalPlan{
+		UserPlanId: int64(data.RuleId),
+		DayTo:      outputTimeStr,
+		Period:     "monthly",
+		CountPeriod: 1,
+		IsFree:     true,
+		PeriodDayTo: outputTimeStr,
+	})
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	if res != "" {
-		return fmt.Errorf(res)
-	}
-	t, err := time.Parse("2006-01-02 15:04:05", require.ExpiredAt)
+
+	expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	expiredAt := t.Unix()
 	if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
 	if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
 		HostId:  req.HostId,
 		HostId:  req.HostId,
 		Comment: req.Comment,
 		Comment: req.Comment,
@@ -271,7 +358,7 @@ func (s *globalLimitService) EditGlobalLimitBySnail(ctx context.Context, req v1.
 	t := time.Unix(req.ExpiredAt, 0)
 	t := time.Unix(req.ExpiredAt, 0)
 	expiredAt := t.Format("2006-01-02 15:04:05")
 	expiredAt := t.Format("2006-01-02 15:04:05")
 	formData := map[string]interface{}{
 	formData := map[string]interface{}{
-		"tag":             data.GlobalLimitName,
+		"tag":             data.Name,
 		"bps":             configCount.Bps,
 		"bps":             configCount.Bps,
 		"max_bytes_month": configCount.MaxBytesMonth,
 		"max_bytes_month": configCount.MaxBytesMonth,
 		"expired_at":      expiredAt,
 		"expired_at":      expiredAt,

+ 0 - 8
internal/service/wafformatter.go

@@ -83,14 +83,6 @@ func (s *wafFormatterService) require(ctx context.Context,req v1.GlobalRequire,c
 		return v1.GlobalRequire{}, err
 		return v1.GlobalRequire{}, err
 	}
 	}
 	req.WafGatewayGroupId = RuleIds.GatewayGroupId
 	req.WafGatewayGroupId = RuleIds.GatewayGroupId
-	switch category {
-	case "tcp":
-		req.LimitRuleId = RuleIds.TcpLimitRuleId
-	case "udp":
-		req.LimitRuleId = RuleIds.UdpLimitRuleId
-	case "web":
-		req.LimitRuleId = RuleIds.WebLimitRuleId
-	}
 	domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
 	domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
 	if err != nil {
 	if err != nil {
 		return v1.GlobalRequire{}, err
 		return v1.GlobalRequire{}, err