package service import ( "context" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" "github.com/spf13/cast" "github.com/spf13/viper" "golang.org/x/sync/errgroup" "strconv" ) type GlobalLimitService interface { GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error } func NewGlobalLimitService( service *Service, globalLimitRepository repository.GlobalLimitRepository, duedate DuedateService, crawler CrawlerService, conf *viper.Viper, required RequiredService, parser ParserService, host HostService, tcpLimit TcpLimitService, udpLimit UdpLimitService, webLimit WebLimitService, gateWayGroup GatewayGroupService, hostRep repository.HostRepository, ) GlobalLimitService { return &globalLimitService{ Service: service, globalLimitRepository: globalLimitRepository, duedate: duedate, crawler: crawler, Url: conf.GetString("crawler.Url"), required: required, parser: parser, host: host, tcpLimit: tcpLimit, udpLimit: udpLimit, webLimit: webLimit, gateWayGroup: gateWayGroup, hostRep: hostRep, } } type globalLimitService struct { *Service globalLimitRepository repository.GlobalLimitRepository duedate DuedateService crawler CrawlerService Url string required RequiredService parser ParserService host HostService tcpLimit TcpLimitService udpLimit UdpLimitService webLimit WebLimitService gateWayGroup GatewayGroupService hostRep repository.HostRepository } func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) { isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId)) if err != nil { return v1.GlobalLimitRequireResponse{}, err } if isExist { return v1.GlobalLimitRequireResponse{}, fmt.Errorf("配置限制已存在") } res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, err } configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err) } res.Bps = configCount.Bps res.MaxBytesMonth = configCount.MaxBytesMonth domain, err := s.hostRep.GetDomainById(ctx, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, err } res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + strconv.Itoa(req.HostId) + "_" + domain return res, nil } func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) { return s.globalLimitRepository.GetGlobalLimit(ctx, id) } func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { require, err := s.GlobalLimitRequire(ctx, req) if err != nil { return err } formData := map[string]interface{}{ "tag": require.GlobalLimitName, "bps": require.Bps, "max_bytes_month": require.MaxBytesMonth, "expired_at": require.ExpiredAt, } respBody, err := s.required.SendForm(ctx, "admin/info/waf_common_limit/new", "admin/new/waf_common_limit", formData) if err != nil { return err } ruleIdBase, err := s.parser.GetRuleIdByColumnName(ctx, respBody, require.GlobalLimitName) if err != nil { return err } if ruleIdBase == "" { res, err := s.parser.ParseAlert(string(respBody)) if err != nil { return err } return fmt.Errorf(res) } ruleId, err := cast.ToIntE(ruleIdBase) if err != nil { return err } var tcpLimitRuleId, udpLimitRuleId, webLimitRuleId int g, gCtx := errgroup.WithContext(ctx) // 启动tcpLimit调用 - 使用独立的请求参数副本 g.Go(func() error { tcpLimitReq := &v1.GeneralLimitRequireRequest{ Tag: require.GlobalLimitName, HostId: req.HostId, RuleId: ruleId, Uid: req.Uid, } result, e := s.tcpLimit.AddTcpLimit(gCtx, tcpLimitReq) if e != nil { return fmt.Errorf("tcpLimit调用失败: %w", e) } if result != 0 { tcpLimitRuleId = result return nil } return fmt.Errorf("tcpLimit调用失败,Id为 %d", result) }) // 启动udpLimit调用 - 使用独立的请求参数副本 g.Go(func() error { udpLimitReq := &v1.GeneralLimitRequireRequest{ Tag: require.GlobalLimitName, HostId: req.HostId, RuleId: ruleId, Uid: req.Uid, } result, e := s.udpLimit.AddUdpLimit(gCtx, udpLimitReq) if e != nil { return fmt.Errorf("udpLimit调用失败: %w", e) } if result != 0 { udpLimitRuleId = result return nil } return fmt.Errorf("udpLimit调用失败,Id为 %d", result) }) // 启动webLimit调用 - 使用独立的请求参数副本 g.Go(func() error { webLimitReq := &v1.GeneralLimitRequireRequest{ Tag: require.GlobalLimitName, HostId: req.HostId, RuleId: ruleId, Uid: req.Uid, } result, e := s.webLimit.AddWebLimit(gCtx, webLimitReq) if e != nil { return fmt.Errorf("webLimit调用失败: %w", e) } if result != 0 { webLimitRuleId = result return nil } return fmt.Errorf("webLimit调用失败,Id为 %d", result) }) if err := g.Wait(); err != nil { return err } err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{ HostId: req.HostId, RuleId: cast.ToInt(ruleId), GlobalLimitName: require.GlobalLimitName, Comment: req.Comment, TcpLimitRuleId: tcpLimitRuleId, UdpLimitRuleId: udpLimitRuleId, WebLimitRuleId: webLimitRuleId, GatewayGroupId: 5,// TODO: 临时写死 }) if err != nil { return err } return nil } func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{ HostId: req.HostId, Comment: req.Comment, }); err != nil { return err } return nil } func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { if err := s.globalLimitRepository.DeleteGlobalLimitByHostId(ctx, int64(req.HostId)); err != nil { return err } return nil }