← 返回首页
🧠

注意力可视化在LLM中的应用

📂 llm ⏱ 7 min 1379 words

--- title: "注意力可视化在LLM中的应用" description: "介绍注意力可视化技术在大型语言模型理解和解释中的应用。" tags: ["注意力可视化", "llm", "模型解释", "可解释性", "注意力机制"] category: "llm" icon: "🧠"

注意力可视化在LLM中的应用

什么是注意力可视化?

注意力可视化是将大型语言模型中的注意力权重以图形化方式展示的技术,帮助理解模型如何关注输入文本的不同部分。

注意力可视化原理

1. 基本注意力权重提取

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

class AttentionExtractor:
    def __init__(self, model):
        self.model = model
        self.attention_weights = {}
        self.hooks = []
    
    def register_hooks(self):
        """注册钩子函数以提取注意力权重"""
        for name, module in self.model.named_modules():
            if hasattr(module, 'attention') or 'attention' in name.lower():
                hook = module.register_forward_hook(self._attention_hook(name))
                self.hooks.append(hook)
    
    def _attention_hook(self, name):
        """注意力钩子函数"""
        def hook(module, input, output):
            # 某些模型返回注意力权重作为额外输出
            if isinstance(output, tuple) and len(output) > 1:
                self.attention_weights[name] = output[1].detach()
        return hook
    
    def remove_hooks(self):
        """移除钩子"""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
    
    def get_attention_weights(self):
        """获取注意力权重"""
        return self.attention_weights

2. 注意力热力图

