← 返回首页
🧠

连续批处理:高吞吐量推理

📂 llm ⏱ 4 min 643 words

--- title: "连续批处理:高吞吐量推理" description: "掌握连续批处理的原理和实现,实现LLM服务的最大化吞吐量" tags: ["连续批处理", "动态批处理", "高吞吐量", "推理优化"] category: "llm" icon: "🧠"

连续批处理:高吞吐量推理

连续批处理简介

连续批处理(Continuous Batching)是一种动态的批处理策略,允许在序列生成过程中动态地添加和移除请求。与传统静态批处理不同,连续批处理不会等待整个批次完成,而是随时处理新请求,显著提高了GPU利用率和系统吞吐量。

连续批处理的核心优势:

工作原理

传统静态批处理

class StaticBatcher:
    """传统静态批处理器"""
    
    def __init__(self, max_batch_size=32):
        self.max_batch_size = max_batch_size
        self.queue = []
    
    def add_request(self, request):
        """添加请求"""
        self.queue.append(request)
    
    def process_batch(self):
        """处理整个批次"""
        # 等待填满批次
        if len(self.queue) < self.max_batch_size:
            return None
        
        # 取出批次
        batch = self.queue[:self.max_batch_size]
        self.queue = self.queue[self.max_batch_size:]
        
        # 处理(等待最长序列完成)
        results = []
        for request in batch:
            result = self.generate(request)
            results.append(result)
        
        return results
    
    def generate(self, request):
        """生成序列"""
        # 生成直到完成
        tokens = []
        while not is_finished(tokens):
            next_token = model.generate_next_token(tokens)
            tokens.append(next_token)
        return tokens

# 问题:GPU在短序列完成后空闲等待

连续批处理

class ContinuousBatcher:
    """连续批处理器"""
    
    def __init__(self, model, tokenizer, max_batch_size=32):
        self.model = model
        self.tokenizer = tokenizer
        self.max_batch_size = max_batch_size
        
        # 活跃请求
        self.active_requests = {}
        
        # 等待队列
        self.waiting_queue = []
        
        # 请求ID计数器
        self.request_id_counter = 0
    
    def add_request(self, prompt):
        """添加新请求"""
        request_id = self.request_id_counter
        self.request_id_counter += 1
        
        # 编码提示
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
        
        # 初始化请求状态
        self.waiting_queue.append({
            "id": request_id,
            "input_ids": input_ids,
            "generated_tokens": [],
            "past_key_values": None,
            "status": "waiting"
        })
        
        return request_id
    
    def step(self):
        """执行一个生成步骤"""
        # 1. 从等待队列移入活跃
        self._schedule_new_requests()
        
        # 2. 执行一步生成
        completed = self._generate_step()
        
        # 3. 移除完成的请求
        self._remove_completed(completed)
        
        return completed
    
    def _schedule_new_requests(self):
        """调度新请求"""
        while (self.waiting_queue and 
               len(self.active_requests) < self.max_batch_size):
            request = self.waiting_queue.pop(0)
            request["status"] = "active"
            self.active_requests[request["id"]] = request
    
    def _generate_step(self):
        """生成一步"""
        if not self.active_requests:
            return []
        
        completed = []
        
        # 准备批处理输入
        batch_input_ids = []
        batch_past_key_values = []
        batch_request_ids = []
        
        for req_id, request in self.active_requests.items():
            if request["generated_tokens"]:
                input_id = request["generated_tokens"][-1]
            else:
                input_id = request["input_ids"][0, -1]
            
            batch_input_ids.append(input_id)
            batch_past_key_values.append(request["past_key_values"])
            batch_request_ids.append(req_id)
        
        # 批量前向传播
        with torch.no_grad():
            outputs = self.model(
                input_ids=torch.tensor(batch_input_ids).unsqueeze(1),
                past_key_values=batch_past_key_values,
                use_cache=True
            )
        
        # 处理输出
        for i, req_id in enumerate(batch_request_ids):
            request = self.active_requests[req_id]
            
            # 获取下一个token
            logits = outputs.logits[i]
            next_token = torch.argmax(logits[:, -1], dim=-1).item()
            
            # 更新状态
            request["generated_tokens"].append(next_token)
            request["past_key_values"] = outputs.past_key_values[i]
            
            # 检查是否完成
            if next_token == self.tokenizer.eos_token_id:
                request["status"] = "completed"
                completed.append(request)
        
        return completed
    
    def _remove_completed(self, completed):
        """移除完成的请求"""
        for request in completed:
            del self.active_requests[request["id"]]

使用vLLM

基本配置

from vllm import LLM, SamplingParams

# vLLM自动使用连续批处理
llm = LLM(
    model="meta-llama/Llama-2-7b-hf",
    max_model_len=4096,
    gpu_memory_utilization=0.9,
    
    # 连续批处理参数
    max_num_batched_tokens=8192,  # 最大批处理token数
    max_num_seqs=256,  # 最大并发序列数
    max_num_running_seqs=128,  # 最大运行中序列数
    
    # 调度参数
    scheduler_policy="fcfs"  # 先来先服务
)

# 推理
prompts = [f"问题{i}" for i in range(100)]
sampling_params = SamplingParams(temperature=0.7, max_tokens=256)

