LLM训练管道:自动化模型训练流程
--- title: "LLM训练管道:自动化模型训练流程" description: "构建端到端的LLM训练管道,从数据准备到模型评估的完整流程" tags: ["LLM", "训练管道", "模型训练", "自动化", "MLOps"] category: "llm" icon: "🏋️"
LLM训练管道:自动化模型训练流程
训练管道概述
LLM训练管道将数据预处理、模型训练、评估和部署等环节自动化,确保训练过程可重复、可追踪、可扩展。
训练配置管理
1. 统一配置系统
import json
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional
@dataclass
class TrainingConfig:
# 模型配置
model_name: str = "meta-llama/Llama-2-7b-hf"
model_max_length: int = 2048
# 训练配置
num_epochs: int = 3
batch_size: int = 8
learning_rate: float = 2e-5
weight_decay: float = 0.01
warmup_ratio: float = 0.1
# 数据配置
train_data: str = "data/train.json"
val_data: str = "data/val.json"
# 输出配置
output_dir: str = "output/checkpoints"
logging_dir: str = "output/logs"
# 其他
seed: int = 42
fp16: bool = True
gradient_accumulation_steps: int = 2
def save(self, path):
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
json.dump(asdict(self), f, indent=2)
@classmethod
def load(cls, path):
with open(path) as f:
return cls(**json.load(f))
# 使用示例
config = TrainingConfig(
model_name="meta-llama/Llama-2-7b-hf",
num_epochs=5,
batch_size=16
)
config.save("configs/training_config.json")
2. 超参数搜索
import itertools
class HyperparameterSearch:
def __init__(self, base_config):
self.base_config = base_config
self.param_grid = {}
def add_param(self, name, values):
self.param_grid[name] = values
return self
def generate_configs(self):
keys = list(self.param_grid.keys())
values = list(self.param_grid.values())
configs = []
for combo in itertools.product(*values):
config_dict = asdict(self.base_config)
for key, value in zip(keys, combo):
config_dict[key] = value
configs.append(TrainingConfig(**config_dict))
return configs
# 使用示例
base = TrainingConfig()
search = HyperparameterSearch(base)
search.add_param("learning_rate", [1e-5, 2e-5, 5e-5])
search.add_param("batch_size", [4, 8, 16])
configs = search.generate_configs()
print(f"生成 {len(configs)} 个配置")
训练流程
1. 训练监控器
import time
from datetime import datetime
class TrainingMonitor:
def __init__(self, log_dir="output/logs"):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
self.metrics = []
self.start_time = None
def start(self):
self.start_time = time.time()
print(f"训练开始: {datetime.now()}")
def log_metrics(self, step, metrics):
metrics["step"] = step
metrics["timestamp"] = datetime.now().isoformat()
metrics["elapsed"] = time.time() - self.start_time
self.metrics.append(metrics)
# 实时打印
metric_str = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}"
for k, v in metrics.items() if k != "timestamp")
print(f"Step {step}: {metric_str}")
def save_metrics(self):
log_file = self.log_dir / f"metrics_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(log_file, "w") as f:
json.dump(self.metrics, f, indent=2)
print(f"指标已保存: {log_file}")
def get_best_step(self, metric="val_loss", mode="min"):
if not self.metrics:
return None
if mode == "min":
best = min(self.metrics, key=lambda x: x.get(metric, float("inf")))
else:
best = max(self.metrics, key=lambda x: x.get(metric, 0))
return best
2. 检查点管理
class CheckpointManager:
def __init__(self, output_dir, max_checkpoints=3):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints
self.checkpoints = []
def save_checkpoint(self, model, step, metrics):
ckpt_dir = self.output_dir / f"checkpoint-{step}"
ckpt_dir.mkdir(exist_ok=True)
# 保存模型
model.save_pretrained(ckpt_dir)
# 保存元数据
metadata = {
"step": step,
"metrics": metrics,
"timestamp": datetime.now().isoformat()
}
with open(ckpt_dir / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
self.checkpoints.append({"step": step, "path": ckpt_dir, "metrics": metrics})
self._cleanup_old_checkpoints()
print(f"检查点已保存: {ckpt_dir}")
return ckpt_dir
def _cleanup_old_checkpoints(self):
if len(self.checkpoints) > self.max_checkpoints:
old = self.checkpoints.pop(0)
import shutil
shutil.rmtree(old["path"])
print(f"删除旧检查点: {old['path']}")
def load_best_checkpoint(self, metric="val_loss"):
best = min(self.checkpoints, key=lambda x: x["metrics"].get(metric, float("inf")))
return best["path"]
3. 自动化训练管道
class TrainingPipeline:
def __init__(self, config):
self.config = config
self.monitor = TrainingMonitor(config.logging_dir)
self.ckpt_manager = CheckpointManager(config.output_dir)
def prepare_data(self):
"""准备训练数据"""
print("加载训练数据...")
# 实际实现中会加载tokenizer和数据集
return {"train": None, "val": None}
def setup_model(self):
"""设置模型"""
print(f"加载模型: {self.config.model_name}")
# 实际实现中会加载模型
return None
def train(self):
"""执行训练"""
self.monitor.start()
# 准备数据
datasets = self.prepare_data()
# 设置模型
model = self.setup_model()
# 训练循环
global_step = 0
for epoch in range(self.config.num_epochs):
print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")
# 模拟训练步骤
for step in range(100): # 简化示例
global_step += 1
# 模拟指标
metrics = {
"train_loss": 2.5 - epoch * 0.3 - step * 0.001,
"val_loss": 2.3 - epoch * 0.2 - step * 0.0008,
"learning_rate": 2e-5 * (1 - step / 100)
}
self.monitor.log_metrics(global_step, metrics)
# 保存检查点
if global_step % 50 == 0:
self.ckpt_manager.save_checkpoint(model, global_step, metrics)
# 保存最终指标
self.monitor.save_metrics()
print("\n训练完成!")
最佳实践
- 配置管理:使用版本控制管理训练配置
- 实验追踪:记录每次训练的完整指标
- 资源管理:监控GPU使用率和内存占用
- 失败恢复:支持从检查点恢复训练
总结
训练管道是LLM开发的核心组件。通过自动化的训练流程、完善的监控机制和灵活的配置管理,我们可以高效地训练出高质量的语言模型。