MLflow:使用MLflow管理LLM项目
MLflow:使用MLflow管理LLM项目
MLflow简介
MLflow是一个开源的机器学习生命周期管理平台,提供实验追踪、模型打包、模型注册和部署等功能。对于大语言模型项目,MLflow可以帮助团队标准化开发流程,提高协作效率。
基础配置
安装与启动
# 安装MLflow
pip install mlflow
# 启动MLflow UI(本地)
mlflow ui --port 5000
# 或使用MLflow跟踪服务器(生产环境)
mlflow server --host 0.0.0.0 --port 5000 --backend-store-uri sqlite:///mlflow.db
基础使用
import mlflow
import mlflow.pytorch
from mlflow.models import infer_signature
# 设置跟踪URI
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("llm-training")
实验追踪
完整的训练追踪
import mlflow
import mlflow.pytorch
import torch
from datetime import datetime
class MLflowLLMTracker:
"""使用MLflow追踪LLM训练"""
def __init__(self, experiment_name: str):
mlflow.set_experiment(experiment_name)
self.run = None
def start_run(self, run_name: str = None):
"""开始新的运行"""
if run_name is None:
run_name = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.run = mlflow.start_run(run_name=run_name)
return self.run
def log_params(self, config: dict):
"""记录参数"""
mlflow.log_params(config)
def log_metrics(self, metrics: dict, step: int = None):
"""记录指标"""
for key, value in metrics.items():
if isinstance(value, (int, float)):
mlflow.log_metric(key, value, step=step)
def log_model(self, model, model_name: str, sample_input=None):
"""记录模型"""
if sample_input is not None:
# 记录带有签名的模型
with torch.no_grad():
sample_output = model(**sample_input)
signature = infer_signature(
sample_input,
sample_output.logits.numpy() if hasattr(sample_output, 'logits') else sample_output
)
mlflow.pytorch.log_model(
model,
model_name,
signature=signature,
registered_model_name=f"{model_name}_registered"
)
else:
mlflow.pytorch.log_model(model, model_name)
def log_artifact(self, local_path: str, artifact_path: str = None):
"""记录工件"""
mlflow.log_artifact(local_path, artifact_path)
def log_dataset(self, dataset, name: str, description: str = ""):
"""记录数据集"""
from mlflow.data import Dataset
mlflow.log_input(
Dataset(
data_source=f"file://{name}",
name=name,
digest=str(hash(str(dataset)))
),
context="training"
)
def log_figure(self, figure, artifact_file: str):
"""记录图表"""
mlflow.log_figure(figure, artifact_file)
def end_run(self, status: str = "COMPLETED"):
"""结束运行"""
mlflow.end_run(status=status)
PyTorch模型集成
import mlflow
import mlflow.pytorch
import torch
import torch.nn as nn
class LLMWithMLflow:
"""集成MLflow的LLM训练"""
def __init__(self, model, config):
self.model = model
self.config = config
# 设置实验
mlflow.set_experiment(config.experiment_name)
def train(self, train_loader, val_loader, num_epochs):
"""训练模型"""
with mlflow.start_run(run_name=self.config.run_name):
# 记录参数
mlflow.log_params({
"model_name": self.config.model_name,
"batch_size": self.config.batch_size,
"learning_rate": self.config.learning_rate,
"num_epochs": num_epochs,
"optimizer": self.config.optimizer,
"max_seq_length": self.config.max_seq_length,
})
# 训练循环
best_val_loss = float('inf')
for epoch in range(num_epochs):
# 训练
train_loss = self._train_epoch(train_loader, epoch)
# 验证
val_loss, val_acc = self._validate(val_loader, epoch)
# 记录指标
mlflow.log_metrics({
"train_loss": train_loss,
"val_loss": val_loss,
"val_accuracy": val_acc,
"epoch": epoch
}, step=epoch)
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
self._save_model(f"best_model_epoch_{epoch}")
# 记录最终模型
mlflow.pytorch.log_model(
self.model,
"final_model",
registered_model_name="llm-model"
)
# 记录训练曲线图
self._log_training_curves()
def _train_epoch(self, train_loader, epoch):
"""训练一个epoch"""
self.model.train()
total_loss = 0
for batch_idx, batch in enumerate(train_loader):
outputs = self.model(**batch)
loss = outputs.loss
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
total_loss += loss.item()
# 每N步记录
if batch_idx % 100 == 0:
mlflow.log_metric(
"batch_loss",
loss.item(),
step=epoch * len(train_loader) + batch_idx
)
return total_loss / len(train_loader)
def _validate(self, val_loader, epoch):
"""验证"""
self.model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch in val_loader:
outputs = self.model(**batch)
total_loss += outputs.loss.item()
if hasattr(outputs, 'logits'):
predictions = torch.argmax(outputs.logits, dim=-1)
correct += (predictions == batch['labels']).sum().item()
total += batch['labels'].numel()
avg_loss = total_loss / len(val_loader)
accuracy = correct / total if total > 0 else 0
return avg_loss, accuracy
def _save_model(self, name: str):
"""保存模型"""
torch.save({
'model_state_dict': self.model.state_dict(),
'config': self.config
}, f"{name}.pt")
mlflow.log_artifact(f"{name}.pt")
def _log_training_curves(self):
"""记录训练曲线"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# 这里应该从记录的指标中获取数据
# 简化示例
axes[0].set_title("Training Loss")
axes[1].set_title("Validation Accuracy")
mlflow.log_figure(fig, "training_curves.png")
模型注册与版本管理
class ModelRegistry:
"""MLflow模型注册"""
def __init__(self, tracking_uri: str = "http://localhost:5000"):
mlflow.set_tracking_uri(tracking_uri)
self.client = mlflow.tracking.MlflowClient()
def register_model(self, model_uri: str, model_name: str):
"""注册模型"""
model_version = mlflow.register_model(
model_uri=model_uri,
name=model_name
)
print(f"Model registered: {model_name}, version: {model_version.version}")
return model_version
def transition_model_stage(self, model_name: str, version: str, stage: str):
"""转换模型阶段"""
# 可用阶段: "None", "Staging", "Production", "Archived"
self.client.transition_model_version_stage(
name=model_name,
version=version,
stage=stage
)
print(f"Model {model_name} v{version} moved to {stage}")
def add_model_description(self, model_name: str, version: str, description: str):
"""添加模型描述"""
self.client.update_model_version(
name=model_name,
version=version,
description=description
)
def get_model_versions(self, model_name: str):
"""获取模型所有版本"""
versions = self.client.search_model_versions(f"name='{model_name}'")
return versions
def load_model(self, model_name: str, stage: str = "Production"):
"""加载模型"""
model_uri = f"models:/{model_name}/{stage}"
model = mlflow.pytorch.load_model(model_uri)
return model
部署服务
class MLflowDeployer:
"""MLflow模型部署"""
def __init__(self):
pass
def deploy_flask(self, model_name: str, stage: str = "Production"):
"""部署为Flask服务"""
import mlflow.pytorch
from flask import Flask, request, jsonify
# 加载模型
model_uri = f"models:/{model_name}/{stage}"
model = mlflow.pytorch.load_model(model_uri)
model.eval()
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
data = request.json
input_text = data.get('text', '')
# 处理输入
# 这里需要添加tokenizer和预处理逻辑
with torch.no_grad():
# outputs = model(inputs)
# prediction = process_output(outputs)
pass
return jsonify({
"prediction": "placeholder",
"model_version": stage
})
@app.route('/health', methods=['GET'])
def health():
return jsonify({"status": "healthy"})
return app
def deploy_docker(self, model_name: str, stage: str, output_dir: str):
"""生成Docker部署配置"""
dockerfile = f"""
FROM python:3.9-slim
WORKDIR /app
# 安装依赖
COPY requirements.txt .
RUN pip install -r requirements.txt
# 复制模型
COPY model/ ./model/
# 复制应用代码
COPY app.py .
EXPOSE 5001
CMD ["python", "app.py"]
"""
requirements = """
mlflow
torch
transformers
flask
gunicorn
"""
# 写入文件
with open(f"{output_dir}/Dockerfile", 'w') as f:
f.write(dockerfile)
with open(f"{output_dir}/requirements.txt", 'w') as f:
f.write(requirements)
print(f"Docker files generated in {output_dir}")
团队协作
class MLflowCollaboration:
"""MLflow团队协作"""
def __init__(self, tracking_uri: str):
mlflow.set_tracking_uri(tracking_uri)
self.client = mlflow.tracking.MlflowClient()
def search_experiments(self, filter_string: str = None):
"""搜索实验"""
experiments = self.client.search_experiments(
filter_string=filter_string,
order_by=["last_update_timestamp DESC"]
)
return experiments
def compare_runs(self, experiment_id: str, run_ids: list):
"""比较多个运行"""
comparison = []
for run_id in run_ids:
run = self.client.get_run(run_id)
comparison.append({
"run_id": run_id,
"run_name": run.info.run_name,
"status": run.info.status,
"start_time": run.info.start_time,
"metrics": run.data.metrics,
"params": run.data.params
})
return comparison
def create_tag(self, run_id: str, key: str, value: str):
"""为运行添加标签"""
self.client.set_tag(run_id, key, value)
def search_runs(self, experiment_id: str, filter_string: str = None):
"""搜索运行"""
runs = self.client.search_runs(
experiment_ids=[experiment_id],
filter_string=filter_string,
order_by=["metrics.val_loss ASC"]
)
return runs
MLflow为LLM项目提供了完整的生命周期管理解决方案,从实验追踪到模型部署,帮助团队提高开发效率和模型质量。