LLM推理管道:优化模型部署与服务
--- title: "LLM推理管道:优化模型部署与服务" description: "构建高效的LLM推理管道,实现低延迟、高吞吐的模型服务" tags: ["LLM", "推理管道", "模型部署", "性能优化", "服务化"] category: "llm" icon: "⚡"
LLM推理管道:优化模型部署与服务
推理管道概述
LLM推理管道负责将训练好的模型转化为可服务的API,需要处理模型加载、请求调度、批处理优化和结果缓存等环节。
推理架构设计
1. 模型服务类
import json
from pathlib import Path
from typing import List, Dict, Optional
from dataclasses import dataclass
@dataclass
class InferenceConfig:
model_path: str
max_batch_size: int = 32
max_sequence_length: int = 2048
device: str = "cuda"
num_workers: int = 4
cache_size: int = 1000
class LLMService:
def __init__(self, config: InferenceConfig):
self.config = config
self.model = None
self.tokenizer = None
self.cache = {}
def load_model(self):
"""加载模型和分词器"""
print(f"加载模型: {self.config.model_path}")
# 实际实现中使用transformers或vllm
# from transformers import AutoModelForCausalLM, AutoTokenizer
# self.model = AutoModelForCausalLM.from_pretrained(self.config.model_path)
# self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
print("模型加载完成")
def preprocess(self, prompts: List[str]) -> List[str]:
"""预处理输入"""
processed = []
for prompt in prompts:
# 截断过长的输入
if len(prompt) > self.config.max_sequence_length:
prompt = prompt[:self.config.max_sequence_length]
processed.append(prompt)
return processed
def postprocess(self, outputs: List[str]) -> List[str]:
"""后处理输出"""
return [output.strip() for output in outputs]
def generate(self, prompts: List[str], **kwargs) -> List[str]:
"""生成文本"""
# 检查缓存
cached_results = []
uncached_prompts = []
uncached_indices = []
for i, prompt in enumerate(prompts):
cache_key = self._get_cache_key(prompt, kwargs)
if cache_key in self.cache:
cached_results.append((i, self.cache[cache_key]))
else:
uncached_prompts.append(prompt)
uncached_indices.append(i)
# 推理未缓存的输入
if uncached_prompts:
processed = self.preprocess(uncached_prompts)
# 实际推理逻辑
new_results = self._inference(processed, **kwargs)
new_results = self.postprocess(new_results)
# 更新缓存
for prompt, result in zip(uncached_prompts, new_results):
cache_key = self._get_cache_key(prompt, kwargs)
if len(self.cache) < self.config.cache_size:
self.cache[cache_key] = result
# 合并结果
for idx, result in zip(uncached_indices, new_results):
cached_results.append((idx, result))
# 按原始顺序排列
cached_results.sort(key=lambda x: x[0])
return [r[1] for r in cached_results]
def _inference(self, prompts: List[str], **kwargs) -> List[str]:
"""执行实际推理"""
# 模拟推理结果
return [f"生成的回复: {p[:20]}..." for p in prompts]
def _get_cache_key(self, prompt, kwargs):
"""生成缓存键"""
import hashlib
key_str = json.dumps({"prompt": prompt, **kwargs}, sort_keys=True)
return hashlib.md5(key_str.encode()).hexdigest()
2. 批处理管理器
import asyncio
from collections import deque
from typing import Callable
class BatchManager:
def __init__(self, process_func: Callable, max_batch_size: int, max_wait_time: float = 0.1):
self.process_func = process_func
self.max_batch_size = max_batch_size
self.max_wait_time = max_wait_time
self.queue = deque()
self.results = {}
async def add_request(self, request_id: str, data: dict) -> dict:
"""添加请求到批次"""
future = asyncio.Future()
self.queue.append({"id": request_id, "data": data, "future": future})
if len(self.queue) >= self.max_batch_size:
await self._process_batch()
return await future
async def _process_batch(self):
"""处理当前批次"""
if not self.queue:
return
batch = []
while self.queue and len(batch) < self.max_batch_size:
batch.append(self.queue.popleft())
# 执行批量推理
inputs = [item["data"]["prompt"] for item in batch]
results = self.process_func(inputs)
# 返回结果
for item, result in zip(batch, results):
item["future"].set_result(result)
3. 请求路由器
class RequestRouter:
def __init__(self):
self.routes = {}
self.fallback = None
def register_route(self, path: str, handler: Callable):
self.routes[path] = handler
def set_fallback(self, handler: Callable):
self.fallback = handler
async def route(self, request: dict) -> dict:
"""路由请求"""
path = request.get("path", "/default")
handler = self.routes.get(path, self.fallback)
if handler is None:
return {"error": f"未知路径: {path}"}
return await handler(request)
性能优化
1. KV缓存优化
class KVCacheManager:
def __init__(self, max_size: int = 1000):
self.max_size = max_size
self.cache = {}
self.access_order = []
def get(self, key):
"""获取缓存"""
if key in self.cache:
self.access_order.remove(key)
self.access_order.append(key)
return self.cache[key]
return None
def put(self, key, value):
"""添加缓存"""
if key in self.cache:
self.access_order.remove(key)
elif len(self.cache) >= self.max_size:
# LRU淘汰
oldest = self.access_order.pop(0)
del self.cache[oldest]
self.cache[key] = value
self.access_order.append(key)
2. 流式输出
class StreamingGenerator:
def __init__(self, model_service: LLMService):
self.service = model_service
async def generate_stream(self, prompt: str, **kwargs):
"""流式生成文本"""
# 模拟流式输出
words = ["这是", "一个", "流式", "输出", "的", "示例"]
for word in words:
yield {"token": word, "finished": False}
await asyncio.sleep(0.05)
yield {"token": "", "finished": True}
完整推理管道
class InferencePipeline:
def __init__(self, config: InferenceConfig):
self.service = LLMService(config)
self.router = RequestRouter()
self._setup_routes()
def _setup_routes(self):
"""设置路由"""
self.router.register_route("/generate", self.handle_generate)
self.router.register_route("/chat", self.handle_chat)
self.router.set_fallback(self.handle_default)
async def handle_generate(self, request: dict) -> dict:
"""处理生成请求"""
prompt = request.get("prompt", "")
result = self.service.generate([prompt])
return {"response": result[0]}
async def handle_chat(self, request: dict) -> dict:
"""处理对话请求"""
messages = request.get("messages", [])
prompt = messages[-1].get("content", "") if messages else ""
result = self.service.generate([prompt])
return {"response": result[0]}
async def handle_default(self, request: dict) -> dict:
return {"error": "未知的请求类型"}
def start(self):
"""启动服务"""
self.service.load_model()
print("推理管道已启动")
# 使用示例
config = InferenceConfig(model_path="models/llama-2-7b")
pipeline = InferencePipeline(config)
pipeline.start()
最佳实践
- 负载均衡:在多GPU或多节点间分配请求
- 健康检查:监控服务状态和性能指标
- 降级策略:在模型不可用时提供降级方案
- A/B测试:支持多版本模型并行服务
总结
推理管道是LLM应用的关键组件。通过批处理优化、KV缓存、流式输出等技术,我们可以构建高性能、高可用的LLM推理服务。