← 返回首页
🧠

LLM隔舱模式

📂 llm ⏱ 4 min 764 words

--- title: "LLM隔舱模式" description: "详解LLM应用中的隔舱模式,通过资源隔离防止级联故障,提高系统可用性" tags: ["隔舱", "资源隔离", "容错"] category: "llm" icon: "🧠"

LLM隔舱模式

什么是隔舱模式

隔舱模式(Bulkhead Pattern)源自船舶设计,通过将系统分割成独立的隔舱(bulkhead),确保一个隔舱的故障不会影响其他隔舱。在LLM应用中,它用于隔离不同用户、不同服务或不同类型的请求资源。

为什么需要资源隔离

当多个LLM服务共享资源池时,一个服务的故障或过载可能会耗尽所有资源,导致整个系统崩溃。隔舱模式通过限制每个服务的资源使用,防止单点故障扩散。

线程池隔舱

基础实现

import threading
import time
from concurrent.futures import ThreadPoolExecutor, Future
from dataclasses import dataclass, field
from typing import Callable, Any, Dict

@dataclass
class BulkheadConfig:
    max_concurrent: int = 10          # 最大并发数
    max_queue_size: int = 100         # 最大队列大小
    queue_timeout: float = 30.0       # 队列等待超时时间
    name: str = "default"

class Bulkhead:
    def __init__(self, config: BulkheadConfig):
        self.config = config
        self.semaphore = threading.Semaphore(config.max_concurrent)
        self.queue_count = 0
        self.active_count = 0
        self.lock = threading.Lock()
        self.stats = {
            "total_calls": 0,
            "successful_calls": 0,
            "failed_calls": 0,
            "rejected_calls": 0,
            "timeout_calls": 0
        }
    
    def execute(self, func: Callable, *args, **kwargs) -> Any:
        with self.lock:
            self.stats["total_calls"] += 1
            
            # 检查队列是否已满
            if self.queue_count >= self.config.max_queue_size:
                self.stats["rejected_calls"] += 1
                raise BulkheadFullError(
                    f"隔舱 {self.config.name} 队列已满,请求被拒绝"
                )
            
            self.queue_count += 1
        
        try:
            # 尝试获取信号量
            acquired = self.semaphore.acquire(
                timeout=self.config.queue_timeout
            )
            
            if not acquired:
                with self.lock:
                    self.queue_count -= 1
                    self.stats["timeout_calls"] += 1
                raise BulkheadTimeoutError(
                    f"隔舱 {self.config.name} 等待超时"
                )
            
            try:
                with self.lock:
                    self.queue_count -= 1
                    self.active_count += 1
                
                result = func(*args, **kwargs)
                
                with self.lock:
                    self.active_count -= 1
                    self.stats["successful_calls"] += 1
                
                return result
                
            except Exception as e:
                with self.lock:
                    self.active_count -= 1
                    self.stats["failed_calls"] += 1
                raise
                
        except Exception as e:
            if not isinstance(e, (BulkheadFullError, BulkheadTimeoutError)):
                with self.lock:
                    self.queue_count -= 1
            raise
        finally:
            self.semaphore.release()
    
    def get_status(self) -> dict:
        with self.lock:
            return {
                "name": self.config.name,
                "max_concurrent": self.config.max_concurrent,
                "active_count": self.active_count,
                "queue_count": self.queue_count,
                "available": self.config.max_concurrent - self.active_count,
                "stats": self.stats.copy()
            }

class BulkheadFullError(Exception):
    pass

class BulkheadTimeoutError(Exception):
    pass

多隔舱管理器

按用户分隔离舱

class UserBulkheadManager:
    def __init__(self, default_config: BulkheadConfig = None):
        self.default_config = default_config or BulkheadConfig()
        self.user_bulkheads: Dict[str, Bulkhead] = {}
        self.lock = threading.Lock()
    
    def get_bulkhead(self, user_id: str, 
                     custom_config: BulkheadConfig = None) -> Bulkhead:
        with self.lock:
            if user_id not in self.user_bulkheads:
                config = custom_config or self.default_config
                self.user_bulkheads[user_id] = Bulkhead(config)
            return self.user_bulkheads[user_id]
    
    def execute_for_user(self, user_id: str, func: Callable,
                         *args, **kwargs) -> Any:
        bulkhead = self.get_bulkhead(user_id)
        return bulkhead.execute(func, *args, **kwargs)
    
    def get_user_status(self, user_id: str) -> dict:
        if user_id in self.user_bulkheads:
            return self.user_bulkheads[user_id].get_status()
        return None
    
    def get_all_status(self) -> dict:
        return {
            user_id: bulkhead.get_status()
            for user_id, bulkhead in self.user_bulkheads.items()
        }

按服务分隔离舱

class ServiceBulkheadManager:
    def __init__(self):
        self.services: Dict[str, Bulkhead] = {}
    
    def register_service(self, service_name: str, config: BulkheadConfig):
        self.services[service_name] = Bulkhead(config)
    
    def execute(self, service_name: str, func: Callable,
                *args, **kwargs) -> Any:
        if service_name not in self.services:
            raise ValueError(f"服务 {service_name} 未注册")
        
        return self.services[service_name].execute(func, *args, **kwargs)
    
    def get_service_status(self, service_name: str) -> dict:
        if service_name in self.services:
            return self.services[service_name].get_status()
        return None
    
    def get_all_status(self) -> dict:
        return {
            name: bulkhead.get_status()
            for name, bulkhead in self.services.items()
        }

LLM应用集成

包装LLM客户端

from typing import Dict, Optional

