package waf import ( "context" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" wafRep "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf" "github.com/go-nunu/nunu-layout-advanced/internal/service" "strconv" "strings" "sync" ) type ZzybgpService interface { SetDefense(ctx context.Context, hostId int64,defense int) error } func NewZzybgpService( service *service.Service, gatewayIpRep wafRep.GatewayipRepository, host service.HostService, aoDun service.AoDunService, ) ZzybgpService { return &zzybgpService{ Service: service, gatewayIpRep: gatewayIpRep, host: host, aoDun: aoDun, } } type zzybgpService struct { *service.Service gatewayIpRep wafRep.GatewayipRepository host service.HostService aoDun service.AoDunService } func (s *zzybgpService) SetDefense(ctx context.Context, hostId int64, defense int) error { ips, err := s.gatewayIpRep.GetGatewayipOnlyIpByHostIdAll(ctx, hostId) if err != nil { return fmt.Errorf("通过hostId获取IP列表失败: %w", err) } if len(ips) == 0 { return nil } // 2.【修复BUG】使用一个明确的变量来存储最终生效的防御值 effectiveDefense := defense // 如果传入的防御值为0,则从全局配置中获取 if effectiveDefense == 0 { config, err := s.host.GetGlobalLimitConfig(ctx, int(hostId)) if err != nil { return fmt.Errorf("获取全局限制配置失败: %w", err) } // 从配置中解析防御值 defenseStr := strings.TrimSuffix(config.ConfigMaxProtection, "G") defenseInt, err := strconv.Atoi(defenseStr) if err != nil { // 如果解析失败,返回包含原始值的错误信息,便于排查 return fmt.Errorf("解析防御配置 '%s' 失败: %w", config.ConfigMaxProtection, err) } // 如果配置的防御值也为0,则无需操作 if defenseInt == 0 { return nil } // 将从配置中获取的值赋给 effectiveDefense effectiveDefense = defenseInt } var wg sync.WaitGroup // 创建一个足够大的 channel 来收集所有可能发生的错误 errChan := make(chan error, len(ips)) wg.Add(len(ips)) for _, ip := range ips { go func(ipAddr string) { defer wg.Done() e := s.aoDun.SetDefense(ctx, v1.SetDefense{ IpAddr: ipAddr, Defense: effectiveDefense, }) if e != nil { // 3.【优化】为错误添加IP地址信息,便于定位问题 errChan <- fmt.Errorf("IP [%s] 设置防御带宽失败: %w", ipAddr, e) } }(ip) } wg.Wait() close(errChan) var allErrors []error for e := range errChan { allErrors = append(allErrors, e) } if len(allErrors) > 0 { // 将多个错误信息拼接成一个字符串 var errStrings []string for _, singleErr := range allErrors { errStrings = append(errStrings, singleErr.Error()) } return fmt.Errorf("设置防御时发生多个错误: %s", strings.Join(errStrings, "; ")) } return nil }