class AttentionVisualizer:
    def __init__(self):
        self.figures = {}
    
    def plot_attention_heatmap(self, attention_matrix, tokens, 
                               title="Attention Heatmap", 
                               figsize=(10, 8)):
        """绘制注意力热力图"""
        fig, ax = plt.subplots(figsize=figsize)
        
        # 处理注意力矩阵
        if isinstance(attention_matrix, torch.Tensor):
            attention_matrix = attention_matrix.cpu().numpy()
        
        # 如果是3D数组(多头),取平均
        if len(attention_matrix.shape) == 3:
            attention_matrix = attention_matrix.mean(axis=0)
        
        # 绘制热力图
        sns.heatmap(attention_matrix, 
                    xticklabels=tokens,
                    yticklabels=tokens,
                    cmap='viridis',
                    ax=ax)
        
        ax.set_title(title)
        ax.set_xlabel('Key')
        ax.set_ylabel('Query')
        
        plt.tight_layout()
        return fig
    
    def plot_multi_head_attention(self, attention_matrices, tokens, 
                                  n_heads=8, figsize=(15, 10)):
        """绘制多头注意力"""
        fig, axes = plt.subplots(2, n_heads//2, figsize=figsize)
        axes = axes.flatten()
        
        for i in range(min(n_heads, len(attention_matrices))):
            ax = axes[i]
            
            attn = attention_matrices[i]
            if isinstance(attn, torch.Tensor):
                attn = attn.cpu().numpy()
            
            sns.heatmap(attn, 
                       xticklabels=tokens if i == 0 else [],
                       yticklabels=tokens if i % (n_heads//2) == 0 else [],
                       cmap='viridis',
                       ax=ax)
            ax.set_title(f'Head {i+1}')
        
        plt.suptitle('Multi-Head Attention')
        plt.tight_layout()
        return fig
    
    def plot_attention_rollout(self, attention_matrices, tokens, 
                               figsize=(12, 6)):
        """绘制注意力滚动图"""
        # 计算注意力滚动
        rollout = self._compute_attention_rollout(attention_matrices)
        
        fig, ax = plt.subplots(figsize=figsize)
        
        # 绘制条形图
        x_pos = np.arange(len(tokens))
        ax.bar(x_pos, rollout, align='center')
        
        ax.set_xticks(x_pos)
        ax.set_xticklabels(tokens, rotation=45, ha='right')
        ax.set_ylabel('Attention Score')
        ax.set_title('Attention Rollout')
        
        plt.tight_layout()
        return fig
    
    def _compute_attention_rollout(self, attention_matrices):
        """计算注意力滚动"""
        # 初始化单位矩阵
        rollout = np.eye(attention_matrices[0].shape[-1])
        
        for attn in attention_matrices:
            if isinstance(attn, torch.Tensor):
                attn = attn.cpu().numpy()
            
            # 多头取平均
            if len(attn.shape) == 3:
                attn = attn.mean(axis=0)
            
            # 添加残差连接
            attn = 0.5 * attn + 0.5 * np.eye(attn.shape[0])
            
            # 归一化
            attn = attn / attn.sum(axis=-1, keepdims=True)
            
            # 滚动
            rollout = np.dot(attn, rollout)
        
        # 取第一行(CLS token对其他token的注意力)
        return rollout[0]

3. 交互式注意力可视化

class InteractiveAttentionVisualizer:
    def __init__(self):
        self.attention_data = {}
    
    def prepare_for_interactive(self, attention_matrices, tokens, layer_idx=0):
        """准备交互式可视化数据"""
        self.attention_data = {
            'tokens': tokens,
            'layers': {}
        }
        
        for i, attn in enumerate(attention_matrices):
            if isinstance(attn, torch.Tensor):
                attn = attn.cpu().numpy()
            
            if len(attn.shape) == 3:
                attn = attn.mean(axis=0)
            
            self.attention_data['layers'][i] = attn
        
        return self.attention_data
    
    def create_attention_graph(self, attention_matrix, tokens, 
                               threshold=0.1, figsize=(12, 8)):
        """创建注意力图"""
        import networkx as nx
        
        G = nx.DiGraph()
        
        # 添加节点
        for i, token in enumerate(tokens):
            G.add_node(i, label=token)
        
        # 添加边(注意力权重)
        for i in range(len(tokens)):
            for j in range(len(tokens)):
                weight = attention_matrix[i][j]
                if weight > threshold:
                    G.add_edge(i, j, weight=weight)
        
        # 绘制图形
        fig, ax = plt.subplots(figsize=figsize)
        
        pos = nx.spring_layout(G, k=2, iterations=50)
        
        # 绘制节点
        nx.draw_networkx_nodes(G, pos, node_size=2000, node_color='lightblue', ax=ax)
        
        # 绘制边
        edges = G.edges(data=True)
        weights = [data['weight'] * 5 for _, _, data in edges]
        nx.draw_networkx_edges(G, pos, width=weights, alpha=0.6, 
                              edge_color='gray', arrows=True, ax=ax)
        
        # 绘制标签
        labels = {i: tokens[i] for i in range(len(tokens))}
        nx.draw_networkx_labels(G, pos, labels, font_size=10, ax=ax)
        
        ax.set_title('Attention Graph')
        plt.tight_layout()
        return fig

LLM注意力可视化实践

1. Transformer注意力分析

class TransformerAttentionAnalyzer:
    def __init__(self, model):
        self.model = model
        self.extractor = AttentionExtractor(model)
        self.visualizer = AttentionVisualizer()
    
    def analyze_input(self, input_ids, tokens):
        """分析输入的注意力"""
        # 注册钩子
        self.extractor.register_hooks()
        
        # 前向传播
        with torch.no_grad():
            outputs = self.model(input_ids)
        
        # 获取注意力权重
        attention_weights = self.extractor.get_attention_weights()
        
        # 移除钩子
        self.extractor.remove_hooks()
        
        return attention_weights
    
    def visualize_layer_attention(self, attention_weights, tokens, 
                                  layer_name, head_idx=None):
        """可视化某层的注意力"""
        if layer_name not in attention_weights:
            return None
        
        attn = attention_weights[layer_name]
        
        if isinstance(attn, torch.Tensor):
            attn = attn.cpu().numpy()
        
        # 处理多头
        if len(attn.shape) == 4:  # [batch, heads, seq, seq]
            if head_idx is not None:
                attn = attn[0, head_idx]  # 取特定头
            else:
                attn = attn[0].mean(axis=0)  # 平均所有头
        
        # 绘制热力图
        fig = self.visualizer.plot_attention_heatmap(
            attn, tokens, title=f"Attention in {layer_name}"
        )
        
        return fig
    
    def analyze_attention_patterns(self, attention_weights):
        """分析注意力模式"""
        patterns = {}
        
        for layer_name, attn in attention_weights.items():
            if isinstance(attn, torch.Tensor):
                attn = attn.cpu().numpy()
            
            # 多头取平均
            if len(attn.shape) == 4:
                attn = attn[0].mean(axis=0)
            
            # 分析注意力模式
            patterns[layer_name] = {
                'entropy': self._calculate_entropy(attn),
                'sparsity': self._calculate_sparsity(attn),
                'locality': self._calculate_locality(attn)
            }
        
        return patterns
    
    def _calculate_entropy(self, attention_matrix):
        """计算注意力熵"""
        # 避免log(0)
        attention_matrix = attention_matrix + 1e-10
        
        # 计算熵
        entropy = -np.sum(attention_matrix * np.log(attention_matrix), axis=-1)
        
        return np.mean(entropy)
    
    def _calculate_sparsity(self, attention_matrix, threshold=0.1):
        """计算注意力稀疏度"""
        sparse_count = np.sum(attention_matrix < threshold)
        total_count = attention_matrix.size
        
        return sparse_count / total_count
    
    def _calculate_locality(self, attention_matrix):
        """计算注意力局部性"""
        seq_len = attention_matrix.shape[0]
        
        # 计算对角线附近的注意力权重
        locality_scores = []
        
        for i in range(seq_len):
            # 计算距离i位置的权重
            distances = np.abs(np.arange(seq_len) - i)
            weights = attention_matrix[i]
            
            # 加权平均距离
            weighted_distance = np.sum(distances * weights)
            locality_scores.append(1.0 / (1.0 + weighted_distance))
        
        return np.mean(locality_scores)

2. 提示工程中的注意力分析

class PromptAttentionAnalyzer:
    def __init__(self, model):
        self.model = model
        self.visualizer = AttentionVisualizer()
    
    def analyze_prompt_parts(self, input_ids, tokens, prompt_parts):
        """分析提示各部分的注意力"""
        # 提取注意力权重
        extractor = AttentionExtractor(self.model)
        extractor.register_hooks()
        
        with torch.no_grad():
            outputs = self.model(input_ids)
        
        attention_weights = extractor.get_attention_weights()
        extractor.remove_hooks()
        
        # 分析各部分注意力
        part_attention = {}
        
        for part_name, start_idx, end_idx in prompt_parts:
            part_tokens = tokens[start_idx:end_idx]
            
            # 收集该部分的注意力
            part_attns = []
            for layer_name, attn in attention_weights.items():
                if isinstance(attn, torch.Tensor):
                    attn = attn.cpu().numpy()
                
                if len(attn.shape) == 4:
                    attn = attn[0].mean(axis=0)
                
                # 提取该部分的注意力
                part_attn = attn[start_idx:end_idx, :]
                part_attns.append(part_attn)
            
            part_attention[part_name] = {
                'tokens': part_tokens,
                'attention': np.mean(part_attns, axis=0)
            }
        
        return part_attention
    
    def compare_prompt_templates(self, templates, tokenizer):
        """比较不同提示模板的注意力"""
        results = {}
        
        for template_name, template in templates.items():
            # 编码模板
            inputs = tokenizer(template, return_tensors="pt")
            
            # 提取注意力
            extractor = AttentionExtractor(self.model)
            extractor.register_hooks()
            
            with torch.no_grad():
                outputs = self.model(**inputs)
            
            attention_weights = extractor.get_attention_weights()
            extractor.remove_hooks()
            
            # 计算统计信息
            stats = self._compute_attention_stats(attention_weights)
            
            results[template_name] = {
                'tokens': tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]),
                'statistics': stats
            }
        
        return results
    
    def _compute_attention_stats(self, attention_weights):
        """计算注意力统计信息"""
        stats = {
            'mean_entropy': [],
            'max_attention': [],
            'attention_coverage': []
        }
        
        for layer_name, attn in attention_weights.items():
            if isinstance(attn, torch.Tensor):
                attn = attn.cpu().numpy()
            
            if len(attn.shape) == 4:
                attn = attn[0].mean(axis=0)
            
            # 计算熵
            attn_plus = attn + 1e-10
            entropy = -np.sum(attn_plus * np.log(attn_plus), axis=-1)
            stats['mean_entropy'].append(np.mean(entropy))
            
            # 最大注意力
            stats['max_attention'].append(np.max(attn))
            
            # 注意力覆盖度(非零权重的比例)
            stats['attention_coverage'].append(np.mean(attn > 0.01))
        
        return {k: np.mean(v) for k, v in stats.items()}

