← 返回首页
🧠

TensorBoard:使用TensorBoard可视化LLM训练

📂 llm ⏱ 5 min 905 words

TensorBoard:使用TensorBoard可视化LLM训练

TensorBoard简介

TensorBoard是TensorFlow和PyTorch都支持的可视化工具,提供训练过程的实时监控、模型图可视化、权重分布分析、嵌入空间投影等功能。对于大语言模型开发,TensorBoard可以帮助研究人员深入理解模型行为。

基础配置

安装与启动

# 安装TensorBoard
pip install tensorboard

# 启动TensorBoard
tensorboard --logdir=./logs --port=6006

# 或指定日志目录
tensorboard --logdir_spec=training:./logs/train,validation:./logs/val

PyTorch集成

from torch.utils.tensorboard import SummaryWriter
import torch

# 创建SummaryWriter
writer = SummaryWriter(log_dir='./logs/experiment_001')

# 记录标量
writer.add_scalar('training/loss', loss_value, global_step)
writer.add_scalar('validation/accuracy', acc_value, global_step)

# 关闭writer
writer.close()

完整训练追踪

多维度追踪器

import torch
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import json
from datetime import datetime

class TensorBoardLLMTracker:
    """TensorBoard LLM训练追踪器"""
    
    def __init__(self, log_dir: str, experiment_name: str = None):
        if experiment_name is None:
            experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        self.log_dir = Path(log_dir) / experiment_name
        self.log_dir.mkdir(parents=True, exist_ok=True)
        
        self.writer = SummaryWriter(log_dir=str(self.log_dir))
        self.global_step = 0
        
        # 保存实验配置
        self.config_file = self.log_dir / "config.json"
    
    def log_config(self, config: dict):
        """记录配置"""
        with open(self.config_file, 'w') as f:
            json.dump(config, f, indent=2)
        
        # 在TensorBoard中添加文本
        config_text = json.dumps(config, indent=2, ensure_ascii=False)
        self.writer.add_text('config', f"```\n{config_text}\n```")
    
    def log_scalar(self, tag: str, value: float, step: int = None):
        """记录标量"""
        if step is None:
            step = self.global_step
        self.writer.add_scalar(tag, value, step)
    
    def log_scalars(self, main_tag: str, tag_scalar_dict: dict, step: int = None):
        """记录多个标量"""
        if step is None:
            step = self.global_step
        self.writer.add_scalars(main_tag, tag_scalar_dict, step)
    
    def log_histogram(self, tag: str, values: torch.Tensor, step: int = None):
        """记录直方图(用于权重和梯度分布)"""
        if step is None:
            step = self.global_step
        self.writer.add_histogram(tag, values, step)
    
    def log_image(self, tag: str, img_tensor: torch.Tensor, step: int = None):
        """记录图像"""
        if step is None:
            step = self.global_step
        self.writer.add_image(tag, img_tensor, step)
    
    def log_text(self, tag: str, text: str, step: int = None):
        """记录文本"""
        if step is None:
            step = self.global_step
        self.writer.add_text(tag, text, step)
    
    def log_model_graph(self, model: torch.nn.Module, input_size: tuple):
        """记录模型图"""
        dummy_input = torch.randn(*input_size)
        self.writer.add_graph(model, dummy_input)
    
    def log_embedding(self, tensor: torch.Tensor, metadata: list = None, 
                      tag: str = 'embeddings', step: int = None):
        """记录嵌入向量"""
        if step is None:
            step = self.global_step
        
        # 添加嵌入投影
        self.writer.add_embedding(
            tensor,
            metadata=metadata,
            tag=tag,
            global_step=step
        )
    
    def log_weights(self, model: torch.nn.Module, step: int = None):
        """记录模型权重"""
        if step is None:
            step = self.global_step
        
        for name, param in model.named_parameters():
            if param.requires_grad:
                # 记录权重分布
                self.writer.add_histogram(
                    f"weights/{name}",
                    param.data,
                    step
                )
                
                # 记录梯度分布
                if param.grad is not None:
                    self.writer.add_histogram(
                        f"gradients/{name}",
                        param.grad,
                        step
                    )
    
    def log_text_samples(self, samples: list, tag: str = "samples", step: int = None):
        """记录文本样本"""
        if step is None:
            step = self.global_step
        
        # 构建HTML表格
        table_html = "<table><tr><th>Input</th><th>Output</th><th>Target</th></tr>"
        for sample in samples:
            table_html += f"<tr><td>{sample['input']}</td>"
            table_html += f"<td>{sample['output']}</td>"
            table_html += f"<td>{sample['target']}</td></tr>"
        table_html += "</table>"
        
        self.writer.add_text(tag, table_html, step)
    
    def increment_step(self):
        """增加全局步数"""
        self.global_step += 1
    
    def flush(self):
        """刷新写入"""
        self.writer.flush()
    
    def close(self):
        """关闭writer"""
        self.writer.close()

