← 返回首页
🧠

Weights & Biases:使用W&B追踪LLM实验

📂 llm ⏱ 3 min 600 words

Weights & Biases:使用W&B追踪LLM实验

W&B简介

Weights & Biases (W&B) 是一个强大的机器学习实验追踪平台,提供实时可视化、超参数搜索、模型版本管理等功能。对于大语言模型开发,W&B可以帮助团队高效地追踪训练过程、比较实验结果。

基础配置

安装与登录

# 安装W&B
pip install wandb

# 登录(首次使用)
wandb login

# 或使用环境变量
export WANDB_API_KEY=your_api_key_here

初始化项目

import wandb

# 初始化项目
wandb.init(
    project="llm-training",
    name="experiment-001",
    config={
        "model_name": "gpt2-medium",
        "dataset": "custom-chinese-corpus",
        "batch_size": 8,
        "learning_rate": 5e-5,
        "num_epochs": 3,
        "warmup_steps": 100,
        "max_seq_length": 512,
    }
)

# 获取配置
config = wandb.config

训练过程追踪

完整的训练追踪示例

import torch
import wandb
from torch.utils.data import DataLoader

class WandbTrainer:
    """使用W&B追踪的训练器"""
    
    def __init__(self, model, config):
        self.model = model
        self.config = config
        
        # 初始化W&B
        wandb.init(
            project=config.project_name,
            name=config.experiment_name,
            config=vars(config)
        )
        
        # 监控模型参数
        wandb.watch(model, log_freq=100)
    
    def train_epoch(self, train_loader: DataLoader, epoch: int):
        """训练一个epoch"""
        self.model.train()
        total_loss = 0
        num_batches = 0
        
        for batch_idx, batch in enumerate(train_loader):
            # 前向传播
            outputs = self.model(**batch)
            loss = outputs.loss
            
            # 反向传播
            loss.backward()
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                self.config.max_grad_norm
            )
            
            # 记录损失
            total_loss += loss.item()
            num_batches += 1
            
            # 每N步记录到W&B
            if batch_idx % self.config.log_interval == 0:
                global_step = epoch * len(train_loader) + batch_idx
                
                wandb.log({
                    "train/loss": loss.item(),
                    "train/epoch": epoch,
                    "train/step": global_step,
                    "train/learning_rate": self.optimizer.param_groups[0]['lr'],
                }, step=global_step)
        
        avg_loss = total_loss / num_batches
        wandb.log({"train/epoch_loss": avg_loss}, step=epoch)
        
        return avg_loss
    
    def validate(self, val_loader: DataLoader, epoch: int):
        """验证"""
        self.model.eval()
        total_loss = 0
        total_correct = 0
        total_samples = 0
        
        with torch.no_grad():
            for batch in val_loader:
                outputs = self.model(**batch)
                loss = outputs.loss
                
                total_loss += loss.item()
                
                # 计算准确率(如果适用)
                if hasattr(outputs, 'logits'):
                    predictions = torch.argmax(outputs.logits, dim=-1)
                    correct = (predictions == batch['labels']).sum().item()
                    total_correct += correct
                    total_samples += batch['labels'].numel()
        
        avg_loss = total_loss / len(val_loader)
        accuracy = total_correct / total_samples if total_samples > 0 else 0
        
        # 记录验证指标
        wandb.log({
            "val/loss": avg_loss,
            "val/accuracy": accuracy,
            "val/epoch": epoch,
        }, step=epoch)
        
        return avg_loss, accuracy
    
    def save_model_checkpoint(self, path: str, epoch: int):
        """保存模型检查点"""
        # 保存本地
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'epoch': epoch,
        }, path)
        
        # 上传到W&B
        artifact = wandb.Artifact(
            name=f"model-epoch-{epoch}",
            type="model",
            description=f"Model checkpoint at epoch {epoch}"
        )
        artifact.add_file(path)
        wandb.log_artifact(artifact)
    
    def finish(self):
        """结束W&B运行"""
        wandb.finish()

超参数搜索

使用W&B Sweep

import wandb

