← 返回首页
🧠

MLflow:使用MLflow管理LLM项目

📂 llm ⏱ 4 min 757 words

MLflow:使用MLflow管理LLM项目

MLflow简介

MLflow是一个开源的机器学习生命周期管理平台,提供实验追踪、模型打包、模型注册和部署等功能。对于大语言模型项目,MLflow可以帮助团队标准化开发流程,提高协作效率。

基础配置

安装与启动

# 安装MLflow
pip install mlflow

# 启动MLflow UI(本地)
mlflow ui --port 5000

# 或使用MLflow跟踪服务器(生产环境)
mlflow server --host 0.0.0.0 --port 5000 --backend-store-uri sqlite:///mlflow.db

基础使用

import mlflow
import mlflow.pytorch
from mlflow.models import infer_signature

# 设置跟踪URI
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("llm-training")

实验追踪

完整的训练追踪

import mlflow
import mlflow.pytorch
import torch
from datetime import datetime

class MLflowLLMTracker:
    """使用MLflow追踪LLM训练"""
    
    def __init__(self, experiment_name: str):
        mlflow.set_experiment(experiment_name)
        self.run = None
    
    def start_run(self, run_name: str = None):
        """开始新的运行"""
        if run_name is None:
            run_name = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        
        self.run = mlflow.start_run(run_name=run_name)
        return self.run
    
    def log_params(self, config: dict):
        """记录参数"""
        mlflow.log_params(config)
    
    def log_metrics(self, metrics: dict, step: int = None):
        """记录指标"""
        for key, value in metrics.items():
            if isinstance(value, (int, float)):
                mlflow.log_metric(key, value, step=step)
    
    def log_model(self, model, model_name: str, sample_input=None):
        """记录模型"""
        if sample_input is not None:
            # 记录带有签名的模型
            with torch.no_grad():
                sample_output = model(**sample_input)
            
            signature = infer_signature(
                sample_input,
                sample_output.logits.numpy() if hasattr(sample_output, 'logits') else sample_output
            )
            
            mlflow.pytorch.log_model(
                model,
                model_name,
                signature=signature,
                registered_model_name=f"{model_name}_registered"
            )
        else:
            mlflow.pytorch.log_model(model, model_name)
    
    def log_artifact(self, local_path: str, artifact_path: str = None):
        """记录工件"""
        mlflow.log_artifact(local_path, artifact_path)
    
    def log_dataset(self, dataset, name: str, description: str = ""):
        """记录数据集"""
        from mlflow.data import Dataset
        
        mlflow.log_input(
            Dataset(
                data_source=f"file://{name}",
                name=name,
                digest=str(hash(str(dataset)))
            ),
            context="training"
        )
    
    def log_figure(self, figure, artifact_file: str):
        """记录图表"""
        mlflow.log_figure(figure, artifact_file)
    
    def end_run(self, status: str = "COMPLETED"):
        """结束运行"""
        mlflow.end_run(status=status)

PyTorch模型集成

import mlflow
import mlflow.pytorch
import torch
import torch.nn as nn

class LLMWithMLflow:
    """集成MLflow的LLM训练"""
    
    def __init__(self, model, config):
        self.model = model
        self.config = config
        
        # 设置实验
        mlflow.set_experiment(config.experiment_name)
    
    def train(self, train_loader, val_loader, num_epochs):
        """训练模型"""
        with mlflow.start_run(run_name=self.config.run_name):
            # 记录参数
            mlflow.log_params({
                "model_name": self.config.model_name,
                "batch_size": self.config.batch_size,
                "learning_rate": self.config.learning_rate,
                "num_epochs": num_epochs,
                "optimizer": self.config.optimizer,
                "max_seq_length": self.config.max_seq_length,
            })
            
            # 训练循环
            best_val_loss = float('inf')
            
            for epoch in range(num_epochs):
                # 训练
                train_loss = self._train_epoch(train_loader, epoch)
                
                # 验证
                val_loss, val_acc = self._validate(val_loader, epoch)
                
                # 记录指标
                mlflow.log_metrics({
                    "train_loss": train_loss,
                    "val_loss": val_loss,
                    "val_accuracy": val_acc,
                    "epoch": epoch
                }, step=epoch)
                
                # 保存最佳模型
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    self._save_model(f"best_model_epoch_{epoch}")
            
            # 记录最终模型
            mlflow.pytorch.log_model(
                self.model,
                "final_model",
                registered_model_name="llm-model"
            )
            
            # 记录训练曲线图
            self._log_training_curves()
    
    def _train_epoch(self, train_loader, epoch):
        """训练一个epoch"""
        self.model.train()
        total_loss = 0
        
        for batch_idx, batch in enumerate(train_loader):
            outputs = self.model(**batch)
            loss = outputs.loss
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            
            total_loss += loss.item()
            
            # 每N步记录
            if batch_idx % 100 == 0:
                mlflow.log_metric(
                    "batch_loss",
                    loss.item(),
                    step=epoch * len(train_loader) + batch_idx
                )
        
        return total_loss / len(train_loader)
    
    def _validate(self, val_loader, epoch):
        """验证"""
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in val_loader:
                outputs = self.model(**batch)
                total_loss += outputs.loss.item()
                
                if hasattr(outputs, 'logits'):
                    predictions = torch.argmax(outputs.logits, dim=-1)
                    correct += (predictions == batch['labels']).sum().item()
                    total += batch['labels'].numel()
        
        avg_loss = total_loss / len(val_loader)
        accuracy = correct / total if total > 0 else 0
        
        return avg_loss, accuracy
    
    def _save_model(self, name: str):
        """保存模型"""
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'config': self.config
        }, f"{name}.pt")
        
        mlflow.log_artifact(f"{name}.pt")
    
    def _log_training_curves(self):
        """记录训练曲线"""
        import matplotlib.pyplot as plt
        
        fig, axes = plt.subplots(1, 2, figsize=(12, 4))
        
        # 这里应该从记录的指标中获取数据
        # 简化示例
        axes[0].set_title("Training Loss")
        axes[1].set_title("Validation Accuracy")
        
        mlflow.log_figure(fig, "training_curves.png")

