← 返回首页
队列

消息队列

📂 python ⏱ 7 min 1337 words

消息队列

消息队列是构建分布式系统的关键组件。本文将介绍RabbitMQ、Kafka和Celery在Python中的应用,实现异步任务处理和事件驱动架构。

RabbitMQ基础

import pika
import json
from typing import Callable, Any

class RabbitMQClient:
    def __init__(self, host: str = "localhost"):
        self.connection = pika.BlockingConnection(
            pika.ConnectionParameters(host)
        )
        self.channel = self.connection.channel()
    
    def declare_queue(self, queue_name: str, durable: bool = True):
        """声明队列"""
        self.channel.queue_declare(queue=queue_name, durable=durable)
    
    def publish(self, queue_name: str, message: dict, persistent: bool = True):
        """发布消息"""
        properties = pika.BasicProperties(
            delivery_mode=2 if persistent else 1,
            content_type="application/json"
        )
        
        self.channel.basic_publish(
            exchange="",
            routing_key=queue_name,
            body=json.dumps(message),
            properties=properties
        )
    
    def consume(self, queue_name: str, callback: Callable):
        """消费消息"""
        def wrapper(ch, method, properties, body):
            try:
                message = json.loads(body)
                callback(message)
                ch.basic_ack(delivery_tag=method.delivery_tag)
            except Exception as e:
                # 消息处理失败,拒绝并重新入队
                ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
        
        self.channel.basic_qos(prefetch_count=1)
        self.channel.basic_consume(
            queue=queue_name,
            on_message_callback=wrapper
        )
        
        print(f"开始消费队列: {queue_name}")
        self.channel.start_consuming()
    
    def close(self):
        self.connection.close()

# 使用示例
def producer():
    client = RabbitMQClient()
    client.declare_queue("tasks")
    
    for i in range(10):
        message = {
            "task_id": i,
            "data": f"任务数据 {i}",
            "priority": i % 3
        }
        client.publish("tasks", message)
        print(f"发布任务: {i}")
    
    client.close()

def consumer():
    def process_task(message):
        print(f"处理任务: {message}")
        # 模拟处理时间
        import time
        time.sleep(1)
        print(f"任务完成: {message['task_id']}")
    
    client = RabbitMQClient()
    client.declare_queue("tasks")
    client.consume("tasks", process_task)

高级RabbitMQ模式

import pika
import json
from functools import wraps
import time

class AdvancedRabbitMQ:
    def __init__(self, host: str = "localhost"):
        self.connection = pika.BlockingConnection(
            pika.ConnectionParameters(host)
        )
        self.channel = self.connection.channel()
    
    def setup_exchange(self, exchange_name: str, exchange_type: str = "direct"):
        """设置交换机"""
        self.channel.exchange_declare(
            exchange=exchange_name,
            exchange_type=exchange_type,
            durable=True
        )
    
    def setup_binding(self, exchange_name: str, queue_name: str, routing_key: str):
        """设置绑定"""
        self.channel.queue_declare(queue=queue_name, durable=True)
        self.channel.queue_bind(
            exchange=exchange_name,
            queue=queue_name,
            routing_key=routing_key
        )
    
    def publish_to_exchange(self, exchange_name: str, routing_key: str, message: dict):
        """发布到交换机"""
        self.channel.basic_publish(
            exchange=exchange_name,
            routing_key=routing_key,
            body=json.dumps(message),
            properties=pika.BasicProperties(
                delivery_mode=2,
                content_type="application/json"
            )
        )
    
    def dead_letter_queue(self, queue_name: str, dlx_exchange: str, dlx_routing_key: str):
        """死信队列配置"""
        args = {
            "x-dead-letter-exchange": dlx_exchange,
            "x-dead-letter-routing-key": dlx_routing_key
        }
        
        self.channel.queue_declare(
            queue=queue_name,
            durable=True,
            arguments=args
        )
    
    def retry_with_delay(self, queue_name: str, message: dict, delay_ms: int):
        """延迟重试"""
        args = {
            "x-dead-letter-exchange": "",
            "x-dead-letter-routing-key": queue_name,
            "x-message-ttl": delay_ms
        }
        
        retry_queue = f"{queue_name}_retry"
        self.channel.queue_declare(queue=retry_queue, durable=True, arguments=args)
        
        self.channel.basic_publish(
            exchange="",
            routing_key=retry_queue,
            body=json.dumps(message),
            properties=pika.BasicProperties(delivery_mode=2)
        )