3. 多语言注意力分析

class MultilingualAttentionAnalyzer:
    def __init__(self, model):
        self.model = model
        self.visualizer = AttentionVisualizer()
    
    def compare_languages(self, texts_dict, tokenizer):
        """比较不同语言的注意力"""
        results = {}
        
        for lang, text in texts_dict.items():
            # 编码文本
            inputs = tokenizer(text, return_tensors="pt", padding=True)
            
            # 提取注意力
            extractor = AttentionExtractor(self.model)
            extractor.register_hooks()
            
            with torch.no_grad():
                outputs = self.model(**inputs)
            
            attention_weights = extractor.get_attention_weights()
            extractor.remove_hooks()
            
            # 分析注意力
            analysis = self._analyze_attention(attention_weights, inputs, tokenizer)
            
            results[lang] = {
                'text': text,
                'tokens': tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]),
                'analysis': analysis
            }
        
        return results
    
    def _analyze_attention(self, attention_weights, inputs, tokenizer):
        """分析注意力"""
        analysis = {
            'token_importance': {},
            'attention_patterns': {}
        }
        
        # 获取tokens
        tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
        
        for layer_name, attn in attention_weights.items():
            if isinstance(attn, torch.Tensor):
                attn = attn.cpu().numpy()
            
            if len(attn.shape) == 4:
                attn = attn[0].mean(axis=0)
            
            # 计算每个token的重要性
            token_importance = np.mean(attn, axis=0)
            
            # 按重要性排序
            important_tokens = sorted(
                zip(tokens, token_importance), 
                key=lambda x: x[1], 
                reverse=True
            )
            
            analysis['token_importance'][layer_name] = important_tokens[:5]
            analysis['attention_patterns'][layer_name] = {
                'entropy': self._calculate_entropy(attn),
                'sparsity': np.mean(attn < 0.01)
            }
        
        return analysis
    
    def _calculate_entropy(self, attention_matrix):
        """计算注意力熵"""
        attention_matrix = attention_matrix + 1e-10
        entropy = -np.sum(attention_matrix * np.log(attention_matrix), axis=-1)
        return np.mean(entropy)