模型注册与版本管理

class ModelRegistry:
    """MLflow模型注册"""
    
    def __init__(self, tracking_uri: str = "http://localhost:5000"):
        mlflow.set_tracking_uri(tracking_uri)
        self.client = mlflow.tracking.MlflowClient()
    
    def register_model(self, model_uri: str, model_name: str):
        """注册模型"""
        model_version = mlflow.register_model(
            model_uri=model_uri,
            name=model_name
        )
        
        print(f"Model registered: {model_name}, version: {model_version.version}")
        return model_version
    
    def transition_model_stage(self, model_name: str, version: str, stage: str):
        """转换模型阶段"""
        # 可用阶段: "None", "Staging", "Production", "Archived"
        self.client.transition_model_version_stage(
            name=model_name,
            version=version,
            stage=stage
        )
        
        print(f"Model {model_name} v{version} moved to {stage}")
    
    def add_model_description(self, model_name: str, version: str, description: str):
        """添加模型描述"""
        self.client.update_model_version(
            name=model_name,
            version=version,
            description=description
        )
    
    def get_model_versions(self, model_name: str):
        """获取模型所有版本"""
        versions = self.client.search_model_versions(f"name='{model_name}'")
        return versions
    
    def load_model(self, model_name: str, stage: str = "Production"):
        """加载模型"""
        model_uri = f"models:/{model_name}/{stage}"
        model = mlflow.pytorch.load_model(model_uri)
        return model

部署服务

class MLflowDeployer:
    """MLflow模型部署"""
    
    def __init__(self):
        pass
    
    def deploy_flask(self, model_name: str, stage: str = "Production"):
        """部署为Flask服务"""
        import mlflow.pytorch
        from flask import Flask, request, jsonify
        
        # 加载模型
        model_uri = f"models:/{model_name}/{stage}"
        model = mlflow.pytorch.load_model(model_uri)
        model.eval()
        
        app = Flask(__name__)
        
        @app.route('/predict', methods=['POST'])
        def predict():
            data = request.json
            input_text = data.get('text', '')
            
            # 处理输入
            # 这里需要添加tokenizer和预处理逻辑
            
            with torch.no_grad():
                # outputs = model(inputs)
                # prediction = process_output(outputs)
                pass
            
            return jsonify({
                "prediction": "placeholder",
                "model_version": stage
            })
        
        @app.route('/health', methods=['GET'])
        def health():
            return jsonify({"status": "healthy"})
        
        return app
    
    def deploy_docker(self, model_name: str, stage: str, output_dir: str):
        """生成Docker部署配置"""
        dockerfile = f"""
FROM python:3.9-slim

WORKDIR /app

# 安装依赖
COPY requirements.txt .
RUN pip install -r requirements.txt

# 复制模型
COPY model/ ./model/

# 复制应用代码
COPY app.py .

EXPOSE 5001

CMD ["python", "app.py"]
"""
        
        requirements = """
mlflow
torch
transformers
flask
gunicorn
"""
        
        # 写入文件
        with open(f"{output_dir}/Dockerfile", 'w') as f:
            f.write(dockerfile)
        
        with open(f"{output_dir}/requirements.txt", 'w') as f:
            f.write(requirements)
        
        print(f"Docker files generated in {output_dir}")

团队协作

class MLflowCollaboration:
    """MLflow团队协作"""
    
    def __init__(self, tracking_uri: str):
        mlflow.set_tracking_uri(tracking_uri)
        self.client = mlflow.tracking.MlflowClient()
    
    def search_experiments(self, filter_string: str = None):
        """搜索实验"""
        experiments = self.client.search_experiments(
            filter_string=filter_string,
            order_by=["last_update_timestamp DESC"]
        )
        return experiments
    
    def compare_runs(self, experiment_id: str, run_ids: list):
        """比较多个运行"""
        comparison = []
        
        for run_id in run_ids:
            run = self.client.get_run(run_id)
            comparison.append({
                "run_id": run_id,
                "run_name": run.info.run_name,
                "status": run.info.status,
                "start_time": run.info.start_time,
                "metrics": run.data.metrics,
                "params": run.data.params
            })
        
        return comparison
    
    def create_tag(self, run_id: str, key: str, value: str):
        """为运行添加标签"""
        self.client.set_tag(run_id, key, value)
    
    def search_runs(self, experiment_id: str, filter_string: str = None):
        """搜索运行"""
        runs = self.client.search_runs(
            experiment_ids=[experiment_id],
            filter_string=filter_string,
            order_by=["metrics.val_loss ASC"]
        )
        return runs

MLflow为LLM项目提供了完整的生命周期管理解决方案,从实验追踪到模型部署,帮助团队提高开发效率和模型质量。