# 工作队列模式
def worker_queue_example():
    client = AdvancedRabbitMQ()
    
    # 设置工作队列
    queue_name = "work_queue"
    client.channel.queue_declare(queue=queue_name, durable=True)
    
    # 轮询分发
    client.channel.basic_qos(prefetch_count=1)
    
    def callback(ch, method, properties, body):
        message = json.loads(body)
        print(f"处理任务: {message['task_id']}")
        
        # 模拟工作
        time.sleep(message.get('duration', 1))
        
        ch.basic_ack(delivery_tag=method.delivery_tag)
        print(f"任务完成: {message['task_id']}")
    
    client.channel.basic_consume(queue=queue_name, on_message_callback=callback)
    client.channel.start_consuming()

Apache Kafka集成

from kafka import KafkaProducer, KafkaConsumer
import json
from typing import List, Dict
import threading

class KafkaProducerClient:
    def __init__(self, bootstrap_servers: List[str] = ["localhost:9092"]):
        self.producer = KafkaProducer(
            bootstrap_servers=bootstrap_servers,
            value_serializer=lambda v: json.dumps(v).encode('utf-8'),
            key_serializer=lambda k: k.encode('utf-8') if k else None,
            acks='all',
            retries=3
        )
    
    def produce(self, topic: str, message: dict, key: str = None):
        """生产消息"""
        future = self.producer.send(topic, value=message, key=key)
        record_metadata = future.get(timeout=10)
        
        return {
            "topic": record_metadata.topic,
            "partition": record_metadata.partition,
            "offset": record_metadata.offset
        }
    
    def close(self):
        self.producer.flush()
        self.producer.close()

class KafkaConsumerClient:
    def __init__(self, topics: List[str], group_id: str, 
                 bootstrap_servers: List[str] = ["localhost:9092"]):
        self.consumer = KafkaConsumer(
            *topics,
            bootstrap_servers=bootstrap_servers,
            group_id=group_id,
            value_deserializer=lambda m: json.loads(m.decode('utf-8')),
            auto_offset_reset='earliest',
            enable_auto_commit=False
        )
    
    def consume(self, callback):
        """消费消息"""
        try:
            for message in self.consumer:
                try:
                    callback(message)
                    self.consumer.commit()
                except Exception as e:
                    print(f"处理消息失败: {e}")
                    # 不提交,消息会重新消费
        except KeyboardInterrupt:
            pass
        finally:
            self.consumer.close()

# 使用示例
def kafka_example():
    # 生产者
    producer = KafkaProducerClient()
    
    for i in range(10):
        message = {
            "event_type": "user_created",
            "user_id": i,
            "data": {"name": f"User {i}"}
        }
        
        result = producer.produce("user_events", message, key=f"user_{i}")
        print(f"消息已发送: {result}")
    
    producer.close()
    
    # 消费者
    def handle_event(message):
        print(f"收到事件: {message.topic} - {message.value}")
    
    consumer = KafkaConsumerClient(
        topics=["user_events"],
        group_id="user_service_group"
    )
    
    consumer.consume(handle_event)

事件驱动架构

from typing import Dict, List, Callable, Any
from dataclasses import dataclass
from datetime import datetime
import uuid
import json

@dataclass
class Event:
    event_type: str
    data: Dict[str, Any]
    event_id: str = None
    timestamp: str = None
    source: str = None
    
    def __post_init__(self):
        if not self.event_id:
            self.event_id = str(uuid.uuid4())
        if not self.timestamp:
            self.timestamp = datetime.utcnow().isoformat()