def train_with_sweep():
    """使用Sweep进行超参数搜索"""
    
    def train():
        # 初始化运行
        run = wandb.init()
        config = wandb.config
        
        # 构建模型
        model = build_model(config)
        
        # 训练循环
        for epoch in range(config.num_epochs):
            train_loss = train_epoch(model, config)
            val_loss, val_acc = validate(model, config)
            
            # 记录指标
            wandb.log({
                "epoch": epoch,
                "train_loss": train_loss,
                "val_loss": val_loss,
                "val_accuracy": val_acc,
            })
        
        # 记录最终指标
        wandb.log({"final_val_accuracy": val_acc})
    
    # 定义搜索空间
    sweep_config = {
        "method": "bayes",
        "metric": {
            "name": "val_accuracy",
            "goal": "maximize"
        },
        "parameters": {
            "learning_rate": {
                "distribution": "log_uniform_values",
                "min": 1e-6,
                "max": 1e-3
            },
            "batch_size": {
                "values": [8, 16, 32, 64]
            },
            "num_layers": {
                "values": [6, 8, 10, 12]
            },
            "hidden_size": {
                "values": [256, 512, 768, 1024]
            },
            "warmup_steps": {
                "distribution": "int_uniform",
                "min": 0,
                "max": 500
            },
            "weight_decay": {
                "values": [0.0, 0.01, 0.1]
            }
        }
    }
    
    # 创建并启动Sweep
    sweep_id = wandb.sweep(sweep_config, project="llm-hyperparameter-sweep")
    wandb.agent(sweep_id, function=train, count=50)

模型版本管理

使用W&B Artifacts

class ModelVersionManager:
    """使用W&B Artifacts管理模型版本"""
    
    def __init__(self):
        pass
    
    def register_model(self, model_path: str, metadata: dict):
        """注册新模型版本"""
        # 创建模型Artifact
        model_artifact = wandb.Artifact(
            name="llm-model",
            type="model",
            metadata=metadata
        )
        
        # 添加模型文件
        model_artifact.add_dir(model_path)
        
        # 注册到W&B
        wandb.log_artifact(model_artifact)
        
        print(f"Model registered: {model_artifact.name}")
        return model_artifact
    
    def load_model(self, version: str = "latest"):
        """加载模型版本"""
        api = wandb.Api()
        
        # 获取最新的模型Artifact
        artifact = api.artifact(
            f"your-entity/llm-training/llm-model:{version}"
        )
        
        # 下载模型
        model_dir = artifact.download()
        
        return model_dir
    
    def compare_versions(self, version_a: str, version_b: str):
        """比较两个模型版本"""
        api = wandb.Api()
        
        artifact_a = api.artifact(
            f"your-entity/llm-training/llm-model:{version_a}"
        )
        artifact_b = api.artifact(
            f"your-entity/llm-training/llm-model:{version_b}"
        )
        
        comparison = {
            "version_a": {
                "name": artifact_a.name,
                "created_at": artifact_a.created_at,
                "metadata": artifact_a.metadata
            },
            "version_b": {
                "name": artifact_b.name,
                "created_at": artifact_b.created_at,
                "metadata": artifact_b.metadata
            }
        }
        
        return comparison

报告生成

def create_wandb_report(experiment_ids: list):
    """创建W&B报告"""
    
    # 创建报告
    report = wandb.Report(
        title="LLM Training Experiments",
        description="Summary of all training experiments"
    )
    
    # 添加指标面板
    for exp_id in experiment_ids:
        run = wandb.init(project="llm-training", id=exp_id, resume="allow")
        
        # 添加损失曲线
        report.add_panel(
            wandb.visualizations.line_series(
                x_key="train/step",
                y_keys=["train/loss", "val/loss"],
                title=f"Loss Curves - {exp_id}"
            )
        )
        
        # 添加超参数表格
        report.add_panel(
            wandb.visualizations.table(
                data=[[k, v] for k, v in wandb.config.items()],
                columns=["Parameter", "Value"],
                title=f"Config - {exp_id}"
            )
        )
        
        wandb.finish()
    
    # 保存报告
    report.save()
    
    return report

最佳实践

# 1. 使用有意义的运行名称
wandb.init(
    name=f"{model_name}_{dataset}_{timestamp}",
    tags=["baseline", "gpt2", "chinese"]
)

# 2. 记录所有重要的超参数
wandb.config.update({
    "model": model_config,
    "training": training_config,
    "data": data_config,
}, allow_val_change=True)

# 3. 定期保存检查点
if step % save_every == 0:
    wandb.log_artifact(checkpoint_path, name=f"checkpoint-step-{step}")

# 4. 使用系统指标监控硬件
wandb.log({
    "system/gpu_memory": torch.cuda.memory_allocated(),
    "system/gpu_utilization": get_gpu_utilization(),
})

# 5. 记录样本预测
wandb.log({
    "predictions": wandb.Table(
        columns=["input", "prediction", "target"],
        data=[[input_text, pred, target] for input_text, pred, target in samples]
    )
})

W&B是LLM开发中不可或缺的工具,它提供了从实验追踪到团队协作的完整解决方案,帮助研究人员和工程师高效地管理和优化模型训练过程。