outputs = llm.generate(prompts, sampling_params)

异步API

import asyncio
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams

class AsyncInferenceService:
    """异步推理服务"""
    
    def __init__(self, model_name):
        engine_args = AsyncEngineArgs(
            model=model_name,
            max_model_len=4096,
            gpu_memory_utilization=0.9,
            max_num_seqs=256
        )
        self.engine = AsyncLLMEngine.from_engine_args(engine_args)
    
    async def generate(self, prompt, request_id):
        """异步生成"""
        sampling_params = SamplingParams(temperature=0.7, max_tokens=256)
        
        async for output in self.engine.generate(
            prompt, sampling_params, request_id=request_id
        ):
            yield output.outputs[0].text
    
    async def batch_generate(self, prompts):
        """批量异步生成"""
        tasks = []
        for i, prompt in enumerate(prompts):
            task = self.generate(prompt, request_id=i)
            tasks.append(task)
        
        results = []
        async for result in asyncio.gather(*tasks):
            results.append(result)
        
        return results

# 使用
service = AsyncInferenceService("meta-llama/Llama-2-7b-hf")
results = asyncio.run(service.batch_generate(["问题1", "问题2", "问题3"]))

调度策略

公平调度

class FairScheduler:
    """公平调度器"""
    
    def __init__(self):
        self.queues = {
            "high_priority": [],
            "normal": [],
            "low_priority": []
        }
    
    def schedule(self):
        """公平调度"""
        # 高优先级优先
        for queue_name in ["high_priority", "normal", "low_priority"]:
            if self.queues[queue_name]:
                return self.queues[queue_name].pop(0)
        
        return None
    
    def add_request(self, request, priority="normal"):
        """添加请求"""
        self.queues[priority].append(request)

贪心调度

class GreedyScheduler:
    """贪心调度器(最大化GPU利用率)"""
    
    def schedule(self, waiting_queue, max_batch_size, current_batch):
        """贪心调度"""
        # 计算剩余GPU容量
        remaining_capacity = max_batch_size - len(current_batch)
        
        if remaining_capacity <= 0:
            return []
        
        # 选择能填满容量的请求
        selected = []
        for request in waiting_queue:
            if len(selected) >= remaining_capacity:
                break
            selected.append(request)
        
        return selected

性能优化

批处理大小调优

def optimize_batch_size(model, tokenizer, test_prompts):
    """优化批处理大小"""
    results = []
    
    for batch_size in [16, 32, 64, 128, 256]:
        llm = LLM(
            model=model,
            max_num_seqs=batch_size,
            max_num_batched_tokens=batch_size * 512
        )
        
        # 测量吞吐量
        start_time = time.time()
        outputs = llm.generate(test_prompts[:batch_size], 
                              SamplingParams(max_tokens=256))
        throughput = len(test_prompts[:batch_size]) / (time.time() - start_time)
        
        results.append({
            "batch_size": batch_size,
            "throughput": throughput,
            "latency": np.mean([o.metrics.finished_time for o in outputs])
        })
    
    return results

内存优化

def optimize_memory_usage(llm_engine):
    """优化内存使用"""
    stats = llm_engine.scheduler.block_manager.get_stats()
    
    # 动态调整
    if stats["gpu_memory_usage"] > 0.9:
        # 减少批处理大小
        llm_engine.scheduler.max_num_seqs *= 0.9
    elif stats["gpu_memory_usage"] < 0.5:
        # 增加批处理大小
        llm_engine.scheduler.max_num_seqs = min(
            llm_engine.scheduler.max_num_seqs * 1.1,
            512
        )

监控指标

from prometheus_client import Histogram, Gauge, Counter

# 监控指标
REQUEST_LATENCY = Histogram(
    'llm_request_latency_seconds',
    'Request latency',
    buckets=[0.1, 0.5, 1, 2, 5, 10]
)

THROUGHPUT = Gauge(
    'llm_throughput_requests_per_second',
    'Requests per second'
)

BATCH_SIZE = Gauge(
    'llm_current_batch_size',
    'Current batch size'
)

QUEUE_SIZE = Gauge(
    'llm_queue_size',
    'Number of waiting requests'
)

COMPLETED_REQUESTS = Counter(
    'llm_completed_requests_total',
    'Total completed requests'
)

def monitor_continuous_batching(batcher):
    """监控连续批处理"""
    while True:
        # 更新指标
        BATCH_SIZE.set(len(batcher.active_requests))
        QUEUE_SIZE.set(len(batcher.waiting_queue))
        
        time.sleep(1)

与其他技术结合

# 连续批处理 + PagedAttention + Flash Attention
def create_optimized_engine(model_name):
    """创建优化的推理引擎"""
    from vllm import LLM
    
    llm = LLM(
        model=model_name,
        
        # PagedAttention配置
        block_size=16,
        gpu_memory_utilization=0.9,
        
        # 连续批处理配置
        max_num_batched_tokens=16384,
        max_num_seqs=512,
        
        # Flash Attention
        enforce_eager=False,  # 启用CUDA Graph
        
        # 张量并行
        tensor_parallel_size=1
    )
    
    return llm

连续批处理是实现高吞吐量LLM服务的关键技术,与PagedAttention和Flash Attention结合可以达到最佳性能。