package service import ( "context" "errors" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" "github.com/mozillazg/go-pinyin" "github.com/spf13/viper" "go.uber.org/zap" "golang.org/x/sync/errgroup" "gorm.io/gorm" "strconv" "strings" "time" ) type GlobalLimitService interface { GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error } func NewGlobalLimitService( service *Service, globalLimitRepository repository.GlobalLimitRepository, duedate DuedateService, crawler CrawlerService, conf *viper.Viper, required RequiredService, parser ParserService, host HostService, gateWayGroup GatewayGroupService, hostRep repository.HostRepository, gateWayGroupRep repository.GatewayGroupRepository, cdnService CdnService, cdnRep repository.CdnRepository, tcpforwardingRep repository.TcpforwardingRepository, udpForWardingRep repository.UdpForWardingRepository, webForWardingRep repository.WebForwardingRepository, allowAndDeny AllowAndDenyIpService, allowAndDenyRep repository.AllowAndDenyIpRepository, tcpforwarding TcpforwardingService, udpForWarding UdpForWardingService, webForWarding WebForwardingService, ) GlobalLimitService { return &globalLimitService{ Service: service, globalLimitRepository: globalLimitRepository, duedate: duedate, crawler: crawler, Url: conf.GetString("crawler.Url"), required: required, parser: parser, host: host, gateWayGroup: gateWayGroup, hostRep: hostRep, gateWayGroupRep: gateWayGroupRep, cdnService: cdnService, cdnRep: cdnRep, tcpforwardingRep: tcpforwardingRep, udpForWardingRep: udpForWardingRep, webForWardingRep: webForWardingRep, allowAndDeny: allowAndDeny, allowAndDenyRep: allowAndDenyRep, tcpforwarding: tcpforwarding, udpForWarding: udpForWarding, webForWarding: webForWarding, } } type globalLimitService struct { *Service globalLimitRepository repository.GlobalLimitRepository duedate DuedateService crawler CrawlerService Url string required RequiredService parser ParserService host HostService gateWayGroup GatewayGroupService hostRep repository.HostRepository gateWayGroupRep repository.GatewayGroupRepository cdnService CdnService cdnRep repository.CdnRepository tcpforwardingRep repository.TcpforwardingRepository udpForWardingRep repository.UdpForWardingRepository webForWardingRep repository.WebForwardingRepository allowAndDeny AllowAndDenyIpService allowAndDenyRep repository.AllowAndDenyIpRepository tcpforwarding TcpforwardingService udpForWarding UdpForWardingService webForWarding WebForwardingService } func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) { data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid) if err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return 0, err } } if data != nil && data.CdnUid != 0 { return int64(data.CdnUid), nil } userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid) if err != nil { return 0, err } // 中文转拼音 a := pinyin.NewArgs() a.Style = pinyin.Normal pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a) userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_") // 查询用户是否存在 UserId,err := s.cdnRep.GetUserId(ctx, userName) if err != nil { return 0, err } if UserId != 0 { return UserId, nil } // 注册用户 userId, err := s.cdnService.AddUser(ctx, v1.User{ Username: userName, Email: userInfo.Email, Fullname: userInfo.Username, Mobile: userInfo.PhoneNumber, }) if err != nil { return 0, err } return userId, nil } func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) { groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{ Name: groupName, }) if err != nil { return 0, err } return groupId, nil } func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) { res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, err } configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err) } bpsInt, err := strconv.Atoi(configCount.Bps) if err != nil { return v1.GlobalLimitRequireResponse{}, err } resultFloat := float64(bpsInt) / 2.0 / 8.0 res.Bps = strconv.FormatFloat( resultFloat, 'f', -1, 64) + "M" res.MaxBytesMonth = configCount.MaxBytesMonth res.Operator = configCount.Operator res.IpCount = configCount.IpCount res.NodeArea = configCount.NodeArea res.ConfigMaxProtection = configCount.ConfigMaxProtection res.IsBanUdp = configCount.IsBanUdp domain, err := s.hostRep.GetDomainById(ctx, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, err } userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid)) if err != nil { return v1.GlobalLimitRequireResponse{}, err } res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId)) if err != nil { return v1.GlobalLimitRequireResponse{}, err } return res, nil } func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) { return s.globalLimitRepository.GetGlobalLimit(ctx, id) } func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) { // 2. 将字符串解析成 time.Time 对象 // time.Parse 会根据你提供的布局来理解输入的字符串 t, err := time.Parse("2006-01-02 15:04:05", req) if err != nil { // 如果输入的字符串格式和布局不匹配,这里会报错 return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err) } // 3. 定义新的输出格式 "YYYY-MM-DD" outputLayout := "2006-01-02" // 4. 将 time.Time 对象格式化为新的字符串 outputTimeStr := t.Format(outputLayout) return outputTimeStr, nil } func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) { t, err := time.Parse("2006-01-02 15:04:05", req) if err != nil { return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err) } expiredAt := t.Unix() return expiredAt, nil } func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId)) if err != nil { return err } if isExist { return fmt.Errorf("配置限制已存在") } require, err := s.GlobalLimitRequire(ctx, req) if err != nil { return err } g, gCtx := errgroup.WithContext(ctx) var gatewayGroupId int var userId int64 var groupId int64 g.Go(func() error { res, e := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(gCtx, require) if e != nil { return fmt.Errorf("获取网关组失败: %w", e) } if res == 0 { return fmt.Errorf("获取网关组失败") } gatewayGroupId = res return nil }) g.Go(func() error { res, e := s.GetCdnUserId(gCtx, int64(req.Uid)) if e != nil { return fmt.Errorf("获取cdn用户失败: %w", e) } if res == 0 { return fmt.Errorf("获取cdn用户失败") } userId = res return nil }) g.Go(func() error { res, e := s.AddGroupId(gCtx, require.GlobalLimitName) if e != nil { return fmt.Errorf("创建规则分组失败: %w", e) } if res == 0 { return fmt.Errorf("创建规则分组失败") } groupId = res return nil }) if err = g.Wait(); err != nil { return err } outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt) if err != nil { return err } // 获取套餐ID maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G") if maxProtection == "" { return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection) } maxProtectionInt, err := strconv.Atoi(maxProtection) if err != nil { return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection) } var planId int64 maxProtectionNum := 1 if maxProtectionInt >= 2000 { maxProtectionNum = maxProtectionInt / 1000 } NodeAreaName := fmt.Sprintf("%s-%dT",require.NodeArea, maxProtectionNum) planId, err = s.globalLimitRepository.GetNodeArea(ctx, NodeAreaName) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { planId = 0 }else { return err } } if planId == 0 { // 安全冗余套餐 planId = 6 s.logger.Warn("获取套餐Id失败", zap.String("节点区域", NodeAreaName), zap.String("防御阈值", require.ConfigMaxProtection),zap.Int64("套餐Id", int64(req.Uid)),zap.Int64("魔方套餐Id", int64(req.HostId))) } ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{ UserId: userId, PlanId: planId, DayTo: outputTimeStr, Name: require.GlobalLimitName, IsFree: true, Period: "monthly", CountPeriod: 1, PeriodDayTo: outputTimeStr, }) if err != nil { return err } if ruleId == 0 { return fmt.Errorf("分配套餐失败") } err = s.gateWayGroupRep.EditGatewayGroup(ctx, &model.GatewayGroup{ Id: gatewayGroupId, HostId: req.HostId, }) if err != nil { return err } expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt) if err != nil { return err } // 如果存在实例,恢复 oldData,err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId)) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } if oldData!= nil && oldData.Id != 0 { err = s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{ HostId: req.HostId, Uid: req.Uid, Name: require.GlobalLimitName, RuleId: int(ruleId), GroupId: int(groupId), GatewayGroupId: gatewayGroupId, CdnUid: int(userId), Comment: req.Comment, ExpiredAt: expiredAt, State: true, }) if err != nil { return err } return nil } // 如果不存在实例,创建 err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{ HostId: req.HostId, Uid: req.Uid, Name: require.GlobalLimitName, RuleId: int(ruleId), GroupId: int(groupId), GatewayGroupId: gatewayGroupId, CdnUid: int(userId), Comment: req.Comment, State: true, ExpiredAt: expiredAt, }) if err != nil { return err } return nil } func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { require, err := s.GlobalLimitRequire(ctx, req) if err != nil { return err } data, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId)) if err != nil { return err } outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt) if err != nil { return err } err = s.cdnService.RenewPlan(ctx, v1.RenewalPlan{ UserPlanId: int64(data.RuleId), DayTo: outputTimeStr, Period: "monthly", CountPeriod: 1, IsFree: true, PeriodDayTo: outputTimeStr, }) if err != nil { return err } expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt) if err != nil { return err } if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{ HostId: req.HostId, Comment: req.Comment, ExpiredAt: expiredAt, }); err != nil { return err } return nil } func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { // 检查是否过期 isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId)) if err != nil { return err } if isExpired { return fmt.Errorf("实例未过期,无法删除") } oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId)) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("实例不存在") } return err } tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId) if err != nil { return err } udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId) if err != nil { return err } webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId) if err != nil { return err } BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId)) if err != nil { return err } // 删除网站 g, gCtx := errgroup.WithContext(ctx) g.Go(func() error { e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{ Ids: tcpIds, Uid: req.Uid, HostId: req.HostId, }) if e != nil { return fmt.Errorf("删除TCP转发失败: %w", e) } return nil }) g.Go(func() error { e := s.udpForWarding.DeleteUdpForwarding(ctx,udpIds) if e != nil { return fmt.Errorf("删除UDP转发失败: %w", e) } return nil }) g.Go(func() error { e := s.webForWarding.DeleteWebForwarding(ctx,webIds) if e != nil { return fmt.Errorf("删除WEB转发失败: %w", e) } return nil }) // 删除套餐 g.Go(func() error { e := s.cdnService.DelUserPlan(gCtx, int64(oldData.RuleId)) if e != nil { return fmt.Errorf("删除套餐失败: %w", e) } return nil }) // 删除网站分组 g.Go(func() error { e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId)) if e != nil { return fmt.Errorf("删除网站分组失败: %w", e) } return nil }) if err = g.Wait(); err != nil { return err } if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), false); err != nil { return err } if err := s.gateWayGroupRep.EditGatewayGroup(ctx,&model.GatewayGroup{ Id: oldData.GatewayGroupId, HostId: 0, }); err != nil { return err } // 删除黑白名单 err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{ HostId: req.HostId, Ids: BwIds, Uid: req.Uid, }) if err != nil { return err } return nil }