实际应用案例

案例:LLM注意力分析系统

# LLM注意力分析系统
class LLMAttentionAnalysisSystem:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.analyzer = TransformerAttentionAnalyzer(model)
        self.visualizer = AttentionVisualizer()
    
    def analyze_text(self, text):
        """分析文本的注意力"""
        # 编码文本
        inputs = self.tokenizer(text, return_tensors="pt")
        tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
        
        # 分析注意力
        attention_weights = self.analyzer.analyze_input(inputs['input_ids'], tokens)
        
        # 分析注意力模式
        patterns = self.analyzer.analyze_attention_patterns(attention_weights)
        
        return {
            'tokens': tokens,
            'attention_weights': attention_weights,
            'patterns': patterns
        }
    
    def visualize_analysis(self, analysis, save_dir=None):
        """可视化分析结果"""
        tokens = analysis['tokens']
        attention_weights = analysis['attention_weights']
        
        visualizations = {}
        
        for layer_name, attn in attention_weights.items():
            if isinstance(attn, torch.Tensor):
                attn = attn.cpu().numpy()
            
            # 多头取平均
            if len(attn.shape) == 4:
                attn = attn[0].mean(axis=0)
            
            # 绘制热力图
            fig = self.visualizer.plot_attention_heatmap(
                attn, tokens, title=f"Layer: {layer_name}"
            )
            
            visualizations[layer_name] = fig
        
        return visualizations
    
    def generate_report(self, analysis):
        """生成分析报告"""
        report = {
            'summary': {},
            'layer_analysis': {},
            'recommendations': []
        }
        
        patterns = analysis['patterns']
        
        # 计算平均指标
        avg_entropy = np.mean([p['entropy'] for p in patterns.values()])
        avg_sparsity = np.mean([p['sparsity'] for p in patterns.values()])
        avg_locality = np.mean([p['locality'] for p in patterns.values()])
        
        report['summary'] = {
            'average_entropy': avg_entropy,
            'average_sparsity': avg_sparsity,
            'average_locality': avg_locality
        }
        
        # 生成建议
        if avg_entropy < 1.0:
            report['recommendations'].append("注意力过于集中,可能影响泛化能力")
        
        if avg_sparsity > 0.8:
            report['recommendations'].append("注意力非常稀疏,考虑使用稀疏注意力机制")
        
        if avg_locality < 0.3:
            report['recommendations'].append("注意力局部性较弱,模型可能难以捕捉长距离依赖")
        
        return report

