← 返回首页
🧠

注意力机制深入解析

📂 llm ⏱ 3 min 428 words

--- title: "注意力机制深入解析" description: "从数学原理到代码实现,全面理解Transformer中的注意力机制" tags: ["注意力机制", "Self-Attention", "Transformer", "深度学习"] category: "llm" icon: "🧠"

注意力机制深入解析

注意力机制的直觉

注意力机制的核心思想可以用一个简单的比喻来理解:当你阅读一篇文章时,你的大脑并不会对每个词给予相同的关注,而是会根据当前的任务,将注意力集中在相关的词上。

例如,当理解句子"小明把书放在桌子上"时:

注意力机制的数学表达

基本形式

注意力函数可以看作是一个查询(Query)到一组键值对(Key-Value pairs)的映射:

Attention(Q, K, V) = softmax(QK^T / √d_k) V

各组件的含义:

缩放点积注意力

import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, V)
    return output, attention_weights

# 示例
batch_size, seq_len, d_model = 2, 5, 64
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

output, weights = scaled_dot_product_attention(Q, K, V)
print(f"输出形状: {output.shape}")  # [2, 5, 64]
print(f"注意力权重形状: {weights.shape}")  # [2, 5, 5]

注意力权重的可视化

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(tokens, attention_weights, head_idx=0):
    """可视化注意力权重"""
    weights = attention_weights[0, head_idx].detach().numpy()
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        weights, 
        xticklabels=tokens, 
        yticklabels=tokens,
        annot=True, 
        fmt='.2f',
        cmap='Blues'
    )
    plt.title(f'Head {head_idx} Attention')
    plt.xlabel('Key')
    plt.ylabel('Query')
    plt.tight_layout()
    plt.savefig('attention可视化.png', dpi=150)
    plt.show()

# 使用示例
tokens = ['我', '喜欢', '学习', '自然', '语言']
# attention_weights 来自模型计算
# visualize_attention(tokens, attention_weights)

注意力的不同类型

1. 自注意力(Self-Attention)

自注意力是 Transformer 中最常用的注意力形式,其中 Q、K、V 都来自同一个输入序列。

class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        return scaled_dot_product_attention(Q, K, V)

2. 交叉注意力(Cross-Attention)

在编码器-解码器模型中,解码器通过交叉注意力关注编码器的输出。

class CrossAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_model)  # 来自解码器
        self.W_k = nn.Linear(d_model, d_model)  # 来自编码器
        self.W_v = nn.Linear(d_model, d_model)  # 来自编码器
    
    def forward(self, decoder_input, encoder_output):
        Q = self.W_q(decoder_input)
        K = self.W_k(encoder_output)
        V = self.W_v(encoder_output)
        return scaled_dot_product_attention(Q, K, V)

3. 掩码注意力(Masked Attention)

在自回归生成时,模型不能看到未来的位置,需要使用掩码。

def create_causal_mask(seq_len):
    """创建因果掩码(下三角矩阵)"""
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask.unsqueeze(0).unsqueeze(0)

# 示例
mask = create_causal_mask(5)
print(mask)
# tensor([[[[1., 0., 0., 0., 0.],
#           [1., 1., 0., 0., 0.],
#           [1., 1., 1., 0., 0.],
#           [1., 1., 1., 1., 0.],
#           [1., 1., 1., 1., 1.]]]])

注意力的计算复杂度

标准自注意力的计算复杂度为 O(n²d),其中 n 是序列长度,d 是模型维度。这意味着:

序列长度 计算量(相对) 内存占用
128 1x 1x
256 4x 4x
512 16x 16x
1024 64x 64x

这解释了为什么处理长序列时需要特殊的优化技术。

注意力机制的变体

为了提高效率和性能,研究者提出了多种注意力变体:

  1. 稀疏注意力:只关注部分位置,降低复杂度
  2. 线性注意力:使用核函数近似,将复杂度降至 O(n)
  3. 分组查询注意力(GQA):多头共享键值,减少计算量
  4. Flash Attention:优化内存访问模式,提升实际速度

实际应用示例

import torch.nn as nn

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model=512, num_heads=8, d_ff=2048):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x, mask=None):
        # 自注意力 + 残差连接 + 层归一化
        attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
        x = self.norm1(x + attn_output)
        
        # FFN + 残差连接 + 层归一化
        ffn_output = self.ffn(x)
        x = self.norm2(x + ffn_output)
        return x

# 使用
layer = TransformerEncoderLayer()
x = torch.randn(2, 10, 512)  # batch=2, seq_len=10
output = layer(x)
print(f"输出形状: {output.shape}")  # [2, 10, 512]

总结

注意力机制是 Transformer 和大语言模型的核心组件。理解其数学原理、不同类型和计算复杂度,对于优化模型性能和设计新的架构至关重要。在接下来的文章中,我们将探讨 GPT 系列模型如何利用注意力机制。