模型训练集成

完整的训练循环

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

class TensorBoardTrainer:
    """带TensorBoard追踪的训练器"""
    
    def __init__(self, model, config):
        self.model = model
        self.config = config
        
        # 初始化追踪器
        self.tracker = TensorBoardLLMTracker(
            log_dir=config.log_dir,
            experiment_name=config.experiment_name
        )
        
        # 记录配置
        self.tracker.log_config(vars(config))
        
        # 记录模型图
        if hasattr(config, 'input_size'):
            self.tracker.log_model_graph(model, config.input_size)
    
    def train_epoch(self, train_loader: DataLoader, epoch: int):
        """训练一个epoch"""
        self.model.train()
        total_loss = 0
        
        for batch_idx, batch in enumerate(train_loader):
            # 前向传播
            outputs = self.model(**batch)
            loss = outputs.loss
            
            # 反向传播
            loss.backward()
            
            # 记录梯度
            self.tracker.log_weights(self.model, step=epoch * len(train_loader) + batch_idx)
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                self.config.max_grad_norm
            )
            
            # 记录损失
            total_loss += loss.item()
            self.tracker.log_scalar(
                'training/batch_loss',
                loss.item(),
                step=epoch * len(train_loader) + batch_idx
            )
            
            self.tracker.increment_step()
        
        avg_loss = total_loss / len(train_loader)
        self.tracker.log_scalar('training/epoch_loss', avg_loss, epoch)
        
        return avg_loss
    
    def validate(self, val_loader: DataLoader, epoch: int):
        """验证"""
        self.model.eval()
        total_loss = 0
        all_predictions = []
        all_targets = []
        
        with torch.no_grad():
            for batch in val_loader:
                outputs = self.model(**batch)
                loss = outputs.loss
                
                total_loss += loss.item()
                
                # 收集预测结果
                if hasattr(outputs, 'logits'):
                    predictions = torch.argmax(outputs.logits, dim=-1)
                    all_predictions.extend(predictions.cpu().numpy())
                    all_targets.extend(batch['labels'].cpu().numpy())
        
        avg_loss = total_loss / len(val_loader)
        
        # 记录验证指标
        self.tracker.log_scalar('validation/loss', avg_loss, epoch)
        
        # 记录混淆矩阵(如果有)
        if all_predictions and all_targets:
            self._log_confusion_matrix(all_predictions, all_targets, epoch)
        
        return avg_loss
    
    def _log_confusion_matrix(self, predictions, targets, step):
        """记录混淆矩阵"""
        import matplotlib.pyplot as plt
        import numpy as np
        from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
        
        # 计算混淆矩阵
        cm = confusion_matrix(targets, predictions)
        
        # 创建图表
        fig, ax = plt.subplots(figsize=(10, 8))
        disp = ConfusionMatrixDisplay(confusion_matrix=cm)
        disp.plot(ax=ax, cmap='Blues')
        ax.set_title(f'Confusion Matrix - Epoch {step}')
        
        # 保存到TensorBoard
        self.tracker.writer.add_figure('validation/confusion_matrix', fig, step)
        plt.close(fig)
    
    def log_attention_weights(self, attention_weights: torch.Tensor, step: int):
        """记录注意力权重"""
        # attention_weights shape: (batch, heads, seq_len, seq_len)
        if attention_weights.dim() == 4:
            # 平均所有头
            avg_attention = attention_weights.mean(dim=1)
            
            # 记录为图像
            for i in range(min(avg_attention.shape[0], 4)):  # 最多记录4个样本
                self.tracker.writer.add_image(
                    f'attention/sample_{i}',
                    avg_attention[i].unsqueeze(0),  # 添加通道维度
                    step
                )
    
    def log_learning_rate(self, optimizer, step: int):
        """记录学习率"""
        for param_group in optimizer.param_groups:
            self.tracker.log_scalar(
                'training/learning_rate',
                param_group['lr'],
                step
            )
    
    def finish(self):
        """完成训练"""
        self.tracker.close()

高级可视化

嵌入空间分析

