Weights & Biases:使用W&B追踪LLM实验
Weights & Biases:使用W&B追踪LLM实验
W&B简介
Weights & Biases (W&B) 是一个强大的机器学习实验追踪平台,提供实时可视化、超参数搜索、模型版本管理等功能。对于大语言模型开发,W&B可以帮助团队高效地追踪训练过程、比较实验结果。
基础配置
安装与登录
# 安装W&B
pip install wandb
# 登录(首次使用)
wandb login
# 或使用环境变量
export WANDB_API_KEY=your_api_key_here
初始化项目
import wandb
# 初始化项目
wandb.init(
project="llm-training",
name="experiment-001",
config={
"model_name": "gpt2-medium",
"dataset": "custom-chinese-corpus",
"batch_size": 8,
"learning_rate": 5e-5,
"num_epochs": 3,
"warmup_steps": 100,
"max_seq_length": 512,
}
)
# 获取配置
config = wandb.config
训练过程追踪
完整的训练追踪示例
import torch
import wandb
from torch.utils.data import DataLoader
class WandbTrainer:
"""使用W&B追踪的训练器"""
def __init__(self, model, config):
self.model = model
self.config = config
# 初始化W&B
wandb.init(
project=config.project_name,
name=config.experiment_name,
config=vars(config)
)
# 监控模型参数
wandb.watch(model, log_freq=100)
def train_epoch(self, train_loader: DataLoader, epoch: int):
"""训练一个epoch"""
self.model.train()
total_loss = 0
num_batches = 0
for batch_idx, batch in enumerate(train_loader):
# 前向传播
outputs = self.model(**batch)
loss = outputs.loss
# 反向传播
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.max_grad_norm
)
# 记录损失
total_loss += loss.item()
num_batches += 1
# 每N步记录到W&B
if batch_idx % self.config.log_interval == 0:
global_step = epoch * len(train_loader) + batch_idx
wandb.log({
"train/loss": loss.item(),
"train/epoch": epoch,
"train/step": global_step,
"train/learning_rate": self.optimizer.param_groups[0]['lr'],
}, step=global_step)
avg_loss = total_loss / num_batches
wandb.log({"train/epoch_loss": avg_loss}, step=epoch)
return avg_loss
def validate(self, val_loader: DataLoader, epoch: int):
"""验证"""
self.model.eval()
total_loss = 0
total_correct = 0
total_samples = 0
with torch.no_grad():
for batch in val_loader:
outputs = self.model(**batch)
loss = outputs.loss
total_loss += loss.item()
# 计算准确率(如果适用)
if hasattr(outputs, 'logits'):
predictions = torch.argmax(outputs.logits, dim=-1)
correct = (predictions == batch['labels']).sum().item()
total_correct += correct
total_samples += batch['labels'].numel()
avg_loss = total_loss / len(val_loader)
accuracy = total_correct / total_samples if total_samples > 0 else 0
# 记录验证指标
wandb.log({
"val/loss": avg_loss,
"val/accuracy": accuracy,
"val/epoch": epoch,
}, step=epoch)
return avg_loss, accuracy
def save_model_checkpoint(self, path: str, epoch: int):
"""保存模型检查点"""
# 保存本地
torch.save({
'model_state_dict': self.model.state_dict(),
'epoch': epoch,
}, path)
# 上传到W&B
artifact = wandb.Artifact(
name=f"model-epoch-{epoch}",
type="model",
description=f"Model checkpoint at epoch {epoch}"
)
artifact.add_file(path)
wandb.log_artifact(artifact)
def finish(self):
"""结束W&B运行"""
wandb.finish()
超参数搜索
使用W&B Sweep
import wandb
def train_with_sweep():
"""使用Sweep进行超参数搜索"""
def train():
# 初始化运行
run = wandb.init()
config = wandb.config
# 构建模型
model = build_model(config)
# 训练循环
for epoch in range(config.num_epochs):
train_loss = train_epoch(model, config)
val_loss, val_acc = validate(model, config)
# 记录指标
wandb.log({
"epoch": epoch,
"train_loss": train_loss,
"val_loss": val_loss,
"val_accuracy": val_acc,
})
# 记录最终指标
wandb.log({"final_val_accuracy": val_acc})
# 定义搜索空间
sweep_config = {
"method": "bayes",
"metric": {
"name": "val_accuracy",
"goal": "maximize"
},
"parameters": {
"learning_rate": {
"distribution": "log_uniform_values",
"min": 1e-6,
"max": 1e-3
},
"batch_size": {
"values": [8, 16, 32, 64]
},
"num_layers": {
"values": [6, 8, 10, 12]
},
"hidden_size": {
"values": [256, 512, 768, 1024]
},
"warmup_steps": {
"distribution": "int_uniform",
"min": 0,
"max": 500
},
"weight_decay": {
"values": [0.0, 0.01, 0.1]
}
}
}
# 创建并启动Sweep
sweep_id = wandb.sweep(sweep_config, project="llm-hyperparameter-sweep")
wandb.agent(sweep_id, function=train, count=50)
模型版本管理
使用W&B Artifacts
class ModelVersionManager:
"""使用W&B Artifacts管理模型版本"""
def __init__(self):
pass
def register_model(self, model_path: str, metadata: dict):
"""注册新模型版本"""
# 创建模型Artifact
model_artifact = wandb.Artifact(
name="llm-model",
type="model",
metadata=metadata
)
# 添加模型文件
model_artifact.add_dir(model_path)
# 注册到W&B
wandb.log_artifact(model_artifact)
print(f"Model registered: {model_artifact.name}")
return model_artifact
def load_model(self, version: str = "latest"):
"""加载模型版本"""
api = wandb.Api()
# 获取最新的模型Artifact
artifact = api.artifact(
f"your-entity/llm-training/llm-model:{version}"
)
# 下载模型
model_dir = artifact.download()
return model_dir
def compare_versions(self, version_a: str, version_b: str):
"""比较两个模型版本"""
api = wandb.Api()
artifact_a = api.artifact(
f"your-entity/llm-training/llm-model:{version_a}"
)
artifact_b = api.artifact(
f"your-entity/llm-training/llm-model:{version_b}"
)
comparison = {
"version_a": {
"name": artifact_a.name,
"created_at": artifact_a.created_at,
"metadata": artifact_a.metadata
},
"version_b": {
"name": artifact_b.name,
"created_at": artifact_b.created_at,
"metadata": artifact_b.metadata
}
}
return comparison
报告生成
def create_wandb_report(experiment_ids: list):
"""创建W&B报告"""
# 创建报告
report = wandb.Report(
title="LLM Training Experiments",
description="Summary of all training experiments"
)
# 添加指标面板
for exp_id in experiment_ids:
run = wandb.init(project="llm-training", id=exp_id, resume="allow")
# 添加损失曲线
report.add_panel(
wandb.visualizations.line_series(
x_key="train/step",
y_keys=["train/loss", "val/loss"],
title=f"Loss Curves - {exp_id}"
)
)
# 添加超参数表格
report.add_panel(
wandb.visualizations.table(
data=[[k, v] for k, v in wandb.config.items()],
columns=["Parameter", "Value"],
title=f"Config - {exp_id}"
)
)
wandb.finish()
# 保存报告
report.save()
return report
最佳实践
# 1. 使用有意义的运行名称
wandb.init(
name=f"{model_name}_{dataset}_{timestamp}",
tags=["baseline", "gpt2", "chinese"]
)
# 2. 记录所有重要的超参数
wandb.config.update({
"model": model_config,
"training": training_config,
"data": data_config,
}, allow_val_change=True)
# 3. 定期保存检查点
if step % save_every == 0:
wandb.log_artifact(checkpoint_path, name=f"checkpoint-step-{step}")
# 4. 使用系统指标监控硬件
wandb.log({
"system/gpu_memory": torch.cuda.memory_allocated(),
"system/gpu_utilization": get_gpu_utilization(),
})
# 5. 记录样本预测
wandb.log({
"predictions": wandb.Table(
columns=["input", "prediction", "target"],
data=[[input_text, pred, target] for input_text, pred, target in samples]
)
})
W&B是LLM开发中不可或缺的工具,它提供了从实验追踪到团队协作的完整解决方案,帮助研究人员和工程师高效地管理和优化模型训练过程。