123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550 |
- package service
- import (
- "context"
- "errors"
- "fmt"
- v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
- "github.com/go-nunu/nunu-layout-advanced/internal/model"
- "github.com/go-nunu/nunu-layout-advanced/internal/repository"
- "github.com/mozillazg/go-pinyin"
- "github.com/spf13/viper"
- "go.uber.org/zap"
- "golang.org/x/sync/errgroup"
- "gorm.io/gorm"
- "strconv"
- "strings"
- "time"
- )
- type GlobalLimitService interface {
- GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
- AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
- EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
- DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
- GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error)
- }
- func NewGlobalLimitService(
- service *Service,
- globalLimitRepository repository.GlobalLimitRepository,
- duedate DuedateService,
- crawler CrawlerService,
- conf *viper.Viper,
- required RequiredService,
- parser ParserService,
- host HostService,
- hostRep repository.HostRepository,
- cdnService CdnService,
- cdnRep repository.CdnRepository,
- tcpforwardingRep repository.TcpforwardingRepository,
- udpForWardingRep repository.UdpForWardingRepository,
- webForWardingRep repository.WebForwardingRepository,
- allowAndDeny AllowAndDenyIpService,
- allowAndDenyRep repository.AllowAndDenyIpRepository,
- tcpforwarding TcpforwardingService,
- udpForWarding UdpForWardingService,
- webForWarding WebForwardingService,
- gatewayIpRep repository.GatewayipRepository,
- gatywayIp GatewayipService,
- ) GlobalLimitService {
- return &globalLimitService{
- Service: service,
- globalLimitRepository: globalLimitRepository,
- duedate: duedate,
- crawler: crawler,
- Url: conf.GetString("crawler.Url"),
- required: required,
- parser: parser,
- host: host,
- hostRep: hostRep,
- cdnService: cdnService,
- cdnRep: cdnRep,
- tcpforwardingRep: tcpforwardingRep,
- udpForWardingRep: udpForWardingRep,
- webForWardingRep: webForWardingRep,
- allowAndDeny: allowAndDeny,
- allowAndDenyRep: allowAndDenyRep,
- tcpforwarding: tcpforwarding,
- udpForWarding: udpForWarding,
- webForWarding: webForWarding,
- gatewayIpRep: gatewayIpRep,
- gatewayIp: gatywayIp,
- }
- }
- type globalLimitService struct {
- *Service
- globalLimitRepository repository.GlobalLimitRepository
- duedate DuedateService
- crawler CrawlerService
- Url string
- required RequiredService
- parser ParserService
- host HostService
- hostRep repository.HostRepository
- cdnService CdnService
- cdnRep repository.CdnRepository
- tcpforwardingRep repository.TcpforwardingRepository
- udpForWardingRep repository.UdpForWardingRepository
- webForWardingRep repository.WebForwardingRepository
- allowAndDeny AllowAndDenyIpService
- allowAndDenyRep repository.AllowAndDenyIpRepository
- tcpforwarding TcpforwardingService
- udpForWarding UdpForWardingService
- webForWarding WebForwardingService
- gatewayIpRep repository.GatewayipRepository
- gatewayIp GatewayipService
- }
- func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
- data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
- if err != nil {
- if !errors.Is(err, gorm.ErrRecordNotFound) {
- return 0, err
- }
- }
- if data != nil && data.CdnUid != 0 {
- return int64(data.CdnUid), nil
- }
- userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
- if err != nil {
- return 0, err
- }
- // 中文转拼音
- a := pinyin.NewArgs()
- a.Style = pinyin.Normal
- pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
- userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
- // 查询用户是否存在
- UserId,err := s.cdnRep.GetUserId(ctx, userName)
- if err != nil {
- return 0, err
- }
- if UserId != 0 {
- return UserId, nil
- }
- // 注册用户
- userId, err := s.cdnService.AddUser(ctx, v1.User{
- Username: userName,
- Email: userInfo.Email,
- Fullname: userInfo.Username,
- Mobile: userInfo.PhoneNumber,
- })
- if err != nil {
- return 0, err
- }
- return userId, nil
- }
- func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
- groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
- Name: groupName,
- })
- if err != nil {
- return 0, err
- }
- return groupId, nil
- }
- func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
- res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
- if err != nil {
- return v1.GlobalLimitRequireResponse{}, err
- }
- configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
- if err != nil {
- return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
- }
- bpsInt, err := strconv.Atoi(configCount.Bps)
- if err != nil {
- return v1.GlobalLimitRequireResponse{}, err
- }
- resultFloat := float64(bpsInt) / 2.0 / 8.0
- res.Bps = strconv.FormatFloat( resultFloat, 'f', -1, 64) + "M"
- res.MaxBytesMonth = configCount.MaxBytesMonth
- res.Operator = configCount.Operator
- res.IpCount = configCount.IpCount
- res.NodeArea = configCount.NodeArea
- res.ConfigMaxProtection = configCount.ConfigMaxProtection
- res.IsBanUdp = configCount.IsBanUdp
- res.HostId = req.HostId
- domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
- if err != nil {
- return v1.GlobalLimitRequireResponse{}, err
- }
- userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
- if err != nil {
- return v1.GlobalLimitRequireResponse{}, err
- }
- res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
- res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
- if err != nil {
- return v1.GlobalLimitRequireResponse{}, err
- }
- return res, nil
- }
- func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
- return s.globalLimitRepository.GetGlobalLimit(ctx, id)
- }
- func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
- // 2. 将字符串解析成 time.Time 对象
- // time.Parse 会根据你提供的布局来理解输入的字符串
- t, err := time.Parse("2006-01-02 15:04:05", req)
- if err != nil {
- // 如果输入的字符串格式和布局不匹配,这里会报错
- return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
- }
- // 3. 定义新的输出格式 "YYYY-MM-DD"
- outputLayout := "2006-01-02"
- // 4. 将 time.Time 对象格式化为新的字符串
- outputTimeStr := t.Format(outputLayout)
- return outputTimeStr, nil
- }
- func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
- t, err := time.Parse("2006-01-02 15:04:05", req)
- if err != nil {
- return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
- }
- expiredAt := t.Unix()
- return expiredAt, nil
- }
- func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
- isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
- if err != nil {
- return err
- }
- if isExist {
- return fmt.Errorf("配置限制已存在")
- }
- require, err := s.GlobalLimitRequire(ctx, req)
- if err != nil {
- return err
- }
- g, gCtx := errgroup.WithContext(ctx)
- var userId int64
- var groupId int64
- g.Go(func() error {
- e := s.gatewayIp.AddIpWhereHostIdNull(gCtx, int64(req.HostId),int64(req.Uid))
- if e != nil {
- return fmt.Errorf("获取网关组失败: %w", e)
- }
- return nil
- })
- g.Go(func() error {
- res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
- if e != nil {
- return fmt.Errorf("获取cdn用户失败: %w", e)
- }
- if res == 0 {
- return fmt.Errorf("获取cdn用户失败")
- }
- userId = res
- return nil
- })
- g.Go(func() error {
- res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
- if e != nil {
- return fmt.Errorf("创建规则分组失败: %w", e)
- }
- if res == 0 {
- return fmt.Errorf("创建规则分组失败")
- }
- groupId = res
- return nil
- })
- if err = g.Wait(); err != nil {
- return err
- }
- outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
- if err != nil {
- return err
- }
- // 获取套餐ID
- maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G")
- if maxProtection == "" {
- return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection)
- }
- maxProtectionInt, err := strconv.Atoi(maxProtection)
- if err != nil {
- return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection)
- }
- var planId int64
- maxProtectionNum := 1
- if maxProtectionInt >= 2000 {
- maxProtectionNum = maxProtectionInt / 1000
- }
- NodeAreaName := fmt.Sprintf("%s-%dT",require.NodeArea, maxProtectionNum)
- planId, err = s.globalLimitRepository.GetNodeArea(ctx, NodeAreaName)
- if err != nil {
- if errors.Is(err, gorm.ErrRecordNotFound) {
- planId = 0
- }else {
- return err
- }
- }
- if planId == 0 {
- // 安全冗余套餐
- planId = 6
- s.logger.Warn("获取套餐Id失败", zap.String("节点区域", NodeAreaName), zap.String("防御阈值", require.ConfigMaxProtection),zap.Int64("套餐Id", int64(req.Uid)),zap.Int64("魔方套餐Id", int64(req.HostId)))
- }
- ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
- UserId: userId,
- PlanId: planId,
- DayTo: outputTimeStr,
- Name: require.GlobalLimitName,
- IsFree: true,
- Period: "monthly",
- CountPeriod: 1,
- PeriodDayTo: outputTimeStr,
- })
- if err != nil {
- return err
- }
- if ruleId == 0 {
- return fmt.Errorf("分配套餐失败")
- }
- expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
- if err != nil {
- return err
- }
- // 如果存在实例,恢复
- oldData,err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
- if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
- return err
- }
- if oldData!= nil && oldData.Id != 0 {
- err = s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
- HostId: req.HostId,
- Uid: req.Uid,
- Name: require.GlobalLimitName,
- RuleId: int(ruleId),
- GroupId: int(groupId),
- CdnUid: int(userId),
- Comment: req.Comment,
- ExpiredAt: expiredAt,
- State: true,
- })
- if err != nil {
- return err
- }
- return nil
- }
- // 如果不存在实例,创建
- err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
- HostId: req.HostId,
- Uid: req.Uid,
- Name: require.GlobalLimitName,
- RuleId: int(ruleId),
- GroupId: int(groupId),
- CdnUid: int(userId),
- Comment: req.Comment,
- State: true,
- ExpiredAt: expiredAt,
- })
- if err != nil {
- return err
- }
- return nil
- }
- func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
- require, err := s.GlobalLimitRequire(ctx, req)
- if err != nil {
- return err
- }
- data, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
- if err != nil {
- return err
- }
- // 如果不存在实例,创建
- gatewayIp, err := s.gatewayIpRep.GetGatewayipByHostIdAll(ctx, int64(req.HostId))
- if err != nil {
- return err
- }
- if gatewayIp != nil {
- err = s.gatewayIp.AddIpWhereHostIdNull(ctx, int64(req.HostId), int64(req.Uid))
- if err != nil {
- return fmt.Errorf("获取网关组失败: %w", err)
- }
- return nil
- }
- outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
- if err != nil {
- return err
- }
- err = s.cdnService.RenewPlan(ctx, v1.RenewalPlan{
- UserPlanId: int64(data.RuleId),
- DayTo: outputTimeStr,
- Period: "monthly",
- CountPeriod: 1,
- IsFree: true,
- PeriodDayTo: outputTimeStr,
- })
- if err != nil {
- return err
- }
- expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
- if err != nil {
- return err
- }
- if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
- HostId: req.HostId,
- Comment: req.Comment,
- ExpiredAt: expiredAt,
- }); err != nil {
- return err
- }
- return nil
- }
- func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
- // 检查是否过期
- isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId))
- if err != nil {
- return err
- }
- if isExpired {
- return fmt.Errorf("实例未过期,无法删除")
- }
- oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
- if err != nil {
- if errors.Is(err, gorm.ErrRecordNotFound) {
- return fmt.Errorf("实例不存在")
- }
- return err
- }
- tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId)
- if err != nil {
- return err
- }
- udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId)
- if err != nil {
- return err
- }
- webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId)
- if err != nil {
- return err
- }
- // 黑白IP
- BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId))
- if err != nil {
- return err
- }
- // 删除网站
- g, gCtx := errgroup.WithContext(ctx)
- g.Go(func() error {
- e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{
- Ids: tcpIds,
- Uid: req.Uid,
- HostId: req.HostId,
- })
- if e != nil {
- return fmt.Errorf("删除TCP转发失败: %w", e)
- }
- return nil
- })
- g.Go(func() error {
- e := s.udpForWarding.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{
- Ids: udpIds,
- Uid: req.Uid,
- HostId: req.HostId,
- })
- if e != nil {
- return fmt.Errorf("删除UDP转发失败: %w", e)
- }
- return nil
- })
- g.Go(func() error {
- e := s.webForWarding.DeleteWebForwarding(ctx,v1.DeleteWebForwardingRequest{
- Ids: webIds,
- Uid: req.Uid,
- HostId: req.HostId,
- })
- if e != nil {
- return fmt.Errorf("删除WEB转发失败: %w", e)
- }
- return nil
- })
- // 删除套餐
- g.Go(func() error {
- e := s.cdnService.DelUserPlan(gCtx, int64(oldData.RuleId))
- if e != nil {
- return fmt.Errorf("删除套餐失败: %w", e)
- }
- return nil
- })
- // 删除网站分组
- g.Go(func() error {
- e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId))
- if e != nil {
- return fmt.Errorf("删除网站分组失败: %w", e)
- }
- return nil
- })
- if err = g.Wait(); err != nil {
- return err
- }
- if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), false); err != nil {
- return err
- }
- if err := s.gatewayIpRep.CleanIPByHostId(ctx, []int64{int64(req.HostId)}); err != nil {
- return err
- }
- // 删除黑白名单
- err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{
- HostId: req.HostId,
- Ids: BwIds,
- Uid: req.Uid,
- })
- if err != nil {
- return err
- }
- return nil
- }
|