← 返回首页
🧠

实验追踪:LLM训练实验的追踪与管理

📂 llm ⏱ 5 min 961 words

实验追踪:LLM训练实验的追踪与管理

为什么需要实验追踪

大语言模型训练涉及大量超参数配置、数据版本和模型检查点。有效的实验追踪系统可以帮助研究人员记录所有关键信息,比较不同实验结果,并快速回溯到最佳配置。

实验追踪系统架构

核心组件设计

from dataclasses import dataclass, field
from typing import Dict, List, Any, Optional
from datetime import datetime
import json
from pathlib import Path
import uuid

@dataclass
class ExperimentMetrics:
    """实验指标"""
    step: int
    train_loss: float
    val_loss: Optional[float] = None
    learning_rate: float = 0.0
    perplexity: Optional[float] = None
    throughput: Optional[float] = None  # tokens/second
    gpu_memory_used: Optional[float] = None
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())

@dataclass
class ExperimentConfig:
    """实验配置"""
    model_name: str
    dataset: str
    batch_size: int
    learning_rate: float
    num_epochs: int
    optimizer: str
    scheduler: str
    weight_decay: float = 0.01
    max_seq_length: int = 512
    additional_params: Dict[str, Any] = field(default_factory=dict)

class ExperimentTracker:
    """实验追踪器"""
    
    def __init__(self, project_name: str, storage_dir: str = "./experiments"):
        self.project_name = project_name
        self.storage_dir = Path(storage_dir) / project_name
        self.storage_dir.mkdir(parents=True, exist_ok=True)
        
        self.experiment_id = str(uuid.uuid4())[:8]
        self.experiment_dir = self.storage_dir / self.experiment_id
        self.experiment_dir.mkdir(exist_ok=True)
        
        self.metrics_history: List[ExperimentMetrics] = []
        self.artifacts: Dict[str, str] = {}
        
        # 初始化日志文件
        self.log_file = self.experiment_dir / "metrics.jsonl"
        self.config_file = self.experiment_dir / "config.json"
        self.artifacts_file = self.experiment_dir / "artifacts.json"
    
    def init_experiment(self, config: ExperimentConfig):
        """初始化实验"""
        with open(self.config_file, 'w') as f:
            json.dump({
                "experiment_id": self.experiment_id,
                "project_name": self.project_name,
                "created_at": datetime.now().isoformat(),
                "config": {
                    "model_name": config.model_name,
                    "dataset": config.dataset,
                    "batch_size": config.batch_size,
                    "learning_rate": config.learning_rate,
                    "num_epochs": config.num_epochs,
                    "optimizer": config.optimizer,
                    "scheduler": config.scheduler,
                    "weight_decay": config.weight_decay,
                    "max_seq_length": config.max_seq_length,
                    **config.additional_params
                }
            }, f, indent=2)
        
        print(f"Experiment initialized: {self.experiment_id}")
        return self.experiment_id
    
    def log_metrics(self, metrics: ExperimentMetrics):
        """记录指标"""
        self.metrics_history.append(metrics)
        
        with open(self.log_file, 'a') as f:
            f.write(json.dumps({
                "step": metrics.step,
                "train_loss": metrics.train_loss,
                "val_loss": metrics.val_loss,
                "learning_rate": metrics.learning_rate,
                "perplexity": metrics.perplexity,
                "throughput": metrics.throughput,
                "gpu_memory_used": metrics.gpu_memory_used,
                "timestamp": metrics.timestamp
            }) + '\n')
    
    def log_artifact(self, artifact_name: str, artifact_path: str):
        """记录工件(模型检查点、数据等)"""
        self.artifacts[artifact_name] = artifact_path
        
        with open(self.artifacts_file, 'w') as f:
            json.dump(self.artifacts, f, indent=2)
    
    def log_tag(self, tag: str):
        """添加标签"""
        tags_file = self.experiment_dir / "tags.txt"
        with open(tags_file, 'a') as f:
            f.write(f"{tag}\n")
    
    def log_note(self, note: str):
        """添加备注"""
        notes_file = self.experiment_dir / "notes.txt"
        with open(notes_file, 'a') as f:
            f.write(f"[{datetime.now().isoformat()}] {note}\n")
    
    def finish_experiment(self, status: str = "completed"):
        """完成实验"""
        summary = {
            "experiment_id": self.experiment_id,
            "status": status,
            "started_at": self.metrics_history[0].timestamp if self.metrics_history else None,
            "finished_at": datetime.now().isoformat(),
            "total_steps": len(self.metrics_history),
            "final_metrics": {
                "train_loss": self.metrics_history[-1].train_loss if self.metrics_history else None,
                "val_loss": self.metrics_history[-1].val_loss if self.metrics_history else None,
            }
        }
        
        summary_file = self.experiment_dir / "summary.json"
        with open(summary_file, 'w') as f:
            json.dump(summary, f, indent=2)
        
        print(f"Experiment finished: {self.experiment_id}")
        return summary

