package waf import ( "context" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" wafRep "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf" "github.com/go-nunu/nunu-layout-advanced/internal/service" "github.com/hashicorp/go-multierror" "strconv" "sync" ) type BuildAudunService interface { AddBandwidth(ctx context.Context, req v1.Bandwidth) error DelBandwidth(ctx context.Context, req v1.Bandwidth) error Bandwidth(ctx context.Context,hostId int64, action string) error } func NewBuildAudunService( service *service.Service, audun service.AoDunService, gatewayIpRep wafRep.GatewayipRepository, host service.HostService, ) BuildAudunService { return &buildAudunService{ Service: service, audun: audun, gatewayIpRep: gatewayIpRep, host: host, } } type buildAudunService struct { *service.Service audun service.AoDunService gatewayIpRep wafRep.GatewayipRepository host service.HostService } func (s *buildAudunService) BuildName(ip string, bandwidth string, apiName string) string { return apiName + ip + "限速" + bandwidth + "M" } func (s *buildAudunService) AddBandwidth(ctx context.Context, req v1.Bandwidth) error { err := s.audun.AddBandwidthLimit(ctx, v1.Bandwidth{ Action: "limit", ClientIPType: "all", Direction: "out", Name: s.BuildName(req.ServerIPStart, strconv.FormatInt(req.SpeedlimitOut, 10), ""), Protocol: 0, ServerIPStart: req.ServerIPStart, ServerIPType: "single", SpeedlimitOut: req.SpeedlimitOut, }) if err != nil { return err } return nil } func (s *buildAudunService) DelBandwidth(ctx context.Context, req v1.Bandwidth) error { err := s.audun.DelBandwidthLimit(ctx, v1.Bandwidth{ Name: s.BuildName(req.ServerIPStart, strconv.FormatInt(req.SpeedlimitOut, 10), "KFW-API-RESTAPI-"), }) if err != nil { return err } return nil } func (s *buildAudunService) Bandwidth(ctx context.Context,hostId int64, action string) error { ips, err := s.gatewayIpRep.GetGatewayipOnlyIpByHostIdAll(ctx, hostId) if err != nil { return err } if len(ips) == 0 { return nil } config, err := s.host.GetGlobalLimitConfig(ctx, int(hostId)) if err != nil { return err } bpsInt, err := strconv.Atoi(config.Bps) if err != nil { return err } var errChan = make(chan error, len(ips)) var wg sync.WaitGroup var allErrors error wg.Add(len(ips)) for _, ip := range ips { go func(ip string) { var e error defer wg.Done() switch action { case "add": e = s.AddBandwidth(ctx, v1.Bandwidth{ ServerIPStart: ip, SpeedlimitOut: int64(bpsInt), }) case "del": e = s.DelBandwidth(ctx,v1.Bandwidth{ ServerIPStart: ip, SpeedlimitOut: int64(bpsInt), }) default: e = fmt.Errorf("未知操作") } if e != nil { errChan <- fmt.Errorf("清除ip %s失败: %w", ip, e) } }(ip) } wg.Wait() close(errChan) for err := range errChan { allErrors = multierror.Append(allErrors, err) } if allErrors != nil { return allErrors } return nil }