class EventBus:
    def __init__(self):
        self.handlers: Dict[str, List[Callable]] = {}
    
    def subscribe(self, event_type: str, handler: Callable):
        """订阅事件"""
        if event_type not in self.handlers:
            self.handlers[event_type] = []
        self.handlers[event_type].append(handler)
    
    def publish(self, event: Event):
        """发布事件"""
        handlers = self.handlers.get(event.event_type, [])
        
        for handler in handlers:
            try:
                handler(event)
            except Exception as e:
                print(f"事件处理失败: {e}")
    
    def emit(self, event_type: str, data: Dict[str, Any], source: str = None):
        """快捷发布"""
        event = Event(
            event_type=event_type,
            data=data,
            source=source
        )
        self.publish(event)

# 使用事件总线
event_bus = EventBus()

# 订阅事件
def handle_user_created(event: Event):
    print(f"新用户创建: {event.data['name']}")
    
    # 发送欢迎邮件
    event_bus.emit("email_welcome", {
        "user_id": event.data["user_id"],
        "email": event.data["email"]
    })

def handle_email_welcome(event: Event):
    print(f"发送欢迎邮件给用户: {event.data['user_id']}")

event_bus.subscribe("user_created", handle_user_created)
event_bus.subscribe("email_welcome", handle_email_welcome)

# 发布事件
event_bus.emit("user_created", {
    "user_id": 1,
    "name": "Alice",
    "email": "alice@example.com"
})

Celery分布式任务队列

from celery import Celery
from celery.decorators import task
from celery.utils.log import get_task_logger
import time
from typing import Dict, Any

# 配置Celery
app = Celery(
    'tasks',
    broker='redis://localhost:6379/0',
    backend='redis://localhost:6379/1'
)

app.conf.update(
    task_serializer='json',
    accept_content=['json'],
    result_serializer='json',
    timezone='UTC',
    enable_utc=True,
    task_track_started=True,
    task_time_limit=30 * 60,  # 30分钟超时
    task_soft_time_limit=25 * 60,  # 25分钟软超时
    worker_prefetch_multiplier=1,
    worker_max_tasks_per_child=100,
)

logger = get_task_logger(__name__)

@app.task(bind=True, max_retries=3)
def process_data(self, data: Dict[str, Any]):
    """处理数据任务"""
    try:
        logger.info(f"开始处理数据: {data}")
        
        # 模拟处理
        time.sleep(2)
        
        result = {
            "status": "success",
            "processed_items": len(data.get("items", [])),
            "data": data
        }
        
        logger.info(f"数据处理完成: {result}")
        return result
        
    except Exception as exc:
        logger.error(f"处理失败: {exc}")
        # 重试,指数退避
        raise self.retry(exc=exc, countdown=60 * (2 ** self.request.retries))

@app.task
def send_notification(user_id: int, message: str):
    """发送通知任务"""
    logger.info(f"发送通知给用户 {user_id}: {message}")
    time.sleep(1)
    return {"status": "sent", "user_id": user_id}

@app.task(bind=True)
def long_running_task(self, task_id: str):
    """长时间运行的任务"""
    steps = 10
    for i in range(steps):
        # 更新任务进度
        self.update_state(
            state='PROGRESS',
            meta={
                'current': i + 1,
                'total': steps,
                'percent': (i + 1) / steps * 100
            }
        )
        time.sleep(1)
    
    return {"status": "completed", "task_id": task_id}

# 任务链和工作流
from celery import chain, group, chord

@app.task
def add(x, y):
    return x + y

@app.task
def multiply(x, y):
    return x * y

@app.task
def sum_results(results):
    return sum(results)

# 链式调用
workflow = chain(
    add.s(4, 4),
    multiply.s(8),
    sum_results.s()
)

