模型版本管理:构建LLM全生命周期管理体系
--- title: "模型版本管理:构建LLM全生命周期管理体系" description: "掌握模型版本控制和注册表的最佳实践,实现模型从训练到上线的全流程管理" tags: ["模型版本管理", "MLOps", "模型注册表"] category: "llm" icon: "🧠"
模型版本管理:构建LLM全生命周期管理体系
为什么需要模型版本管理
大语言模型的开发和部署是一个复杂的过程,涉及数据准备、模型训练、评估、部署等多个环节。没有系统的版本管理,团队将面临以下问题:
- 难以复现:无法回溯到特定版本的模型
- 协作困难:多人同时修改导致混乱
- 追踪缺失:不知道哪个模型对应哪个数据集
- 回滚困难:线上问题无法快速恢复到稳定版本
模型版本管理的核心目标是为每个模型建立完整的"身份证",记录其来源、配置、性能指标和部署历史。
版本控制系统设计
基础版本管理
import hashlib
import json
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import Dict, Any, Optional
@dataclass
class ModelVersion:
version_id: str
model_name: str
version: str
created_at: str
artifacts: Dict[str, str] # 文件路径到哈希的映射
metadata: Dict[str, Any]
parent_version: Optional[str] = None
class ModelVersionControl:
def __init__(self, storage_path: str):
self.storage_path = storage_path
self.versions = {}
def create_version(
self,
model_name: str,
version: str,
artifacts: Dict[str, str],
metadata: Dict[str, Any]
) -> ModelVersion:
version_id = self._generate_version_id(model_name, version)
model_version = ModelVersion(
version_id=version_id,
model_name=model_name,
version=version,
created_at=datetime.now().isoformat(),
artifacts=artifacts,
metadata=metadata,
parent_version=self._get_latest_version(model_name)
)
self.versions[version_id] = model_version
self._save_version(model_version)
return model_version
def _generate_version_id(self, model_name: str, version: str) -> str:
content = f"{model_name}:{version}:{datetime.now().isoformat()}"
return hashlib.sha256(content.encode()).hexdigest()[:12]
Git式版本追踪
class GitStyleVersioning:
def __init__(self):
self.commits = []
self.branches = {"main": []}
def commit(self, model_version: ModelVersion, message: str, branch: str = "main"):
commit = {
"id": model_version.version_id,
"parent": self._get_head(branch),
"message": message,
"model_version": model_version,
"timestamp": datetime.now().isoformat()
}
self.commits.append(commit)
self.branches[branch].append(commit["id"])
def log(self, branch: str = "main", limit: int = 10):
history = self.branches[branch][-limit:]
return [self._get_commit(commit_id) for commit_id in reversed(history)]
def diff(self, version_id_1: str, version_id_2: str):
v1 = self._get_commit(version_id_1)["model_version"]
v2 = self._get_commit(version_id_2)["model_version"]
diff = {
"metadata_changes": self._diff_dict(v1.metadata, v2.metadata),
"artifacts_changes": self._diff_dict(v1.artifacts, v2.artifacts)
}
return diff
模型注册表(Model Registry)
MLflow集成
import mlflow
import mlflow.pytorch
from mlflow.models import ModelSignature
from mlflow.types import Schema, TensorSpec
import numpy as np
class LLMModelRegistry:
def __init__(self, experiment_name: str):
mlflow.set_experiment(experiment_name)
self.experiment_name = experiment_name
def register_model(
self,
model,
model_name: str,
version: str,
metrics: Dict[str, float],
params: Dict[str, Any],
artifacts: List[str]
):
with mlflow.start_run():
# 记录参数
mlflow.log_params(params)
# 记录指标
mlflow.log_metrics(metrics)
# 定义模型签名
signature = ModelSignature(
inputs=Schema([
TensorSpec(np.dtype("int64"), (-1, -1), "input_ids"),
TensorSpec(np.dtype("int64"), (-1, -1), "attention_mask")
]),
outputs=Schema([
TensorSpec(np.dtype("float32"), (-1, -1, 32000), "logits")
])
)
# 记录模型
mlflow.pytorch.log_model(
model,
"model",
signature=signature,
registered_model_name=model_name
)
# 记录额外的artifacts
for artifact in artifacts:
mlflow.log_artifact(artifact)
def load_model(self, model_name: str, version: str = "latest"):
model_uri = f"models:/{model_name}/{version}"
return mlflow.pytorch.load_model(model_uri)
自定义注册表
import sqlite3
import json
from typing import List, Optional
class CustomModelRegistry:
def __init__(self, db_path: str = "models.db"):
self.db_path = db_path
self._init_db()
def _init_db(self):
conn = sqlite3.connect(self.db_path)
conn.execute("""
CREATE TABLE IF NOT EXISTS models (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
version TEXT NOT NULL,
stage TEXT DEFAULT 'development',
metrics TEXT,
params TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(name, version)
)
""")
conn.commit()
conn.close()
def register(
self,
name: str,
version: str,
stage: str = "development",
metrics: Dict = None,
params: Dict = None
):
conn = sqlite3.connect(self.db_path)
conn.execute("""
INSERT OR REPLACE INTO models (name, version, stage, metrics, params)
VALUES (?, ?, ?, ?, ?)
""", (name, version, stage, json.dumps(metrics), json.dumps(params)))
conn.commit()
conn.close()
def transition_stage(self, name: str, version: str, new_stage: str):
valid_transitions = {
"development": ["staging"],
"staging": ["production", "development"],
"production": ["archived"],
"archived": []
}
current_stage = self._get_stage(name, version)
if new_stage not in valid_transitions.get(current_stage, []):
raise ValueError(f"Cannot transition from {current_stage} to {new_stage}")
conn = sqlite3.connect(self.db_path)
conn.execute("""
UPDATE models SET stage = ?, updated_at = CURRENT_TIMESTAMP
WHERE name = ? AND version = ?
""", (new_stage, name, version))
conn.commit()
conn.close()
def get_production_model(self, name: str) -> Optional[Dict]:
conn = sqlite3.connect(self.db_path)
cursor = conn.execute("""
SELECT * FROM models WHERE name = ? AND stage = 'production'
ORDER BY created_at DESC LIMIT 1
""", (name,))
row = cursor.fetchone()
conn.close()
if row:
return {
"id": row[0],
"name": row[1],
"version": row[2],
"stage": row[3],
"metrics": json.loads(row[4]) if row[4] else {},
"params": json.loads(row[5]) if row[5] else {}
}
return None
模型血缘追踪
class ModelLineage:
def __init__(self):
self.lineage = {}
def record_training(
self,
model_id: str,
dataset_id: str,
config: Dict[str, Any],
parent_model_id: str = None
):
self.lineage[model_id] = {
"dataset": dataset_id,
"config": config,
"parent": parent_model_id,
"timestamp": datetime.now().isoformat()
}
def get_lineage(self, model_id: str) -> List[Dict]:
chain = []
current = model_id
while current and current in self.lineage:
chain.append({
"model_id": current,
**self.lineage[current]
})
current = self.lineage[current].get("parent")
return chain
def get_impact_analysis(self, dataset_id: str) -> List[str]:
affected = []
for model_id, info in self.lineage.items():
if info["dataset"] == dataset_id:
affected.append(model_id)
return affected
最佳实践
- 语义化版本:使用Major.Minor.Patch格式,清晰表达变更性质
- 元数据完整:记录训练数据、超参数、评估结果等关键信息
- 阶段管理:建立development → staging → production的流转流程
- 自动化集成:将版本管理集成到CI/CD流程中
- 定期审计:清理过期模型,保持注册表的整洁
模型版本管理是MLOps的基石,建立规范的版本管理体系能显著提升团队的协作效率和模型质量。