← 返回首页
🧠

模型注册中心

📂 llm ⏱ 2 min 349 words

--- title: "模型注册中心" description: "构建和使用模型注册中心来管理LLM模型的全生命周期,包括版本控制、元数据管理和部署跟踪" tags: ["模型管理", "版本控制", "元数据"] category: "llm" icon: "🧠"

模型注册中心

概述

模型注册中心(Model Registry)是ML Ops的核心组件,用于集中管理模型的版本、元数据、性能指标和部署状态。对于LLM项目,模型注册中心帮助团队追踪从实验到生产的所有模型变体,确保可追溯性和可复现性。

核心功能

模型注册中心应提供以下能力:模型版本管理、元数据存储、性能指标追踪、部署状态管理、血缘关系追踪以及访问控制。

架构设计

数据模型

# models/registry.py
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Optional

class ModelStatus(Enum):
    DEVELOPMENT = "development"
    STAGING = "staging"
    PRODUCTION = "production"
    ARCHIVED = "archived"

@dataclass
class ModelVersion:
    name: str
    version: str
    base_model: str
    status: ModelStatus
    created_at: datetime
    created_by: str
    
    # 性能指标
    metrics: dict = field(default_factory=dict)
    
    # 部署信息
    deployment_config: Optional[dict] = None
    
    # 数据血缘
    training_data_ref: Optional[str] = None
    training_config_ref: Optional[str] = None
    
    # 模型文件位置
    artifact_uri: Optional[str] = None
    
    # 标签和描述
    tags: list = field(default_factory=list)
    description: str = ""

@dataclass
class ModelLineage:
    parent_model: str
    child_model: str
    relationship: str  # "finetune", "distill", "merge"
    created_at: datetime
    metadata: dict = field(default_factory=dict)

注册中心实现

# registry/store.py
import json
from pathlib import Path
from datetime import datetime
from typing import Optional

class ModelRegistryStore:
    def __init__(self, store_path: str):
        self.store_path = Path(store_path)
        self.store_path.mkdir(parents=True, exist_ok=True)
    
    def register_model(self, model: ModelVersion) -> str:
        model_dir = self.store_path / model.name / model.version
        model_dir.mkdir(parents=True, exist_ok=True)
        
        model_data = {
            "name": model.name,
            "version": model.version,
            "base_model": model.base_model,
            "status": model.status.value,
            "created_at": model.created_at.isoformat(),
            "created_by": model.created_by,
            "metrics": model.metrics,
            "tags": model.tags,
            "description": model.description,
            "artifact_uri": model.artifact_uri,
        }
        
        (model_dir / "model.json").write_text(
            json.dumps(model_data, indent=2, ensure_ascii=False)
        )
        
        return f"{model.name}/{model.version}"
    
    def get_model(self, name: str, version: str) -> Optional[ModelVersion]:
        model_dir = self.store_path / name / version
        model_file = model_dir / "model.json"
        
        if not model_file.exists():
            return None
        
        data = json.loads(model_file.read_text())
        return ModelVersion(
            name=data["name"],
            version=data["version"],
            base_model=data["base_model"],
            status=ModelStatus(data["status"]),
            created_at=datetime.fromisoformat(data["created_at"]),
            created_by=data["created_by"],
            metrics=data.get("metrics", {}),
            tags=data.get("tags", []),
            description=data.get("description", ""),
            artifact_uri=data.get("artifact_uri"),
        )
    
    def list_models(self, name: Optional[str] = None) -> list:
        if name:
            model_dir = self.store_path / name
            if not model_dir.exists():
                return []
            return [d.name for d in model_dir.iterdir() if d.is_dir()]
        
        return [d.name for d in self.store_path.iterdir() if d.is_dir()]
    
    def promote_model(self, name: str, version: str, target_status: ModelStatus):
        model = self.get_model(name, version)
        if not model:
            raise ValueError(f"Model {name}/{version} not found")
        
        if target_status == ModelStatus.PRODUCTION:
            if not model.metrics.get("accuracy"):
                raise ValueError("Model must have accuracy metric before promotion")
        
        model.status = target_status
        self.register_model(model)
        print(f"Promoted {name}/{version} to {target_status.value}")

使用示例

from models.registry import ModelVersion, ModelStatus, ModelRegistryStore

registry = ModelRegistryStore("./model_store")

# 注册新模型
model = ModelVersion(
    name="chat-assistant",
    version="2.1.0",
    base_model="qwen-7b",
    status=ModelStatus.DEVELOPMENT,
    created_at=datetime.now(),
    created_by="data-team",
    metrics={"accuracy": 0.91, "latency_ms": 350},
    tags=["production-ready", "chat"],
    training_data_ref="s3://data/instruction-v4",
    artifact_uri="s3://models/chat-assistant/v2.1.0"
)
registry.register_model(model)

# 提升到生产环境
registry.promote_model("chat-assistant", "2.1.0", ModelStatus.PRODUCTION)

与CI/CD集成

模型注册中心应与CI/CD流水线紧密集成。训练完成后自动注册模型,评估通过后自动提升状态,部署时从注册中心拉取指定版本的模型配置。这种集成确保了从训练到部署的全链路自动化。