# 并行执行
parallel_tasks = group(add.s(i, i) for i in range(10))

# 和弦(并行任务后接回调)
callback = sum_results.s()
chord_task = chord(
    [add.s(i, i) for i in range(10)],
    callback
)

# 使用示例
def execute_workflows():
    # 执行链
    result = workflow.apply_async()
    print(f"链式任务结果: {result.get(timeout=10)}")
    
    # 执行并行任务
    result = parallel_tasks.apply_async()
    print(f"并行任务结果: {result.get(timeout=10)}")
    
    # 执行和弦
    result = chord_task.apply_async()
    print(f"和弦任务结果: {result.get(timeout=10)}")

Celery监控和管理

from celery import Celery
from celery.events import EventDispatcher
from celery.utils.log import get_task_logger
import json
from datetime import datetime

class CeleryMonitor:
    def __init__(self, app: Celery):
        self.app = app
        self.logger = get_task_logger(__name__)
    
    def get_active_tasks(self):
        """获取活跃任务"""
        inspector = self.app.control.inspect()
        
        active = inspector.active()
        if not active:
            return []
        
        tasks = []
        for worker, worker_tasks in active.items():
            for task in worker_tasks:
                tasks.append({
                    "worker": worker,
                    "task_id": task["id"],
                    "name": task["name"],
                    "args": task["args"],
                    "kwargs": task["kwargs"],
                    "started": task["time_start"]
                })
        
        return tasks
    
    def get_task_results(self, task_ids: list):
        """获取任务结果"""
        results = []
        for task_id in task_ids:
            result = self.app.AsyncResult(task_id)
            results.append({
                "task_id": task_id,
                "state": result.state,
                "result": result.result if result.ready() else None,
                "info": result.info
            })
        
        return results
    
    def revoke_task(self, task_id: str, terminate: bool = False):
        """撤销任务"""
        self.app.control.revoke(
            task_id,
            terminate=terminate,
            signal='SIGTERM'
        )
    
    def rate_limit(self, task_name: str, rate_limit: str):
        """设置任务速率限制"""
        self.app.control.rate_limit(task_name, rate_limit)

# 使用监控
monitor = CeleryMonitor(app)

# 获取活跃任务
active_tasks = monitor.get_active_tasks()
print(f"活跃任务: {len(active_tasks)}")

# 撤销任务
monitor.revoke_task("some-task-id", terminate=True)

# 设置速率限制
monitor.rate_limit("tasks.process_data", "100/hour")

消息队列最佳实践

# 1. 消息序列化
import json
from datetime import datetime
from typing import Any

class MessageSerializer:
    @staticmethod
    def default_serializer(obj: Any) -> Any:
        if isinstance(obj, datetime):
            return obj.isoformat()
        raise TypeError(f"无法序列化类型: {type(obj)}")
    
    @staticmethod
    def serialize(data: dict) -> bytes:
        return json.dumps(data, default=MessageSerializer.default_serializer).encode()
    
    @staticmethod
    def deserialize(data: bytes) -> dict:
        return json.loads(data.decode())

# 2. 幂等性处理
class IdempotentProcessor:
    def __init__(self):
        self.processed_messages = set()
    
    def process(self, message_id: str, processor):
        if message_id in self.processed_messages:
            print(f"消息已处理,跳过: {message_id}")
            return
        
        processor()
        self.processed_messages.add(message_id)
        
        # 限制内存使用
        if len(self.processed_messages) > 10000:
            # 保留最近的5000条
            self.processed_messages = set(list(self.processed_messages)[-5000:])

