← 返回首页
🤖

LLM推理架构:批处理、流式与KV Cache优化

📂 architecture ⏱ 5 min 846 words

LLM推理架构:批处理、流式与KV Cache优化

LLM推理的核心挑战

大语言模型推理面临独特挑战:自回归生成需要串行计算、KV Cache占用大量显存、长上下文增加计算复杂度、批处理效率影响服务吞吐量。推理优化需要在延迟、吞吐量和成本之间找到最佳平衡点。

# LLM推理请求管理
from dataclasses import dataclass, field
from typing import List, Optional, AsyncGenerator
from collections import deque
import asyncio

@dataclass
class LLMRequest:
    request_id: str
    prompt: str
    max_tokens: int = 1024
    temperature: float = 1.0
    top_p: float = 0.9
    stream: bool = False
    stop: List[str] = field(default_factory=list)

@dataclass
class LLMResponse:
    request_id: str
    text: str
    tokens_generated: int
    latency_ms: float
    usage: dict = field(default_factory=dict)

class LLMInferenceEngine:
    def __init__(self, model_name: str, max_batch_size: int = 8):
        self.model_name = model_name
        self.max_batch_size = max_batch_size
        self.request_queue = deque()
        self.kv_cache = {}
    
    async def generate(self, request: LLMRequest) -> LLMResponse:
        if request.stream:
            return await self._stream_generate(request)
        return await self._batch_generate(request)
    
    async def _batch_generate(self, request: LLMRequest) -> LLMResponse:
        import time
        start_time = time.time()
        
        # KV Cache计算
        cache_key = hash(request.prompt[:100])
        cached = self.kv_cache.get(cache_key)
        
        if cached:
            tokens = self._continue_generation(cached, request.max_tokens)
        else:
            tokens = self._full_generation(request.prompt, request.max_tokens)
            self.kv_cache[cache_key] = tokens
        
        latency = (time.time() - start_time) * 1000
        
        return LLMResponse(
            request_id=request.request_id,
            text=tokens,
            tokens_generated=len(tokens.split()),
            latency_ms=latency,
            usage={"prompt_tokens": len(request.prompt.split())}
        )
    
    async def _stream_generate(self, request: LLMRequest) -> AsyncGenerator:
        for i in range(request.max_tokens):
            token = f"token_{i}"
            yield {"token": token, "finished": i == request.max_tokens - 1}
            await asyncio.sleep(0.01)
    
    def _full_generation(self, prompt: str, max_tokens: int) -> str:
        return f"Generated text from {prompt[:50]}..."
    
    def _continue_generation(self, cached, max_tokens: int) -> str:
        return f"Continued from cache..."

KV Cache管理策略

KV Cache是LLM推理的关键优化,缓存已计算的Key和Value张量,避免重复计算。但KV Cache占用显存随序列长度线性增长,需要精细的内存管理策略。

# KV Cache管理器
from collections import OrderedDict
import torch

class KVCacheManager:
    def __init__(self, max_cache_size: int, eviction_policy: str = "lru"):
        self.max_cache_size = max_cache_size
        self.eviction_policy = eviction_policy
        self.cache = OrderedDict()
        self.total_tokens = 0
    
    def get(self, prompt_hash: str) -> Optional[dict]:
        if prompt_hash in self.cache:
            self.cache.move_to_end(prompt_hash)
            return self.cache[prompt_hash]
        return None
    
    def put(self, prompt_hash: str, kv_data: dict, token_count: int):
        if prompt_hash in self.cache:
            self.cache.move_to_end(prompt_hash)
            return
        
        while self.total_tokens + token_count > self.max_cache_size:
            self._evict()
        
        self.cache[prompt_hash] = kv_data
        self.total_tokens += token_count
    
    def _evict(self):
        if self.eviction_policy == "lru":
            key, _ = self.cache.popitem(last=False)
        elif self.eviction_policy == "lfu":
            key = min(self.cache, key=lambda k: self.cache[k]["access_count"])
            del self.cache[key]
        else:
            key, _ = self.cache.popitem(last=False)
    
    def get_memory_usage(self) -> dict:
        return {
            "total_tokens": self.total_tokens,
            "cache_entries": len(self.cache),
            "max_capacity": self.max_cache_size,
            "utilization": self.total_tokens / self.max_cache_size
        }

