← 返回首页
🤖

MLOps架构:模型版本管理、AB测试与全链路监控

📂 architecture ⏱ 7 min 1400 words

MLOps架构:模型版本管理、AB测试与全链路监控

MLOps平台架构

MLOps是将机器学习模型从实验阶段推向生产环境的工程实践,核心目标是实现模型开发、部署、监控的自动化和标准化。平台架构包括:实验管理、模型注册、自动化流水线、部署编排和生产监控。

# MLOps平台核心组件
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from datetime import datetime
from enum import Enum
import uuid

class ModelStage(Enum):
    DEVELOPMENT = "development"
    STAGING = "staging"
    PRODUCTION = "production"
    ARCHIVED = "archived"

@dataclass
class ModelArtifact:
    model_id: str
    name: str
    version: str
    stage: ModelStage
    metrics: Dict[str, float]
    parameters: Dict[str, Any]
    artifacts_uri: str
    created_at: datetime = field(default_factory=datetime.now)
    tags: List[str] = field(default_factory=list)
    description: str = ""

@dataclass
class Experiment:
    experiment_id: str
    name: str
    run_id: str
    params: Dict[str, Any]
    metrics: Dict[str, float]
    artifacts: List[str]
    status: str = "running"

class MLOpsPlatform:
    def __init__(self):
        self.model_registry = {}
        self.experiments = {}
        self.deployments = {}
    
    def register_model(self, artifact: ModelArtifact):
        """注册模型到注册表"""
        key = f"{artifact.name}:{artifact.version}"
        self.model_registry[key] = artifact
        print(f"Registered model: {key}")
    
    def transition_model(self, model_name: str, version: str, 
                        target_stage: ModelStage) -> bool:
        """模型阶段转换"""
        key = f"{model_name}:{version}"
        model = self.model_registry.get(key)
        
        if not model:
            return False
        
        # 验证转换是否合法
        valid_transitions = {
            ModelStage.DEVELOPMENT: [ModelStage.STAGING],
            ModelStage.STAGING: [ModelStage.PRODUCTION, ModelStage.ARCHIVED],
            ModelStage.PRODUCTION: [ModelStage.ARCHIVED],
        }
        
        if target_stage in valid_transitions.get(model.stage, []):
            model.stage = target_stage
            print(f"Model {key} transitioned to {target_stage.value}")
            return True
        
        return False
    
    def get_production_model(self, model_name: str) -> Optional[ModelArtifact]:
        """获取生产环境模型"""
        for key, model in self.model_registry.items():
            if model.name == model_name and model.stage == ModelStage.PRODUCTION:
                return model
        return None

实验跟踪与版本管理

实验跟踪记录每次训练的参数、指标和产物,支持实验对比和可重现性。使用MLflow或Weights & Biases进行实验管理,确保每次实验都有完整的上下文记录。

# 实验跟踪器
import json
from pathlib import Path

class ExperimentTracker:
    def __init__(self, tracking_uri: str):
        self.tracking_uri = tracking_uri
        self.experiments = {}
        self.runs = {}
    
    def create_experiment(self, name: str, tags: Dict = None) -> str:
        """创建新实验"""
        exp_id = str(uuid.uuid4())[:8]
        self.experiments[exp_id] = {
            "name": name,
            "tags": tags or {},
            "created_at": datetime.now().isoformat(),
            "runs": []
        }
        return exp_id
    
    def start_run(self, experiment_id: str, run_name: str = None) -> str:
        """开始一次运行"""
        run_id = str(uuid.uuid4())[:8]
        self.runs[run_id] = {
            "experiment_id": experiment_id,
            "name": run_name or f"run_{run_id}",
            "status": "running",
            "params": {},
            "metrics": {},
            "artifacts": [],
            "start_time": datetime.now().isoformat()
        }
        
        if experiment_id in self.experiments:
            self.experiments[experiment_id]["runs"].append(run_id)
        
        return run_id
    
    def log_param(self, run_id: str, key: str, value: Any):
        """记录参数"""
        if run_id in self.runs:
            self.runs[run_id]["params"][key] = value
    
    def log_metric(self, run_id: str, key: str, value: float, 
                   step: int = None):
        """记录指标"""
        if run_id in self.runs:
            if key not in self.runs[run_id]["metrics"]:
                self.runs[run_id]["metrics"][key] = []
            self.runs[run_id]["metrics"][key].append({
                "value": value,
                "step": step,
                "timestamp": datetime.now().isoformat()
            })
    
    def log_artifact(self, run_id: str, artifact_path: str):
        """记录产物"""
        if run_id in self.runs:
            self.runs[run_id]["artifacts"].append(artifact_path)
    
    def end_run(self, run_id: str, status: str = "completed"):
        """结束运行"""
        if run_id in self.runs:
            self.runs[run_id]["status"] = status
            self.runs[run_id]["end_time"] = datetime.now().isoformat()
    
    def compare_runs(self, run_ids: List[str]) -> Dict:
        """对比多次运行"""
        comparison = {"params": {}, "metrics": {}}
        
        for run_id in run_ids:
            run = self.runs.get(run_id)
            if run:
                comparison["params"][run_id] = run["params"]
                comparison["metrics"][run_id] = {
                    k: v[-1]["value"] if v else None 
                    for k, v in run["metrics"].items()
                }
        
        return comparison

