LLM推理架构:批处理、流式与KV Cache优化
LLM推理架构:批处理、流式与KV Cache优化
LLM推理的核心挑战
大语言模型推理面临独特挑战:自回归生成需要串行计算、KV Cache占用大量显存、长上下文增加计算复杂度、批处理效率影响服务吞吐量。推理优化需要在延迟、吞吐量和成本之间找到最佳平衡点。
# LLM推理请求管理
from dataclasses import dataclass, field
from typing import List, Optional, AsyncGenerator
from collections import deque
import asyncio
@dataclass
class LLMRequest:
request_id: str
prompt: str
max_tokens: int = 1024
temperature: float = 1.0
top_p: float = 0.9
stream: bool = False
stop: List[str] = field(default_factory=list)
@dataclass
class LLMResponse:
request_id: str
text: str
tokens_generated: int
latency_ms: float
usage: dict = field(default_factory=dict)
class LLMInferenceEngine:
def __init__(self, model_name: str, max_batch_size: int = 8):
self.model_name = model_name
self.max_batch_size = max_batch_size
self.request_queue = deque()
self.kv_cache = {}
async def generate(self, request: LLMRequest) -> LLMResponse:
if request.stream:
return await self._stream_generate(request)
return await self._batch_generate(request)
async def _batch_generate(self, request: LLMRequest) -> LLMResponse:
import time
start_time = time.time()
# KV Cache计算
cache_key = hash(request.prompt[:100])
cached = self.kv_cache.get(cache_key)
if cached:
tokens = self._continue_generation(cached, request.max_tokens)
else:
tokens = self._full_generation(request.prompt, request.max_tokens)
self.kv_cache[cache_key] = tokens
latency = (time.time() - start_time) * 1000
return LLMResponse(
request_id=request.request_id,
text=tokens,
tokens_generated=len(tokens.split()),
latency_ms=latency,
usage={"prompt_tokens": len(request.prompt.split())}
)
async def _stream_generate(self, request: LLMRequest) -> AsyncGenerator:
for i in range(request.max_tokens):
token = f"token_{i}"
yield {"token": token, "finished": i == request.max_tokens - 1}
await asyncio.sleep(0.01)
def _full_generation(self, prompt: str, max_tokens: int) -> str:
return f"Generated text from {prompt[:50]}..."
def _continue_generation(self, cached, max_tokens: int) -> str:
return f"Continued from cache..."
KV Cache管理策略
KV Cache是LLM推理的关键优化,缓存已计算的Key和Value张量,避免重复计算。但KV Cache占用显存随序列长度线性增长,需要精细的内存管理策略。
# KV Cache管理器
from collections import OrderedDict
import torch
class KVCacheManager:
def __init__(self, max_cache_size: int, eviction_policy: str = "lru"):
self.max_cache_size = max_cache_size
self.eviction_policy = eviction_policy
self.cache = OrderedDict()
self.total_tokens = 0
def get(self, prompt_hash: str) -> Optional[dict]:
if prompt_hash in self.cache:
self.cache.move_to_end(prompt_hash)
return self.cache[prompt_hash]
return None
def put(self, prompt_hash: str, kv_data: dict, token_count: int):
if prompt_hash in self.cache:
self.cache.move_to_end(prompt_hash)
return
while self.total_tokens + token_count > self.max_cache_size:
self._evict()
self.cache[prompt_hash] = kv_data
self.total_tokens += token_count
def _evict(self):
if self.eviction_policy == "lru":
key, _ = self.cache.popitem(last=False)
elif self.eviction_policy == "lfu":
key = min(self.cache, key=lambda k: self.cache[k]["access_count"])
del self.cache[key]
else:
key, _ = self.cache.popitem(last=False)
def get_memory_usage(self) -> dict:
return {
"total_tokens": self.total_tokens,
"cache_entries": len(self.cache),
"max_capacity": self.max_cache_size,
"utilization": self.total_tokens / self.max_cache_size
}
# PagedAttention实现
class PagedKVCache:
"""借鉴操作系统分页机制的KV Cache实现"""
def __init__(self, page_size: int = 16, total_pages: int = 1000):
self.page_size = page_size
self.total_pages = total_pages
self.free_pages = list(range(total_pages))
self.page_table = {} # sequence_id -> [page_indices]
def allocate_pages(self, sequence_id: str, num_pages: int) -> list:
if len(self.free_pages) < num_pages:
raise MemoryError("Not enough KV cache pages")
allocated = self.free_pages[:num_pages]
self.free_pages = self.free_pages[num_pages:]
self.page_table[sequence_id] = allocated
return allocated
def free_pages(self, sequence_id: str):
if sequence_id in self.page_table:
pages = self.page_table.pop(sequence_id)
self.free_pages.extend(pages)
def get_utilization(self) -> float:
used = self.total_pages - len(self.free_pages)
return used / self.total_pages
连续批处理与调度
连续批处理(Continuous Batching)允许不同长度的请求动态加入和离开批次,相比静态批处理大幅提升GPU利用率。vLLM和TensorRT-LLM都实现了这一技术。
# 连续批处理调度器
from typing import Dict, Set
import heapq
class ContinuousBatchScheduler:
def __init__(self, max_batch_size: int, max_tokens_per_batch: int):
self.max_batch_size = max_batch_size
self.max_tokens = max_tokens_per_batch
self.pending_queue = []
self.active_batch: Dict[str, dict] = {}
self.completed: Set[str] = set()
def submit_request(self, request: LLMRequest):
priority = -request.max_tokens # 短请求优先
heapq.heappush(self.pending_queue, (priority, request))
def schedule_next_batch(self) -> List[LLMRequest]:
batch = []
current_tokens = 0
while (self.pending_queue and
len(batch) < self.max_batch_size and
current_tokens < self.max_tokens):
_, request = heapq.heappop(self.pending_queue)
if current_tokens + request.max_tokens <= self.max_tokens:
batch.append(request)
self.active_batch[request.request_id] = {
"request": request,
"tokens_generated": 0
}
current_tokens += request.max_tokens
return batch
def update_progress(self, request_id: str, tokens_generated: int):
if request_id in self.active_batch:
self.active_batch[request_id]["tokens_generated"] += tokens_generated
request = self.active_batch[request_id]["request"]
if (self.active_batch[request_id]["tokens_generated"] >= request.max_tokens or
any(stop in "" for stop in request.stop)):
self.completed.add(request_id)
del self.active_batch[request_id]
def get_batch_status(self) -> dict:
return {
"pending": len(self.pending_queue),
"active": len(self.active_batch),
"completed": len(self.completed),
"gpu_utilization": len(self.active_batch) / self.max_batch_size
}
# 投机解码
class SpeculativeDecoder:
def __init__(self, draft_model, target_model, gamma: int = 5):
self.draft_model = draft_model
self.target_model = target_model
self.gamma = gamma # 投机长度
def generate(self, prompt: str, max_tokens: int) -> str:
tokens = prompt.split()
for _ in range(max_tokens // self.gamma):
# 草稿模型生成gamma个token
draft_tokens = self._draft_generate(tokens, self.gamma)
# 目标模型验证
verified = self._verify_tokens(tokens + draft_tokens)
tokens.extend(verified)
return " ".join(tokens)
def _draft_generate(self, tokens: list, gamma: int) -> list:
return [f"draft_{i}" for i in range(gamma)]
def _verify_tokens(self, tokens: list) -> list:
return tokens[-5:] # 返回验证通过的token
流式响应架构
流式响应是LLM服务的关键特性,需要支持Server-Sent Events(SSE)、WebSocket和gRPC Streaming。架构需要处理背压、连接管理和错误恢复。
# 流式响应处理器
import asyncio
from typing import AsyncGenerator, Callable
import json
class StreamingHandler:
def __init__(self, buffer_size: int = 10):
self.buffer_size = buffer_size
self.token_buffer = []
self.subscribers = []
async def stream_tokens(self, request: LLMRequest) -> AsyncGenerator:
request_id = request.request_id
for i in range(request.max_tokens):
token = f"token_{i}"
self.token_buffer.append(token)
# 逐token或批量发送
if len(self.token_buffer) >= self.buffer_size:
yield self._format_sse(self.token_buffer)
self.token_buffer = []
await asyncio.sleep(0.01)
# 发送剩余token
if self.token_buffer:
yield self._format_sse(self.token_buffer)
# 发送完成信号
yield self._format_sse([], finished=True)
def _format_sse(self, tokens: list, finished: bool = False) -> str:
data = {
"tokens": tokens,
"finished": finished
}
return f"data: {json.dumps(data)}\n\n"
class WebSocketStreamingManager:
def __init__(self):
self.connections = {}
async def handle_connection(self, websocket, request_id: str):
self.connections[request_id] = websocket
try:
async for message in websocket:
request = json.loads(message)
async for chunk in self.stream_response(request):
await websocket.send(chunk)
finally:
del self.connections[request_id]
async def stream_response(self, request: dict) -> AsyncGenerator:
for token in self._generate_tokens(request):
yield json.dumps({"token": token})
def _generate_tokens(self, request: dict):
return ["hello", " ", "world"]
推理成本优化
LLM推理成本主要来自GPU算力。优化策略包括:量化(INT8/INT4)、模型蒸馏、批处理合并、KV Cache复用、请求优先级调度。监控成本指标确保服务经济性。
# 推理成本管理
@dataclass
class CostMetrics:
gpu_hours: float
tokens_processed: int
cost_per_token: float
total_cost: float
class CostOptimizer:
def __init__(self, gpu_cost_per_hour: float = 3.0):
self.gpu_cost_per_hour = gpu_cost_per_hour
def calculate_cost(self, requests: list) -> CostMetrics:
total_tokens = sum(r.get("tokens", 0) for r in requests)
estimated_hours = total_tokens / 1000000 # 假设1M tokens/hour
return CostMetrics(
gpu_hours=estimated_hours,
tokens_processed=total_tokens,
cost_per_token=self.gpu_cost_per_hour / 1000000,
total_cost=estimated_hours * self.gpu_cost_per_hour
)
def optimize_batching(self, requests: list,
target_latency_ms: float) -> dict:
"""根据延迟目标优化批处理大小"""
# 分析请求延迟分布
latencies = [r.get("latency_ms", 0) for r in requests]
avg_latency = sum(latencies) / len(latencies) if latencies else 0
# 计算最优批大小
optimal_batch = min(32, int(target_latency_ms / max(avg_latency, 1)))
return {
"optimal_batch_size": optimal_batch,
"estimated_latency_ms": avg_latency * optimal_batch,
"throughput_improvement": optimal_batch
}