← 返回首页
🧠

批处理策略

📂 llm ⏱ 3 min 462 words

--- title: "批处理策略" description: "LLM批处理策略,包括连续批处理和动态批处理实现" tags: ["批处理", "连续批处理", "动态批处理", "推理优化", "吞吐量"] category: "llm" icon: "🧠"

批处理策略

批处理是提升LLM推理吞吐量的关键技术。通过将多个请求组合在一起处理,可以充分利用GPU的并行计算能力,显著提高硬件利用率和系统吞吐量。不同的批处理策略适用于不同的场景。

批处理类型

静态批处理

最简单的批处理方式,等待批次填满后统一处理。

import asyncio
from typing import List, Any
from dataclasses import dataclass, field
import time

@dataclass
class StaticBatchProcessor:
    batch_size: int
    timeout: float
    queue: List[Any] = field(default_factory=list)
    results: dict = field(default_factory=dict)

    async def add_request(self, request_id: str, data: dict) -> dict:
        future = asyncio.Future()
        self.queue.append({"id": request_id, "data": data, "future": future})

        if len(self.queue) >= self.batch_size:
            await self._process_batch()

        try:
            return await asyncio.wait_for(future, timeout=self.timeout)
        except asyncio.TimeoutError:
            return {"error": "timeout"}

    async def _process_batch(self):
        batch = self.queue[:self.batch_size]
        self.queue = self.queue[self.batch_size:]

        results = []
        for item in batch:
            result = f"processed_{item['id']}"
            results.append(result)
            item["future"].set_result(result)

    async def process_pending(self):
        if self.queue:
            await self._process_batch()

processor = StaticBatchProcessor(batch_size=4, timeout=5.0)
async def demo():
    tasks = [processor.add_request(f"req-{i}", {"input": f"data-{i}"}) for i in range(8)]
    await asyncio.gather(*tasks)
    print("静态批处理完成")

asyncio.run(demo())

动态批处理

根据请求到达动态调整批次大小,减少等待时间。

import asyncio
import time
from typing import List
from dataclasses import dataclass, field

@dataclass
class DynamicBatchProcessor:
    max_batch_size: int
    max_wait_time: float
    queue: List[dict] = field(default_factory=list)
    batch_start_time: float = 0

    async def add_request(self, request_id: str, data: dict) -> dict:
        future = asyncio.Future()
        self.queue.append({"id": request_id, "data": data, "future": future})

        if not self.batch_start_time:
            self.batch_start_time = time.time()

        wait_time = time.time() - self.batch_start_time
        should_process = (
            len(self.queue) >= self.max_batch_size or
            wait_time >= self.max_wait_time
        )

        if should_process:
            await self._process_batch()
            self.batch_start_time = 0

        return await future

    async def _process_batch(self):
        batch = self.queue[:self.max_batch_size]
        self.queue = self.queue[self.max_batch_size:]

        for item in batch:
            result = f"processed_{item['id']}"
            item["future"].set_result(result)

    async def flush(self):
        if self.queue:
            await self._process_batch()
            self.batch_start_time = 0

processor = DynamicBatchProcessor(max_batch_size=8, max_wait_time=0.1)
async def demo():
    tasks = [processor.add_request(f"req-{i}", {"input": f"data-{i}"}) for i in range(5)]
    await asyncio.gather(*tasks)
    await processor.flush()
    print("动态批处理完成")

asyncio.run(demo())

连续批处理

vLLM采用的连续批处理策略,允许在生成过程中动态插入新请求。

import asyncio
from dataclasses import dataclass, field
from typing import List, Dict

@dataclass
class ContinuousBatchScheduler:
    max_batch_size: int
    active_requests: Dict[str, dict] = field(default_factory=dict)
    waiting_queue: List[dict] = field(default_factory=list)
    completed: List[str] = field(default_factory=list)

    def submit(self, request_id: str, prompt: str, max_tokens: int):
        self.waiting_queue.append({
            "id": request_id,
            "prompt": prompt,
            "max_tokens": max_tokens,
            "generated_tokens": 0,
            "status": "waiting"
        })

    def schedule_step(self) -> List[str]:
        batch = list(self.active_requests.values())
        while len(batch) < self.max_batch_size and self.waiting_queue:
            new_req = self.waiting_queue.pop(0)
            new_req["status"] = "running"
            self.active_requests[new_req["id"]] = new_req
            batch.append(new_req)

        processed = []
        for req in list(self.active_requests.values()):
            req["generated_tokens"] += 1
            if req["generated_tokens"] >= req["max_tokens"]:
                self.completed.append(req["id"])
                del self.active_requests[req["id"]]
                processed.append(req["id"])

        return processed

    def get_stats(self) -> dict:
        return {
            "active": len(self.active_requests),
            "waiting": len(self.waiting_queue),
            "completed": len(self.completed)
        }

scheduler = ContinuousBatchScheduler(max_batch_size=4)
scheduler.submit("req-0", "什么是机器学习", 10)
scheduler.submit("req-1", "解释量子计算", 15)
scheduler.submit("req-2", "用Python排序", 12)

for step in range(20):
    completed = scheduler.schedule_step()
    stats = scheduler.get_stats()
    if step % 5 == 0:
        print(f"Step {step}: {stats}, completed: {completed}")
    if stats["active"] == 0 and stats["waiting"] == 0:
        break

print(f"最终统计: {scheduler.get_stats()}")

vLLM配置

from vllm import LLM, SamplingParams

llm = LLM(
    model="Qwen/Qwen2.5-7B-Instruct",
    max_num_batched_tokens=8192,
    max_num_seqs=64,
    gpu_memory_utilization=0.9,
    enable_chunked_prefill=True
)

prompts = [f"请用Python实现第{i+1}个排序算法" for i in range(10)]
sampling_params = SamplingParams(temperature=0.7, max_tokens=256)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    print(output.outputs[0].text[:50])

性能对比

策略 吞吐量 延迟 适用场景
无批处理 最低 实时交互
静态批处理 较高 离线处理
动态批处理 中高 中等 混合负载
连续批处理 最高 生产环境

选择合适的批处理策略需要权衡吞吐量和延迟需求。实时交互场景优先考虑延迟,离线处理场景优先考虑吞吐量。