# PagedAttention实现
class PagedKVCache:
    """借鉴操作系统分页机制的KV Cache实现"""
    
    def __init__(self, page_size: int = 16, total_pages: int = 1000):
        self.page_size = page_size
        self.total_pages = total_pages
        self.free_pages = list(range(total_pages))
        self.page_table = {}  # sequence_id -> [page_indices]
    
    def allocate_pages(self, sequence_id: str, num_pages: int) -> list:
        if len(self.free_pages) < num_pages:
            raise MemoryError("Not enough KV cache pages")
        
        allocated = self.free_pages[:num_pages]
        self.free_pages = self.free_pages[num_pages:]
        self.page_table[sequence_id] = allocated
        return allocated
    
    def free_pages(self, sequence_id: str):
        if sequence_id in self.page_table:
            pages = self.page_table.pop(sequence_id)
            self.free_pages.extend(pages)
    
    def get_utilization(self) -> float:
        used = self.total_pages - len(self.free_pages)
        return used / self.total_pages

连续批处理与调度

连续批处理(Continuous Batching)允许不同长度的请求动态加入和离开批次,相比静态批处理大幅提升GPU利用率。vLLM和TensorRT-LLM都实现了这一技术。

# 连续批处理调度器
from typing import Dict, Set
import heapq

class ContinuousBatchScheduler:
    def __init__(self, max_batch_size: int, max_tokens_per_batch: int):
        self.max_batch_size = max_batch_size
        self.max_tokens = max_tokens_per_batch
        self.pending_queue = []
        self.active_batch: Dict[str, dict] = {}
        self.completed: Set[str] = set()
    
    def submit_request(self, request: LLMRequest):
        priority = -request.max_tokens  # 短请求优先
        heapq.heappush(self.pending_queue, (priority, request))
    
    def schedule_next_batch(self) -> List[LLMRequest]:
        batch = []
        current_tokens = 0
        
        while (self.pending_queue and 
               len(batch) < self.max_batch_size and
               current_tokens < self.max_tokens):
            _, request = heapq.heappop(self.pending_queue)
            
            if current_tokens + request.max_tokens <= self.max_tokens:
                batch.append(request)
                self.active_batch[request.request_id] = {
                    "request": request,
                    "tokens_generated": 0
                }
                current_tokens += request.max_tokens
        
        return batch
    
    def update_progress(self, request_id: str, tokens_generated: int):
        if request_id in self.active_batch:
            self.active_batch[request_id]["tokens_generated"] += tokens_generated
            
            request = self.active_batch[request_id]["request"]
            if (self.active_batch[request_id]["tokens_generated"] >= request.max_tokens or
                any(stop in "" for stop in request.stop)):
                self.completed.add(request_id)
                del self.active_batch[request_id]
    
    def get_batch_status(self) -> dict:
        return {
            "pending": len(self.pending_queue),
            "active": len(self.active_batch),
            "completed": len(self.completed),
            "gpu_utilization": len(self.active_batch) / self.max_batch_size
        }

