← 返回首页
🧠

LLM熔断机制

📂 llm ⏱ 4 min 669 words

--- title: "LLM熔断机制" description: "详解LLM应用中的熔断器模式,包括熔断状态管理、恢复策略和降级方案" tags: ["熔断", "容错", "高可用"] category: "llm" icon: "🧠"

LLM熔断机制

什么是熔断器

熔断器(Circuit Breaker)是一种保护机制,当检测到下游服务故障时,自动切断请求,避免级联故障。它像电路保险丝一样,在故障时"熔断"以保护整个系统。

熔断器的三种状态

关闭(正常) → 打开(熔断) → 半开(测试) → 关闭(恢复)

状态实现

from enum import Enum
from dataclasses import dataclass
import time
import threading

class CircuitState(Enum):
    CLOSED = "closed"      # 正常状态
    OPEN = "open"          # 熔断状态
    HALF_OPEN = "half_open"  # 测试状态

@dataclass
class CircuitBreakerConfig:
    failure_threshold: int = 5      # 触发熔断的失败次数
    success_threshold: int = 3      # 恢复所需的成功次数
    timeout: float = 60.0           # 熔断持续时间(秒)
    half_open_max_calls: int = 3    # 半开状态的最大测试调用

class CircuitBreaker:
    def __init__(self, name: str, config: CircuitBreakerConfig = None):
        self.name = name
        self.config = config or CircuitBreakerConfig()
        
        self.state = CircuitState.CLOSED
        self.failure_count = 0
        self.success_count = 0
        self.last_failure_time = 0
        self.half_open_calls = 0
        
        self.lock = threading.Lock()
        self.state_changes = []
    
    def can_execute(self) -> bool:
        with self.lock:
            if self.state == CircuitState.CLOSED:
                return True
            
            if self.state == CircuitState.OPEN:
                # 检查是否应该进入半开状态
                if time.time() - self.last_failure_time >= self.config.timeout:
                    self.state = CircuitState.HALF_OPEN
                    self.half_open_calls = 0
                    self._record_state_change(CircuitState.HALF_OPEN)
                    return True
                return False
            
            if self.state == CircuitState.HALF_OPEN:
                return self.half_open_calls < self.config.half_open_max_calls
            
            return False
    
    def record_success(self):
        with self.lock:
            if self.state == CircuitState.HALF_OPEN:
                self.success_count += 1
                if self.success_count >= self.config.success_threshold:
                    self.state = CircuitState.CLOSED
                    self.failure_count = 0
                    self.success_count = 0
                    self._record_state_change(CircuitState.CLOSED)
            else:
                self.failure_count = 0
    
    def record_failure(self):
        with self.lock:
            self.failure_count += 1
            self.success_count = 0
            self.last_failure_time = time.time()
            
            if self.state == CircuitState.HALF_OPEN:
                self.state = CircuitState.OPEN
                self._record_state_change(CircuitState.OPEN)
            elif self.failure_count >= self.config.failure_threshold:
                self.state = CircuitState.OPEN
                self._record_state_change(CircuitState.OPEN)
    
    def _record_state_change(self, new_state: CircuitState):
        self.state_changes.append({
            "timestamp": time.time(),
            "from_state": self.state.value,
            "to_state": new_state.value
        })
    
    def get_status(self) -> dict:
        return {
            "name": self.name,
            "state": self.state.value,
            "failure_count": self.failure_count,
            "success_count": self.success_count,
            "last_failure_time": self.last_failure_time
        }

LLM API熔断器封装

import time
from typing import Callable, Any

class LLMCircuitBreaker:
    def __init__(self, name: str, config: CircuitBreakerConfig = None):
        self.breaker = CircuitBreaker(name, config)
        self.fallback_func = None
    
    def set_fallback(self, func: Callable):
        self.fallback_func = func
    
    def call(self, llm_func: Callable, *args, **kwargs) -> Any:
        if not self.breaker.can_execute():
            # 熔断状态,执行降级
            if self.fallback_func:
                return self.fallback_func(*args, **kwargs)
            raise CircuitOpenError(
                f"熔断器 {self.breaker.name} 处于打开状态,"
                f"将在 {self.breaker.config.timeout}秒后重试"
            )
        
        try:
            start_time = time.time()
            result = llm_func(*args, **kwargs)
            response_time = (time.time() - start_time) * 1000
            
            # 检查响应是否有效
            if self._is_valid_response(result):
                self.breaker.record_success()
                return result
            else:
                self.breaker.record_failure()
                return self._handle_invalid_response(
                    llm_func, args, kwargs, result
                )
                
        except Exception as e:
            self.breaker.record_failure()
            if self.fallback_func:
                return self.fallback_func(*args, **kwargs)
            raise
    
    def _is_valid_response(self, response: Any) -> bool:
        if response is None:
            return False
        if isinstance(response, dict):
            return "error" not in response
        return True
    
    def _handle_invalid_response(self, llm_func: Callable,
                                  args, kwargs, response) -> Any:
        if self.fallback_func:
            return self.fallback_func(*args, **kwargs)
        return response