# 3. 死信队列处理
class DeadLetterHandler:
    def __init__(self, rabbitmq_client):
        self.client = rabbitmq_client
        self.dlq_name = "dead_letter_queue"
    
    def setup(self):
        """设置死信队列"""
        self.client.declare_queue(self.dlq_name)
    
    def handle_dead_letter(self, message: dict):
        """处理死信消息"""
        print(f"处理死信消息: {message}")
        
        # 分析失败原因
        error = message.get("error", "unknown")
        
        # 根据错误类型处理
        if "timeout" in str(error):
            # 重试
            self.retry_message(message)
        elif "validation" in str(error):
            # 记录并丢弃
            self.log_invalid_message(message)
        else:
            # 人工处理
            self.escalate_to_human(message)
    
    def retry_message(self, message: dict):
        """重试消息"""
        retry_count = message.get("retry_count", 0)
        if retry_count < 3:
            message["retry_count"] = retry_count + 1
            self.client.publish("tasks", message)
        else:
            self.escalate_to_human(message)

# 4. 性能监控
class QueueMonitor:
    def __init__(self):
        self.metrics = {
            "messages_published": 0,
            "messages_consumed": 0,
            "processing_time": [],
            "errors": 0
        }
    
    def record_publish(self):
        self.metrics["messages_published"] += 1
    
    def record_consume(self, processing_time: float):
        self.metrics["messages_consumed"] += 1
        self.metrics["processing_time"].append(processing_time)
    
    def record_error(self):
        self.metrics["errors"] += 1
    
    def get_stats(self) -> dict:
        processing_times = self.metrics["processing_time"]
        return {
            "published": self.metrics["messages_published"],
            "consumed": self.metrics["messages_consumed"],
            "avg_processing_time": sum(processing_times) / len(processing_times) if processing_times else 0,
            "error_rate": self.metrics["errors"] / max(self.metrics["messages_consumed"], 1)
        }

常见问题解决

# 1. 消息积压处理
class BackpressureHandler:
    def __init__(self, max_queue_size: int = 1000):
        self.max_queue_size = max_queue_size
    
    def handle_backpressure(self, queue_size: int):
        if queue_size > self.max_queue_size:
            # 减少消费速度
            return {"action": "slow_down", "factor": 0.5}
        elif queue_size < self.max_queue_size * 0.3:
            # 增加消费速度
            return {"action": "speed_up", "factor": 1.5}
        return {"action": "normal"}

# 2. 消息顺序保证
class OrderedProcessor:
    def __init__(self):
        self.pending_messages = {}
        self.next_sequence = {}
    
    def process(self, partition_key: str, sequence: int, message: dict):
        if partition_key not in self.next_sequence:
            self.next_sequence[partition_key] = 0
        
        if sequence == self.next_sequence[partition_key]:
            # 按顺序处理
            self._process_message(message)
            self.next_sequence[partition_key] += 1
            
            # 检查是否有待处理消息
            self._process_pending(partition_key)
        else:
            # 缓存乱序消息
            if partition_key not in self.pending_messages:
                self.pending_messages[partition_key] = {}
            self.pending_messages[partition_key][sequence] = message
    
    def _process_pending(self, partition_key: str):
        while self.next_sequence[partition_key] in self.pending_messages.get(partition_key, {}):
            message = self.pending_messages[partition_key].pop(self.next_sequence[partition_key])
            self._process_message(message)
            self.next_sequence[partition_key] += 1

# 3. 消息确认机制
class ReliableConsumer:
    def __init__(self):
        self.unconfirmed = {}
    
    def consume(self, message):
        message_id = message["id"]
        self.unconfirmed[message_id] = message
        
        try:
            self._process(message)
            self._confirm(message_id)
        except Exception as e:
            self._reject(message_id, str(e))
    
    def _confirm(self, message_id: str):
        del self.unconfirmed[message_id]
        # 发送确认到消息队列
    
    def _reject(self, message_id: str, reason: str):
        message = self.unconfirmed.pop(message_id)
        # 重新入队或发送到死信队列

总结

消息队列是构建分布式系统的核心组件。RabbitMQ适合复杂路由和可靠传递,Kafka适合高吞吐量事件流,Celery简化分布式任务处理。根据场景选择合适的方案,实现异步、解耦、可扩展的系统架构。