← 返回首页
🧠

LLM超时模式

📂 llm ⏱ 4 min 723 words

--- title: "LLM超时模式" description: "详解LLM应用中的超时处理策略,包括连接超时、读取超时、超时传播等最佳实践" tags: ["超时", "错误处理", "可靠性"] category: "llm" icon: "🧠"

LLM超时模式

为什么需要超时处理

LLM API调用可能因网络问题、服务过载或其他原因而长时间无响应。没有超时机制,应用程序可能会无限期等待,耗尽连接池和线程资源。

超时类型

连接超时

建立TCP连接的时间限制。

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

class LLMClient:
    def __init__(self, base_url: str, api_key: str):
        self.base_url = base_url
        self.api_key = api_key
        self.session = self._create_session()
    
    def _create_session(self) -> requests.Session:
        session = requests.Session()
        
        # 配置重试策略
        retry_strategy = Retry(
            total=3,
            backoff_factor=1,
            status_forcelist=[429, 500, 502, 503, 504]
        )
        
        adapter = HTTPAdapter(
            max_retries=retry_strategy,
            pool_connections=10,
            pool_maxsize=10
        )
        
        session.mount("http://", adapter)
        session.mount("https://", adapter)
        
        return session
    
    def chat(self, messages: list, **kwargs) -> dict:
        # 连接超时:5秒,读取超时:60秒
        timeout = (5, 60)
        
        response = self.session.post(
            f"{self.base_url}/chat/completions",
            headers={"Authorization": f"Bearer {self.api_key}"},
            json={"messages": messages, **kwargs},
            timeout=timeout
        )
        
        return response.json()

分阶段超时

针对LLM调用的不同阶段设置不同的超时时间。

import time
from dataclasses import dataclass
from typing import Callable, Any

@dataclass
class TimeoutConfig:
    connect_timeout: float = 5.0      # 连接超时
    read_timeout: float = 120.0       # 读取超时(LLM可能需要较长时间)
    total_timeout: float = 180.0      # 总超时
    retry_count: int = 3              # 重试次数
    retry_delay: float = 1.0          # 重试延迟

class TimeoutLLMClient:
    def __init__(self, config: TimeoutConfig = None):
        self.config = config or TimeoutConfig()
    
    def call_with_timeout(self, func: Callable, *args, **kwargs) -> Any:
        start_time = time.time()
        last_exception = None
        
        for attempt in range(self.config.retry_count):
            try:
                # 检查总超时
                elapsed = time.time() - start_time
                remaining = self.config.total_timeout - elapsed
                
                if remaining <= 0:
                    raise TimeoutError(
                        f"总超时时间 {self.config.total_timeout}秒已用尽"
                    )
                
                # 设置本次调用的超时
                call_timeout = min(remaining, self.config.read_timeout)
                
                result = func(*args, timeout=call_timeout, **kwargs)
                return result
                
            except Exception as e:
                last_exception = e
                
                if attempt < self.config.retry_count - 1:
                    # 计算退避时间
                    delay = self.config.retry_delay * (2 ** attempt)
                    time.sleep(delay)
        
        raise TimeoutError(
            f"调用在 {self.config.retry_count} 次重试后仍然失败: {last_exception}"
        )

超时传播

从上游传递超时

from contextvars import ContextVar
from typing import Optional

# 使用ContextVar传递超时信息
timeout_context: ContextVar[Optional[float]] = ContextVar(
    'timeout_context', default=None
)

class TimeoutPropagation:
    def __init__(self, default_timeout: float = 60.0):
        self.default_timeout = default_timeout
    
    def set_timeout(self, timeout: float):
        timeout_context.set(timeout)
    
    def get_timeout(self) -> float:
        return timeout_context.get() or self.default_timeout
    
    def remaining_time(self, start_time: float) -> float:
        timeout = self.get_timeout()
        elapsed = time.time() - start_time
        return max(0, timeout - elapsed)

# 使用示例
def call_downstream_service(start_time: float, 
                           propagation: TimeoutPropagation) -> dict:
    remaining = propagation.remaining_time(start_time)
    
    if remaining <= 0:
        raise TimeoutError("上游超时已传递到下游")
    
    # 使用剩余时间作为下游调用的超时
    response = requests.post(
        "https://api.llm-provider.com/chat",
        timeout=(2, remaining),  # 连接超时2秒,读取用剩余时间
        json={"prompt": "Hello"}
    )
    
    return response.json()

优雅超时处理

超时降级

from enum import Enum
from dataclasses import dataclass
from typing import Any, Callable

class TimeoutStrategy(Enum):
    FAIL_FAST = "fail_fast"          # 立即失败
    RETRY_WITH_BACKOFF = "retry"     # 退避重试
    PARTIAL_RESPONSE = "partial"     # 返回部分响应
    CACHED_RESPONSE = "cached"       # 返回缓存响应
    QUEUE_FOR_LATER = "queue"        # 排队稍后处理

@dataclass
class TimeoutHandler:
    strategy: TimeoutStrategy
    fallback_value: Any = None
    max_retries: int = 3