class BulkheadLLMClient:
    def __init__(self):
        self.service_bulkheads = ServiceBulkheadManager()
        self._initialize_default_bulkheads()
    
    def _initialize_default_bulkheads(self):
        # 不同LLM服务的隔舱配置
        configs = {
            "chat": BulkheadConfig(
                max_concurrent=50,
                max_queue_size=200,
                queue_timeout=60.0,
                name="chat"
            ),
            "completion": BulkheadConfig(
                max_concurrent=30,
                max_queue_size=100,
                queue_timeout=30.0,
                name="completion"
            ),
            "embedding": BulkheadConfig(
                max_concurrent=100,
                max_queue_size=500,
                queue_timeout=10.0,
                name="embedding"
            )
        }
        
        for name, config in configs.items():
            self.service_bulkheads.register_service(name, config)
    
    def chat(self, messages: list, model: str = "gpt-4",
             **kwargs) -> dict:
        def llm_call():
            # 模拟LLM API调用
            return {"choices": [{"message": {"content": "响应"}}]}
        
        return self.service_bulkheads.execute("chat", llm_call)
    
    def completion(self, prompt: str, model: str = "gpt-3.5-turbo",
                   **kwargs) -> dict:
        def llm_call():
            return {"choices": [{"text": "补全结果"}]}
        
        return self.service_bulkheads.execute("completion", llm_call)
    
    def embedding(self, texts: list, model: str = "text-embedding-ada-002",
                  **kwargs) -> dict:
        def llm_call():
            return {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
        
        return self.service_bulkheads.execute("embedding", llm_call)
    
    def get_status(self) -> dict:
        return self.service_bulkheads.get_all_status()

异步隔舱

import asyncio
from asyncio import Semaphore

class AsyncBulkhead:
    def __init__(self, config: BulkheadConfig):
        self.config = config
        self.semaphore = Semaphore(config.max_concurrent)
        self.active_count = 0
        self.queue_count = 0
        self.lock = asyncio.Lock()
    
    async def execute(self, func: Callable, *args, **kwargs) -> Any:
        async with self.lock:
            self.queue_count += 1
            
            if self.queue_count > self.config.max_queue_size:
                self.queue_count -= 1
                raise BulkheadFullError("异步隔舱队列已满")
        
        try:
            try:
                await asyncio.wait_for(
                    self.semaphore.acquire(),
                    timeout=self.config.queue_timeout
                )
            except asyncio.TimeoutError:
                async with self.lock:
                    self.queue_count -= 1
                raise BulkheadTimeoutError("异步隔舱等待超时")
            
            try:
                async with self.lock:
                    self.queue_count -= 1
                    self.active_count += 1
                
                if asyncio.iscoroutinefunction(func):
                    result = await func(*args, **kwargs)
                else:
                    result = func(*args, **kwargs)
                
                async with self.lock:
                    self.active_count -= 1
                
                return result
                
            except Exception:
                async with self.lock:
                    self.active_count -= 1
                raise
            finally:
                self.semaphore.release()
                
        except Exception:
            async with self.lock:
                if self.queue_count > 0:
                    self.queue_count -= 1
            raise

监控和告警

class BulkheadMonitor:
    def __init__(self):
        self.alerts = []
        self.metrics_history = []
    
    def check_bulkhead(self, bulkhead: Bulkhead, 
                       thresholds: dict = None) -> list:
        thresholds = thresholds or {
            "queue_usage": 0.8,
            "concurrent_usage": 0.9,
            "rejection_rate": 0.1
        }
        
        status = bulkhead.get_status()
        alerts = []
        
        # 检查队列使用率
        if status["max_queue_size"] > 0:
            queue_usage = status["queue_count"] / status["max_queue_size"]
            if queue_usage > thresholds["queue_usage"]:
                alerts.append({
                    "severity": "warning",
                    "bulkhead": status["name"],
                    "message": f"队列使用率 {queue_usage:.2%} 超过阈值",
                    "metric": "queue_usage"
                })
        
        # 检查并发使用率
        concurrent_usage = status["active_count"] / status["max_concurrent"]
        if concurrent_usage > thresholds["concurrent_usage"]:
            alerts.append({
                "severity": "critical",
                "bulkhead": status["name"],
                "message": f"并发使用率 {concurrent_usage:.2%} 超过阈值",
                "metric": "concurrent_usage"
            })
        
        # 检查拒绝率
        total_calls = status["stats"]["total_calls"]
        if total_calls > 0:
            rejection_rate = status["stats"]["rejected_calls"] / total_calls
            if rejection_rate > thresholds["rejection_rate"]:
                alerts.append({
                    "severity": "warning",
                    "bulkhead": status["name"],
                    "message": f"拒绝率 {rejection_rate:.2%} 超过阈值",
                    "metric": "rejection_rate"
                })
        
        return alerts
    
    def check_all(self, bulkhead_manager) -> list:
        all_alerts = []
        
        for name, bulkhead in bulkhead_manager.services.items():
            alerts = self.check_bulkhead(bulkhead)
            all_alerts.extend(alerts)
        
        return all_alerts

最佳实践

  1. 合理配置隔舱大小:根据实际流量和资源情况设置并发数和队列大小
  2. 分层隔离:按用户、服务、模型等多维度隔离
  3. 监控告警:实时监控隔舱状态,及时发现问题
  4. 优雅降级:隔舱满时提供有意义的降级响应
  5. 定期调整:根据业务变化定期调整隔舱配置

隔舱模式是构建高可用LLM应用的关键技术,能有效隔离故障,提高系统的整体稳定性。