批处理策略
--- title: "批处理策略" description: "LLM批处理策略,包括连续批处理和动态批处理实现" tags: ["批处理", "连续批处理", "动态批处理", "推理优化", "吞吐量"] category: "llm" icon: "🧠"
批处理策略
批处理是提升LLM推理吞吐量的关键技术。通过将多个请求组合在一起处理,可以充分利用GPU的并行计算能力,显著提高硬件利用率和系统吞吐量。不同的批处理策略适用于不同的场景。
批处理类型
静态批处理
最简单的批处理方式,等待批次填满后统一处理。
import asyncio
from typing import List, Any
from dataclasses import dataclass, field
import time
@dataclass
class StaticBatchProcessor:
batch_size: int
timeout: float
queue: List[Any] = field(default_factory=list)
results: dict = field(default_factory=dict)
async def add_request(self, request_id: str, data: dict) -> dict:
future = asyncio.Future()
self.queue.append({"id": request_id, "data": data, "future": future})
if len(self.queue) >= self.batch_size:
await self._process_batch()
try:
return await asyncio.wait_for(future, timeout=self.timeout)
except asyncio.TimeoutError:
return {"error": "timeout"}
async def _process_batch(self):
batch = self.queue[:self.batch_size]
self.queue = self.queue[self.batch_size:]
results = []
for item in batch:
result = f"processed_{item['id']}"
results.append(result)
item["future"].set_result(result)
async def process_pending(self):
if self.queue:
await self._process_batch()
processor = StaticBatchProcessor(batch_size=4, timeout=5.0)
async def demo():
tasks = [processor.add_request(f"req-{i}", {"input": f"data-{i}"}) for i in range(8)]
await asyncio.gather(*tasks)
print("静态批处理完成")
asyncio.run(demo())
动态批处理
根据请求到达动态调整批次大小,减少等待时间。
import asyncio
import time
from typing import List
from dataclasses import dataclass, field
@dataclass
class DynamicBatchProcessor:
max_batch_size: int
max_wait_time: float
queue: List[dict] = field(default_factory=list)
batch_start_time: float = 0
async def add_request(self, request_id: str, data: dict) -> dict:
future = asyncio.Future()
self.queue.append({"id": request_id, "data": data, "future": future})
if not self.batch_start_time:
self.batch_start_time = time.time()
wait_time = time.time() - self.batch_start_time
should_process = (
len(self.queue) >= self.max_batch_size or
wait_time >= self.max_wait_time
)
if should_process:
await self._process_batch()
self.batch_start_time = 0
return await future
async def _process_batch(self):
batch = self.queue[:self.max_batch_size]
self.queue = self.queue[self.max_batch_size:]
for item in batch:
result = f"processed_{item['id']}"
item["future"].set_result(result)
async def flush(self):
if self.queue:
await self._process_batch()
self.batch_start_time = 0
processor = DynamicBatchProcessor(max_batch_size=8, max_wait_time=0.1)
async def demo():
tasks = [processor.add_request(f"req-{i}", {"input": f"data-{i}"}) for i in range(5)]
await asyncio.gather(*tasks)
await processor.flush()
print("动态批处理完成")
asyncio.run(demo())
连续批处理
vLLM采用的连续批处理策略,允许在生成过程中动态插入新请求。
import asyncio
from dataclasses import dataclass, field
from typing import List, Dict
@dataclass
class ContinuousBatchScheduler:
max_batch_size: int
active_requests: Dict[str, dict] = field(default_factory=dict)
waiting_queue: List[dict] = field(default_factory=list)
completed: List[str] = field(default_factory=list)
def submit(self, request_id: str, prompt: str, max_tokens: int):
self.waiting_queue.append({
"id": request_id,
"prompt": prompt,
"max_tokens": max_tokens,
"generated_tokens": 0,
"status": "waiting"
})
def schedule_step(self) -> List[str]:
batch = list(self.active_requests.values())
while len(batch) < self.max_batch_size and self.waiting_queue:
new_req = self.waiting_queue.pop(0)
new_req["status"] = "running"
self.active_requests[new_req["id"]] = new_req
batch.append(new_req)
processed = []
for req in list(self.active_requests.values()):
req["generated_tokens"] += 1
if req["generated_tokens"] >= req["max_tokens"]:
self.completed.append(req["id"])
del self.active_requests[req["id"]]
processed.append(req["id"])
return processed
def get_stats(self) -> dict:
return {
"active": len(self.active_requests),
"waiting": len(self.waiting_queue),
"completed": len(self.completed)
}
scheduler = ContinuousBatchScheduler(max_batch_size=4)
scheduler.submit("req-0", "什么是机器学习", 10)
scheduler.submit("req-1", "解释量子计算", 15)
scheduler.submit("req-2", "用Python排序", 12)
for step in range(20):
completed = scheduler.schedule_step()
stats = scheduler.get_stats()
if step % 5 == 0:
print(f"Step {step}: {stats}, completed: {completed}")
if stats["active"] == 0 and stats["waiting"] == 0:
break
print(f"最终统计: {scheduler.get_stats()}")
vLLM配置
from vllm import LLM, SamplingParams
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_num_batched_tokens=8192,
max_num_seqs=64,
gpu_memory_utilization=0.9,
enable_chunked_prefill=True
)
prompts = [f"请用Python实现第{i+1}个排序算法" for i in range(10)]
sampling_params = SamplingParams(temperature=0.7, max_tokens=256)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.outputs[0].text[:50])
性能对比
| 策略 | 吞吐量 | 延迟 | 适用场景 |
|---|---|---|---|
| 无批处理 | 低 | 最低 | 实时交互 |
| 静态批处理 | 高 | 较高 | 离线处理 |
| 动态批处理 | 中高 | 中等 | 混合负载 |
| 连续批处理 | 最高 | 低 | 生产环境 |
选择合适的批处理策略需要权衡吞吐量和延迟需求。实时交互场景优先考虑延迟,离线处理场景优先考虑吞吐量。