class GracefulTimeoutManager:
    def __init__(self):
        self.handlers = {}
        self.cache = {}
    
    def register_handler(self, operation: str, handler: TimeoutHandler):
        self.handlers[operation] = handler
    
    def execute_with_timeout(self, operation: str, func: Callable,
                             *args, **kwargs) -> Any:
        handler = self.handlers.get(operation, TimeoutHandler(
            strategy=TimeoutStrategy.FAIL_FAST
        ))
        
        try:
            return func(*args, **kwargs)
            
        except TimeoutError as e:
            return self._handle_timeout(operation, handler, func, args, kwargs, e)
    
    def _handle_timeout(self, operation: str, handler: TimeoutHandler,
                        func: Callable, args, kwargs, 
                        original_error: TimeoutError) -> Any:
        
        if handler.strategy == TimeoutStrategy.FAIL_FAST:
            raise original_error
        
        elif handler.strategy == TimeoutStrategy.RETRY_WITH_BACKOFF:
            import time
            for attempt in range(handler.max_retries):
                try:
                    delay = (2 ** attempt) * 0.5
                    time.sleep(delay)
                    return func(*args, **kwargs)
                except TimeoutError:
                    if attempt == handler.max_retries - 1:
                        raise
            raise original_error
        
        elif handler.strategy == TimeoutStrategy.PARTIAL_RESPONSE:
            return {"partial": True, "message": "响应不完整", "fallback": True}
        
        elif handler.strategy == TimeoutStrategy.CACHED_RESPONSE:
            cache_key = f"{operation}:{hash(str(args))}"
            if cache_key in self.cache:
                return self.cache[cache_key]
            raise original_error
        
        elif handler.strategy == TimeoutStrategy.QUEUE_FOR_LATER:
            return {"queued": True, "message": "请求已加入队列稍后处理"}
        
        raise original_error

流式响应超时

import asyncio
from typing import AsyncIterator

class StreamingTimeoutManager:
    def __init__(self, chunk_timeout: float = 10.0, 
                 total_timeout: float = 120.0):
        self.chunk_timeout = chunk_timeout
        self.total_timeout = total_timeout
    
    async def stream_with_timeout(self, stream_func,
                                  *args, **kwargs) -> AsyncIterator:
        start_time = asyncio.get_event_loop().time()
        accumulated_response = []
        
        try:
            async for chunk in stream_func(*args, **kwargs):
                current_time = asyncio.get_event_loop().time()
                
                # 检查总超时
                if current_time - start_time > self.total_timeout:
                    yield {
                        "error": "total_timeout",
                        "partial_response": "".join(accumulated_response)
                    }
                    return
                
                accumulated_response.append(chunk.get("content", ""))
                yield chunk
                
        except asyncio.TimeoutError:
            yield {
                "error": "stream_timeout",
                "partial_response": "".join(accumulated_response)
            }

# 使用示例
async def main():
    manager = StreamingTimeoutManager(
        chunk_timeout=5.0,
        total_timeout=60.0
    )
    
    async def mock_stream(*args, **kwargs):
        for i in range(10):
            await asyncio.sleep(0.5)
            yield {"content": f"chunk {i} "}
    
    async for chunk in manager.stream_with_timeout(mock_stream):
        if "error" in chunk:
            print(f"超时: {chunk['error']}, 部分响应: {chunk['partial_response']}")
        else:
            print(chunk.get("content", ""), end="", flush=True)

超时监控和指标

import time
from collections import defaultdict

class TimeoutMonitor:
    def __init__(self):
        self.metrics = defaultdict(lambda: {
            "total_calls": 0,
            "timeout_calls": 0,
            "avg_response_time": 0,
            "timeout_rate": 0
        })
        self.response_times = defaultdict(list)
    
    def record_call(self, operation: str, response_time: float, 
                    is_timeout: bool):
        self.metrics[operation]["total_calls"] += 1
        
        if is_timeout:
            self.metrics[operation]["timeout_calls"] += 1
        
        self.response_times[operation].append(response_time)
        
        # 保持最近1000个记录
        if len(self.response_times[operation]) > 1000:
            self.response_times[operation] = self.response_times[operation][-500:]
        
        # 更新指标
        times = self.response_times[operation]
        self.metrics[operation]["avg_response_time"] = sum(times) / len(times)
        self.metrics[operation]["timeout_rate"] = (
            self.metrics[operation]["timeout_calls"] / 
            self.metrics[operation]["total_calls"]
        )
    
    def get_operation_stats(self, operation: str) -> dict:
        return self.metrics.get(operation, {})
    
    def get_all_stats(self) -> dict:
        return dict(self.metrics)
    
    def should_alert(self, operation: str, 
                     timeout_rate_threshold: float = 0.1) -> bool:
        stats = self.metrics.get(operation, {})
        return stats.get("timeout_rate", 0) > timeout_rate_threshold

最佳实践

  1. 设置合理的超时时间:根据LLM服务的响应特性设置超时
  2. 实现分层超时:为不同阶段设置不同的超时
  3. 支持超时传播:将超时信息从上游传递到下游
  4. 优雅降级:超时时提供有用的降级响应
  5. 监控超时率:实时监控并告警异常超时
  6. 记录超时日志:便于问题排查和性能优化

正确的超时处理是构建可靠LLM应用的关键,能有效防止资源耗尽和级联故障。