← 返回首页
🧠

LLM回滚策略

📂 llm ⏱ 4 min 648 words

--- title: "LLM回滚策略" description: "LLM模型的回滚策略设计与实现,包括快速回滚、蓝绿回滚和自动化回滚机制" tags: ["回滚策略", "故障恢复", "版本管理"] category: "llm" icon: "🧠"

LLM回滚策略

概述

回滚(Rollback)是当新部署的模型出现问题时,快速恢复到上一个稳定版本的能力。对于LLM项目,回滚策略至关重要,因为模型行为的不确定性可能导致严重的用户体验问题或业务损失。

回滚类型

即时回滚

即时回滚适用于蓝绿部署场景,只需切换流量即可:

# rollback/instant.py
from datetime import datetime
from typing import Optional
import json

class InstantRollback:
    def __init__(self, config_path: str):
        self.config_path = config_path
        self.deployment_history: list[dict] = []
        self._load_history()
    
    def _load_history(self):
        try:
            with open(self.config_path) as f:
                data = json.load(f)
                self.deployment_history = data.get("history", [])
        except FileNotFoundError:
            self.deployment_history = []
    
    def _save_history(self):
        with open(self.config_path, "w") as f:
            json.dump({"history": self.deployment_history}, f, indent=2)
    
    def record_deployment(self, version: str, endpoint: str, 
                         config: dict):
        self.deployment_history.append({
            "version": version,
            "endpoint": endpoint,
            "config": config,
            "deployed_at": datetime.now().isoformat(),
            "status": "active"
        })
        self._save_history()
    
    def rollback(self, target_version: Optional[str] = None) -> dict:
        if target_version:
            target = next(
                (d for d in self.deployment_history if d["version"] == target_version),
                None
            )
        else:
            # 回滚到上一个版本
            active_idx = next(
                (i for i, d in enumerate(self.deployment_history) 
                 if d["status"] == "active"),
                None
            )
            target = (
                self.deployment_history[active_idx - 1] 
                if active_idx and active_idx > 0 
                else None
            )
        
        if not target:
            raise ValueError("No valid rollback target found")
        
        # 标记当前活跃版本为已回滚
        for d in self.deployment_history:
            if d["status"] == "active":
                d["status"] = "rolled_back"
        
        # 激活目标版本
        target["status"] = "active"
        target["rolled_back_at"] = datetime.now().isoformat()
        
        self._save_history()
        
        return {
            "rollback_to": target["version"],
            "endpoint": target["endpoint"],
            "config": target["config"],
        }
    
    def get_current_version(self) -> dict:
        active = next(
            (d for d in self.deployment_history if d["status"] == "active"),
            None
        )
        return active

渐进式回滚

对于金丝雀发布,需要逐步减少新版本流量:

# rollback/gradual.py
import time

class GradualRollback:
    def __init__(self, router):
        self.router = router
        self.rollback_steps = [0.5, 0.2, 0.05, 0.0]  # 金丝雀流量比例
        self.step_duration_seconds = 60
    
    def execute_rollback(self) -> bool:
        print("Starting gradual rollback...")
        
        for step, target_weight in enumerate(self.rollback_steps):
            print(f"Step {step + 1}/{len(self.rollback_steps)}: "
                  f"Reducing canary to {target_weight * 100:.0f}%")
            
            self.router.set_canary_weight(target_weight)
            
            if target_weight == 0:
                print("Rollback complete: 100% traffic to stable version")
                return True
            
            time.sleep(self.step_duration_seconds)
            
            if not self._check_rollback_health():
                print("⚠️ Health check failed during rollback, pausing")
                return False
        
        return True
    
    def _check_rollback_health(self) -> bool:
        # 检查当前健康状态
        # 在回滚过程中,主要关注错误率下降
        return True

数据库状态回滚

# rollback/state.py
import sqlite3
from datetime import datetime
from typing import Optional