# 使用示例
# 假设已有模型和分词器
# analyzer_system = LLMAttentionAnalysisSystem(model, tokenizer)
# analysis = analyzer_system.analyze_text("这是一个测试句子。")
# visualizations = analyzer_system.visualize_analysis(analysis)
# report = analyzer_system.generate_report(analysis)

案例:注意力质量评估

# 注意力质量评估
def evaluate_attention_quality(attention_weights, tokens, 
                               expected_patterns=None):
    """评估注意力质量"""
    
    quality_metrics = {}
    
    for layer_name, attn in attention_weights.items():
        if isinstance(attn, torch.Tensor):
            attn = attn.cpu().numpy()
        
        if len(attn.shape) == 4:
            attn = attn[0].mean(axis=0)
        
        metrics = {
            'entropy': calculate_entropy(attn),
            'sparsity': calculate_sparsity(attn),
            'locality': calculate_locality(attn),
            'focus_consistency': calculate_focus_consistency(attn)
        }
        
        # 如果提供了预期模式,计算匹配度
        if expected_patterns:
            metrics['pattern_match'] = calculate_pattern_match(
                attn, expected_patterns[layer_name]
            )
        
        quality_metrics[layer_name] = metrics
    
    # 计算整体质量分数
    overall_score = calculate_overall_quality_score(quality_metrics)
    
    return {
        'layer_metrics': quality_metrics,
        'overall_score': overall_score
    }

def calculate_entropy(attention_matrix):
    """计算注意力熵"""
    attention_matrix = attention_matrix + 1e-10
    entropy = -np.sum(attention_matrix * np.log(attention_matrix), axis=-1)
    return np.mean(entropy)

def calculate_sparsity(attention_matrix, threshold=0.01):
    """计算注意力稀疏度"""
    return np.mean(attention_matrix < threshold)

def calculate_locality(attention_matrix):
    """计算注意力局部性"""
    seq_len = attention_matrix.shape[0]
    locality_scores = []
    
    for i in range(seq_len):
        distances = np.abs(np.arange(seq_len) - i)
        weights = attention_matrix[i]
        weighted_distance = np.sum(distances * weights)
        locality_scores.append(1.0 / (1.0 + weighted_distance))
    
    return np.mean(locality_scores)

def calculate_focus_consistency(attention_matrix):
    """计算注意力焦点一致性"""
    seq_len = attention_matrix.shape[0]
    consistency_scores = []
    
    for i in range(seq_len):
        # 找到注意力最集中的位置
        max_attention_idx = np.argmax(attention_matrix[i])
        
        # 检查是否一致地关注某些位置
        top_k = 3
        top_indices = np.argsort(attention_matrix[i])[-top_k:]
        
        # 计算一致性(这些位置是否总是被关注)
        consistency = np.mean(attention_matrix[:, top_indices])
        consistency_scores.append(consistency)
    
    return np.mean(consistency_scores)

def calculate_pattern_match(attention_matrix, expected_pattern):
    """计算注意力模式匹配度"""
    # 这里简化处理,实际应用中需要更复杂的匹配算法
    return 0.8  # 占位符

def calculate_overall_quality_score(quality_metrics):
    """计算整体质量分数"""
    scores = []
    
    for layer_metrics in quality_metrics.values():
        # 根据各指标计算分数
        score = 0
        
        # 熵:适中最好
        entropy = layer_metrics['entropy']
        if 1.0 < entropy < 3.0:
            score += 0.3
        
        # 稀疏度:适中最好
        sparsity = layer_metrics['sparsity']
        if 0.3 < sparsity < 0.7:
            score += 0.3
        
        # 局部性:越高越好
        locality = layer_metrics['locality']
        score += 0.4 * locality
        
        scores.append(score)
    
    return np.mean(scores)

总结

注意力可视化是理解和解释LLM的重要工具:

  1. 模型理解 - 帮助理解模型如何处理输入
  2. 调试工具 - 诊断模型行为异常
  3. 提示优化 - 指导提示工程改进
  4. 多语言分析 - 比较不同语言的处理方式
  5. 质量评估 - 评估注意力机制的质量

通过注意力可视化,我们可以更好地理解LLM的内部工作机制,提高模型的可解释性和可控性。