← 返回首页
🧠

架构设计

📂 llm ⏱ 4 min 710 words

--- title: "架构设计" description: "掌握LLM系统架构设计的核心原则,包括微服务架构、可扩展性和性能优化策略" tags: ["架构设计", "LLM系统", "微服务", "性能优化"] category: "llm" icon: "🧠"

架构设计

LLM系统架构概述

构建一个生产级的LLM应用需要考虑多个方面:模型服务、数据处理、缓存、监控、安全等。良好的架构设计能够确保系统的可扩展性、可靠性和可维护性。

核心组件

1. 服务层设计

from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
from dataclasses import dataclass
import asyncio

@dataclass
class RequestContext:
    request_id: str
    user_id: str
    timestamp: float
    metadata: Dict[str, Any]

class LLMService(ABC):
    """LLM服务基类"""
    
    @abstractmethod
    async def generate(self, prompt: str, 
                       options: Dict[str, Any] = None) -> str:
        pass
    
    @abstractmethod
    async def stream(self, prompt: str, 
                     options: Dict[str, Any] = None):
        pass

class LLMServiceImpl(LLMService):
    """LLM服务实现"""
    
    def __init__(self, model_client, cache_service, 
                 rate_limiter, logger):
        self.model_client = model_client
        self.cache = cache_service
        self.rate_limiter = rate_limiter
        self.logger = logger
    
    async def generate(self, prompt: str, 
                       options: Dict[str, Any] = None) -> str:
        """生成文本"""
        options = options or {}
        
        # 检查缓存
        cache_key = self._get_cache_key(prompt, options)
        cached_result = await self.cache.get(cache_key)
        if cached_result:
            self.logger.info(f"Cache hit for {cache_key}")
            return cached_result
        
        # 速率限制
        await self.rate_limiter.acquire()
        
        try:
            # 调用模型
            result = await self.model_client.generate(
                prompt, **options
            )
            
            # 缓存结果
            await self.cache.set(cache_key, result, ttl=3600)
            
            return result
        except Exception as e:
            self.logger.error(f"Generation failed: {e}")
            raise
    
    async def stream(self, prompt: str, 
                     options: Dict[str, Any] = None):
        """流式生成"""
        options = options or {}
        
        async for chunk in self.model_client.stream(prompt, **options):
            yield chunk
    
    def _get_cache_key(self, prompt: str, options: Dict) -> str:
        """生成缓存键"""
        import hashlib
        import json
        content = json.dumps({"prompt": prompt, "options": options})
        return hashlib.md5(content.encode()).hexdigest()

2. 微服务架构

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn

app = FastAPI()

class GenerationRequest(BaseModel):
    prompt: str
    max_tokens: int = 1000
    temperature: float = 0.7

class GenerationResponse(BaseModel):
    text: str
    tokens_used: int
    model: str

class LLMGateway:
    """LLM网关"""
    
    def __init__(self):
        self.services = {}
        self.load_balancer = None
    
    def register_service(self, name: str, service: LLMService):
        """注册服务"""
        self.services[name] = service
    
    def get_service(self, name: str) -> LLMService:
        """获取服务"""
        if name not in self.services:
            raise ValueError(f"Service {name} not found")
        return self.services[name]
    
    async def route_request(self, request: GenerationRequest,
                           service_name: str = "default") -> GenerationResponse:
        """路由请求"""
        service = self.get_service(service_name)
        
        result = await service.generate(
            prompt=request.prompt,
            options={
                "max_tokens": request.max_tokens,
                "temperature": request.temperature
            }
        )
        
        return GenerationResponse(
            text=result["text"],
            tokens_used=result["tokens_used"],
            model=result["model"]
        )

# 初始化网关
gateway = LLMGateway()

@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
    """生成文本API"""
    try:
        response = await gateway.route_request(request)
        return response
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/stream")
async def stream_text(request: GenerationRequest):
    """流式生成API"""
    from fastapi.responses import StreamingResponse
    
    async def generate_stream():
        service = gateway.get_service("default")
        async for chunk in service.stream(
            request.prompt,
            {"max_tokens": request.max_tokens}
        ):
            yield f"data: {json.dumps(chunk)}\n\n"
        yield "data: [DONE]\n\n"
    
    return StreamingResponse(
        generate_stream(),
        media_type="text/event-stream"
    )

缓存策略

import redis
import json
from typing import Optional, Any
from datetime import datetime, timedelta