训练集成

PyTorch训练集成

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

class TrackedTrainer:
    """带追踪功能的训练器"""
    
    def __init__(self, model, config, tracker: ExperimentTracker):
        self.model = model
        self.config = config
        self.tracker = tracker
        
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
        
        self.current_step = 0
        self.best_val_loss = float('inf')
    
    def train_epoch(self, train_loader: DataLoader, epoch: int):
        """训练一个epoch"""
        self.model.train()
        total_loss = 0
        
        for batch_idx, batch in enumerate(train_loader):
            # 前向传播
            outputs = self.model(**batch)
            loss = outputs.loss
            
            # 反向传播
            self.optimizer.zero_grad()
            loss.backward()
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                self.config.max_grad_norm
            )
            
            self.optimizer.step()
            
            # 记录指标
            total_loss += loss.item()
            self.current_step += 1
            
            if self.current_step % self.config.logging_steps == 0:
                metrics = ExperimentMetrics(
                    step=self.current_step,
                    train_loss=loss.item(),
                    learning_rate=self.optimizer.param_groups[0]['lr'],
                    throughput=self._calculate_throughput()
                )
                self.tracker.log_metrics(metrics)
                print(f"Step {self.current_step}, Loss: {loss.item():.4f}")
        
        return total_loss / len(train_loader)
    
    def validate(self, val_loader: DataLoader):
        """验证"""
        self.model.eval()
        total_loss = 0
        
        with torch.no_grad():
            for batch in val_loader:
                outputs = self.model(**batch)
                total_loss += outputs.loss.item()
        
        avg_val_loss = total_loss / len(val_loader)
        
        # 记录验证指标
        if self.metrics_history:
            last_metrics = self.tracker.metrics_history[-1]
            metrics = ExperimentMetrics(
                step=self.current_step,
                train_loss=last_metrics.train_loss,
                val_loss=avg_val_loss,
                learning_rate=self.optimizer.param_groups[0]['lr']
            )
            self.tracker.log_metrics(metrics)
        
        # 保存最佳模型
        if avg_val_loss < self.best_val_loss:
            self.best_val_loss = avg_val_loss
            self._save_checkpoint("best_model")
        
        return avg_val_loss
    
    def _calculate_throughput(self):
        """计算训练吞吐量"""
        # 简化实现
        return 1000.0  # tokens/second
    
    def _save_checkpoint(self, name: str):
        """保存检查点"""
        checkpoint_path = self.tracker.experiment_dir / "checkpoints" / f"{name}.pt"
        checkpoint_path.parent.mkdir(exist_ok=True)
        
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'step': self.current_step,
            'best_val_loss': self.best_val_loss
        }, checkpoint_path)
        
        self.tracker.log_artifact(name, str(checkpoint_path))

实验比较与分析