import torch
from torch.utils.tensorboard import SummaryWriter
from sklearn.manifold import TSNE
import numpy as np

class EmbeddingVisualizer:
    """嵌入空间可视化"""
    
    def __init__(self, writer: SummaryWriter):
        self.writer = writer
    
    def visualize_embeddings(self, embeddings: torch.Tensor, 
                           labels: list = None,
                           tag: str = 'token_embeddings',
                           step: int = 0,
                           num_points: int = 1000):
        """可视化嵌入空间"""
        # 采样(如果点太多)
        if embeddings.shape[0] > num_points:
            indices = torch.randperm(embeddings.shape[0])[:num_points]
            embeddings = embeddings[indices]
            if labels:
                labels = [labels[i] for i in indices.cpu().numpy()]
        
        # 使用t-SNE降维
        tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
        embeddings_2d = tsne.fit_transform(embeddings.cpu().numpy())
        
        # 转换为图像
        fig, ax = plt.subplots(figsize=(10, 8))
        
        scatter = ax.scatter(
            embeddings_2d[:, 0],
            embeddings_2d[:, 1],
            c=range(len(embeddings_2d)),
            cmap='viridis',
            alpha=0.6
        )
        
        if labels:
            for i, label in enumerate(labels[:100]):  # 只显示前100个标签
                ax.annotate(
                    label,
                    (embeddings_2d[i, 0], embeddings_2d[i, 1]),
                    fontsize=8,
                    alpha=0.7
                )
        
        ax.set_title(f'{tag} - t-SNE Visualization')
        plt.colorbar(scatter)
        
        self.writer.add_figure(tag, fig, step)
        plt.close(fig)
    
    def visualize_attention_patterns(self, attention_weights: torch.Tensor,
                                    tokenizer=None,
                                    step: int = 0):
        """可视化注意力模式"""
        # attention_weights: (batch, heads, seq_len, seq_len)
        num_heads = attention_weights.shape[1]
        
        fig, axes = plt.subplots(2, num_heads // 2, figsize=(20, 10))
        axes = axes.flatten()
        
        for head_idx in range(num_heads):
            # 获取单个头的注意力权重
            attn = attention_weights[0, head_idx].cpu().numpy()
            
            # 绘制热图
            im = axes[head_idx].imshow(attn, cmap='viridis', aspect='auto')
            axes[head_idx].set_title(f'Head {head_idx + 1}')
            axes[head_idx].set_xlabel('Key Position')
            axes[head_idx].set_ylabel('Query Position')
        
        plt.tight_layout()
        self.writer.add_figure('attention/patterns', fig, step)
        plt.close(fig)

文本生成监控

class TextGenerationMonitor:
    """文本生成监控"""
    
    def __init__(self, writer: SummaryWriter):
        self.writer = writer
    
    def log_generation_samples(self, model, tokenizer, prompts: list, 
                              step: int, max_length: int = 100):
        """记录生成样本"""
        model.eval()
        samples = []
        
        for prompt in prompts:
            inputs = tokenizer(prompt, return_tensors="pt")
            
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_length=max_length,
                    num_return_sequences=1,
                    do_sample=True,
                    temperature=0.7,
                    top_k=50
                )
            
            generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            samples.append({
                "prompt": prompt,
                "generated": generated_text
            })
        
        # 创建表格
        table_html = "<table><tr><th>Prompt</th><th>Generated</th></tr>"
        for sample in samples:
            table_html += f"<tr><td>{sample['prompt']}</td><td>{sample['generated']}</td></tr>"
        table_html += "</table>"
        
        self.writer.add_text('generation/samples', table_html, step)
    
    def log_token_distribution(self, logits: torch.Tensor, step: int):
        """记录token分布"""
        # logits: (batch, vocab_size)
        probs = torch.softmax(logits[0], dim=-1)
        
        # 获取top-k tokens
        top_k = 20
        top_probs, top_indices = torch.topk(probs, top_k)
        
        # 绘制分布
        fig, ax = plt.subplots(figsize=(12, 6))
        ax.bar(range(top_k), top_probs.cpu().numpy())
        ax.set_xlabel('Token Index')
        ax.set_ylabel('Probability')
        ax.set_title('Token Probability Distribution')
        
        self.writer.add_figure('generation/token_distribution', fig, step)
        plt.close(fig)

TensorBoard提供了强大的可视化能力,帮助LLM研究人员和工程师深入理解模型训练过程和模型行为,是大语言模型开发中不可或缺的工具。