连续批处理:高吞吐量推理
--- title: "连续批处理:高吞吐量推理" description: "掌握连续批处理的原理和实现,实现LLM服务的最大化吞吐量" tags: ["连续批处理", "动态批处理", "高吞吐量", "推理优化"] category: "llm" icon: "🧠"
连续批处理:高吞吐量推理
连续批处理简介
连续批处理(Continuous Batching)是一种动态的批处理策略,允许在序列生成过程中动态地添加和移除请求。与传统静态批处理不同,连续批处理不会等待整个批次完成,而是随时处理新请求,显著提高了GPU利用率和系统吞吐量。
连续批处理的核心优势:
- 高吞吐量:GPU始终保持忙碌状态
- 低延迟:新请求无需等待当前批次完成
- 灵活调度:支持不同长度的请求混合处理
- 资源优化:最大化GPU计算资源利用率
工作原理
传统静态批处理
class StaticBatcher:
"""传统静态批处理器"""
def __init__(self, max_batch_size=32):
self.max_batch_size = max_batch_size
self.queue = []
def add_request(self, request):
"""添加请求"""
self.queue.append(request)
def process_batch(self):
"""处理整个批次"""
# 等待填满批次
if len(self.queue) < self.max_batch_size:
return None
# 取出批次
batch = self.queue[:self.max_batch_size]
self.queue = self.queue[self.max_batch_size:]
# 处理(等待最长序列完成)
results = []
for request in batch:
result = self.generate(request)
results.append(result)
return results
def generate(self, request):
"""生成序列"""
# 生成直到完成
tokens = []
while not is_finished(tokens):
next_token = model.generate_next_token(tokens)
tokens.append(next_token)
return tokens
# 问题:GPU在短序列完成后空闲等待
连续批处理
class ContinuousBatcher:
"""连续批处理器"""
def __init__(self, model, tokenizer, max_batch_size=32):
self.model = model
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
# 活跃请求
self.active_requests = {}
# 等待队列
self.waiting_queue = []
# 请求ID计数器
self.request_id_counter = 0
def add_request(self, prompt):
"""添加新请求"""
request_id = self.request_id_counter
self.request_id_counter += 1
# 编码提示
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
# 初始化请求状态
self.waiting_queue.append({
"id": request_id,
"input_ids": input_ids,
"generated_tokens": [],
"past_key_values": None,
"status": "waiting"
})
return request_id
def step(self):
"""执行一个生成步骤"""
# 1. 从等待队列移入活跃
self._schedule_new_requests()
# 2. 执行一步生成
completed = self._generate_step()
# 3. 移除完成的请求
self._remove_completed(completed)
return completed
def _schedule_new_requests(self):
"""调度新请求"""
while (self.waiting_queue and
len(self.active_requests) < self.max_batch_size):
request = self.waiting_queue.pop(0)
request["status"] = "active"
self.active_requests[request["id"]] = request
def _generate_step(self):
"""生成一步"""
if not self.active_requests:
return []
completed = []
# 准备批处理输入
batch_input_ids = []
batch_past_key_values = []
batch_request_ids = []
for req_id, request in self.active_requests.items():
if request["generated_tokens"]:
input_id = request["generated_tokens"][-1]
else:
input_id = request["input_ids"][0, -1]
batch_input_ids.append(input_id)
batch_past_key_values.append(request["past_key_values"])
batch_request_ids.append(req_id)
# 批量前向传播
with torch.no_grad():
outputs = self.model(
input_ids=torch.tensor(batch_input_ids).unsqueeze(1),
past_key_values=batch_past_key_values,
use_cache=True
)
# 处理输出
for i, req_id in enumerate(batch_request_ids):
request = self.active_requests[req_id]
# 获取下一个token
logits = outputs.logits[i]
next_token = torch.argmax(logits[:, -1], dim=-1).item()
# 更新状态
request["generated_tokens"].append(next_token)
request["past_key_values"] = outputs.past_key_values[i]
# 检查是否完成
if next_token == self.tokenizer.eos_token_id:
request["status"] = "completed"
completed.append(request)
return completed
def _remove_completed(self, completed):
"""移除完成的请求"""
for request in completed:
del self.active_requests[request["id"]]
使用vLLM
基本配置
from vllm import LLM, SamplingParams
# vLLM自动使用连续批处理
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
max_model_len=4096,
gpu_memory_utilization=0.9,
# 连续批处理参数
max_num_batched_tokens=8192, # 最大批处理token数
max_num_seqs=256, # 最大并发序列数
max_num_running_seqs=128, # 最大运行中序列数
# 调度参数
scheduler_policy="fcfs" # 先来先服务
)
# 推理
prompts = [f"问题{i}" for i in range(100)]
sampling_params = SamplingParams(temperature=0.7, max_tokens=256)
outputs = llm.generate(prompts, sampling_params)
异步API
import asyncio
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
class AsyncInferenceService:
"""异步推理服务"""
def __init__(self, model_name):
engine_args = AsyncEngineArgs(
model=model_name,
max_model_len=4096,
gpu_memory_utilization=0.9,
max_num_seqs=256
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
async def generate(self, prompt, request_id):
"""异步生成"""
sampling_params = SamplingParams(temperature=0.7, max_tokens=256)
async for output in self.engine.generate(
prompt, sampling_params, request_id=request_id
):
yield output.outputs[0].text
async def batch_generate(self, prompts):
"""批量异步生成"""
tasks = []
for i, prompt in enumerate(prompts):
task = self.generate(prompt, request_id=i)
tasks.append(task)
results = []
async for result in asyncio.gather(*tasks):
results.append(result)
return results
# 使用
service = AsyncInferenceService("meta-llama/Llama-2-7b-hf")
results = asyncio.run(service.batch_generate(["问题1", "问题2", "问题3"]))
调度策略
公平调度
class FairScheduler:
"""公平调度器"""
def __init__(self):
self.queues = {
"high_priority": [],
"normal": [],
"low_priority": []
}
def schedule(self):
"""公平调度"""
# 高优先级优先
for queue_name in ["high_priority", "normal", "low_priority"]:
if self.queues[queue_name]:
return self.queues[queue_name].pop(0)
return None
def add_request(self, request, priority="normal"):
"""添加请求"""
self.queues[priority].append(request)
贪心调度
class GreedyScheduler:
"""贪心调度器(最大化GPU利用率)"""
def schedule(self, waiting_queue, max_batch_size, current_batch):
"""贪心调度"""
# 计算剩余GPU容量
remaining_capacity = max_batch_size - len(current_batch)
if remaining_capacity <= 0:
return []
# 选择能填满容量的请求
selected = []
for request in waiting_queue:
if len(selected) >= remaining_capacity:
break
selected.append(request)
return selected
性能优化
批处理大小调优
def optimize_batch_size(model, tokenizer, test_prompts):
"""优化批处理大小"""
results = []
for batch_size in [16, 32, 64, 128, 256]:
llm = LLM(
model=model,
max_num_seqs=batch_size,
max_num_batched_tokens=batch_size * 512
)
# 测量吞吐量
start_time = time.time()
outputs = llm.generate(test_prompts[:batch_size],
SamplingParams(max_tokens=256))
throughput = len(test_prompts[:batch_size]) / (time.time() - start_time)
results.append({
"batch_size": batch_size,
"throughput": throughput,
"latency": np.mean([o.metrics.finished_time for o in outputs])
})
return results
内存优化
def optimize_memory_usage(llm_engine):
"""优化内存使用"""
stats = llm_engine.scheduler.block_manager.get_stats()
# 动态调整
if stats["gpu_memory_usage"] > 0.9:
# 减少批处理大小
llm_engine.scheduler.max_num_seqs *= 0.9
elif stats["gpu_memory_usage"] < 0.5:
# 增加批处理大小
llm_engine.scheduler.max_num_seqs = min(
llm_engine.scheduler.max_num_seqs * 1.1,
512
)
监控指标
from prometheus_client import Histogram, Gauge, Counter
# 监控指标
REQUEST_LATENCY = Histogram(
'llm_request_latency_seconds',
'Request latency',
buckets=[0.1, 0.5, 1, 2, 5, 10]
)
THROUGHPUT = Gauge(
'llm_throughput_requests_per_second',
'Requests per second'
)
BATCH_SIZE = Gauge(
'llm_current_batch_size',
'Current batch size'
)
QUEUE_SIZE = Gauge(
'llm_queue_size',
'Number of waiting requests'
)
COMPLETED_REQUESTS = Counter(
'llm_completed_requests_total',
'Total completed requests'
)
def monitor_continuous_batching(batcher):
"""监控连续批处理"""
while True:
# 更新指标
BATCH_SIZE.set(len(batcher.active_requests))
QUEUE_SIZE.set(len(batcher.waiting_queue))
time.sleep(1)
与其他技术结合
# 连续批处理 + PagedAttention + Flash Attention
def create_optimized_engine(model_name):
"""创建优化的推理引擎"""
from vllm import LLM
llm = LLM(
model=model_name,
# PagedAttention配置
block_size=16,
gpu_memory_utilization=0.9,
# 连续批处理配置
max_num_batched_tokens=16384,
max_num_seqs=512,
# Flash Attention
enforce_eager=False, # 启用CUDA Graph
# 张量并行
tensor_parallel_size=1
)
return llm
连续批处理是实现高吞吐量LLM服务的关键技术,与PagedAttention和Flash Attention结合可以达到最佳性能。