# 使用示例
tracker = ExperimentTracker("./mlruns")
exp_id = tracker.create_experiment("classifier_training")
run_id = tracker.start_run(exp_id, "resnet50_v1")

tracker.log_param(run_id, "model", "resnet50")
tracker.log_param(run_id, "lr", 0.001)
tracker.log_metric(run_id, "accuracy", 0.95)
tracker.log_metric(run_id, "loss", 0.12)
tracker.end_run(run_id)

自动化部署流水线

自动化部署流水线确保模型从注册到上线的全流程自动化,包括模型验证、容器化、基础设施配置和流量切换。

# 部署流水线
class DeploymentPipeline:
    def __init__(self, platform: MLOpsPlatform):
        self.platform = platform
        self.stages = []
    
    def add_validation(self, validation_fn):
        """添加验证阶段"""
        self.stages.append(("validate", validation_fn))
        return self
    
    def add_containerization(self, dockerfile_template: str):
        """添加容器化阶段"""
        self.stages.append(("containerize", dockerfile_template))
        return self
    
    def add_deployment(self, deployment_config: Dict):
        """添加部署阶段"""
        self.stages.append(("deploy", deployment_config))
        return self
    
    def execute(self, model_name: str, version: str) -> Dict:
        """执行部署流水线"""
        model = self.platform.get_production_model(model_name)
        results = {"stages": [], "status": "success"}
        
        for stage_name, stage_config in self.stages:
            print(f"Executing stage: {stage_name}")
            
            if stage_name == "validate":
                passed = stage_config(model)
                results["stages"].append({
                    "name": stage_name,
                    "passed": passed
                })
                if not passed:
                    results["status"] = "failed"
                    break
            
            elif stage_name == "containerize":
                image_tag = f"{model_name}:{version}"
                results["stages"].append({
                    "name": stage_name,
                    "image": image_tag
                })
            
            elif stage_name == "deploy":
                results["stages"].append({
                    "name": stage_name,
                    "config": stage_config
                })
        
        return results

# 模型验证器
class ModelValidator:
    def __init__(self, checks: List[Dict]):
        self.checks = checks
    
    def validate(self, model: ModelArtifact) -> bool:
        """执行所有验证检查"""
        for check in self.checks:
            if check["type"] == "metric_threshold":
                metric_name = check["metric"]
                threshold = check["threshold"]
                actual = model.metrics.get(metric_name, 0)
                
                if actual < threshold:
                    print(f"Validation failed: {metric_name}={actual} < {threshold}")
                    return False
            
            elif check["type"] == "schema_check":
                required_params = check["required"]
                for param in required_params:
                    if param not in model.parameters:
                        print(f"Validation failed: missing parameter {param}")
                        return False
        
        return True

# Kubernetes部署配置
def generate_k8s_config(model_name: str, version: str, 
                        replicas: int = 3) -> Dict:
    return {
        "apiVersion": "apps/v1",
        "kind": "Deployment",
        "metadata": {
            "name": f"{model_name}-serving",
            "labels": {"app": model_name, "version": version}
        },
        "spec": {
            "replicas": replicas,
            "selector": {
                "matchLabels": {"app": model_name}
            },
            "template": {
                "metadata": {
                    "labels": {"app": model_name, "version": version}
                },
                "spec": {
                    "containers": [{
                        "name": model_name,
                        "image": f"{model_name}:{version}",
                        "ports": [{"containerPort": 8000}],
                        "resources": {
                            "requests": {"memory": "2Gi", "cpu": "1"},
                            "limits": {"nvidia.com/gpu": 1}
                        }
                    }]
                }
            }
        }
    }

A/B测试与灰度发布

