架构设计
--- title: "架构设计" description: "掌握LLM系统架构设计的核心原则,包括微服务架构、可扩展性和性能优化策略" tags: ["架构设计", "LLM系统", "微服务", "性能优化"] category: "llm" icon: "🧠"
架构设计
LLM系统架构概述
构建一个生产级的LLM应用需要考虑多个方面:模型服务、数据处理、缓存、监控、安全等。良好的架构设计能够确保系统的可扩展性、可靠性和可维护性。
核心组件
1. 服务层设计
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
from dataclasses import dataclass
import asyncio
@dataclass
class RequestContext:
request_id: str
user_id: str
timestamp: float
metadata: Dict[str, Any]
class LLMService(ABC):
"""LLM服务基类"""
@abstractmethod
async def generate(self, prompt: str,
options: Dict[str, Any] = None) -> str:
pass
@abstractmethod
async def stream(self, prompt: str,
options: Dict[str, Any] = None):
pass
class LLMServiceImpl(LLMService):
"""LLM服务实现"""
def __init__(self, model_client, cache_service,
rate_limiter, logger):
self.model_client = model_client
self.cache = cache_service
self.rate_limiter = rate_limiter
self.logger = logger
async def generate(self, prompt: str,
options: Dict[str, Any] = None) -> str:
"""生成文本"""
options = options or {}
# 检查缓存
cache_key = self._get_cache_key(prompt, options)
cached_result = await self.cache.get(cache_key)
if cached_result:
self.logger.info(f"Cache hit for {cache_key}")
return cached_result
# 速率限制
await self.rate_limiter.acquire()
try:
# 调用模型
result = await self.model_client.generate(
prompt, **options
)
# 缓存结果
await self.cache.set(cache_key, result, ttl=3600)
return result
except Exception as e:
self.logger.error(f"Generation failed: {e}")
raise
async def stream(self, prompt: str,
options: Dict[str, Any] = None):
"""流式生成"""
options = options or {}
async for chunk in self.model_client.stream(prompt, **options):
yield chunk
def _get_cache_key(self, prompt: str, options: Dict) -> str:
"""生成缓存键"""
import hashlib
import json
content = json.dumps({"prompt": prompt, "options": options})
return hashlib.md5(content.encode()).hexdigest()
2. 微服务架构
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
app = FastAPI()
class GenerationRequest(BaseModel):
prompt: str
max_tokens: int = 1000
temperature: float = 0.7
class GenerationResponse(BaseModel):
text: str
tokens_used: int
model: str
class LLMGateway:
"""LLM网关"""
def __init__(self):
self.services = {}
self.load_balancer = None
def register_service(self, name: str, service: LLMService):
"""注册服务"""
self.services[name] = service
def get_service(self, name: str) -> LLMService:
"""获取服务"""
if name not in self.services:
raise ValueError(f"Service {name} not found")
return self.services[name]
async def route_request(self, request: GenerationRequest,
service_name: str = "default") -> GenerationResponse:
"""路由请求"""
service = self.get_service(service_name)
result = await service.generate(
prompt=request.prompt,
options={
"max_tokens": request.max_tokens,
"temperature": request.temperature
}
)
return GenerationResponse(
text=result["text"],
tokens_used=result["tokens_used"],
model=result["model"]
)
# 初始化网关
gateway = LLMGateway()
@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
"""生成文本API"""
try:
response = await gateway.route_request(request)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/stream")
async def stream_text(request: GenerationRequest):
"""流式生成API"""
from fastapi.responses import StreamingResponse
async def generate_stream():
service = gateway.get_service("default")
async for chunk in service.stream(
request.prompt,
{"max_tokens": request.max_tokens}
):
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream"
)
缓存策略
import redis
import json
from typing import Optional, Any
from datetime import datetime, timedelta
class CacheService:
"""缓存服务"""
def __init__(self, redis_url: str = "redis://localhost:6379"):
self.redis = redis.from_url(redis_url)
self.default_ttl = 3600 # 1小时
async def get(self, key: str) -> Optional[Any]:
"""获取缓存"""
try:
data = self.redis.get(key)
if data:
return json.loads(data)
except Exception:
pass
return None
async def set(self, key: str, value: Any,
ttl: int = None) -> bool:
"""设置缓存"""
try:
ttl = ttl or self.default_ttl
self.redis.setex(
key,
ttl,
json.dumps(value)
)
return True
except Exception:
return False
async def delete(self, key: str) -> bool:
"""删除缓存"""
try:
self.redis.delete(key)
return True
except Exception:
return False
async def get_or_set(self, key: str,
factory, ttl: int = None) -> Any:
"""获取缓存或设置"""
cached = await self.get(key)
if cached is not None:
return cached
value = await factory()
await self.set(key, value, ttl)
return value
class SemanticCache:
"""语义缓存"""
def __init__(self, embedding_model, cache_service):
self.embedding_model = embedding_model
self.cache = cache_service
self.similarity_threshold = 0.95
async def get_similar(self, query: str) -> Optional[Any]:
"""获取语义相似的缓存"""
query_embedding = self.embedding_model.encode(query)
# 搜索相似查询
similar_queries = await self.cache.search_by_embedding(
query_embedding,
threshold=self.similarity_threshold
)
if similar_queries:
return similar_queries[0]["value"]
return None
async def set_with_embedding(self, query: str, value: Any):
"""设置带嵌入的缓存"""
embedding = self.embedding_model.encode(query)
await self.cache.set_with_embedding(query, value, embedding)
监控与可观测性
import time
from dataclasses import dataclass
from typing import List
from prometheus_client import Counter, Histogram, Gauge
@dataclass
class MetricsCollector:
"""指标收集器"""
def __init__(self):
# 请求计数
self.request_count = Counter(
'llm_requests_total',
'Total LLM requests',
['model', 'status']
)
# 延迟直方图
self.latency = Histogram(
'llm_request_latency_seconds',
'LLM request latency',
['model']
)
# Token使用量
self.token_usage = Counter(
'llm_tokens_total',
'Total tokens used',
['model', 'type']
)
# 错误计数
self.error_count = Counter(
'llm_errors_total',
'Total LLM errors',
['model', 'error_type']
)
def record_request(self, model: str, status: str):
"""记录请求"""
self.request_count.labels(model=model, status=status).inc()
def record_latency(self, model: str, duration: float):
"""记录延迟"""
self.latency.labels(model=model).observe(duration)
def record_tokens(self, model: str, prompt_tokens: int,
completion_tokens: int):
"""记录Token使用量"""
self.token_usage.labels(model=model, type='prompt').inc(prompt_tokens)
self.token_usage.labels(model=model, type='completion').inc(completion_tokens)
def record_error(self, model: str, error_type: str):
"""记录错误"""
self.error_count.labels(model=model, error_type=error_type).inc()
class MonitoredLLMService:
"""带监控的LLM服务"""
def __init__(self, llm_service: LLMService,
metrics: MetricsCollector):
self.llm_service = llm_service
self.metrics = metrics
async def generate(self, prompt: str,
options: Dict[str, Any] = None) -> str:
"""带监控的生成"""
start_time = time.time()
try:
result = await self.llm_service.generate(prompt, options)
# 记录成功
self.metrics.record_request(self.model_name, "success")
self.metrics.record_latency(
self.model_name,
time.time() - start_time
)
return result
except Exception as e:
# 记录失败
self.metrics.record_request(self.model_name, "error")
self.metrics.record_error(self.model_name, type(e).__name__)
raise
总结
良好的架构设计是构建可靠LLM系统的基础。通过合理的服务分层、缓存策略和监控体系,可以确保系统在生产环境中的稳定运行。