class CircuitOpenError(Exception):
    pass

使用示例

def main():
    # 配置熔断器
    config = CircuitBreakerConfig(
        failure_threshold=3,
        success_threshold=2,
        timeout=30.0
    )
    
    breaker = LLMCircuitBreaker("llm-api", config)
    
    # 设置降级函数
    def fallback(*args, **kwargs):
        return {"error": "服务暂时不可用,请稍后重试"}
    
    breaker.set_fallback(fallback)
    
    # 包装LLM调用
    def call_llm(prompt: str) -> dict:
        import requests
        response = requests.post(
            "https://api.llm-provider.com/chat",
            json={"prompt": prompt}
        )
        return response.json()
    
    # 使用熔断器调用LLM
    result = breaker.call(call_llm, "你好,请介绍一下自己")
    print(result)

if __name__ == "__main__":
    main()

多级熔断器

针对不同错误类型设置不同的熔断策略。

class MultiLevelCircuitBreaker:
    def __init__(self):
        self.breakers = {
            "timeout": CircuitBreaker("timeout", CircuitBreakerConfig(
                failure_threshold=3,
                timeout=30.0
            )),
            "rate_limit": CircuitBreaker("rate_limit", CircuitBreakerConfig(
                failure_threshold=5,
                timeout=60.0
            )),
            "auth_error": CircuitBreaker("auth_error", CircuitBreakerConfig(
                failure_threshold=2,
                timeout=300.0  # 认证错误更长的熔断时间
            )),
            "server_error": CircuitBreaker("server_error", CircuitBreakerConfig(
                failure_threshold=5,
                timeout=60.0
            ))
        }
    
    def classify_error(self, error: Exception) -> str:
        error_str = str(error).lower()
        
        if "timeout" in error_str:
            return "timeout"
        elif "rate limit" in error_str or "429" in error_str:
            return "rate_limit"
        elif "unauthorized" in error_str or "401" in error_str:
            return "auth_error"
        elif "500" in error_str or "server" in error_str:
            return "server_error"
        else:
            return "server_error"
    
    def can_execute(self, error_type: str = None) -> bool:
        if error_type:
            return self.breakers[error_type].can_execute()
        
        return all(b.can_execute() for b in self.breakers.values())
    
    def record_error(self, error: Exception):
        error_type = self.classify_error(error)
        self.breakers[error_type].record_failure()
    
    def record_success(self):
        for breaker in self.breakers.values():
            breaker.record_success()
    
    def get_status(self) -> dict:
        return {
            name: breaker.get_status()
            for name, breaker in self.breakers.items()
        }

熔断器监控

class CircuitBreakerMonitor:
    def __init__(self):
        self.breakers = {}
        self.alerts = []
    
    def register(self, breaker: CircuitBreaker):
        self.breakers[breaker.name] = breaker
    
    def check_all(self) -> list:
        alerts = []
        
        for name, breaker in self.breakers.items():
            status = breaker.get_status()
            
            if status["state"] == "open":
                alerts.append({
                    "severity": "critical",
                    "breaker": name,
                    "message": f"熔断器 {name} 处于打开状态",
                    "details": status
                })
            elif status["state"] == "half_open":
                alerts.append({
                    "severity": "warning",
                    "breaker": name,
                    "message": f"熔断器 {name} 正在测试恢复",
                    "details": status
                })
        
        return alerts
    
    def get_metrics(self) -> dict:
        total_failures = sum(
            b.failure_count for b in self.breakers.values()
        )
        
        open_breakers = sum(
            1 for b in self.breakers.values()
            if b.state == CircuitState.OPEN
        )
        
        return {
            "total_failures": total_failures,
            "open_breakers": open_breakers,
            "breakers": {
                name: b.get_status()
                for name, b in self.breakers.items()
            }
        }

最佳实践

  1. 合理设置阈值:根据实际业务需求调整失败阈值和恢复时间
  2. 实现优雅降级:熔断时提供有用的降级响应
  3. 监控熔断状态:实时监控并告警
  4. 定期测试:定期验证熔断器配置是否合理
  5. 记录日志:记录所有状态变化便于问题排查

熔断器是构建高可用LLM应用的重要组件,能有效防止级联故障,提高系统稳定性。