package waf import ( "context" "errors" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" flexCdn2 "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/flexCdn" "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf" "github.com/go-nunu/nunu-layout-advanced/internal/service" "github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn" "github.com/mozillazg/go-pinyin" "github.com/spf13/viper" "golang.org/x/sync/errgroup" "gorm.io/gorm" "strconv" "strings" "time" ) type GlobalLimitService interface { GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) } func NewGlobalLimitService( service *service.Service, globalLimitRepository waf.GlobalLimitRepository, duedate service.DuedateService, crawler service.CrawlerService, conf *viper.Viper, required service.RequiredService, parser service.ParserService, host service.HostService, hostRep repository.HostRepository, cdnService flexCdn.CdnService, cdnRep flexCdn2.CdnRepository, tcpforwardingRep waf.TcpforwardingRepository, udpForWardingRep waf.UdpForWardingRepository, webForWardingRep waf.WebForwardingRepository, allowAndDeny AllowAndDenyIpService, allowAndDenyRep waf.AllowAndDenyIpRepository, tcpforwarding TcpforwardingService, udpForWarding UdpForWardingService, webForWarding WebForwardingService, gatewayIpRep waf.GatewayipRepository, gatywayIp GatewayipService, bulidAudun BuildAudunService, zzyBgp ZzybgpService, ) GlobalLimitService { return &globalLimitService{ Service: service, globalLimitRepository: globalLimitRepository, duedate: duedate, crawler: crawler, Url: conf.GetString("crawler.Url"), required: required, parser: parser, host: host, hostRep: hostRep, cdnService: cdnService, cdnRep: cdnRep, tcpforwardingRep: tcpforwardingRep, udpForWardingRep: udpForWardingRep, webForWardingRep: webForWardingRep, allowAndDeny: allowAndDeny, allowAndDenyRep: allowAndDenyRep, tcpforwarding: tcpforwarding, udpForWarding: udpForWarding, webForWarding: webForWarding, gatewayIpRep: gatewayIpRep, gatewayIp: gatywayIp, bulidAudun: bulidAudun, zzyBgp: zzyBgp, } } type globalLimitService struct { *service.Service globalLimitRepository waf.GlobalLimitRepository duedate service.DuedateService crawler service.CrawlerService Url string required service.RequiredService parser service.ParserService host service.HostService hostRep repository.HostRepository cdnService flexCdn.CdnService cdnRep flexCdn2.CdnRepository tcpforwardingRep waf.TcpforwardingRepository udpForWardingRep waf.UdpForWardingRepository webForWardingRep waf.WebForwardingRepository allowAndDeny AllowAndDenyIpService allowAndDenyRep waf.AllowAndDenyIpRepository tcpforwarding TcpforwardingService udpForWarding UdpForWardingService webForWarding WebForwardingService gatewayIpRep waf.GatewayipRepository gatewayIp GatewayipService bulidAudun BuildAudunService zzyBgp ZzybgpService } func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) { data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid) if err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return 0, err } } if data != nil && data.CdnUid != 0 { return int64(data.CdnUid), nil } userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid) if err != nil { return 0, err } // 中文转拼音 a := pinyin.NewArgs() a.Style = pinyin.Normal pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a) userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_") // 查询用户是否存在 UserId,err := s.cdnRep.GetUserId(ctx, userName) if err != nil { return 0, err } if UserId != 0 { return UserId, nil } // 注册用户 userId, err := s.cdnService.AddUser(ctx, v1.User{ Username: userName, Email: userInfo.Email, Fullname: userInfo.Username, Mobile: userInfo.PhoneNumber, }) if err != nil { return 0, err } return userId, nil } func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) { groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{ Name: groupName, }) if err != nil { return 0, err } return groupId, nil } func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) { res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, err } configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err) } res.MaxBytesMonth = configCount.MaxBytesMonth res.Operator = configCount.Operator res.IpCount = configCount.IpCount res.NodeArea = configCount.NodeArea res.ConfigMaxProtection = configCount.ConfigMaxProtection res.IsBanUdp = configCount.IsBanUdp res.HostId = req.HostId res.Bps = configCount.Bps domain, err := s.hostRep.GetDomainById(ctx, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, err } userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid)) if err != nil { return v1.GlobalLimitRequireResponse{}, err } res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId)) if err != nil { return v1.GlobalLimitRequireResponse{}, err } return res, nil } func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) { return s.globalLimitRepository.GetGlobalLimit(ctx, id) } func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) { // 2. 将字符串解析成 time.Time 对象 // time.Parse 会根据你提供的布局来理解输入的字符串 t, err := time.Parse("2006-01-02 15:04:05", req) if err != nil { // 如果输入的字符串格式和布局不匹配,这里会报错 return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err) } // 3. 定义新的输出格式 "YYYY-MM-DD" outputLayout := "2006-01-02" // 4. 将 time.Time 对象格式化为新的字符串 outputTimeStr := t.Format(outputLayout) return outputTimeStr, nil } func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) { t, err := time.Parse("2006-01-02 15:04:05", req) if err != nil { return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err) } expiredAt := t.Unix() return expiredAt, nil } func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId)) if err != nil { return err } if isExist { return fmt.Errorf("配置限制已存在") } require, err := s.GlobalLimitRequire(ctx, req) if err != nil { return err } g, gCtx := errgroup.WithContext(ctx) var userId int64 var groupId int64 g.Go(func() error { e := s.gatewayIp.AddIpWhereHostIdNull(gCtx, int64(req.HostId),int64(req.Uid)) if e != nil { return fmt.Errorf("获取网关组失败: %w", e) } return nil }) g.Go(func() error { res, e := s.GetCdnUserId(gCtx, int64(req.Uid)) if e != nil { return fmt.Errorf("获取cdn用户失败: %w", e) } if res == 0 { return fmt.Errorf("获取cdn用户失败") } userId = res return nil }) g.Go(func() error { res, e := s.AddGroupId(gCtx, require.GlobalLimitName) if e != nil { return fmt.Errorf("创建规则分组失败: %w", e) } if res == 0 { return fmt.Errorf("创建规则分组失败") } groupId = res return nil }) if err = g.Wait(); err != nil { return err } // 添加防护 err = s.zzyBgp.SetDefense(ctx, int64(req.HostId), 0) if err != nil { return err } // 添加带宽限制 err = s.bulidAudun.Bandwidth(ctx, int64(req.HostId), "add") if err != nil { return err } expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt) if err != nil { return err } // 如果存在实例,恢复 oldData,err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId)) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } if oldData!= nil && oldData.Id != 0 { err = s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{ HostId: req.HostId, Uid: req.Uid, Name: require.GlobalLimitName, GroupId: int(groupId), CdnUid: int(userId), Comment: req.Comment, ExpiredAt: expiredAt, State: true, }) if err != nil { return err } return nil } // 如果不存在实例,创建 err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{ HostId: req.HostId, Uid: req.Uid, Name: require.GlobalLimitName, GroupId: int(groupId), CdnUid: int(userId), Comment: req.Comment, State: true, ExpiredAt: expiredAt, }) if err != nil { return err } return nil } func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { require, err := s.GlobalLimitRequire(ctx, req) if err != nil { return err } // 如果不存在实例,创建 gatewayIp, err := s.gatewayIpRep.GetGatewayipByHostIdAll(ctx, int64(req.HostId)) if err != nil { return err } if gatewayIp == nil { err = s.gatewayIp.AddIpWhereHostIdNull(ctx, int64(req.HostId), int64(req.Uid)) if err != nil { return fmt.Errorf("获取网关组失败: %w", err) } } expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt) if err != nil { return err } if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{ HostId: req.HostId, Comment: req.Comment, ExpiredAt: expiredAt, }); err != nil { return err } return nil } func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { // 检查是否过期 isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId)) if err != nil { return err } if isExpired { return fmt.Errorf("实例未过期,无法删除") } oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId)) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("实例不存在") } return err } tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId) if err != nil { return err } udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId) if err != nil { return err } webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId) if err != nil { return err } // 重置防护 err = s.zzyBgp.SetDefense(ctx, int64(req.HostId), 10) if err != nil { return err } // 删除带宽限制 err = s.bulidAudun.Bandwidth(ctx, int64(req.HostId), "del") if err != nil { return err } // 黑白IP BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId)) if err != nil { return err } // 删除网站 g, gCtx := errgroup.WithContext(ctx) g.Go(func() error { e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{ Ids: tcpIds, Uid: req.Uid, HostId: req.HostId, }) if e != nil { return fmt.Errorf("删除TCP转发失败: %w", e) } return nil }) g.Go(func() error { e := s.udpForWarding.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{ Ids: udpIds, Uid: req.Uid, HostId: req.HostId, }) if e != nil { return fmt.Errorf("删除UDP转发失败: %w", e) } return nil }) g.Go(func() error { e := s.webForWarding.DeleteWebForwarding(ctx,v1.DeleteWebForwardingRequest{ Ids: webIds, Uid: req.Uid, HostId: req.HostId, }) if e != nil { return fmt.Errorf("删除WEB转发失败: %w", e) } return nil }) // 删除网站分组 g.Go(func() error { e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId)) if e != nil { return fmt.Errorf("删除网站分组失败: %w", e) } return nil }) if err = g.Wait(); err != nil { return err } if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), false); err != nil { return err } if err := s.gatewayIpRep.CleanIPByHostId(ctx, []int64{int64(req.HostId)}); err != nil { return err } // 删除黑白名单 err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{ HostId: req.HostId, Ids: BwIds, Uid: req.Uid, }) if err != nil { return err } return nil }