Procházet zdrojové kódy

feat(service): 新增 CDN 相关功能并重构全局限制逻辑- 在 CdnService 接口中添加了多个新方法,用于处理用户、组和网站创建等操作
- 新增 GetCdnUserId 和 AddGroupId 方法,用于获取 CDN 用户 ID 和添加组 ID
- 重构了 GlobalLimitService 中的 GlobalLimitRequire 方法,集成了 CDN 相关操作
- 在 GlobalLimit 模型中添加了 GatewayGroupId 字段
- 新增了 GetGlobalLimitFirst 和 GetUserInfo 方法,用于获取全局限制和用户信息

fusu před 1 měsícem
rodič
revize
718d96b9e9

+ 7 - 1
api/v1/globalLimit.go

@@ -64,4 +64,10 @@ type GlobalLimitExpiredByHost struct {
 type GlobalLimitExpiredBySnail struct {
 	HostId int `json:"host_id" form:"host_id" gorm:"column:waf_common_limit_id"`
 	ExpiredAt int64 `json:"expired_at" form:"expired_at" gorm:"column:expired_at"`
-}
+}
+
+type UserInfo struct {
+	Username string `json:"username" form:"username"`
+	Email    string `json:"email" form:"email"`
+	PhoneNumber    string `json:"phonenumber" form:"phonenumber"`
+}

+ 1 - 0
internal/model/globallimit.go

