package service import ( "context" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" "github.com/spf13/cast" "github.com/spf13/viper" "strconv" "sync" "github.com/sourcegraph/conc" ) type GlobalLimitService interface { GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error } func NewGlobalLimitService( service *Service, globalLimitRepository repository.GlobalLimitRepository, duedate DuedateService, crawler CrawlerService, conf *viper.Viper, required RequiredService, parser ParserService, host HostService, tcpLimit TcpLimitService, udpLimit UdpLimitService, webLimit WebLimitService, gateWayGroup GatewayGroupService, hostRep repository.HostRepository, ) GlobalLimitService { return &globalLimitService{ Service: service, globalLimitRepository: globalLimitRepository, duedate: duedate, crawler: crawler, Url: conf.GetString("crawler.Url"), required: required, parser: parser, host: host, tcpLimit: tcpLimit, udpLimit: udpLimit, webLimit: webLimit, gateWayGroup: gateWayGroup, hostRep: hostRep, } } type globalLimitService struct { *Service globalLimitRepository repository.GlobalLimitRepository duedate DuedateService crawler CrawlerService Url string required RequiredService parser ParserService host HostService tcpLimit TcpLimitService udpLimit UdpLimitService webLimit WebLimitService gateWayGroup GatewayGroupService hostRep repository.HostRepository } func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) { isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId)) if err != nil { return v1.GlobalLimitRequireResponse{}, err } if isExist { return v1.GlobalLimitRequireResponse{}, fmt.Errorf("配置限制已存在") } res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, err } configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err) } res.Bps = configCount.Bps res.MaxBytesMonth = configCount.MaxBytesMonth domain, err := s.hostRep.GetDomainById(ctx, req.HostId) if err != nil { return v1.GlobalLimitRequireResponse{}, err } res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + strconv.Itoa(req.HostId) + "_" + domain return res, nil } func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) { return s.globalLimitRepository.GetGlobalLimit(ctx, id) } func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { require, err := s.GlobalLimitRequire(ctx, req) if err != nil { return err } formData := map[string]interface{}{ "tag": require.GlobalLimitName, "bps": require.Bps, "max_bytes_month": require.MaxBytesMonth, "expired_at": require.ExpiredAt, } respBody, err := s.required.SendForm(ctx, "admin/info/waf_common_limit/new", "admin/new/waf_common_limit", formData) if err != nil { return err } ruleIdBase, err := s.parser.GetRuleIdByColumnName(ctx, respBody, require.GlobalLimitName) if err != nil { return err } if ruleIdBase == "" { res, err := s.parser.ParseAlert(string(respBody)) if err != nil { return err } return fmt.Errorf(res) } ruleId, err := cast.ToIntE(ruleIdBase) if err != nil { return err } // 使用conc库并发执行API调用 var tcpLimitRuleId, udpLimitRuleId, webLimitRuleId int var mu sync.Mutex // 用于保护共享变量 // 为每个并发调用创建独立的请求参数(深拷贝) // 避免共享同一个指针可能导致的数据竞争 // 创建一个WaitGroup来协调多个并发任务 wg := conc.NewWaitGroup() // 启动tcpLimit调用 - 使用独立的请求参数副本 wg.Go(func() { // 为该goroutine创建独立的请求参数副本 tcpLimitReq := &v1.GeneralLimitRequireRequest{ Tag: require.GlobalLimitName, HostId: req.HostId, RuleId: ruleId, Uid: req.Uid, } result, e := s.tcpLimit.AddTcpLimit(ctx, tcpLimitReq) if e != nil { // 只在修改共享的错误变量时加锁 mu.Lock() err = e mu.Unlock() } else { // 不需要加锁,因为tcpLimitRuleId只被这一个goroutine修改 tcpLimitRuleId = result } }) // 启动udpLimit调用 - 使用独立的请求参数副本 wg.Go(func() { // 为该goroutine创建独立的请求参数副本 udpLimitReq := &v1.GeneralLimitRequireRequest{ Tag: require.GlobalLimitName, HostId: req.HostId, RuleId: ruleId, Uid: req.Uid, } result, e := s.udpLimit.AddUdpLimit(ctx, udpLimitReq) if e != nil { // 只在修改共享的错误变量时加锁 mu.Lock() err = e mu.Unlock() } else { // 不需要加锁,因为udpLimitRuleId只被这一个goroutine修改 udpLimitRuleId = result } }) // 启动webLimit调用 - 使用独立的请求参数副本 wg.Go(func() { // 为该goroutine创建独立的请求参数副本 webLimitReq := &v1.GeneralLimitRequireRequest{ Tag: require.GlobalLimitName, HostId: req.HostId, RuleId: ruleId, Uid: req.Uid, } result, e := s.webLimit.AddWebLimit(ctx, webLimitReq) if e != nil { // 只在修改共享的错误变量时加锁 mu.Lock() err = e mu.Unlock() } else { // 不需要加锁,因为webLimitRuleId只被这一个goroutine修改 webLimitRuleId = result } }) // 等待所有调用完成 wg.Wait() // 检查是否有错误发生 if err != nil { return err } err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{ HostId: req.HostId, RuleId: cast.ToInt(ruleId), GlobalLimitName: require.GlobalLimitName, Comment: req.Comment, TcpLimitRuleId: tcpLimitRuleId, UdpLimitRuleId: udpLimitRuleId, WebLimitRuleId: webLimitRuleId, GatewayGroupId: 5, }) if err != nil { return err } return nil } func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{ HostId: req.HostId, Comment: req.Comment, }); err != nil { return err } return nil } func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error { if err := s.globalLimitRepository.DeleteGlobalLimitByHostId(ctx, int64(req.HostId)); err != nil { return err } return nil }