← 返回首页
🧠

模型版本管理:构建LLM全生命周期管理体系

📂 llm ⏱ 4 min 641 words

--- title: "模型版本管理:构建LLM全生命周期管理体系" description: "掌握模型版本控制和注册表的最佳实践,实现模型从训练到上线的全流程管理" tags: ["模型版本管理", "MLOps", "模型注册表"] category: "llm" icon: "🧠"

模型版本管理:构建LLM全生命周期管理体系

为什么需要模型版本管理

大语言模型的开发和部署是一个复杂的过程,涉及数据准备、模型训练、评估、部署等多个环节。没有系统的版本管理,团队将面临以下问题:

模型版本管理的核心目标是为每个模型建立完整的"身份证",记录其来源、配置、性能指标和部署历史。

版本控制系统设计

基础版本管理

import hashlib
import json
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import Dict, Any, Optional

@dataclass
class ModelVersion:
    version_id: str
    model_name: str
    version: str
    created_at: str
    artifacts: Dict[str, str]  # 文件路径到哈希的映射
    metadata: Dict[str, Any]
    parent_version: Optional[str] = None

class ModelVersionControl:
    def __init__(self, storage_path: str):
        self.storage_path = storage_path
        self.versions = {}

    def create_version(
        self,
        model_name: str,
        version: str,
        artifacts: Dict[str, str],
        metadata: Dict[str, Any]
    ) -> ModelVersion:
        version_id = self._generate_version_id(model_name, version)

        model_version = ModelVersion(
            version_id=version_id,
            model_name=model_name,
            version=version,
            created_at=datetime.now().isoformat(),
            artifacts=artifacts,
            metadata=metadata,
            parent_version=self._get_latest_version(model_name)
        )

        self.versions[version_id] = model_version
        self._save_version(model_version)

        return model_version

    def _generate_version_id(self, model_name: str, version: str) -> str:
        content = f"{model_name}:{version}:{datetime.now().isoformat()}"
        return hashlib.sha256(content.encode()).hexdigest()[:12]

Git式版本追踪

class GitStyleVersioning:
    def __init__(self):
        self.commits = []
        self.branches = {"main": []}

    def commit(self, model_version: ModelVersion, message: str, branch: str = "main"):
        commit = {
            "id": model_version.version_id,
            "parent": self._get_head(branch),
            "message": message,
            "model_version": model_version,
            "timestamp": datetime.now().isoformat()
        }
        self.commits.append(commit)
        self.branches[branch].append(commit["id"])

    def log(self, branch: str = "main", limit: int = 10):
        history = self.branches[branch][-limit:]
        return [self._get_commit(commit_id) for commit_id in reversed(history)]

    def diff(self, version_id_1: str, version_id_2: str):
        v1 = self._get_commit(version_id_1)["model_version"]
        v2 = self._get_commit(version_id_2)["model_version"]

        diff = {
            "metadata_changes": self._diff_dict(v1.metadata, v2.metadata),
            "artifacts_changes": self._diff_dict(v1.artifacts, v2.artifacts)
        }
        return diff

模型注册表(Model Registry)

MLflow集成

import mlflow
import mlflow.pytorch
from mlflow.models import ModelSignature
from mlflow.types import Schema, TensorSpec
import numpy as np

class LLMModelRegistry:
    def __init__(self, experiment_name: str):
        mlflow.set_experiment(experiment_name)
        self.experiment_name = experiment_name

    def register_model(
        self,
        model,
        model_name: str,
        version: str,
        metrics: Dict[str, float],
        params: Dict[str, Any],
        artifacts: List[str]
    ):
        with mlflow.start_run():
            # 记录参数
            mlflow.log_params(params)

            # 记录指标
            mlflow.log_metrics(metrics)

            # 定义模型签名
            signature = ModelSignature(
                inputs=Schema([
                    TensorSpec(np.dtype("int64"), (-1, -1), "input_ids"),
                    TensorSpec(np.dtype("int64"), (-1, -1), "attention_mask")
                ]),
                outputs=Schema([
                    TensorSpec(np.dtype("float32"), (-1, -1, 32000), "logits")
                ])
            )

            # 记录模型
            mlflow.pytorch.log_model(
                model,
                "model",
                signature=signature,
                registered_model_name=model_name
            )

            # 记录额外的artifacts
            for artifact in artifacts:
                mlflow.log_artifact(artifact)

    def load_model(self, model_name: str, version: str = "latest"):
        model_uri = f"models:/{model_name}/{version}"
        return mlflow.pytorch.load_model(model_uri)

