实验追踪:LLM训练实验的追踪与管理
实验追踪:LLM训练实验的追踪与管理
为什么需要实验追踪
大语言模型训练涉及大量超参数配置、数据版本和模型检查点。有效的实验追踪系统可以帮助研究人员记录所有关键信息,比较不同实验结果,并快速回溯到最佳配置。
实验追踪系统架构
核心组件设计
from dataclasses import dataclass, field
from typing import Dict, List, Any, Optional
from datetime import datetime
import json
from pathlib import Path
import uuid
@dataclass
class ExperimentMetrics:
"""实验指标"""
step: int
train_loss: float
val_loss: Optional[float] = None
learning_rate: float = 0.0
perplexity: Optional[float] = None
throughput: Optional[float] = None # tokens/second
gpu_memory_used: Optional[float] = None
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
@dataclass
class ExperimentConfig:
"""实验配置"""
model_name: str
dataset: str
batch_size: int
learning_rate: float
num_epochs: int
optimizer: str
scheduler: str
weight_decay: float = 0.01
max_seq_length: int = 512
additional_params: Dict[str, Any] = field(default_factory=dict)
class ExperimentTracker:
"""实验追踪器"""
def __init__(self, project_name: str, storage_dir: str = "./experiments"):
self.project_name = project_name
self.storage_dir = Path(storage_dir) / project_name
self.storage_dir.mkdir(parents=True, exist_ok=True)
self.experiment_id = str(uuid.uuid4())[:8]
self.experiment_dir = self.storage_dir / self.experiment_id
self.experiment_dir.mkdir(exist_ok=True)
self.metrics_history: List[ExperimentMetrics] = []
self.artifacts: Dict[str, str] = {}
# 初始化日志文件
self.log_file = self.experiment_dir / "metrics.jsonl"
self.config_file = self.experiment_dir / "config.json"
self.artifacts_file = self.experiment_dir / "artifacts.json"
def init_experiment(self, config: ExperimentConfig):
"""初始化实验"""
with open(self.config_file, 'w') as f:
json.dump({
"experiment_id": self.experiment_id,
"project_name": self.project_name,
"created_at": datetime.now().isoformat(),
"config": {
"model_name": config.model_name,
"dataset": config.dataset,
"batch_size": config.batch_size,
"learning_rate": config.learning_rate,
"num_epochs": config.num_epochs,
"optimizer": config.optimizer,
"scheduler": config.scheduler,
"weight_decay": config.weight_decay,
"max_seq_length": config.max_seq_length,
**config.additional_params
}
}, f, indent=2)
print(f"Experiment initialized: {self.experiment_id}")
return self.experiment_id
def log_metrics(self, metrics: ExperimentMetrics):
"""记录指标"""
self.metrics_history.append(metrics)
with open(self.log_file, 'a') as f:
f.write(json.dumps({
"step": metrics.step,
"train_loss": metrics.train_loss,
"val_loss": metrics.val_loss,
"learning_rate": metrics.learning_rate,
"perplexity": metrics.perplexity,
"throughput": metrics.throughput,
"gpu_memory_used": metrics.gpu_memory_used,
"timestamp": metrics.timestamp
}) + '\n')
def log_artifact(self, artifact_name: str, artifact_path: str):
"""记录工件(模型检查点、数据等)"""
self.artifacts[artifact_name] = artifact_path
with open(self.artifacts_file, 'w') as f:
json.dump(self.artifacts, f, indent=2)
def log_tag(self, tag: str):
"""添加标签"""
tags_file = self.experiment_dir / "tags.txt"
with open(tags_file, 'a') as f:
f.write(f"{tag}\n")
def log_note(self, note: str):
"""添加备注"""
notes_file = self.experiment_dir / "notes.txt"
with open(notes_file, 'a') as f:
f.write(f"[{datetime.now().isoformat()}] {note}\n")
def finish_experiment(self, status: str = "completed"):
"""完成实验"""
summary = {
"experiment_id": self.experiment_id,
"status": status,
"started_at": self.metrics_history[0].timestamp if self.metrics_history else None,
"finished_at": datetime.now().isoformat(),
"total_steps": len(self.metrics_history),
"final_metrics": {
"train_loss": self.metrics_history[-1].train_loss if self.metrics_history else None,
"val_loss": self.metrics_history[-1].val_loss if self.metrics_history else None,
}
}
summary_file = self.experiment_dir / "summary.json"
with open(summary_file, 'w') as f:
json.dump(summary, f, indent=2)
print(f"Experiment finished: {self.experiment_id}")
return summary
训练集成
PyTorch训练集成
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class TrackedTrainer:
"""带追踪功能的训练器"""
def __init__(self, model, config, tracker: ExperimentTracker):
self.model = model
self.config = config
self.tracker = tracker
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
self.current_step = 0
self.best_val_loss = float('inf')
def train_epoch(self, train_loader: DataLoader, epoch: int):
"""训练一个epoch"""
self.model.train()
total_loss = 0
for batch_idx, batch in enumerate(train_loader):
# 前向传播
outputs = self.model(**batch)
loss = outputs.loss
# 反向传播
self.optimizer.zero_grad()
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.max_grad_norm
)
self.optimizer.step()
# 记录指标
total_loss += loss.item()
self.current_step += 1
if self.current_step % self.config.logging_steps == 0:
metrics = ExperimentMetrics(
step=self.current_step,
train_loss=loss.item(),
learning_rate=self.optimizer.param_groups[0]['lr'],
throughput=self._calculate_throughput()
)
self.tracker.log_metrics(metrics)
print(f"Step {self.current_step}, Loss: {loss.item():.4f}")
return total_loss / len(train_loader)
def validate(self, val_loader: DataLoader):
"""验证"""
self.model.eval()
total_loss = 0
with torch.no_grad():
for batch in val_loader:
outputs = self.model(**batch)
total_loss += outputs.loss.item()
avg_val_loss = total_loss / len(val_loader)
# 记录验证指标
if self.metrics_history:
last_metrics = self.tracker.metrics_history[-1]
metrics = ExperimentMetrics(
step=self.current_step,
train_loss=last_metrics.train_loss,
val_loss=avg_val_loss,
learning_rate=self.optimizer.param_groups[0]['lr']
)
self.tracker.log_metrics(metrics)
# 保存最佳模型
if avg_val_loss < self.best_val_loss:
self.best_val_loss = avg_val_loss
self._save_checkpoint("best_model")
return avg_val_loss
def _calculate_throughput(self):
"""计算训练吞吐量"""
# 简化实现
return 1000.0 # tokens/second
def _save_checkpoint(self, name: str):
"""保存检查点"""
checkpoint_path = self.tracker.experiment_dir / "checkpoints" / f"{name}.pt"
checkpoint_path.parent.mkdir(exist_ok=True)
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'step': self.current_step,
'best_val_loss': self.best_val_loss
}, checkpoint_path)
self.tracker.log_artifact(name, str(checkpoint_path))
实验比较与分析
class ExperimentAnalyzer:
"""实验分析器"""
def __init__(self, storage_dir: str):
self.storage_dir = Path(storage_dir)
def load_experiment(self, experiment_id: str):
"""加载实验数据"""
exp_dir = self.storage_dir / experiment_id
# 加载配置
with open(exp_dir / "config.json", 'r') as f:
config = json.load(f)
# 加载指标
metrics = []
with open(exp_dir / "metrics.jsonl", 'r') as f:
for line in f:
metrics.append(json.loads(line))
# 加载摘要
with open(exp_dir / "summary.json", 'r') as f:
summary = json.load(f)
return {
"config": config,
"metrics": metrics,
"summary": summary
}
def compare_experiments(self, experiment_ids: List[str]):
"""比较多个实验"""
comparison = {
"experiments": [],
"best_by_val_loss": None,
"best_by_perplexity": None
}
for exp_id in experiment_ids:
exp_data = self.load_experiment(exp_id)
comparison["experiments"].append({
"id": exp_id,
"config": exp_data["config"]["config"],
"final_val_loss": exp_data["summary"]["final_metrics"]["val_loss"],
"total_steps": exp_data["summary"]["total_steps"]
})
# 找出最佳实验
valid_exps = [
e for e in comparison["experiments"]
if e["final_val_loss"] is not None
]
if valid_exps:
comparison["best_by_val_loss"] = min(
valid_exps, key=lambda x: x["final_val_loss"]
)["id"]
return comparison
def generate_report(self, experiment_id: str):
"""生成实验报告"""
exp_data = self.load_experiment(experiment_id)
report = f"""# 实验报告: {experiment_id}
## 实验配置
- 模型: {exp_data['config']['config']['model_name']}
- 数据集: {exp_data['config']['config']['dataset']}
- 批量大小: {exp_data['config']['config']['batch_size']}
- 学习率: {exp_data['config']['config']['learning_rate']}
- 训练轮数: {exp_data['config']['config']['num_epochs']}
## 训练结果
- 总步数: {exp_data['summary']['total_steps']}
- 最终训练损失: {exp_data['summary']['final_metrics']['train_loss']:.4f}
- 最终验证损失: {exp_data['summary']['final_metrics']['val_loss']:.4f}
## 指标变化
"""
# 添加关键指标变化
for metric in ['train_loss', 'val_loss']:
values = [m[metric] for m in exp_data['metrics'] if m.get(metric) is not None]
if values:
report += f"- {metric}: {values[0]:.4f} -> {values[-1]:.4f}\n"
return report
可视化
import matplotlib.pyplot as plt
from typing import List, Dict
class ExperimentVisualizer:
"""实验可视化"""
def __init__(self, storage_dir: str):
self.storage_dir = Path(storage_dir)
def plot_training_curves(self, experiment_id: str, save_path: Optional[str] = None):
"""绘制训练曲线"""
exp_dir = self.storage_dir / experiment_id
# 加载指标
steps = []
train_losses = []
val_losses = []
learning_rates = []
with open(exp_dir / "metrics.jsonl", 'r') as f:
for line in f:
data = json.loads(line)
steps.append(data['step'])
train_losses.append(data['train_loss'])
val_losses.append(data.get('val_loss'))
learning_rates.append(data.get('learning_rate'))
fig, axes = plt.subplots(2, 1, figsize=(10, 8))
# 损失曲线
axes[0].plot(steps, train_losses, label='Train Loss', marker='o', markersize=2)
if any(v is not None for v in val_losses):
val_steps = [s for s, v in zip(steps, val_losses) if v is not None]
val_values = [v for v in val_losses if v is not None]
axes[0].plot(val_steps, val_values, label='Val Loss', marker='s', markersize=2)
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True)
# 学习率曲线
if any(v is not None for v in learning_rates):
lr_steps = [s for s, v in zip(steps, learning_rates) if v is not None]
lr_values = [v for v in learning_rates if v is not None]
axes[1].plot(lr_steps, lr_values, label='Learning Rate', color='green')
axes[1].set_xlabel('Step')
axes[1].set_ylabel('Learning Rate')
axes[1].set_title('Learning Rate Schedule')
axes[1].legend()
axes[1].grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150)
print(f"Plot saved to {save_path}")
return fig
def compare_experiments(self, experiment_ids: List[str], metric: str = 'val_loss'):
"""比较多个实验的指定指标"""
fig, ax = plt.subplots(figsize=(12, 6))
for exp_id in experiment_ids:
exp_dir = self.storage_dir / exp_id
steps = []
values = []
with open(exp_dir / "metrics.jsonl", 'r') as f:
for line in f:
data = json.loads(line)
if data.get(metric) is not None:
steps.append(data['step'])
values.append(data[metric])
if steps:
ax.plot(steps, values, label=exp_id, marker='o', markersize=3)
ax.set_xlabel('Step')
ax.set_ylabel(metric.replace('_', ' ').title())
ax.set_title(f'Experiment Comparison: {metric}')
ax.legend()
ax.grid(True)
return fig
有效的实验追踪系统是LLM开发流程的关键组成部分,它帮助团队保持组织性,加速迭代,并确保研究的可复现性。