MLOps架构:模型版本管理、AB测试与全链路监控
MLOps架构:模型版本管理、AB测试与全链路监控
MLOps平台架构
MLOps是将机器学习模型从实验阶段推向生产环境的工程实践,核心目标是实现模型开发、部署、监控的自动化和标准化。平台架构包括:实验管理、模型注册、自动化流水线、部署编排和生产监控。
# MLOps平台核心组件
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from datetime import datetime
from enum import Enum
import uuid
class ModelStage(Enum):
DEVELOPMENT = "development"
STAGING = "staging"
PRODUCTION = "production"
ARCHIVED = "archived"
@dataclass
class ModelArtifact:
model_id: str
name: str
version: str
stage: ModelStage
metrics: Dict[str, float]
parameters: Dict[str, Any]
artifacts_uri: str
created_at: datetime = field(default_factory=datetime.now)
tags: List[str] = field(default_factory=list)
description: str = ""
@dataclass
class Experiment:
experiment_id: str
name: str
run_id: str
params: Dict[str, Any]
metrics: Dict[str, float]
artifacts: List[str]
status: str = "running"
class MLOpsPlatform:
def __init__(self):
self.model_registry = {}
self.experiments = {}
self.deployments = {}
def register_model(self, artifact: ModelArtifact):
"""注册模型到注册表"""
key = f"{artifact.name}:{artifact.version}"
self.model_registry[key] = artifact
print(f"Registered model: {key}")
def transition_model(self, model_name: str, version: str,
target_stage: ModelStage) -> bool:
"""模型阶段转换"""
key = f"{model_name}:{version}"
model = self.model_registry.get(key)
if not model:
return False
# 验证转换是否合法
valid_transitions = {
ModelStage.DEVELOPMENT: [ModelStage.STAGING],
ModelStage.STAGING: [ModelStage.PRODUCTION, ModelStage.ARCHIVED],
ModelStage.PRODUCTION: [ModelStage.ARCHIVED],
}
if target_stage in valid_transitions.get(model.stage, []):
model.stage = target_stage
print(f"Model {key} transitioned to {target_stage.value}")
return True
return False
def get_production_model(self, model_name: str) -> Optional[ModelArtifact]:
"""获取生产环境模型"""
for key, model in self.model_registry.items():
if model.name == model_name and model.stage == ModelStage.PRODUCTION:
return model
return None
实验跟踪与版本管理
实验跟踪记录每次训练的参数、指标和产物,支持实验对比和可重现性。使用MLflow或Weights & Biases进行实验管理,确保每次实验都有完整的上下文记录。
# 实验跟踪器
import json
from pathlib import Path
class ExperimentTracker:
def __init__(self, tracking_uri: str):
self.tracking_uri = tracking_uri
self.experiments = {}
self.runs = {}
def create_experiment(self, name: str, tags: Dict = None) -> str:
"""创建新实验"""
exp_id = str(uuid.uuid4())[:8]
self.experiments[exp_id] = {
"name": name,
"tags": tags or {},
"created_at": datetime.now().isoformat(),
"runs": []
}
return exp_id
def start_run(self, experiment_id: str, run_name: str = None) -> str:
"""开始一次运行"""
run_id = str(uuid.uuid4())[:8]
self.runs[run_id] = {
"experiment_id": experiment_id,
"name": run_name or f"run_{run_id}",
"status": "running",
"params": {},
"metrics": {},
"artifacts": [],
"start_time": datetime.now().isoformat()
}
if experiment_id in self.experiments:
self.experiments[experiment_id]["runs"].append(run_id)
return run_id
def log_param(self, run_id: str, key: str, value: Any):
"""记录参数"""
if run_id in self.runs:
self.runs[run_id]["params"][key] = value
def log_metric(self, run_id: str, key: str, value: float,
step: int = None):
"""记录指标"""
if run_id in self.runs:
if key not in self.runs[run_id]["metrics"]:
self.runs[run_id]["metrics"][key] = []
self.runs[run_id]["metrics"][key].append({
"value": value,
"step": step,
"timestamp": datetime.now().isoformat()
})
def log_artifact(self, run_id: str, artifact_path: str):
"""记录产物"""
if run_id in self.runs:
self.runs[run_id]["artifacts"].append(artifact_path)
def end_run(self, run_id: str, status: str = "completed"):
"""结束运行"""
if run_id in self.runs:
self.runs[run_id]["status"] = status
self.runs[run_id]["end_time"] = datetime.now().isoformat()
def compare_runs(self, run_ids: List[str]) -> Dict:
"""对比多次运行"""
comparison = {"params": {}, "metrics": {}}
for run_id in run_ids:
run = self.runs.get(run_id)
if run:
comparison["params"][run_id] = run["params"]
comparison["metrics"][run_id] = {
k: v[-1]["value"] if v else None
for k, v in run["metrics"].items()
}
return comparison
# 使用示例
tracker = ExperimentTracker("./mlruns")
exp_id = tracker.create_experiment("classifier_training")
run_id = tracker.start_run(exp_id, "resnet50_v1")
tracker.log_param(run_id, "model", "resnet50")
tracker.log_param(run_id, "lr", 0.001)
tracker.log_metric(run_id, "accuracy", 0.95)
tracker.log_metric(run_id, "loss", 0.12)
tracker.end_run(run_id)
自动化部署流水线
自动化部署流水线确保模型从注册到上线的全流程自动化,包括模型验证、容器化、基础设施配置和流量切换。
# 部署流水线
class DeploymentPipeline:
def __init__(self, platform: MLOpsPlatform):
self.platform = platform
self.stages = []
def add_validation(self, validation_fn):
"""添加验证阶段"""
self.stages.append(("validate", validation_fn))
return self
def add_containerization(self, dockerfile_template: str):
"""添加容器化阶段"""
self.stages.append(("containerize", dockerfile_template))
return self
def add_deployment(self, deployment_config: Dict):
"""添加部署阶段"""
self.stages.append(("deploy", deployment_config))
return self
def execute(self, model_name: str, version: str) -> Dict:
"""执行部署流水线"""
model = self.platform.get_production_model(model_name)
results = {"stages": [], "status": "success"}
for stage_name, stage_config in self.stages:
print(f"Executing stage: {stage_name}")
if stage_name == "validate":
passed = stage_config(model)
results["stages"].append({
"name": stage_name,
"passed": passed
})
if not passed:
results["status"] = "failed"
break
elif stage_name == "containerize":
image_tag = f"{model_name}:{version}"
results["stages"].append({
"name": stage_name,
"image": image_tag
})
elif stage_name == "deploy":
results["stages"].append({
"name": stage_name,
"config": stage_config
})
return results
# 模型验证器
class ModelValidator:
def __init__(self, checks: List[Dict]):
self.checks = checks
def validate(self, model: ModelArtifact) -> bool:
"""执行所有验证检查"""
for check in self.checks:
if check["type"] == "metric_threshold":
metric_name = check["metric"]
threshold = check["threshold"]
actual = model.metrics.get(metric_name, 0)
if actual < threshold:
print(f"Validation failed: {metric_name}={actual} < {threshold}")
return False
elif check["type"] == "schema_check":
required_params = check["required"]
for param in required_params:
if param not in model.parameters:
print(f"Validation failed: missing parameter {param}")
return False
return True
# Kubernetes部署配置
def generate_k8s_config(model_name: str, version: str,
replicas: int = 3) -> Dict:
return {
"apiVersion": "apps/v1",
"kind": "Deployment",
"metadata": {
"name": f"{model_name}-serving",
"labels": {"app": model_name, "version": version}
},
"spec": {
"replicas": replicas,
"selector": {
"matchLabels": {"app": model_name}
},
"template": {
"metadata": {
"labels": {"app": model_name, "version": version}
},
"spec": {
"containers": [{
"name": model_name,
"image": f"{model_name}:{version}",
"ports": [{"containerPort": 8000}],
"resources": {
"requests": {"memory": "2Gi", "cpu": "1"},
"limits": {"nvidia.com/gpu": 1}
}
}]
}
}
}
}
A/B测试与灰度发布
A/B测试验证新模型是否优于现有模型,灰度发布逐步将流量切换到新模型。关键是定义清晰的成功指标、控制流量分配和设置自动回滚条件。
# A/B测试管理器
@dataclass
class ABTest:
test_id: str
name: str
control_model: str
treatment_model: str
traffic_split: Dict[str, float] # model -> percentage
status: str = "running"
start_time: datetime = field(default_factory=datetime.now)
metrics: Dict[str, Dict] = field(default_factory=dict)
class ABTestManager:
def __init__(self):
self.active_tests = {}
self.test_history = []
def create_test(self, name: str, control: str, treatment: str,
traffic_split: Dict[str, float] = None) -> ABTest:
"""创建A/B测试"""
test = ABTest(
test_id=str(uuid.uuid4())[:8],
name=name,
control_model=control,
treatment_model=treatment,
traffic_split=traffic_split or {control: 0.5, treatment: 0.5}
)
self.active_tests[test.test_id] = test
return test
def route_request(self, test_id: str) -> str:
"""根据流量分配路由请求"""
test = self.active_tests.get(test_id)
if not test:
return "control"
import random
rand = random.random()
cumulative = 0
for model, percentage in test.traffic_split.items():
cumulative += percentage
if rand <= cumulative:
return model
return "control"
def record_metric(self, test_id: str, model: str,
metric_name: str, value: float):
"""记录测试指标"""
test = self.active_tests.get(test_id)
if test:
if model not in test.metrics:
test.metrics[model] = {}
if metric_name not in test.metrics[model]:
test.metrics[model][metric_name] = []
test.metrics[model][metric_name].append(value)
def analyze_results(self, test_id: str) -> Dict:
"""分析测试结果"""
test = self.active_tests.get(test_id)
if not test:
return {"error": "Test not found"}
results = {}
for model, metrics in test.metrics.items():
results[model] = {}
for metric_name, values in metrics.items():
if values:
results[model][metric_name] = {
"mean": sum(values) / len(values),
"count": len(values),
"std": (sum((x - sum(values)/len(values))**2
for x in values) / len(values)) ** 0.5
}
# 判断是否显著
control_metrics = results.get(test.control_model, {})
treatment_metrics = results.get(test.treatment_model, {})
recommendation = "continue"
if control_metrics and treatment_metrics:
control_conv = control_metrics.get("conversion_rate", {}).get("mean", 0)
treatment_conv = treatment_metrics.get("conversion_rate", {}).get("mean", 0)
if treatment_conv > control_conv * 1.05: # 5%提升
recommendation = "promote_treatment"
elif treatment_conv < control_conv * 0.95: # 5%下降
recommendation = "rollback"
return {
"results": results,
"recommendation": recommendation
}
def end_test(self, test_id: str, winner: str = None):
"""结束测试"""
test = self.active_tests.pop(test_id, None)
if test:
test.status = "completed"
test.end_time = datetime.now()
if winner:
test.winner = winner
self.test_history.append(test)
# 灰度发布控制器
class CanaryReleaseController:
def __init__(self, initial_percentage: float = 0.1):
self.current_percentage = initial_percentage
self.stages = [0.1, 0.25, 0.5, 0.75, 1.0]
self.current_stage = 0
def should_promote(self, metrics: Dict) -> bool:
"""根据指标决定是否进入下一阶段"""
error_rate = metrics.get("error_rate", 0)
latency_p99 = metrics.get("latency_p99", 0)
# 错误率超过阈值则回滚
if error_rate > 0.01:
return False
# P99延迟超过阈值则暂停
if latency_p99 > 200:
return False
return True
def advance(self):
"""进入下一阶段"""
if self.current_stage < len(self.stages) - 1:
self.current_stage += 1
self.current_percentage = self.stages[self.current_stage]
def rollback(self):
"""回滚到上一阶段"""
if self.current_stage > 0:
self.current_stage -= 1
self.current_percentage = self.stages[self.current_stage]
else:
self.current_percentage = 0
生产环境监控
生产环境需要全方位监控模型健康状况,包括:性能指标(延迟、吞吐)、业务指标(转化率、收益)、数据质量(漂移检测、缺失率)和系统资源(GPU、内存)。
# 模型监控系统
from collections import deque
import statistics
class ModelMonitor:
def __init__(self, alert_thresholds: Dict = None):
self.metrics_history = {}
self.alert_thresholds = alert_thresholds or {
"error_rate": 0.01,
"latency_p99": 200,
"drift_score": 0.1
}
self.alerts = []
def record_metric(self, metric_name: str, value: float,
model_name: str = "default"):
"""记录指标"""
key = f"{model_name}:{metric_name}"
if key not in self.metrics_history:
self.metrics_history[key] = deque(maxlen=1000)
self.metrics_history[key].append({
"value": value,
"timestamp": datetime.now()
})
# 检查告警阈值
self._check_alerts(key, value)
def _check_alerts(self, metric_key: str, value: float):
"""检查是否触发告警"""
metric_name = metric_key.split(":")[-1]
threshold = self.alert_thresholds.get(metric_name)
if threshold and value > threshold:
alert = {
"metric": metric_key,
"value": value,
"threshold": threshold,
"timestamp": datetime.now(),
"severity": "critical" if value > threshold * 2 else "warning"
}
self.alerts.append(alert)
print(f"Alert: {metric_key} = {value} exceeds threshold {threshold}")
def detect_drift(self, metric_name: str, model_name: str,
reference_window: int = 100) -> Dict:
"""检测数据漂移"""
key = f"{model_name}:{metric_name}"
history = list(self.metrics_history.get(key, []))
if len(history) < reference_window * 2:
return {"drift_detected": False, "reason": "insufficient_data"}
# 分割为参考窗口和当前窗口
reference = [h["value"] for h in history[:reference_window]]
current = [h["value"] for h in history[-reference_window:]]
# 计算统计指标
ref_mean = statistics.mean(reference)
curr_mean = statistics.mean(current)
drift_score = abs(curr_mean - ref_mean) / (ref_mean + 1e-8)
return {
"drift_detected": drift_score > self.alert_thresholds["drift_score"],
"drift_score": drift_score,
"reference_mean": ref_mean,
"current_mean": curr_mean
}
def get_health_report(self, model_name: str) -> Dict:
"""生成健康报告"""
report = {
"model": model_name,
"timestamp": datetime.now(),
"metrics": {},
"alerts": [],
"health_score": 100
}
# 收集所有相关指标
for key, history in self.metrics_history.items():
if key.startswith(model_name):
metric_name = key.split(":")[-1]
values = [h["value"] for h in history]
if values:
report["metrics"][metric_name] = {
"current": values[-1],
"mean": statistics.mean(values),
"min": min(values),
"max": max(values)
}
# 收集告警
model_alerts = [a for a in self.alerts
if a["metric"].startswith(model_name)]
report["alerts"] = model_alerts[-10:] # 最近10条告警
# 计算健康分数
if model_alerts:
critical_alerts = [a for a in model_alerts if a["severity"] == "critical"]
report["health_score"] -= len(critical_alerts) * 20
report["health_score"] -= (len(model_alerts) - len(critical_alerts)) * 5
report["health_score"] = max(0, report["health_score"])
return report
# 自动回滚系统
class AutoRollbackSystem:
def __init__(self, monitor: ModelMonitor, platform: MLOpsPlatform):
self.monitor = monitor
self.platform = platform
self.rollback_thresholds = {
"error_rate": 0.05,
"latency_degradation": 2.0
}
def check_and_rollback(self, model_name: str) -> bool:
"""检查是否需要回滚"""
report = self.monitor.get_health_report(model_name)
# 检查错误率
error_rate = report["metrics"].get("error_rate", {}).get("current", 0)
if error_rate > self.rollback_thresholds["error_rate"]:
print(f"Triggering rollback for {model_name}: high error rate")
self._execute_rollback(model_name)
return True
# 检查健康分数
if report["health_score"] < 50:
print(f"Triggering rollback for {model_name}: low health score")
self._execute_rollback(model_name)
return True
return False
def _execute_rollback(self, model_name: str):
"""执行回滚"""
# 获取上一个稳定版本
print(f"Rolling back model {model_name}")
# 实际实现需要从注册表中获取上一个生产版本