@@ -8,6 +8,7 @@ type GlobalLimit struct {
 	RuleId          int
 	Uid             int
 	CdnUid          int
+	GatewayGroupId  int
 	Comment         string
 	ExpiredAt       int64
 	createdAt       time.Time

+ 21 - 0
internal/repository/globallimit.go

@@ -16,6 +16,8 @@ type GlobalLimitRepository interface {
 	GetGlobalLimitByHostId(ctx context.Context, hostId int64) (*model.GlobalLimit, error)
 	GetGlobalLimitAllExpired(ctx context.Context,ids []int) ([]v1.GlobalLimitExpiredByHost, error)
 	GetGlobalLimitAllHostId(ctx context.Context) ([]v1.GlobalLimitExpired, error)
+	GetGlobalLimitFirst(ctx context.Context,uid int64) (*model.GlobalLimit, error)
+	GetUserInfo(ctx context.Context, uid int64) (v1.UserInfo, error)
 }
 
 func NewGlobalLimitRepository(
@@ -102,4 +104,23 @@ func (r *globalLimitRepository) GetGlobalLimitAllHostId(ctx context.Context) ([]
 		return nil, err
 	}
 	return res, nil
+}
+
+func (r *globalLimitRepository) GetGlobalLimitFirst(ctx context.Context,uid int64) (*model.GlobalLimit, error)  {
+	var req model.GlobalLimit
+	if err := r.DB(ctx).Where("uid = ?", uid).First(&req).Error; err != nil {
+		return nil, err
+	}
+	return &req, nil
+}
+
+func (r *globalLimitRepository) GetUserInfo(ctx context.Context, uid int64) (v1.UserInfo, error) {
+	var res v1.UserInfo
+	if err := r.DB(ctx).Table("shd_user").
+		Where("id = ?", uid).
+		Select("username", "email", "phonenumber").
+		Find(&res).Error; err != nil {
+		return v1.UserInfo{}, err
+	}
+	return res, nil
 }

+ 9 - 1
internal/service/cdn.go

@@ -11,6 +11,14 @@ import (
 
 type CdnService interface {
 	GetToken(ctx context.Context) (string, error)
+	AddUser(ctx context.Context, req v1.User) (int64, error)
+	CreateGroup(ctx context.Context, req v1.Group) (int64, error)
+	BindPlan(ctx context.Context, req v1.Plan) (int64, error)
+	RenewPlan(ctx context.Context, req v1.RenewalPlan) error
+	CreateWebsite(ctx context.Context, req v1.Website) (int64, error)
+	EditProtocol(ctx context.Context, req v1.ProxyJson,action string) error
+	CreateOrigin(ctx context.Context, req v1.Origin) (int64, error)
+	EditOrigin(ctx context.Context, req v1.Origin) error
 }
 func NewCdnService(
     service *Service,
@@ -253,7 +261,7 @@ func (s *cdnService) CreateWebsite(ctx context.Context, req v1.Website) (int64,
 	return res.Data.WebsiteId, nil
 }
 
-func (s *cdnService) EditTcpProtocol(ctx context.Context, req v1.ProxyJson,action string)  error  {
+func (s *cdnService) EditProtocol(ctx context.Context, req v1.ProxyJson,action string)  error  {
 	formData := map[string]interface{}{
 		"serverId": req.ServerId,
 	}

+ 70 - 92
internal/service/globallimit.go

@@ -6,7 +6,6 @@ import (
 	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
-	"github.com/spf13/cast"
 	"github.com/spf13/viper"
 	"golang.org/x/sync/errgroup"
 	"strconv"
@@ -36,6 +35,7 @@ func NewGlobalLimitService(
 	gateWayGroup GatewayGroupService,
 	hostRep repository.HostRepository,
 	gateWayGroupRep repository.GatewayGroupRepository,
+	cdnService CdnService,
 ) GlobalLimitService {
 	return &globalLimitService{
 		Service:               service,
@@ -52,6 +52,7 @@ func NewGlobalLimitService(
 		gateWayGroup:          gateWayGroup,
 		hostRep:                hostRep,
 		gateWayGroupRep:       gateWayGroupRep,
+		cdnService:            cdnService,
 	}
 }
 
@@ -70,6 +71,41 @@ type globalLimitService struct {
 	gateWayGroup          GatewayGroupService
 	hostRep               repository.HostRepository
 	gateWayGroupRep       repository.GatewayGroupRepository
+	cdnService            CdnService
+}
+
+func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
+	data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
+	if err != nil {
+		return 0, err
+	}
+	if data != nil && data.CdnUid != 0 {
+		return int64(data.CdnUid), nil
+	}
+	userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
+	if err != nil {
+		return 0, err
+	}
+	userId, err := s.cdnService.AddUser(ctx, v1.User{
+		Username: userInfo.Username,
+		Email:    userInfo.Email,
+		Fullname: userInfo.Username,
+		Mobile:   userInfo.PhoneNumber,
+	})
+	if err != nil {
+		return 0, err
+	}
+	return userId, nil
+}
+
+func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error)  {
+	groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
+		Name: groupName,
+	})
+	if err != nil {
+		return 0, err
+	}
+	return groupId, nil
 }
 
 func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
@@ -116,124 +152,66 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 	if err != nil {
 		return err
 	}
-	gatewayGroupId, err := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(ctx, require.Operator, require.IpCount)
-	if err != nil {
-		return err
-	}
-	formData := map[string]interface{}{
-		"tag":             require.GlobalLimitName,
-		"bps":             require.Bps,
-		"max_bytes_month": require.MaxBytesMonth,
-		"expired_at":      require.ExpiredAt,
-	}
-	respBody, err := s.required.SendForm(ctx, "admin/info/waf_common_limit/new", "admin/new/waf_common_limit", formData)
-	if err != nil {
-		return err
-	}
-	ruleIdBase, err := s.parser.GetRuleIdByColumnName(ctx, respBody, require.GlobalLimitName)
-	if err != nil {
-		return err
-	}
-	if ruleIdBase == "" {
-		res, err := s.parser.ParseAlert(string(respBody))
-		if err != nil {
-			return err
-		}
-		return fmt.Errorf(res)
-	}
-	ruleId, err := cast.ToIntE(ruleIdBase)
-	if err != nil {
-		return err
-	}
-	var tcpLimitRuleId, udpLimitRuleId, webLimitRuleId int
 
 	g, gCtx := errgroup.WithContext(ctx)
-
-	// 启动tcpLimit调用 - 使用独立的请求参数副本
+	var gatewayGroupId int
+	var userId int64
+	var groupId int64
 	g.Go(func() error {
-		tcpLimitReq := &v1.GeneralLimitRequireRequest{
-			Tag:    require.GlobalLimitName,
-			HostId: req.HostId,
-			RuleId: ruleId,
-			Uid:    req.Uid,
-		}
-		result, e := s.tcpLimit.AddTcpLimit(gCtx, tcpLimitReq)
+		gatewayGroupId, e := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(gCtx, require.Operator, require.IpCount)
 		if e != nil {
-			return fmt.Errorf("tcpLimit调用失败: %w", e)
+			return e
 		}
-		if result != 0 {
-			tcpLimitRuleId = result
-			return nil
+		if gatewayGroupId == 0 {
+			return fmt.Errorf("获取网关组失败")
 		}
-		return fmt.Errorf("tcpLimit调用失败,Id为 %d", result)
+		return nil
 	})
 
-	// 启动udpLimit调用 - 使用独立的请求参数副本
+
+
 	g.Go(func() error {
-		udpLimitReq := &v1.GeneralLimitRequireRequest{
-			Tag:    require.GlobalLimitName,
-			HostId: req.HostId,
-			RuleId: ruleId,
-			Uid:    req.Uid,
-		}
-		result, e := s.udpLimit.AddUdpLimit(gCtx, udpLimitReq)
+		res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
 		if e != nil {
-			return fmt.Errorf("udpLimit调用失败: %w", e)
+			return e
 		}
-		if result != 0 {
-			udpLimitRuleId = result
-			return nil
+		if res == 0 {
+			return fmt.Errorf("获取cdn用户失败")
 		}
-		return fmt.Errorf("udpLimit调用失败,Id为 %d", result)
+		userId = res
+		return nil
 	})
 
-
-	// 启动webLimit调用 - 使用独立的请求参数副本
 	g.Go(func() error {
-		webLimitReq := &v1.GeneralLimitRequireRequest{
-			Tag:    require.GlobalLimitName,
-			HostId: req.HostId,
-			RuleId: ruleId,
-			Uid:    req.Uid,
-		}
-		result, e := s.webLimit.AddWebLimit(gCtx, webLimitReq)
+		res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
 		if e != nil {
-			return fmt.Errorf("webLimit调用失败: %w", e)
+			return e
 		}
-		if result != 0 {
-			webLimitRuleId = result
-			return nil
+		if res == 0 {
+			return fmt.Errorf("创建规则分组失败")
 		}
-		return fmt.Errorf("webLimit调用失败,Id为 %d", result)
+		return nil
 	})
 
-	if err := g.Wait(); err != nil {
-		return err
-	}
-	t, err := time.Parse("2006-01-02 15:04:05", require.ExpiredAt)
-	if err != nil {
+	if err = g.Wait(); err != nil {
 		return err
 	}
-	expiredAt := t.Unix()
+
+	ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
+
+	})
+
 	err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
-		HostId:          req.HostId,
-		RuleId:          cast.ToInt(ruleId),
-		Uid:             req.Uid,
-		GlobalLimitName: require.GlobalLimitName,
-		Comment:         req.Comment,
-		TcpLimitRuleId:  tcpLimitRuleId,
-		UdpLimitRuleId:  udpLimitRuleId,
-		WebLimitRuleId:  webLimitRuleId,
-		GatewayGroupId:  gatewayGroupId,
-		ExpiredAt:       expiredAt,
+		HostId:    req.HostId,
+		Uid:       req.Uid,
+		RuleId:    ,
+		CdnUid:    int(userId),
+		Comment:   req.Comment,
+		ExpiredAt: require.ExpiredAt,
 	})
 	if err != nil {
 		return err
 	}
-	err = s.gateWayGroupRep.EditGatewayGroup(ctx, &model.GatewayGroup{
-		RuleId: gatewayGroupId,
-		HostId: req.HostId,
-	})
 	return nil
 }