# 投机解码
class SpeculativeDecoder:
    def __init__(self, draft_model, target_model, gamma: int = 5):
        self.draft_model = draft_model
        self.target_model = target_model
        self.gamma = gamma  # 投机长度
    
    def generate(self, prompt: str, max_tokens: int) -> str:
        tokens = prompt.split()
        
        for _ in range(max_tokens // self.gamma):
            # 草稿模型生成gamma个token
            draft_tokens = self._draft_generate(tokens, self.gamma)
            
            # 目标模型验证
            verified = self._verify_tokens(tokens + draft_tokens)
            
            tokens.extend(verified)
        
        return " ".join(tokens)
    
    def _draft_generate(self, tokens: list, gamma: int) -> list:
        return [f"draft_{i}" for i in range(gamma)]
    
    def _verify_tokens(self, tokens: list) -> list:
        return tokens[-5:]  # 返回验证通过的token

流式响应架构

流式响应是LLM服务的关键特性,需要支持Server-Sent Events(SSE)、WebSocket和gRPC Streaming。架构需要处理背压、连接管理和错误恢复。

# 流式响应处理器
import asyncio
from typing import AsyncGenerator, Callable
import json

class StreamingHandler:
    def __init__(self, buffer_size: int = 10):
        self.buffer_size = buffer_size
        self.token_buffer = []
        self.subscribers = []
    
    async def stream_tokens(self, request: LLMRequest) -> AsyncGenerator:
        request_id = request.request_id
        
        for i in range(request.max_tokens):
            token = f"token_{i}"
            self.token_buffer.append(token)
            
            # 逐token或批量发送
            if len(self.token_buffer) >= self.buffer_size:
                yield self._format_sse(self.token_buffer)
                self.token_buffer = []
            
            await asyncio.sleep(0.01)
        
        # 发送剩余token
        if self.token_buffer:
            yield self._format_sse(self.token_buffer)
        
        # 发送完成信号
        yield self._format_sse([], finished=True)
    
    def _format_sse(self, tokens: list, finished: bool = False) -> str:
        data = {
            "tokens": tokens,
            "finished": finished
        }
        return f"data: {json.dumps(data)}\n\n"

class WebSocketStreamingManager:
    def __init__(self):
        self.connections = {}
    
    async def handle_connection(self, websocket, request_id: str):
        self.connections[request_id] = websocket
        
        try:
            async for message in websocket:
                request = json.loads(message)
                async for chunk in self.stream_response(request):
                    await websocket.send(chunk)
        finally:
            del self.connections[request_id]
    
    async def stream_response(self, request: dict) -> AsyncGenerator:
        for token in self._generate_tokens(request):
            yield json.dumps({"token": token})
    
    def _generate_tokens(self, request: dict):
        return ["hello", " ", "world"]

推理成本优化

LLM推理成本主要来自GPU算力。优化策略包括:量化(INT8/INT4)、模型蒸馏、批处理合并、KV Cache复用、请求优先级调度。监控成本指标确保服务经济性。

# 推理成本管理
@dataclass
class CostMetrics:
    gpu_hours: float
    tokens_processed: int
    cost_per_token: float
    total_cost: float

class CostOptimizer:
    def __init__(self, gpu_cost_per_hour: float = 3.0):
        self.gpu_cost_per_hour = gpu_cost_per_hour
    
    def calculate_cost(self, requests: list) -> CostMetrics:
        total_tokens = sum(r.get("tokens", 0) for r in requests)
        estimated_hours = total_tokens / 1000000  # 假设1M tokens/hour
        
        return CostMetrics(
            gpu_hours=estimated_hours,
            tokens_processed=total_tokens,
            cost_per_token=self.gpu_cost_per_hour / 1000000,
            total_cost=estimated_hours * self.gpu_cost_per_hour
        )
    
    def optimize_batching(self, requests: list, 
                         target_latency_ms: float) -> dict:
        """根据延迟目标优化批处理大小"""
        # 分析请求延迟分布
        latencies = [r.get("latency_ms", 0) for r in requests]
        avg_latency = sum(latencies) / len(latencies) if latencies else 0
        
        # 计算最优批大小
        optimal_batch = min(32, int(target_latency_ms / max(avg_latency, 1)))
        
        return {
            "optimal_batch_size": optimal_batch,
            "estimated_latency_ms": avg_latency * optimal_batch,
            "throughput_improvement": optimal_batch
        }