← 返回首页
🧠

重试策略:LLM API调用的智能重试机制

📂 llm ⏱ 3 min 563 words

--- title: "重试策略:LLM API调用的智能重试机制" description: "详解LLM API调用中的各种重试策略,包括指数退避、自适应重试、条件重试等,提升应用可用性" tags: ["重试策略", "指数退避", "容错机制", "LLM API"] category: "llm" icon: "🧠"

重试策略:LLM API调用的智能重试机制

为什么需要智能重试

LLM API调用可能因多种临时性原因失败:网络波动、服务端过载、速率限制等。合理的重试策略能在不影响用户体验的前提下,显著提升调用成功率。

基础重试模式

简单重试

import time

def simple_retry(func, max_retries=3):
    for attempt in range(max_retries):
        try:
            return func()
        except Exception as e:
            if attempt == max_retries - 1:
                raise
            print(f"第{attempt+1}次尝试失败,重试中...")
            time.sleep(1)

简单重试的问题:固定间隔可能导致大量请求同时重试(雷群效应)。

指数退避

import random
import time

def exponential_backoff(func, max_retries=5, base_delay=1, max_delay=60):
    for attempt in range(max_retries):
        try:
            return func()
        except Exception as e:
            if attempt == max_retries - 1:
                raise

            # 指数退避 + 随机抖动
            delay = min(base_delay * (2 ** attempt), max_delay)
            jitter = delay * 0.5 * random.random()
            wait_time = delay + jitter

            print(f"等待 {wait_time:.1f}s 后重试...")
            time.sleep(wait_time)

使用tenacity库的高级重试

from tenacity import (
    retry,
    stop_after_attempt,
    wait_exponential,
    retry_if_exception_type,
    before_sleep_log,
)
import logging

logger = logging.getLogger(__name__)

@retry(
    stop=stop_after_attempt(5),
    wait=wait_exponential(multiplier=1, min=2, max=60),
    retry=retry_if_exception_type((
        ConnectionError,
        TimeoutError,
        500, 502, 503, 504,
    )),
    before_sleep=before_sleep_log(logger, logging.WARNING),
)
async def call_llm_with_retry(messages: list[dict]) -> str:
    response = await client.chat.completions.create(
        model="gpt-4o",
        messages=messages,
    )
    return response.choices[0].message.content

条件重试策略

并非所有错误都应该重试。不同错误类型需要不同的处理:

from openai import (
    RateLimitError,
    AuthenticationError,
    BadRequestError,
    APIConnectionError,
)

class ConditionalRetryStrategy:
    def __init__(self):
        self.retryable_errors = {
            APIConnectionError: {"max_retries": 5, "backoff": "exponential"},
            RateLimitError: {"max_retries": 3, "backoff": "fixed", "fixed_delay": 60},
            500: {"max_retries": 3, "backoff": "exponential"},
            502: {"max_retries": 2, "backoff": "exponential"},
            503: {"max_retries": 3, "backoff": "exponential"},
        }
        self.non_retryable_errors = {
            AuthenticationError: "请检查API密钥",
            BadRequestError: "请求参数错误,不可重试",
        }

    def should_retry(self, error, attempt):
        error_type = type(error)
        error_code = getattr(error, "status_code", None)

        # 不可重试的错误
        if error_type in self.non_retryable_errors:
            return False, self.non_retryable_errors[error_type]

        # 可重试的错误
        for retryable_type, config in self.retryable_errors.items():
            if error_type == retryable_type or error_code == retryable_type:
                if attempt < config["max_retries"]:
                    return True, f"第{attempt+1}次重试"
                return False, f"重试次数已达上限({config['max_retries']})"

        return False, "未知错误类型,不重试"

流式输出的重试

流式输出的重试需要特殊处理:

async def stream_with_retry(messages, max_retries=3):
    for attempt in range(max_retries):
        collected_content = []
        current_pos = 0

        try:
            stream = await client.chat.completions.create(
                model="gpt-4o",
                messages=messages,
                stream=True,
            )

            async for chunk in stream:
                if chunk.choices[0].delta.content:
                    content = chunk.choices[0].delta.content
                    collected_content.append(content)
                    yield content
                    current_pos += len(content)

            return  # 成功完成

        except Exception as e:
            if isinstance(e, (RateLimitError, APIConnectionError)):
                logger.warning(f"流式输出中断,重试中... (已接收{current_pos}字符)")
                # 可以选择是否丢弃已接收的内容
                yield f"\n[重试中...已接收{current_pos}字符]\n"
                await asyncio.sleep(2 ** attempt)
                continue
            raise

    raise MaxRetriesExceeded("流式输出重试次数耗尽")

批量请求的重试

处理批量请求时,部分成功部分失败的处理:

from dataclasses import dataclass
from typing import Optional

@dataclass
class BatchResult:
    index: int
    success: bool
    content: Optional[str] = None
    error: Optional[str] = None

async def batch_with_retry(prompts: list[str]) -> list[BatchResult]:
    results = []
    remaining = [(i, p) for i, p in enumerate(prompts)]

    for attempt in range(3):
        if not remaining:
            break

        tasks = []
        for idx, prompt in remaining:
            tasks.append((idx, _single_call(prompt)))

        failed = []
        for idx, task in tasks:
            try:
                content = await task
                results.append(BatchResult(idx, True, content))
            except Exception as e:
                failed.append((idx, prompts[idx]))

        remaining = failed
        if remaining:
            await asyncio.sleep(2 ** attempt)

    # 处理最终失败的
    for idx, _ in remaining:
        results.append(BatchResult(idx, False, error="重试失败"))

    return sorted(results, key=lambda r: r.index)

自适应重试策略

根据历史成功率动态调整重试参数:

class AdaptiveRetryStrategy:
    def __init__(self):
        self.error_history = []
        self.base_delay = 1

    def record_error(self, error_type: str, latency: float):
        self.error_history.append({
            "type": error_type,
            "latency": latency,
            "timestamp": time.time(),
        })
        # 保留最近100条记录
        self.error_history = self.error_history[-100:]

    def get_retry_delay(self, error_type: str, attempt: int) -> float:
        recent_errors = [
            e for e in self.error_history
            if e["type"] == error_type
            and time.time() - e["timestamp"] < 300
        ]
        error_rate = len(recent_errors) / 100

        # 错误率高时,增加等待时间
        multiplier = 1 + error_rate * 3
        delay = self.base_delay * (2 ** attempt) * multiplier
        jitter = delay * 0.3 * random.random()

        return min(delay + jitter, 120)

重试监控

class RetryMonitor:
    def __init__(self):
        self.metrics = {
            "total_calls": 0,
            "successful_first_try": 0,
            "retries_used": 0,
            "final_failures": 0,
        }

    def record_success(self, attempts: int):
        self.metrics["total_calls"] += 1
        if attempts == 1:
            self.metrics["successful_first_try"] += 1
        else:
            self.metrics["retries_used"] += attempts - 1

    def record_failure(self, attempts: int):
        self.metrics["total_calls"] += 1
        self.metrics["final_failures"] += 1

    def get_stats(self):
        total = self.metrics["total_calls"]
        if total == 0:
            return {}
        return {
            "first_try_success_rate": self.metrics["successful_first_try"] / total,
            "avg_retries": self.metrics["retries_used"] / total,
            "failure_rate": self.metrics["final_failures"] / total,
        }

总结

智能重试策略的核心是:区分可重试与不可重试错误、使用指数退避避免雷群效应、为流式和批量请求定制重试逻辑、通过监控持续优化重试参数。