package rabbitmq import ( "context" "errors" "fmt" "sync" "time" "github.com/go-nunu/nunu-layout-advanced/pkg/log" amqp "github.com/rabbitmq/amqp091-go" "go.uber.org/zap" ) var ( ErrClosed = errors.New("rabbitmq: client is closed") ) type Config struct { Host string `yaml:"host"` Port int `yaml:"port"` Username string `yaml:"username"` Password string `yaml:"password"` VHost string `yaml:"vhost"` ConnectionTimeout time.Duration `yaml:"connection_timeout"` Tasks map[string]TaskConfig `yaml:"tasks"` // 支持多个任务配置 } type TaskConfig struct { Exchange string `mapstructure:"exchange"` ExchangeType string `mapstructure:"exchange_type"` Queue string `mapstructure:"queue"` RoutingKey string `mapstructure:"routing_key"` ConsumerCount int `mapstructure:"consumer_count"` PrefetchCount int `mapstructure:"prefetch_count"` } type RabbitMQ struct { config Config conn *amqp.Connection ch *amqp.Channel logger *log.Logger mu sync.RWMutex closed bool } // New 创建新的RabbitMQ客户端 func New(config Config, logger *log.Logger) (*RabbitMQ, error) { r := &RabbitMQ{ config: config, logger: logger, } if err := r.Connect(); err != nil { return nil, err } if err := r.SetupAllTaskQueues(); err != nil { _ = r.Close() // Attempt to close the connection if setup fails return nil, fmt.Errorf("failed to setup task queues: %w", err) } go r.reconnectLoop() return r, nil } // Connect 连接到RabbitMQ服务器 func (r *RabbitMQ) Connect() error { r.mu.Lock() defer r.mu.Unlock() if r.conn != nil && !r.conn.IsClosed() { _ = r.ch.Close() _ = r.conn.Close() } vhost := r.config.VHost if vhost == "" { vhost = "/" } else if vhost[0] != '/' { vhost = "/" + vhost } // 构造完整的连接URL fullURL := fmt.Sprintf("amqp://%s:%s@%s:%d%s", r.config.Username, r.config.Password, r.config.Host, r.config.Port, vhost, ) r.logger.Info("正在尝试连接到 RabbitMQ...", zap.String("url", fullURL)) var err error r.conn, err = amqp.Dial(fullURL) if err != nil { // 记录详细的底层错误 r.logger.Error("连接RabbitMQ失败", zap.Error(err)) return fmt.Errorf("连接RabbitMQ失败: %w", err) } r.ch, err = r.conn.Channel() if err != nil { _ = r.conn.Close() return fmt.Errorf("创建通道失败: %w", err) } r.closed = false r.logger.Info("RabbitMQ连接成功") return nil } // reconnectLoop 监控连接状态并处理重连 func (r *RabbitMQ) reconnectLoop() { for { closeChan := make(chan *amqp.Error) r.mu.RLock() if r.conn == nil { r.mu.RUnlock() time.Sleep(5 * time.Second) continue } r.conn.NotifyClose(closeChan) isClosed := r.closed r.mu.RUnlock() if isClosed { r.logger.Info("RabbitMQ客户端已关闭,停止重连循环。") return } closeErr := <-closeChan if closeErr != nil { r.logger.Error("RabbitMQ连接断开,将尝试重新连接", zap.Error(closeErr)) } else { r.logger.Info("RabbitMQ连接正常关闭。") } r.mu.RLock() isClosed = r.closed r.mu.RUnlock() if isClosed { r.logger.Info("RabbitMQ客户端已关闭,停止重连。") return } backoff := 1 * time.Second maxBackoff := 30 * time.Second for { if r.isClosed() { return } err := r.Connect() if err == nil { r.logger.Info("RabbitMQ重新连接成功") // 重新设置任务队列 if err := r.SetupAllTaskQueues(); err != nil { r.logger.Error("重新设置所有任务队列失败", zap.Error(err)) } break } r.logger.Error("RabbitMQ重连失败", zap.Error(err), zap.Duration("backoff", backoff)) time.Sleep(backoff) backoff *= 2 if backoff > maxBackoff { backoff = maxBackoff } } } } // Close 关闭连接 func (r *RabbitMQ) Close() error { r.mu.Lock() defer r.mu.Unlock() if r.closed { return nil } r.closed = true var errs []error if r.ch != nil { if err := r.ch.Close(); err != nil { errs = append(errs, fmt.Errorf("关闭channel失败: %w", err)) } } if r.conn != nil && !r.conn.IsClosed() { if err := r.conn.Close(); err != nil { errs = append(errs, fmt.Errorf("关闭connection失败: %w", err)) } } if len(errs) > 0 { return fmt.Errorf("关闭RabbitMQ时发生错误: %v", errs) } return nil } func (r *RabbitMQ) isClosed() bool { r.mu.RLock() defer r.mu.RUnlock() return r.closed } // GetTaskConfig retrieves a specific task's configuration. func (r *RabbitMQ) GetTaskConfig(name string) (TaskConfig, bool) { taskCfg, ok := r.config.Tasks[name] return taskCfg, ok } func (r *RabbitMQ) withChannel(fn func(*amqp.Channel) error) error { if r.isClosed() { return ErrClosed } r.mu.RLock() defer r.mu.RUnlock() if r.ch == nil || r.conn.IsClosed() { return errors.New("rabbitmq: channel or connection is not available") } return fn(r.ch) } // Publish sends a message to the specified exchange with the given routing key. // This is a convenience wrapper around PublishWithCh. func (r *RabbitMQ) Publish(exchange, routingKey string, body []byte) error { return r.PublishWithCh(exchange, routingKey, amqp.Publishing{ ContentType: "text/plain", Body: body, DeliveryMode: amqp.Persistent, // Default to persistent }) } // PublishWithCh sends a message to the specified exchange with the given routing key using a custom amqp.Publishing struct. // It creates a new channel for each publication to ensure thread safety, as amqp.Channel is not safe for concurrent use. func (r *RabbitMQ) PublishWithCh(exchange, routingKey string, msg amqp.Publishing) error { r.mu.RLock() // Check if the connection is alive and well. if r.closed || r.conn == nil || r.conn.IsClosed() { r.mu.RUnlock() return fmt.Errorf("rabbitmq: connection is not available") } // We must get the connection under the lock, but we can release the lock before creating the channel // because the connection object itself is safe for concurrent use. conn := r.conn r.mu.RUnlock() // Create a new channel for this specific publication. This is the key to thread safety. ch, err := conn.Channel() if err != nil { return fmt.Errorf("rabbitmq: failed to open a channel: %w", err) } defer ch.Close() // Ensure the channel is closed after the operation. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Publish the message using the temporary channel. return ch.PublishWithContext(ctx, exchange, routingKey, false, // mandatory false, // immediate msg, ) } // Consume 获取消息消费通道. 注意: Qos的设置需要调用方在获取channel后自行处理,或者为Consume方法增加prefetchCount参数 func (r *RabbitMQ) Consume(queue, consumer string, prefetchCount int) (<-chan amqp.Delivery, error) { var deliveries <-chan amqp.Delivery err := r.withChannel(func(ch *amqp.Channel) error { if err := ch.Qos(prefetchCount, 0, false); err != nil { return fmt.Errorf("设置Qos失败: %w", err) } var err error deliveries, err = ch.Consume( queue, consumer, false, // auto-ack: false, 手动确认 false, // exclusive false, // no-local false, // no-wait nil, // args ) return err }) return deliveries, err } // SetupAllTaskQueues 遍历配置中的所有任务,并为每个任务设置队列 func (r *RabbitMQ) SetupAllTaskQueues() error { if len(r.config.Tasks) == 0 { r.logger.Info("在配置中未找到任何任务队列定义。") return nil } for name, taskCfg := range r.config.Tasks { if err := r.setupQueue(taskCfg); err != nil { return fmt.Errorf("为任务 '%s' 设置队列失败: %w", name, err) } } return nil } // setupQueue 为单个任务配置设置交换机、队列和绑定 func (r *RabbitMQ) setupQueue(taskCfg TaskConfig) error { if taskCfg.Exchange == "" { r.logger.Warn("任务队列的交换机名称为空,将使用默认交换机。这在多任务场景下可能导致问题。", zap.String("queue", taskCfg.Queue)) return r.withChannel(func(ch *amqp.Channel) error { _, err := ch.QueueDeclare(taskCfg.Queue, true, false, false, false, nil) if err != nil { return fmt.Errorf("声明队列失败 (默认交换机): %w", err) } r.logger.Info("成功声明队列并绑定到默认交换机", zap.String("queue", taskCfg.Queue)) return nil }) } return r.withChannel(func(ch *amqp.Channel) error { // 声明主交换机 exchangeType := taskCfg.ExchangeType if exchangeType == "" { exchangeType = "direct" // 默认为 direct 类型,兼容旧配置 } err := ch.ExchangeDeclare( taskCfg.Exchange, // name exchangeType, // type true, // durable false, // autoDelete false, // internal false, // noWait nil, // args ) if err != nil { return fmt.Errorf("声明主交换机 '%s' 失败: %w", taskCfg.Exchange, err) } // 为主队列设置死信交换机参数 dlxExchange := taskCfg.Exchange + ".dlx" args := amqp.Table{ "x-dead-letter-exchange": dlxExchange, } // 声明主队列 _, err = ch.QueueDeclare(taskCfg.Queue, true, false, false, false, args) if err != nil { return fmt.Errorf("声明主队列 '%s' 失败: %w", taskCfg.Queue, err) } // 绑定主队列到主交换机 if err := ch.QueueBind(taskCfg.Queue, taskCfg.RoutingKey, taskCfg.Exchange, false, nil); err != nil { return fmt.Errorf("绑定主队列失败: %w", err) } // --- 设置死信队列 --- // 声明死信交换机 (DLX) if err := ch.ExchangeDeclare(dlxExchange, "direct", true, false, false, false, nil); err != nil { return fmt.Errorf("声明死信交换机 '%s' 失败: %w", dlxExchange, err) } // 声明死信队列 (DLQ) dlq := taskCfg.Queue + ".dlq" _, err = ch.QueueDeclare(dlq, true, false, false, false, nil) if err != nil { return fmt.Errorf("声明死信队列 '%s' 失败: %w", dlq, err) } // 绑定DLQ到DLX,使用与主队列相同的路由键 if err := ch.QueueBind(dlq, taskCfg.RoutingKey, dlxExchange, false, nil); err != nil { return fmt.Errorf("绑定死信队列失败: %w", err) } return nil }) }