LLM回滚策略
--- title: "LLM回滚策略" description: "LLM模型的回滚策略设计与实现,包括快速回滚、蓝绿回滚和自动化回滚机制" tags: ["回滚策略", "故障恢复", "版本管理"] category: "llm" icon: "🧠"
LLM回滚策略
概述
回滚(Rollback)是当新部署的模型出现问题时,快速恢复到上一个稳定版本的能力。对于LLM项目,回滚策略至关重要,因为模型行为的不确定性可能导致严重的用户体验问题或业务损失。
回滚类型
即时回滚
即时回滚适用于蓝绿部署场景,只需切换流量即可:
# rollback/instant.py
from datetime import datetime
from typing import Optional
import json
class InstantRollback:
def __init__(self, config_path: str):
self.config_path = config_path
self.deployment_history: list[dict] = []
self._load_history()
def _load_history(self):
try:
with open(self.config_path) as f:
data = json.load(f)
self.deployment_history = data.get("history", [])
except FileNotFoundError:
self.deployment_history = []
def _save_history(self):
with open(self.config_path, "w") as f:
json.dump({"history": self.deployment_history}, f, indent=2)
def record_deployment(self, version: str, endpoint: str,
config: dict):
self.deployment_history.append({
"version": version,
"endpoint": endpoint,
"config": config,
"deployed_at": datetime.now().isoformat(),
"status": "active"
})
self._save_history()
def rollback(self, target_version: Optional[str] = None) -> dict:
if target_version:
target = next(
(d for d in self.deployment_history if d["version"] == target_version),
None
)
else:
# 回滚到上一个版本
active_idx = next(
(i for i, d in enumerate(self.deployment_history)
if d["status"] == "active"),
None
)
target = (
self.deployment_history[active_idx - 1]
if active_idx and active_idx > 0
else None
)
if not target:
raise ValueError("No valid rollback target found")
# 标记当前活跃版本为已回滚
for d in self.deployment_history:
if d["status"] == "active":
d["status"] = "rolled_back"
# 激活目标版本
target["status"] = "active"
target["rolled_back_at"] = datetime.now().isoformat()
self._save_history()
return {
"rollback_to": target["version"],
"endpoint": target["endpoint"],
"config": target["config"],
}
def get_current_version(self) -> dict:
active = next(
(d for d in self.deployment_history if d["status"] == "active"),
None
)
return active
渐进式回滚
对于金丝雀发布,需要逐步减少新版本流量:
# rollback/gradual.py
import time
class GradualRollback:
def __init__(self, router):
self.router = router
self.rollback_steps = [0.5, 0.2, 0.05, 0.0] # 金丝雀流量比例
self.step_duration_seconds = 60
def execute_rollback(self) -> bool:
print("Starting gradual rollback...")
for step, target_weight in enumerate(self.rollback_steps):
print(f"Step {step + 1}/{len(self.rollback_steps)}: "
f"Reducing canary to {target_weight * 100:.0f}%")
self.router.set_canary_weight(target_weight)
if target_weight == 0:
print("Rollback complete: 100% traffic to stable version")
return True
time.sleep(self.step_duration_seconds)
if not self._check_rollback_health():
print("⚠️ Health check failed during rollback, pausing")
return False
return True
def _check_rollback_health(self) -> bool:
# 检查当前健康状态
# 在回滚过程中,主要关注错误率下降
return True
数据库状态回滚
# rollback/state.py
import sqlite3
from datetime import datetime
from typing import Optional
class StateRollback:
def __init__(self, db_path: str):
self.db_path = db_path
self._init_db()
def _init_db(self):
conn = sqlite3.connect(self.db_path)
conn.execute("""
CREATE TABLE IF NOT EXISTS deployment_snapshots (
id INTEGER PRIMARY KEY,
version TEXT,
state_json TEXT,
created_at TEXT,
description TEXT
)
""")
conn.commit()
conn.close()
def save_snapshot(self, version: str, state: dict,
description: str = "") -> int:
conn = sqlite3.connect(self.db_path)
cursor = conn.execute(
"""INSERT INTO deployment_snapshots
(version, state_json, created_at, description)
VALUES (?, ?, ?, ?)""",
(version, json.dumps(state),
datetime.now().isoformat(), description)
)
snapshot_id = cursor.lastrowid
conn.commit()
conn.close()
return snapshot_id
def restore_snapshot(self, snapshot_id: int) -> dict:
conn = sqlite3.connect(self.db_path)
cursor = conn.execute(
"SELECT state_json FROM deployment_snapshots WHERE id = ?",
(snapshot_id,)
)
row = cursor.fetchone()
conn.close()
if not row:
raise ValueError(f"Snapshot {snapshot_id} not found")
return json.loads(row[0])
def list_snapshots(self, version: Optional[str] = None) -> list:
conn = sqlite3.connect(self.db_path)
if version:
cursor = conn.execute(
"SELECT id, version, created_at, description "
"FROM deployment_snapshots WHERE version = ?",
(version,)
)
else:
cursor = conn.execute(
"SELECT id, version, created_at, description "
"FROM deployment_snapshots ORDER BY created_at DESC"
)
snapshots = cursor.fetchall()
conn.close()
return [
{"id": s[0], "version": s[1], "created_at": s[2], "description": s[3]}
for s in snapshots
]
自动化回滚触发器
# rollback/auto_trigger.py
from dataclasses import dataclass
from enum import Enum
class AlertSeverity(Enum):
INFO = "info"
WARNING = "warning"
CRITICAL = "critical"
@dataclass
class RollbackTrigger:
metric_name: str
threshold: float
operator: str # "gt", "lt", "eq"
severity: AlertSeverity
cooldown_minutes: int = 5
class AutoRollbackTrigger:
def __init__(self, rollback_manager, triggers: list[RollbackTrigger]):
self.rollback_manager = rollback_manager
self.triggers = triggers
self.last_triggered: dict[str, float] = {}
def check_and_trigger(self, metrics: dict) -> bool:
for trigger in self.triggers:
if trigger.metric_name not in metrics:
continue
value = metrics[trigger.metric_name]
triggered = self._evaluate(value, trigger)
if triggered:
if self._in_cooldown(trigger.metric_name, trigger.cooldown_minutes):
continue
print(f"🚨 Trigger activated: {trigger.metric_name} "
f"{trigger.operator} {trigger.threshold} "
f"(current: {value})")
self.last_triggered[trigger.metric_name] = time.time()
if trigger.severity == AlertSeverity.CRITICAL:
self.rollback_manager.rollback()
return True
elif trigger.severity == AlertSeverity.WARNING:
print("Warning: auto-rollback not triggered for warning severity")
return False
def _evaluate(self, value: float, trigger: RollbackTrigger) -> bool:
if trigger.operator == "gt":
return value > trigger.threshold
elif trigger.operator == "lt":
return value < trigger.threshold
elif trigger.operator == "eq":
return value == trigger.threshold
return False
def _in_cooldown(self, metric: str, cooldown_min: int) -> bool:
if metric not in self.last_triggered:
return False
elapsed = time.time() - self.last_triggered[metric]
return elapsed < cooldown_min * 60
最佳实践
- 保留历史版本:至少保留最近3个稳定版本的完整配置和模型文件
- 回滚演练:定期执行回滚演练,确保回滚流程可靠
- 监控驱动:设置自动回滚触发器,基于关键指标自动决策
- 回滚后验证:回滚完成后验证系统恢复正常
- 事后分析:回滚后记录原因,防止类似问题再次发生