package job import ( "context" "encoding/json" "fmt" "github.com/go-nunu/nunu-layout-advanced/internal/service" "github.com/google/uuid" "github.com/rabbitmq/amqp091-go" "go.uber.org/zap" "strconv" "strings" "sync" ) // taskHandler 定义了处理单个消息的函数签名 // 它负责业务逻辑的执行,并返回一个 error 来告知调用者处理是否成功。 type taskHandler func(ctx context.Context, logger *zap.Logger, delivery amqp091.Delivery) error // WhitelistJob 定义了处理白名单相关任务的接口 type WhitelistJob interface { // DomainConsumer 启动消费者,处理域名白名单任务 DomainConsumer(ctx context.Context) // IpConsumer 启动消费者,处理 IP 白名单任务 IpConsumer(ctx context.Context) } // NewWhitelistJob 创建一个新的 WhitelistJob func NewWhitelistJob(job *Job, aoDunService service.AoDunService, wafForMatter service.WafFormatterService, ) WhitelistJob { return &whitelistJob{ Job: job, aoDunService: aoDunService, wafForMatter: wafForMatter, } } type whitelistJob struct { *Job aoDunService service.AoDunService wafForMatter service.WafFormatterService } // DomainConsumer 启动域名白名单消费者 func (j *whitelistJob) DomainConsumer(ctx context.Context) { j.consume(ctx, "domain_whitelist", "domain_whitelist_consumer", j.handleDomainMessage) } // IpConsumer 启动IP白名单消费者 func (j *whitelistJob) IpConsumer(ctx context.Context) { j.consume(ctx, "ip_white", "ip_white_consumer", j.handleIpMessage) } // consume 是一个通用的 RabbitMQ 消费者方法,封装了重复的逻辑 func (j *whitelistJob) consume(ctx context.Context, taskName, consumerName string, handler taskHandler) { taskCfg, ok := j.Rabbitmq.GetTaskConfig(taskName) if !ok { j.logger.Error(fmt.Sprintf("未找到任务 '%s' 的配置", taskName)) return } j.logger.Info("正在启动消费者...", zap.String("task", taskName), zap.String("queue", taskCfg.Queue), zap.String("consumer", consumerName), ) msgs, err := j.Rabbitmq.Consume(taskCfg.Queue, consumerName, taskCfg.PrefetchCount) if err != nil { j.logger.Error("启动消费者失败", zap.String("task", taskName), zap.Error(err)) return } for { select { case <-ctx.Done(): j.logger.Info("消费者正在关闭...", zap.String("task", taskName)) return case d, ok := <-msgs: if !ok { j.logger.Warn("消息通道已关闭,消费者退出。", zap.String("task", taskName)) return } // 尝试从消息头获取 trace_id,如果不存在则生成一个新的 traceID, ok := d.Headers["trace_id"].(string) if !ok || traceID == "" { traceID = uuid.New().String() } // 创建一个带有 trace_id 的 logger,用于本次任务的所有日志记录 scopedLogger := j.logger.With(zap.String("trace_id", traceID)) // 创建一个带有 trace_id 的 context,用于传递给下游服务 ctxWithTrace := context.WithValue(ctx, "trace_id", traceID) // 调用具体的业务处理器 processingErr := handler(ctxWithTrace, scopedLogger, d) // 根据处理结果统一进行 Ack/Nack if processingErr != nil { // 业务失败,拒绝消息并不重新入队 if err := d.Nack(false, false); err != nil { scopedLogger.Error("消息 Nack 失败", zap.Error(err), zap.String("task", taskName)) } } else { // 业务处理成功,手动发送确认 if err := d.Ack(false); err != nil { scopedLogger.Error("消息 Ack 失败", zap.Error(err), zap.String("task", taskName)) } } } } } // handleDomainMessage 是域名白名单任务的具体处理器 func (j *whitelistJob) handleDomainMessage(ctx context.Context, logger *zap.Logger, d amqp091.Delivery) error { type domainTaskPayload struct { Domain string `json:"domain"` Ip string `json:"ip"` Action string `json:"action"` // "add" or "del" } var payload domainTaskPayload if err := json.Unmarshal(d.Body, &payload); err != nil { logger.Error("解析域名白名单消息失败", zap.Error(err), zap.ByteString("body", d.Body)) return nil // 返回 nil 以避免消息重入队列,因为这是一个格式错误 } logger.Info("收到域名白名单任务", zap.String("action", payload.Action), zap.String("domain", payload.Domain), zap.String("ip", payload.Ip), zap.String("routing_key", d.RoutingKey), ) var processingErr error switch payload.Action { case "add", "del": processingErr = j.aoDunService.DomainWhiteList(ctx, payload.Domain, payload.Ip, payload.Action) default: processingErr = fmt.Errorf("unknown action: %s", payload.Action) logger.Warn("在域名白名单任务中收到未知操作", zap.String("action", payload.Action), zap.String("domain", payload.Domain)) } if processingErr != nil { logger.Error("处理域名白名单任务失败", zap.Error(processingErr), zap.String("domain", payload.Domain)) } else { logger.Info("已成功处理域名白名单任务", zap.String("action", payload.Action), zap.String("domain", payload.Domain)) } return processingErr } // handleIpMessage 是 IP 白名单任务的具体处理器 func (j *whitelistJob) handleIpMessage(ctx context.Context, logger *zap.Logger, d amqp091.Delivery) error { type ipTaskPayload struct { Ips []string `json:"ips"` Action string `json:"action"` ReturnSourceIp string `json:"return_source_ip"` } var payload ipTaskPayload if err := json.Unmarshal(d.Body, &payload); err != nil { logger.Error("解析IP白名单消息失败", zap.Error(err), zap.ByteString("body", d.Body)) return nil // 消息格式错误,不应重试 } logger.Info("收到IP白名单任务", zap.String("action", payload.Action), zap.Any("ips", payload.Ips), zap.String("routing_key", d.RoutingKey), ) var processingErr error switch payload.Action { case "add": ips, err := j.wafForMatter.AppendWafIp(ctx, payload.Ips, payload.ReturnSourceIp) if err != nil { // 如果附加IP失败,记录错误并终止 processingErr = fmt.Errorf("为WAF准备IP列表失败: %w", err) } else { var wg sync.WaitGroup errChan := make(chan error, 2) wg.Add(2) go func() { defer wg.Done() if err := j.aoDunService.AddWhiteStaticList(ctx, false, ips); err != nil { errChan <- err } }() go func() { defer wg.Done() if err := j.aoDunService.AddWhiteStaticList(ctx, true, ips); err != nil { errChan <- err } }() wg.Wait() close(errChan) var errs []string for err := range errChan { errs = append(errs, err.Error()) } if len(errs) > 0 { processingErr = fmt.Errorf("添加IP到白名单时发生错误: %s", strings.Join(errs, "; ")) } } case "del": var wg sync.WaitGroup errChan := make(chan error, len(payload.Ips)*2) deleteFromWall := func(isSmall bool, ip string) { defer wg.Done() id, err := j.aoDunService.GetWhiteStaticList(ctx, isSmall, ip) if err != nil { errChan <- fmt.Errorf("获取IP '%s' (isSmall: %t) ID失败: %w", ip, isSmall, err) return } if err := j.aoDunService.DelWhiteStaticList(ctx, isSmall, strconv.Itoa(id)); err != nil { errChan <- fmt.Errorf("删除IP '%s' (isSmall: %t, id: %d) 失败: %w", ip, isSmall, id, err) } } for _, ip := range payload.Ips { wg.Add(2) go deleteFromWall(false, ip) go deleteFromWall(true, ip) } wg.Wait() close(errChan) var errs []string for err := range errChan { logger.Error("删除IP白名单过程中发生错误", zap.Error(err)) errs = append(errs, err.Error()) } if len(errs) > 0 { processingErr = fmt.Errorf("删除IP任务中发生错误: %s", strings.Join(errs, "; ")) } default: processingErr = fmt.Errorf("unknown action: %s", payload.Action) logger.Warn("在IP白名单任务中收到未知操作", zap.String("action", payload.Action), zap.Any("ips", payload.Ips)) } if processingErr != nil { logger.Error("处理IP白名单任务失败", zap.Error(processingErr), zap.Any("ips", payload.Ips)) } else { logger.Info("已成功处理IP白名单任务", zap.String("action", payload.Action), zap.Any("ips", payload.Ips)) } return processingErr }