自定义注册表

import sqlite3
import json
from typing import List, Optional

class CustomModelRegistry:
    def __init__(self, db_path: str = "models.db"):
        self.db_path = db_path
        self._init_db()

    def _init_db(self):
        conn = sqlite3.connect(self.db_path)
        conn.execute("""
            CREATE TABLE IF NOT EXISTS models (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                name TEXT NOT NULL,
                version TEXT NOT NULL,
                stage TEXT DEFAULT 'development',
                metrics TEXT,
                params TEXT,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                UNIQUE(name, version)
            )
        """)
        conn.commit()
        conn.close()

    def register(
        self,
        name: str,
        version: str,
        stage: str = "development",
        metrics: Dict = None,
        params: Dict = None
    ):
        conn = sqlite3.connect(self.db_path)
        conn.execute("""
            INSERT OR REPLACE INTO models (name, version, stage, metrics, params)
            VALUES (?, ?, ?, ?, ?)
        """, (name, version, stage, json.dumps(metrics), json.dumps(params)))
        conn.commit()
        conn.close()

    def transition_stage(self, name: str, version: str, new_stage: str):
        valid_transitions = {
            "development": ["staging"],
            "staging": ["production", "development"],
            "production": ["archived"],
            "archived": []
        }

        current_stage = self._get_stage(name, version)
        if new_stage not in valid_transitions.get(current_stage, []):
            raise ValueError(f"Cannot transition from {current_stage} to {new_stage}")

        conn = sqlite3.connect(self.db_path)
        conn.execute("""
            UPDATE models SET stage = ?, updated_at = CURRENT_TIMESTAMP
            WHERE name = ? AND version = ?
        """, (new_stage, name, version))
        conn.commit()
        conn.close()

    def get_production_model(self, name: str) -> Optional[Dict]:
        conn = sqlite3.connect(self.db_path)
        cursor = conn.execute("""
            SELECT * FROM models WHERE name = ? AND stage = 'production'
            ORDER BY created_at DESC LIMIT 1
        """, (name,))
        row = cursor.fetchone()
        conn.close()

        if row:
            return {
                "id": row[0],
                "name": row[1],
                "version": row[2],
                "stage": row[3],
                "metrics": json.loads(row[4]) if row[4] else {},
                "params": json.loads(row[5]) if row[5] else {}
            }
        return None

模型血缘追踪

class ModelLineage:
    def __init__(self):
        self.lineage = {}

    def record_training(
        self,
        model_id: str,
        dataset_id: str,
        config: Dict[str, Any],
        parent_model_id: str = None
    ):
        self.lineage[model_id] = {
            "dataset": dataset_id,
            "config": config,
            "parent": parent_model_id,
            "timestamp": datetime.now().isoformat()
        }

    def get_lineage(self, model_id: str) -> List[Dict]:
        chain = []
        current = model_id
        while current and current in self.lineage:
            chain.append({
                "model_id": current,
                **self.lineage[current]
            })
            current = self.lineage[current].get("parent")
        return chain

    def get_impact_analysis(self, dataset_id: str) -> List[str]:
        affected = []
        for model_id, info in self.lineage.items():
            if info["dataset"] == dataset_id:
                affected.append(model_id)
        return affected

最佳实践

  1. 语义化版本:使用Major.Minor.Patch格式,清晰表达变更性质
  2. 元数据完整:记录训练数据、超参数、评估结果等关键信息
  3. 阶段管理:建立development → staging → production的流转流程
  4. 自动化集成:将版本管理集成到CI/CD流程中
  5. 定期审计:清理过期模型,保持注册表的整洁

模型版本管理是MLOps的基石,建立规范的版本管理体系能显著提升团队的协作效率和模型质量。