A/B测试验证新模型是否优于现有模型,灰度发布逐步将流量切换到新模型。关键是定义清晰的成功指标、控制流量分配和设置自动回滚条件。

# A/B测试管理器
@dataclass
class ABTest:
    test_id: str
    name: str
    control_model: str
    treatment_model: str
    traffic_split: Dict[str, float]  # model -> percentage
    status: str = "running"
    start_time: datetime = field(default_factory=datetime.now)
    metrics: Dict[str, Dict] = field(default_factory=dict)

class ABTestManager:
    def __init__(self):
        self.active_tests = {}
        self.test_history = []
    
    def create_test(self, name: str, control: str, treatment: str,
                   traffic_split: Dict[str, float] = None) -> ABTest:
        """创建A/B测试"""
        test = ABTest(
            test_id=str(uuid.uuid4())[:8],
            name=name,
            control_model=control,
            treatment_model=treatment,
            traffic_split=traffic_split or {control: 0.5, treatment: 0.5}
        )
        
        self.active_tests[test.test_id] = test
        return test
    
    def route_request(self, test_id: str) -> str:
        """根据流量分配路由请求"""
        test = self.active_tests.get(test_id)
        if not test:
            return "control"
        
        import random
        rand = random.random()
        cumulative = 0
        
        for model, percentage in test.traffic_split.items():
            cumulative += percentage
            if rand <= cumulative:
                return model
        
        return "control"
    
    def record_metric(self, test_id: str, model: str, 
                     metric_name: str, value: float):
        """记录测试指标"""
        test = self.active_tests.get(test_id)
        if test:
            if model not in test.metrics:
                test.metrics[model] = {}
            if metric_name not in test.metrics[model]:
                test.metrics[model][metric_name] = []
            test.metrics[model][metric_name].append(value)
    
    def analyze_results(self, test_id: str) -> Dict:
        """分析测试结果"""
        test = self.active_tests.get(test_id)
        if not test:
            return {"error": "Test not found"}
        
        results = {}
        for model, metrics in test.metrics.items():
            results[model] = {}
            for metric_name, values in metrics.items():
                if values:
                    results[model][metric_name] = {
                        "mean": sum(values) / len(values),
                        "count": len(values),
                        "std": (sum((x - sum(values)/len(values))**2 
                               for x in values) / len(values)) ** 0.5
                    }
        
        # 判断是否显著
        control_metrics = results.get(test.control_model, {})
        treatment_metrics = results.get(test.treatment_model, {})
        
        recommendation = "continue"
        if control_metrics and treatment_metrics:
            control_conv = control_metrics.get("conversion_rate", {}).get("mean", 0)
            treatment_conv = treatment_metrics.get("conversion_rate", {}).get("mean", 0)
            
            if treatment_conv > control_conv * 1.05:  # 5%提升
                recommendation = "promote_treatment"
            elif treatment_conv < control_conv * 0.95:  # 5%下降
                recommendation = "rollback"
        
        return {
            "results": results,
            "recommendation": recommendation
        }
    
    def end_test(self, test_id: str, winner: str = None):
        """结束测试"""
        test = self.active_tests.pop(test_id, None)
        if test:
            test.status = "completed"
            test.end_time = datetime.now()
            if winner:
                test.winner = winner
            self.test_history.append(test)

# 灰度发布控制器
class CanaryReleaseController:
    def __init__(self, initial_percentage: float = 0.1):
        self.current_percentage = initial_percentage
        self.stages = [0.1, 0.25, 0.5, 0.75, 1.0]
        self.current_stage = 0
    
    def should_promote(self, metrics: Dict) -> bool:
        """根据指标决定是否进入下一阶段"""
        error_rate = metrics.get("error_rate", 0)
        latency_p99 = metrics.get("latency_p99", 0)
        
        # 错误率超过阈值则回滚
        if error_rate > 0.01:
            return False
        
        # P99延迟超过阈值则暂停
        if latency_p99 > 200:
            return False
        
        return True
    
    def advance(self):
        """进入下一阶段"""
        if self.current_stage < len(self.stages) - 1:
            self.current_stage += 1
            self.current_percentage = self.stages[self.current_stage]
    
    def rollback(self):
        """回滚到上一阶段"""
        if self.current_stage > 0:
            self.current_stage -= 1
            self.current_percentage = self.stages[self.current_stage]
        else:
            self.current_percentage = 0

生产环境监控

生产环境需要全方位监控模型健康状况,包括:性能指标(延迟、吞吐)、业务指标(转化率、收益)、数据质量(漂移检测、缺失率)和系统资源(GPU、内存)。