class StateRollback:
    def __init__(self, db_path: str):
        self.db_path = db_path
        self._init_db()
    
    def _init_db(self):
        conn = sqlite3.connect(self.db_path)
        conn.execute("""
            CREATE TABLE IF NOT EXISTS deployment_snapshots (
                id INTEGER PRIMARY KEY,
                version TEXT,
                state_json TEXT,
                created_at TEXT,
                description TEXT
            )
        """)
        conn.commit()
        conn.close()
    
    def save_snapshot(self, version: str, state: dict, 
                     description: str = "") -> int:
        conn = sqlite3.connect(self.db_path)
        cursor = conn.execute(
            """INSERT INTO deployment_snapshots 
               (version, state_json, created_at, description)
               VALUES (?, ?, ?, ?)""",
            (version, json.dumps(state), 
             datetime.now().isoformat(), description)
        )
        snapshot_id = cursor.lastrowid
        conn.commit()
        conn.close()
        return snapshot_id
    
    def restore_snapshot(self, snapshot_id: int) -> dict:
        conn = sqlite3.connect(self.db_path)
        cursor = conn.execute(
            "SELECT state_json FROM deployment_snapshots WHERE id = ?",
            (snapshot_id,)
        )
        row = cursor.fetchone()
        conn.close()
        
        if not row:
            raise ValueError(f"Snapshot {snapshot_id} not found")
        
        return json.loads(row[0])
    
    def list_snapshots(self, version: Optional[str] = None) -> list:
        conn = sqlite3.connect(self.db_path)
        if version:
            cursor = conn.execute(
                "SELECT id, version, created_at, description "
                "FROM deployment_snapshots WHERE version = ?",
                (version,)
            )
        else:
            cursor = conn.execute(
                "SELECT id, version, created_at, description "
                "FROM deployment_snapshots ORDER BY created_at DESC"
            )
        
        snapshots = cursor.fetchall()
        conn.close()
        return [
            {"id": s[0], "version": s[1], "created_at": s[2], "description": s[3]}
            for s in snapshots
        ]

自动化回滚触发器

# rollback/auto_trigger.py
from dataclasses import dataclass
from enum import Enum

class AlertSeverity(Enum):
    INFO = "info"
    WARNING = "warning"
    CRITICAL = "critical"

@dataclass
class RollbackTrigger:
    metric_name: str
    threshold: float
    operator: str  # "gt", "lt", "eq"
    severity: AlertSeverity
    cooldown_minutes: int = 5

class AutoRollbackTrigger:
    def __init__(self, rollback_manager, triggers: list[RollbackTrigger]):
        self.rollback_manager = rollback_manager
        self.triggers = triggers
        self.last_triggered: dict[str, float] = {}
    
    def check_and_trigger(self, metrics: dict) -> bool:
        for trigger in self.triggers:
            if trigger.metric_name not in metrics:
                continue
            
            value = metrics[trigger.metric_name]
            triggered = self._evaluate(value, trigger)
            
            if triggered:
                if self._in_cooldown(trigger.metric_name, trigger.cooldown_minutes):
                    continue
                
                print(f"🚨 Trigger activated: {trigger.metric_name} "
                      f"{trigger.operator} {trigger.threshold} "
                      f"(current: {value})")
                
                self.last_triggered[trigger.metric_name] = time.time()
                
                if trigger.severity == AlertSeverity.CRITICAL:
                    self.rollback_manager.rollback()
                    return True
                elif trigger.severity == AlertSeverity.WARNING:
                    print("Warning: auto-rollback not triggered for warning severity")
        
        return False
    
    def _evaluate(self, value: float, trigger: RollbackTrigger) -> bool:
        if trigger.operator == "gt":
            return value > trigger.threshold
        elif trigger.operator == "lt":
            return value < trigger.threshold
        elif trigger.operator == "eq":
            return value == trigger.threshold
        return False
    
    def _in_cooldown(self, metric: str, cooldown_min: int) -> bool:
        if metric not in self.last_triggered:
            return False
        elapsed = time.time() - self.last_triggered[metric]
        return elapsed < cooldown_min * 60

最佳实践

  1. 保留历史版本:至少保留最近3个稳定版本的完整配置和模型文件
  2. 回滚演练:定期执行回滚演练,确保回滚流程可靠
  3. 监控驱动:设置自动回滚触发器,基于关键指标自动决策
  4. 回滚后验证:回滚完成后验证系统恢复正常
  5. 事后分析:回滚后记录原因,防止类似问题再次发生