class ExperimentAnalyzer:
    """实验分析器"""
    
    def __init__(self, storage_dir: str):
        self.storage_dir = Path(storage_dir)
    
    def load_experiment(self, experiment_id: str):
        """加载实验数据"""
        exp_dir = self.storage_dir / experiment_id
        
        # 加载配置
        with open(exp_dir / "config.json", 'r') as f:
            config = json.load(f)
        
        # 加载指标
        metrics = []
        with open(exp_dir / "metrics.jsonl", 'r') as f:
            for line in f:
                metrics.append(json.loads(line))
        
        # 加载摘要
        with open(exp_dir / "summary.json", 'r') as f:
            summary = json.load(f)
        
        return {
            "config": config,
            "metrics": metrics,
            "summary": summary
        }
    
    def compare_experiments(self, experiment_ids: List[str]):
        """比较多个实验"""
        comparison = {
            "experiments": [],
            "best_by_val_loss": None,
            "best_by_perplexity": None
        }
        
        for exp_id in experiment_ids:
            exp_data = self.load_experiment(exp_id)
            comparison["experiments"].append({
                "id": exp_id,
                "config": exp_data["config"]["config"],
                "final_val_loss": exp_data["summary"]["final_metrics"]["val_loss"],
                "total_steps": exp_data["summary"]["total_steps"]
            })
        
        # 找出最佳实验
        valid_exps = [
            e for e in comparison["experiments"]
            if e["final_val_loss"] is not None
        ]
        
        if valid_exps:
            comparison["best_by_val_loss"] = min(
                valid_exps, key=lambda x: x["final_val_loss"]
            )["id"]
        
        return comparison
    
    def generate_report(self, experiment_id: str):
        """生成实验报告"""
        exp_data = self.load_experiment(experiment_id)
        
        report = f"""# 实验报告: {experiment_id}

## 实验配置
- 模型: {exp_data['config']['config']['model_name']}
- 数据集: {exp_data['config']['config']['dataset']}
- 批量大小: {exp_data['config']['config']['batch_size']}
- 学习率: {exp_data['config']['config']['learning_rate']}
- 训练轮数: {exp_data['config']['config']['num_epochs']}

## 训练结果
- 总步数: {exp_data['summary']['total_steps']}
- 最终训练损失: {exp_data['summary']['final_metrics']['train_loss']:.4f}
- 最终验证损失: {exp_data['summary']['final_metrics']['val_loss']:.4f}

## 指标变化
"""
        # 添加关键指标变化
        for metric in ['train_loss', 'val_loss']:
            values = [m[metric] for m in exp_data['metrics'] if m.get(metric) is not None]
            if values:
                report += f"- {metric}: {values[0]:.4f} -> {values[-1]:.4f}\n"
        
        return report

可视化

import matplotlib.pyplot as plt
from typing import List, Dict

class ExperimentVisualizer:
    """实验可视化"""
    
    def __init__(self, storage_dir: str):
        self.storage_dir = Path(storage_dir)
    
    def plot_training_curves(self, experiment_id: str, save_path: Optional[str] = None):
        """绘制训练曲线"""
        exp_dir = self.storage_dir / experiment_id
        
        # 加载指标
        steps = []
        train_losses = []
        val_losses = []
        learning_rates = []
        
        with open(exp_dir / "metrics.jsonl", 'r') as f:
            for line in f:
                data = json.loads(line)
                steps.append(data['step'])
                train_losses.append(data['train_loss'])
                val_losses.append(data.get('val_loss'))
                learning_rates.append(data.get('learning_rate'))
        
        fig, axes = plt.subplots(2, 1, figsize=(10, 8))
        
        # 损失曲线
        axes[0].plot(steps, train_losses, label='Train Loss', marker='o', markersize=2)
        if any(v is not None for v in val_losses):
            val_steps = [s for s, v in zip(steps, val_losses) if v is not None]
            val_values = [v for v in val_losses if v is not None]
            axes[0].plot(val_steps, val_values, label='Val Loss', marker='s', markersize=2)
        axes[0].set_xlabel('Step')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Training and Validation Loss')
        axes[0].legend()
        axes[0].grid(True)
        
        # 学习率曲线
        if any(v is not None for v in learning_rates):
            lr_steps = [s for s, v in zip(steps, learning_rates) if v is not None]
            lr_values = [v for v in learning_rates if v is not None]
            axes[1].plot(lr_steps, lr_values, label='Learning Rate', color='green')
            axes[1].set_xlabel('Step')
            axes[1].set_ylabel('Learning Rate')
            axes[1].set_title('Learning Rate Schedule')
            axes[1].legend()
            axes[1].grid(True)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=150)
            print(f"Plot saved to {save_path}")
        
        return fig
    
    def compare_experiments(self, experiment_ids: List[str], metric: str = 'val_loss'):
        """比较多个实验的指定指标"""
        fig, ax = plt.subplots(figsize=(12, 6))
        
        for exp_id in experiment_ids:
            exp_dir = self.storage_dir / exp_id
            steps = []
            values = []
            
            with open(exp_dir / "metrics.jsonl", 'r') as f:
                for line in f:
                    data = json.loads(line)
                    if data.get(metric) is not None:
                        steps.append(data['step'])
                        values.append(data[metric])
            
            if steps:
                ax.plot(steps, values, label=exp_id, marker='o', markersize=3)
        
        ax.set_xlabel('Step')
        ax.set_ylabel(metric.replace('_', ' ').title())
        ax.set_title(f'Experiment Comparison: {metric}')
        ax.legend()
        ax.grid(True)
        
        return fig

有效的实验追踪系统是LLM开发流程的关键组成部分,它帮助团队保持组织性,加速迭代,并确保研究的可复现性。