TensorBoard:使用TensorBoard可视化LLM训练
TensorBoard:使用TensorBoard可视化LLM训练
TensorBoard简介
TensorBoard是TensorFlow和PyTorch都支持的可视化工具,提供训练过程的实时监控、模型图可视化、权重分布分析、嵌入空间投影等功能。对于大语言模型开发,TensorBoard可以帮助研究人员深入理解模型行为。
基础配置
安装与启动
# 安装TensorBoard
pip install tensorboard
# 启动TensorBoard
tensorboard --logdir=./logs --port=6006
# 或指定日志目录
tensorboard --logdir_spec=training:./logs/train,validation:./logs/val
PyTorch集成
from torch.utils.tensorboard import SummaryWriter
import torch
# 创建SummaryWriter
writer = SummaryWriter(log_dir='./logs/experiment_001')
# 记录标量
writer.add_scalar('training/loss', loss_value, global_step)
writer.add_scalar('validation/accuracy', acc_value, global_step)
# 关闭writer
writer.close()
完整训练追踪
多维度追踪器
import torch
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import json
from datetime import datetime
class TensorBoardLLMTracker:
"""TensorBoard LLM训练追踪器"""
def __init__(self, log_dir: str, experiment_name: str = None):
if experiment_name is None:
experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
self.log_dir = Path(log_dir) / experiment_name
self.log_dir.mkdir(parents=True, exist_ok=True)
self.writer = SummaryWriter(log_dir=str(self.log_dir))
self.global_step = 0
# 保存实验配置
self.config_file = self.log_dir / "config.json"
def log_config(self, config: dict):
"""记录配置"""
with open(self.config_file, 'w') as f:
json.dump(config, f, indent=2)
# 在TensorBoard中添加文本
config_text = json.dumps(config, indent=2, ensure_ascii=False)
self.writer.add_text('config', f"```\n{config_text}\n```")
def log_scalar(self, tag: str, value: float, step: int = None):
"""记录标量"""
if step is None:
step = self.global_step
self.writer.add_scalar(tag, value, step)
def log_scalars(self, main_tag: str, tag_scalar_dict: dict, step: int = None):
"""记录多个标量"""
if step is None:
step = self.global_step
self.writer.add_scalars(main_tag, tag_scalar_dict, step)
def log_histogram(self, tag: str, values: torch.Tensor, step: int = None):
"""记录直方图(用于权重和梯度分布)"""
if step is None:
step = self.global_step
self.writer.add_histogram(tag, values, step)
def log_image(self, tag: str, img_tensor: torch.Tensor, step: int = None):
"""记录图像"""
if step is None:
step = self.global_step
self.writer.add_image(tag, img_tensor, step)
def log_text(self, tag: str, text: str, step: int = None):
"""记录文本"""
if step is None:
step = self.global_step
self.writer.add_text(tag, text, step)
def log_model_graph(self, model: torch.nn.Module, input_size: tuple):
"""记录模型图"""
dummy_input = torch.randn(*input_size)
self.writer.add_graph(model, dummy_input)
def log_embedding(self, tensor: torch.Tensor, metadata: list = None,
tag: str = 'embeddings', step: int = None):
"""记录嵌入向量"""
if step is None:
step = self.global_step
# 添加嵌入投影
self.writer.add_embedding(
tensor,
metadata=metadata,
tag=tag,
global_step=step
)
def log_weights(self, model: torch.nn.Module, step: int = None):
"""记录模型权重"""
if step is None:
step = self.global_step
for name, param in model.named_parameters():
if param.requires_grad:
# 记录权重分布
self.writer.add_histogram(
f"weights/{name}",
param.data,
step
)
# 记录梯度分布
if param.grad is not None:
self.writer.add_histogram(
f"gradients/{name}",
param.grad,
step
)
def log_text_samples(self, samples: list, tag: str = "samples", step: int = None):
"""记录文本样本"""
if step is None:
step = self.global_step
# 构建HTML表格
table_html = "<table><tr><th>Input</th><th>Output</th><th>Target</th></tr>"
for sample in samples:
table_html += f"<tr><td>{sample['input']}</td>"
table_html += f"<td>{sample['output']}</td>"
table_html += f"<td>{sample['target']}</td></tr>"
table_html += "</table>"
self.writer.add_text(tag, table_html, step)
def increment_step(self):
"""增加全局步数"""
self.global_step += 1
def flush(self):
"""刷新写入"""
self.writer.flush()
def close(self):
"""关闭writer"""
self.writer.close()
模型训练集成
完整的训练循环
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
class TensorBoardTrainer:
"""带TensorBoard追踪的训练器"""
def __init__(self, model, config):
self.model = model
self.config = config
# 初始化追踪器
self.tracker = TensorBoardLLMTracker(
log_dir=config.log_dir,
experiment_name=config.experiment_name
)
# 记录配置
self.tracker.log_config(vars(config))
# 记录模型图
if hasattr(config, 'input_size'):
self.tracker.log_model_graph(model, config.input_size)
def train_epoch(self, train_loader: DataLoader, epoch: int):
"""训练一个epoch"""
self.model.train()
total_loss = 0
for batch_idx, batch in enumerate(train_loader):
# 前向传播
outputs = self.model(**batch)
loss = outputs.loss
# 反向传播
loss.backward()
# 记录梯度
self.tracker.log_weights(self.model, step=epoch * len(train_loader) + batch_idx)
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.max_grad_norm
)
# 记录损失
total_loss += loss.item()
self.tracker.log_scalar(
'training/batch_loss',
loss.item(),
step=epoch * len(train_loader) + batch_idx
)
self.tracker.increment_step()
avg_loss = total_loss / len(train_loader)
self.tracker.log_scalar('training/epoch_loss', avg_loss, epoch)
return avg_loss
def validate(self, val_loader: DataLoader, epoch: int):
"""验证"""
self.model.eval()
total_loss = 0
all_predictions = []
all_targets = []
with torch.no_grad():
for batch in val_loader:
outputs = self.model(**batch)
loss = outputs.loss
total_loss += loss.item()
# 收集预测结果
if hasattr(outputs, 'logits'):
predictions = torch.argmax(outputs.logits, dim=-1)
all_predictions.extend(predictions.cpu().numpy())
all_targets.extend(batch['labels'].cpu().numpy())
avg_loss = total_loss / len(val_loader)
# 记录验证指标
self.tracker.log_scalar('validation/loss', avg_loss, epoch)
# 记录混淆矩阵(如果有)
if all_predictions and all_targets:
self._log_confusion_matrix(all_predictions, all_targets, epoch)
return avg_loss
def _log_confusion_matrix(self, predictions, targets, step):
"""记录混淆矩阵"""
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# 计算混淆矩阵
cm = confusion_matrix(targets, predictions)
# 创建图表
fig, ax = plt.subplots(figsize=(10, 8))
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(ax=ax, cmap='Blues')
ax.set_title(f'Confusion Matrix - Epoch {step}')
# 保存到TensorBoard
self.tracker.writer.add_figure('validation/confusion_matrix', fig, step)
plt.close(fig)
def log_attention_weights(self, attention_weights: torch.Tensor, step: int):
"""记录注意力权重"""
# attention_weights shape: (batch, heads, seq_len, seq_len)
if attention_weights.dim() == 4:
# 平均所有头
avg_attention = attention_weights.mean(dim=1)
# 记录为图像
for i in range(min(avg_attention.shape[0], 4)): # 最多记录4个样本
self.tracker.writer.add_image(
f'attention/sample_{i}',
avg_attention[i].unsqueeze(0), # 添加通道维度
step
)
def log_learning_rate(self, optimizer, step: int):
"""记录学习率"""
for param_group in optimizer.param_groups:
self.tracker.log_scalar(
'training/learning_rate',
param_group['lr'],
step
)
def finish(self):
"""完成训练"""
self.tracker.close()
高级可视化
嵌入空间分析
import torch
from torch.utils.tensorboard import SummaryWriter
from sklearn.manifold import TSNE
import numpy as np
class EmbeddingVisualizer:
"""嵌入空间可视化"""
def __init__(self, writer: SummaryWriter):
self.writer = writer
def visualize_embeddings(self, embeddings: torch.Tensor,
labels: list = None,
tag: str = 'token_embeddings',
step: int = 0,
num_points: int = 1000):
"""可视化嵌入空间"""
# 采样(如果点太多)
if embeddings.shape[0] > num_points:
indices = torch.randperm(embeddings.shape[0])[:num_points]
embeddings = embeddings[indices]
if labels:
labels = [labels[i] for i in indices.cpu().numpy()]
# 使用t-SNE降维
tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
embeddings_2d = tsne.fit_transform(embeddings.cpu().numpy())
# 转换为图像
fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(
embeddings_2d[:, 0],
embeddings_2d[:, 1],
c=range(len(embeddings_2d)),
cmap='viridis',
alpha=0.6
)
if labels:
for i, label in enumerate(labels[:100]): # 只显示前100个标签
ax.annotate(
label,
(embeddings_2d[i, 0], embeddings_2d[i, 1]),
fontsize=8,
alpha=0.7
)
ax.set_title(f'{tag} - t-SNE Visualization')
plt.colorbar(scatter)
self.writer.add_figure(tag, fig, step)
plt.close(fig)
def visualize_attention_patterns(self, attention_weights: torch.Tensor,
tokenizer=None,
step: int = 0):
"""可视化注意力模式"""
# attention_weights: (batch, heads, seq_len, seq_len)
num_heads = attention_weights.shape[1]
fig, axes = plt.subplots(2, num_heads // 2, figsize=(20, 10))
axes = axes.flatten()
for head_idx in range(num_heads):
# 获取单个头的注意力权重
attn = attention_weights[0, head_idx].cpu().numpy()
# 绘制热图
im = axes[head_idx].imshow(attn, cmap='viridis', aspect='auto')
axes[head_idx].set_title(f'Head {head_idx + 1}')
axes[head_idx].set_xlabel('Key Position')
axes[head_idx].set_ylabel('Query Position')
plt.tight_layout()
self.writer.add_figure('attention/patterns', fig, step)
plt.close(fig)
文本生成监控
class TextGenerationMonitor:
"""文本生成监控"""
def __init__(self, writer: SummaryWriter):
self.writer = writer
def log_generation_samples(self, model, tokenizer, prompts: list,
step: int, max_length: int = 100):
"""记录生成样本"""
model.eval()
samples = []
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
do_sample=True,
temperature=0.7,
top_k=50
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
samples.append({
"prompt": prompt,
"generated": generated_text
})
# 创建表格
table_html = "<table><tr><th>Prompt</th><th>Generated</th></tr>"
for sample in samples:
table_html += f"<tr><td>{sample['prompt']}</td><td>{sample['generated']}</td></tr>"
table_html += "</table>"
self.writer.add_text('generation/samples', table_html, step)
def log_token_distribution(self, logits: torch.Tensor, step: int):
"""记录token分布"""
# logits: (batch, vocab_size)
probs = torch.softmax(logits[0], dim=-1)
# 获取top-k tokens
top_k = 20
top_probs, top_indices = torch.topk(probs, top_k)
# 绘制分布
fig, ax = plt.subplots(figsize=(12, 6))
ax.bar(range(top_k), top_probs.cpu().numpy())
ax.set_xlabel('Token Index')
ax.set_ylabel('Probability')
ax.set_title('Token Probability Distribution')
self.writer.add_figure('generation/token_distribution', fig, step)
plt.close(fig)
TensorBoard提供了强大的可视化能力,帮助LLM研究人员和工程师深入理解模型训练过程和模型行为,是大语言模型开发中不可或缺的工具。