# 模型监控系统
from collections import deque
import statistics

class ModelMonitor:
    def __init__(self, alert_thresholds: Dict = None):
        self.metrics_history = {}
        self.alert_thresholds = alert_thresholds or {
            "error_rate": 0.01,
            "latency_p99": 200,
            "drift_score": 0.1
        }
        self.alerts = []
    
    def record_metric(self, metric_name: str, value: float, 
                     model_name: str = "default"):
        """记录指标"""
        key = f"{model_name}:{metric_name}"
        if key not in self.metrics_history:
            self.metrics_history[key] = deque(maxlen=1000)
        
        self.metrics_history[key].append({
            "value": value,
            "timestamp": datetime.now()
        })
        
        # 检查告警阈值
        self._check_alerts(key, value)
    
    def _check_alerts(self, metric_key: str, value: float):
        """检查是否触发告警"""
        metric_name = metric_key.split(":")[-1]
        threshold = self.alert_thresholds.get(metric_name)
        
        if threshold and value > threshold:
            alert = {
                "metric": metric_key,
                "value": value,
                "threshold": threshold,
                "timestamp": datetime.now(),
                "severity": "critical" if value > threshold * 2 else "warning"
            }
            self.alerts.append(alert)
            print(f"Alert: {metric_key} = {value} exceeds threshold {threshold}")
    
    def detect_drift(self, metric_name: str, model_name: str,
                    reference_window: int = 100) -> Dict:
        """检测数据漂移"""
        key = f"{model_name}:{metric_name}"
        history = list(self.metrics_history.get(key, []))
        
        if len(history) < reference_window * 2:
            return {"drift_detected": False, "reason": "insufficient_data"}
        
        # 分割为参考窗口和当前窗口
        reference = [h["value"] for h in history[:reference_window]]
        current = [h["value"] for h in history[-reference_window:]]
        
        # 计算统计指标
        ref_mean = statistics.mean(reference)
        curr_mean = statistics.mean(current)
        
        drift_score = abs(curr_mean - ref_mean) / (ref_mean + 1e-8)
        
        return {
            "drift_detected": drift_score > self.alert_thresholds["drift_score"],
            "drift_score": drift_score,
            "reference_mean": ref_mean,
            "current_mean": curr_mean
        }
    
    def get_health_report(self, model_name: str) -> Dict:
        """生成健康报告"""
        report = {
            "model": model_name,
            "timestamp": datetime.now(),
            "metrics": {},
            "alerts": [],
            "health_score": 100
        }
        
        # 收集所有相关指标
        for key, history in self.metrics_history.items():
            if key.startswith(model_name):
                metric_name = key.split(":")[-1]
                values = [h["value"] for h in history]
                if values:
                    report["metrics"][metric_name] = {
                        "current": values[-1],
                        "mean": statistics.mean(values),
                        "min": min(values),
                        "max": max(values)
                    }
        
        # 收集告警
        model_alerts = [a for a in self.alerts 
                       if a["metric"].startswith(model_name)]
        report["alerts"] = model_alerts[-10:]  # 最近10条告警
        
        # 计算健康分数
        if model_alerts:
            critical_alerts = [a for a in model_alerts if a["severity"] == "critical"]
            report["health_score"] -= len(critical_alerts) * 20
            report["health_score"] -= (len(model_alerts) - len(critical_alerts)) * 5
        
        report["health_score"] = max(0, report["health_score"])
        
        return report

# 自动回滚系统
class AutoRollbackSystem:
    def __init__(self, monitor: ModelMonitor, platform: MLOpsPlatform):
        self.monitor = monitor
        self.platform = platform
        self.rollback_thresholds = {
            "error_rate": 0.05,
            "latency_degradation": 2.0
        }
    
    def check_and_rollback(self, model_name: str) -> bool:
        """检查是否需要回滚"""
        report = self.monitor.get_health_report(model_name)
        
        # 检查错误率
        error_rate = report["metrics"].get("error_rate", {}).get("current", 0)
        if error_rate > self.rollback_thresholds["error_rate"]:
            print(f"Triggering rollback for {model_name}: high error rate")
            self._execute_rollback(model_name)
            return True
        
        # 检查健康分数
        if report["health_score"] < 50:
            print(f"Triggering rollback for {model_name}: low health score")
            self._execute_rollback(model_name)
            return True
        
        return False
    
    def _execute_rollback(self, model_name: str):
        """执行回滚"""
        # 获取上一个稳定版本
        print(f"Rolling back model {model_name}")
        # 实际实现需要从注册表中获取上一个生产版本