LLM超时模式
--- title: "LLM超时模式" description: "详解LLM应用中的超时处理策略,包括连接超时、读取超时、超时传播等最佳实践" tags: ["超时", "错误处理", "可靠性"] category: "llm" icon: "🧠"
LLM超时模式
为什么需要超时处理
LLM API调用可能因网络问题、服务过载或其他原因而长时间无响应。没有超时机制,应用程序可能会无限期等待,耗尽连接池和线程资源。
超时类型
连接超时
建立TCP连接的时间限制。
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
class LLMClient:
def __init__(self, base_url: str, api_key: str):
self.base_url = base_url
self.api_key = api_key
self.session = self._create_session()
def _create_session(self) -> requests.Session:
session = requests.Session()
# 配置重试策略
retry_strategy = Retry(
total=3,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504]
)
adapter = HTTPAdapter(
max_retries=retry_strategy,
pool_connections=10,
pool_maxsize=10
)
session.mount("http://", adapter)
session.mount("https://", adapter)
return session
def chat(self, messages: list, **kwargs) -> dict:
# 连接超时:5秒,读取超时:60秒
timeout = (5, 60)
response = self.session.post(
f"{self.base_url}/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={"messages": messages, **kwargs},
timeout=timeout
)
return response.json()
分阶段超时
针对LLM调用的不同阶段设置不同的超时时间。
import time
from dataclasses import dataclass
from typing import Callable, Any
@dataclass
class TimeoutConfig:
connect_timeout: float = 5.0 # 连接超时
read_timeout: float = 120.0 # 读取超时(LLM可能需要较长时间)
total_timeout: float = 180.0 # 总超时
retry_count: int = 3 # 重试次数
retry_delay: float = 1.0 # 重试延迟
class TimeoutLLMClient:
def __init__(self, config: TimeoutConfig = None):
self.config = config or TimeoutConfig()
def call_with_timeout(self, func: Callable, *args, **kwargs) -> Any:
start_time = time.time()
last_exception = None
for attempt in range(self.config.retry_count):
try:
# 检查总超时
elapsed = time.time() - start_time
remaining = self.config.total_timeout - elapsed
if remaining <= 0:
raise TimeoutError(
f"总超时时间 {self.config.total_timeout}秒已用尽"
)
# 设置本次调用的超时
call_timeout = min(remaining, self.config.read_timeout)
result = func(*args, timeout=call_timeout, **kwargs)
return result
except Exception as e:
last_exception = e
if attempt < self.config.retry_count - 1:
# 计算退避时间
delay = self.config.retry_delay * (2 ** attempt)
time.sleep(delay)
raise TimeoutError(
f"调用在 {self.config.retry_count} 次重试后仍然失败: {last_exception}"
)
超时传播
从上游传递超时
from contextvars import ContextVar
from typing import Optional
# 使用ContextVar传递超时信息
timeout_context: ContextVar[Optional[float]] = ContextVar(
'timeout_context', default=None
)
class TimeoutPropagation:
def __init__(self, default_timeout: float = 60.0):
self.default_timeout = default_timeout
def set_timeout(self, timeout: float):
timeout_context.set(timeout)
def get_timeout(self) -> float:
return timeout_context.get() or self.default_timeout
def remaining_time(self, start_time: float) -> float:
timeout = self.get_timeout()
elapsed = time.time() - start_time
return max(0, timeout - elapsed)
# 使用示例
def call_downstream_service(start_time: float,
propagation: TimeoutPropagation) -> dict:
remaining = propagation.remaining_time(start_time)
if remaining <= 0:
raise TimeoutError("上游超时已传递到下游")
# 使用剩余时间作为下游调用的超时
response = requests.post(
"https://api.llm-provider.com/chat",
timeout=(2, remaining), # 连接超时2秒,读取用剩余时间
json={"prompt": "Hello"}
)
return response.json()
优雅超时处理
超时降级
from enum import Enum
from dataclasses import dataclass
from typing import Any, Callable
class TimeoutStrategy(Enum):
FAIL_FAST = "fail_fast" # 立即失败
RETRY_WITH_BACKOFF = "retry" # 退避重试
PARTIAL_RESPONSE = "partial" # 返回部分响应
CACHED_RESPONSE = "cached" # 返回缓存响应
QUEUE_FOR_LATER = "queue" # 排队稍后处理
@dataclass
class TimeoutHandler:
strategy: TimeoutStrategy
fallback_value: Any = None
max_retries: int = 3
class GracefulTimeoutManager:
def __init__(self):
self.handlers = {}
self.cache = {}
def register_handler(self, operation: str, handler: TimeoutHandler):
self.handlers[operation] = handler
def execute_with_timeout(self, operation: str, func: Callable,
*args, **kwargs) -> Any:
handler = self.handlers.get(operation, TimeoutHandler(
strategy=TimeoutStrategy.FAIL_FAST
))
try:
return func(*args, **kwargs)
except TimeoutError as e:
return self._handle_timeout(operation, handler, func, args, kwargs, e)
def _handle_timeout(self, operation: str, handler: TimeoutHandler,
func: Callable, args, kwargs,
original_error: TimeoutError) -> Any:
if handler.strategy == TimeoutStrategy.FAIL_FAST:
raise original_error
elif handler.strategy == TimeoutStrategy.RETRY_WITH_BACKOFF:
import time
for attempt in range(handler.max_retries):
try:
delay = (2 ** attempt) * 0.5
time.sleep(delay)
return func(*args, **kwargs)
except TimeoutError:
if attempt == handler.max_retries - 1:
raise
raise original_error
elif handler.strategy == TimeoutStrategy.PARTIAL_RESPONSE:
return {"partial": True, "message": "响应不完整", "fallback": True}
elif handler.strategy == TimeoutStrategy.CACHED_RESPONSE:
cache_key = f"{operation}:{hash(str(args))}"
if cache_key in self.cache:
return self.cache[cache_key]
raise original_error
elif handler.strategy == TimeoutStrategy.QUEUE_FOR_LATER:
return {"queued": True, "message": "请求已加入队列稍后处理"}
raise original_error
流式响应超时
import asyncio
from typing import AsyncIterator
class StreamingTimeoutManager:
def __init__(self, chunk_timeout: float = 10.0,
total_timeout: float = 120.0):
self.chunk_timeout = chunk_timeout
self.total_timeout = total_timeout
async def stream_with_timeout(self, stream_func,
*args, **kwargs) -> AsyncIterator:
start_time = asyncio.get_event_loop().time()
accumulated_response = []
try:
async for chunk in stream_func(*args, **kwargs):
current_time = asyncio.get_event_loop().time()
# 检查总超时
if current_time - start_time > self.total_timeout:
yield {
"error": "total_timeout",
"partial_response": "".join(accumulated_response)
}
return
accumulated_response.append(chunk.get("content", ""))
yield chunk
except asyncio.TimeoutError:
yield {
"error": "stream_timeout",
"partial_response": "".join(accumulated_response)
}
# 使用示例
async def main():
manager = StreamingTimeoutManager(
chunk_timeout=5.0,
total_timeout=60.0
)
async def mock_stream(*args, **kwargs):
for i in range(10):
await asyncio.sleep(0.5)
yield {"content": f"chunk {i} "}
async for chunk in manager.stream_with_timeout(mock_stream):
if "error" in chunk:
print(f"超时: {chunk['error']}, 部分响应: {chunk['partial_response']}")
else:
print(chunk.get("content", ""), end="", flush=True)
超时监控和指标
import time
from collections import defaultdict
class TimeoutMonitor:
def __init__(self):
self.metrics = defaultdict(lambda: {
"total_calls": 0,
"timeout_calls": 0,
"avg_response_time": 0,
"timeout_rate": 0
})
self.response_times = defaultdict(list)
def record_call(self, operation: str, response_time: float,
is_timeout: bool):
self.metrics[operation]["total_calls"] += 1
if is_timeout:
self.metrics[operation]["timeout_calls"] += 1
self.response_times[operation].append(response_time)
# 保持最近1000个记录
if len(self.response_times[operation]) > 1000:
self.response_times[operation] = self.response_times[operation][-500:]
# 更新指标
times = self.response_times[operation]
self.metrics[operation]["avg_response_time"] = sum(times) / len(times)
self.metrics[operation]["timeout_rate"] = (
self.metrics[operation]["timeout_calls"] /
self.metrics[operation]["total_calls"]
)
def get_operation_stats(self, operation: str) -> dict:
return self.metrics.get(operation, {})
def get_all_stats(self) -> dict:
return dict(self.metrics)
def should_alert(self, operation: str,
timeout_rate_threshold: float = 0.1) -> bool:
stats = self.metrics.get(operation, {})
return stats.get("timeout_rate", 0) > timeout_rate_threshold
最佳实践
- 设置合理的超时时间:根据LLM服务的响应特性设置超时
- 实现分层超时:为不同阶段设置不同的超时
- 支持超时传播:将超时信息从上游传递到下游
- 优雅降级:超时时提供有用的降级响应
- 监控超时率:实时监控并告警异常超时
- 记录超时日志:便于问题排查和性能优化
正确的超时处理是构建可靠LLM应用的关键,能有效防止资源耗尽和级联故障。