class CacheService:
    """缓存服务"""
    
    def __init__(self, redis_url: str = "redis://localhost:6379"):
        self.redis = redis.from_url(redis_url)
        self.default_ttl = 3600  # 1小时
    
    async def get(self, key: str) -> Optional[Any]:
        """获取缓存"""
        try:
            data = self.redis.get(key)
            if data:
                return json.loads(data)
        except Exception:
            pass
        return None
    
    async def set(self, key: str, value: Any, 
                  ttl: int = None) -> bool:
        """设置缓存"""
        try:
            ttl = ttl or self.default_ttl
            self.redis.setex(
                key,
                ttl,
                json.dumps(value)
            )
            return True
        except Exception:
            return False
    
    async def delete(self, key: str) -> bool:
        """删除缓存"""
        try:
            self.redis.delete(key)
            return True
        except Exception:
            return False
    
    async def get_or_set(self, key: str, 
                         factory, ttl: int = None) -> Any:
        """获取缓存或设置"""
        cached = await self.get(key)
        if cached is not None:
            return cached
        
        value = await factory()
        await self.set(key, value, ttl)
        return value

class SemanticCache:
    """语义缓存"""
    
    def __init__(self, embedding_model, cache_service):
        self.embedding_model = embedding_model
        self.cache = cache_service
        self.similarity_threshold = 0.95
    
    async def get_similar(self, query: str) -> Optional[Any]:
        """获取语义相似的缓存"""
        query_embedding = self.embedding_model.encode(query)
        
        # 搜索相似查询
        similar_queries = await self.cache.search_by_embedding(
            query_embedding,
            threshold=self.similarity_threshold
        )
        
        if similar_queries:
            return similar_queries[0]["value"]
        return None
    
    async def set_with_embedding(self, query: str, value: Any):
        """设置带嵌入的缓存"""
        embedding = self.embedding_model.encode(query)
        await self.cache.set_with_embedding(query, value, embedding)

监控与可观测性

import time
from dataclasses import dataclass
from typing import List
from prometheus_client import Counter, Histogram, Gauge

@dataclass
class MetricsCollector:
    """指标收集器"""
    
    def __init__(self):
        # 请求计数
        self.request_count = Counter(
            'llm_requests_total',
            'Total LLM requests',
            ['model', 'status']
        )
        
        # 延迟直方图
        self.latency = Histogram(
            'llm_request_latency_seconds',
            'LLM request latency',
            ['model']
        )
        
        # Token使用量
        self.token_usage = Counter(
            'llm_tokens_total',
            'Total tokens used',
            ['model', 'type']
        )
        
        # 错误计数
        self.error_count = Counter(
            'llm_errors_total',
            'Total LLM errors',
            ['model', 'error_type']
        )
    
    def record_request(self, model: str, status: str):
        """记录请求"""
        self.request_count.labels(model=model, status=status).inc()
    
    def record_latency(self, model: str, duration: float):
        """记录延迟"""
        self.latency.labels(model=model).observe(duration)
    
    def record_tokens(self, model: str, prompt_tokens: int, 
                      completion_tokens: int):
        """记录Token使用量"""
        self.token_usage.labels(model=model, type='prompt').inc(prompt_tokens)
        self.token_usage.labels(model=model, type='completion').inc(completion_tokens)
    
    def record_error(self, model: str, error_type: str):
        """记录错误"""
        self.error_count.labels(model=model, error_type=error_type).inc()

class MonitoredLLMService:
    """带监控的LLM服务"""
    
    def __init__(self, llm_service: LLMService, 
                 metrics: MetricsCollector):
        self.llm_service = llm_service
        self.metrics = metrics
    
    async def generate(self, prompt: str, 
                       options: Dict[str, Any] = None) -> str:
        """带监控的生成"""
        start_time = time.time()
        
        try:
            result = await self.llm_service.generate(prompt, options)
            
            # 记录成功
            self.metrics.record_request(self.model_name, "success")
            self.metrics.record_latency(
                self.model_name, 
                time.time() - start_time
            )
            
            return result
        except Exception as e:
            # 记录失败
            self.metrics.record_request(self.model_name, "error")
            self.metrics.record_error(self.model_name, type(e).__name__)
            raise

总结

良好的架构设计是构建可靠LLM系统的基础。通过合理的服务分层、缓存策略和监控体系,可以确保系统在生产环境中的稳定运行。