package service import ( "context" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" "github.com/spf13/cast" "github.com/spf13/viper" "golang.org/x/sync/errgroup" "strconv" "time" ) type GlobalLimitService interface { GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error EditGlobalLimitBySnail(ctx context.Context, req v1.GlobalLimitEditRequest) error } func NewGlobalLimitService( service *Service, globalLimitRepository repository.GlobalLimitRepository, duedate DuedateService, crawler CrawlerService, conf *viper.Viper, required RequiredService, parser ParserService, host HostService, tcpLimit TcpLimitService, udpLimit UdpLimitService, webLimit WebLimitService, gateWayGroup GatewayGroupService, hostRep repository.HostRepository, gateWayGroupRep repository.GatewayGroupRepository, ) GlobalLimitService { return &globalLimitService{ Service: service, globalLimitRepository: globalLimitRepository, duedate: duedate, crawler: crawler, Url: conf.GetString("crawler.Url"), required: required, parser: parser, host: host, tcpLimit: tcpLimit, udpLimit: udpLimit, webLimit: webLimit, gateWayGroup: gateWayGroup, hostRep: hostRep, gateWayGroupRep: gateWayGroupRep, } } type globalLimitService struct { *Service globalLimitRepository repository.GlobalLimitRepository duedate DuedateService crawler CrawlerService Url string required RequiredService parser ParserService host HostService tcpLimit TcpLimitService udpLimit UdpLimitService webLimit WebLimitService gateWayGroup GatewayGroupService hostRep repository.HostRepository gateWayGroupRep repository.GatewayGroupRepository } func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) { res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, err } configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err) } bpsInt, err := strconv.Atoi(configCount.Bps) if err != nil { return v1.GlobalLimitRequireResponse{}, err } resultFloat := float64(bpsInt) / 2.0 / 8.0 res.Bps = strconv.FormatFloat( resultFloat, 'f', -1, 64) + "M" res.MaxBytesMonth = configCount.MaxBytesMonth res.Operator = configCount.Operator res.IpCount = configCount.IpCount domain, err := s.hostRep.GetDomainById(ctx, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, err } res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + strconv.Itoa(req.HostId) + "_" + domain return res, nil } func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) { return s.globalLimitRepository.GetGlobalLimit(ctx, id) } func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId)) if err != nil { return err } if isExist { return fmt.Errorf("配置限制已存在") } require, err := s.GlobalLimitRequire(ctx, req) if err != nil { return err } gatewayGroupId, err := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(ctx, require.Operator, require.IpCount) if err != nil { return err } formData := map[string]interface{}{ "tag": require.GlobalLimitName, "bps": require.Bps, "max_bytes_month": require.MaxBytesMonth, "expired_at": require.ExpiredAt, } respBody, err := s.required.SendForm(ctx, "admin/info/waf_common_limit/new", "admin/new/waf_common_limit", formData) if err != nil { return err } ruleIdBase, err := s.parser.GetRuleIdByColumnName(ctx, respBody, require.GlobalLimitName) if err != nil { return err } if ruleIdBase == "" { res, err := s.parser.ParseAlert(string(respBody)) if err != nil { return err } return fmt.Errorf(res) } ruleId, err := cast.ToIntE(ruleIdBase) if err != nil { return err } var tcpLimitRuleId, udpLimitRuleId, webLimitRuleId int g, gCtx := errgroup.WithContext(ctx) // 启动tcpLimit调用 - 使用独立的请求参数副本 g.Go(func() error { tcpLimitReq := &v1.GeneralLimitRequireRequest{ Tag: require.GlobalLimitName, HostId: req.HostId, RuleId: ruleId, Uid: req.Uid, } result, e := s.tcpLimit.AddTcpLimit(gCtx, tcpLimitReq) if e != nil { return fmt.Errorf("tcpLimit调用失败: %w", e) } if result != 0 { tcpLimitRuleId = result return nil } return fmt.Errorf("tcpLimit调用失败,Id为 %d", result) }) // 启动udpLimit调用 - 使用独立的请求参数副本 g.Go(func() error { udpLimitReq := &v1.GeneralLimitRequireRequest{ Tag: require.GlobalLimitName, HostId: req.HostId, RuleId: ruleId, Uid: req.Uid, } result, e := s.udpLimit.AddUdpLimit(gCtx, udpLimitReq) if e != nil { return fmt.Errorf("udpLimit调用失败: %w", e) } if result != 0 { udpLimitRuleId = result return nil } return fmt.Errorf("udpLimit调用失败,Id为 %d", result) }) // 启动webLimit调用 - 使用独立的请求参数副本 g.Go(func() error { webLimitReq := &v1.GeneralLimitRequireRequest{ Tag: require.GlobalLimitName, HostId: req.HostId, RuleId: ruleId, Uid: req.Uid, } result, e := s.webLimit.AddWebLimit(gCtx, webLimitReq) if e != nil { return fmt.Errorf("webLimit调用失败: %w", e) } if result != 0 { webLimitRuleId = result return nil } return fmt.Errorf("webLimit调用失败,Id为 %d", result) }) if err := g.Wait(); err != nil { return err } t, err := time.Parse("2006-01-02 15:04:05", require.ExpiredAt) if err != nil { return err } expiredAt := t.Unix() err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{ HostId: req.HostId, RuleId: cast.ToInt(ruleId), Uid: req.Uid, GlobalLimitName: require.GlobalLimitName, Comment: req.Comment, TcpLimitRuleId: tcpLimitRuleId, UdpLimitRuleId: udpLimitRuleId, WebLimitRuleId: webLimitRuleId, GatewayGroupId: gatewayGroupId, ExpiredAt: expiredAt, }) if err != nil { return err } err = s.gateWayGroupRep.EditGatewayGroup(ctx, &model.GatewayGroup{ RuleId: gatewayGroupId, HostId: req.HostId, }) return nil } func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { require, err := s.GlobalLimitRequire(ctx, req) if err != nil { return err } formData := map[string]interface{}{ "tag": require.GlobalLimitName, "bps": require.Bps, "max_bytes_month": require.MaxBytesMonth, "expired_at": require.ExpiredAt, } data, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId)) if err != nil { return err } respBody, err := s.required.SendForm(ctx, "admin/info/waf_common_limit/edit?&__goadmin_edit_pk="+strconv.Itoa(data.RuleId), "admin/edit/waf_common_limit", formData) if err != nil { return err } res, err := s.parser.ParseAlert(string(respBody)) if err != nil { return err } if res != "" { return fmt.Errorf(res) } t, err := time.Parse("2006-01-02 15:04:05", require.ExpiredAt) if err != nil { return err } expiredAt := t.Unix() if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{ HostId: req.HostId, Comment: req.Comment, ExpiredAt: expiredAt, }); err != nil { return err } return nil } func (s *globalLimitService) EditGlobalLimitBySnail(ctx context.Context, req v1.GlobalLimitEditRequest) error { configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId) if err != nil { return fmt.Errorf("获取配置限制失败: %w", err) } data, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId)) if err != nil { return err } t := time.Unix(req.ExpiredAt, 0) expiredAt := t.Format("2006-01-02 15:04:05") formData := map[string]interface{}{ "tag": data.GlobalLimitName, "bps": configCount.Bps, "max_bytes_month": configCount.MaxBytesMonth, "expired_at": expiredAt, } respBody, err := s.required.SendForm(ctx, "admin/info/waf_common_limit/edit?&__goadmin_edit_pk="+strconv.Itoa(req.RuleId), "admin/edit/waf_common_limit", formData) if err != nil { return err } if respBody == nil { return nil } return nil } func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { if err := s.globalLimitRepository.DeleteGlobalLimitByHostId(ctx, int64(req.HostId)); err != nil { return err } return nil }