注意力机制深入解析
--- title: "注意力机制深入解析" description: "从数学原理到代码实现,全面理解Transformer中的注意力机制" tags: ["注意力机制", "Self-Attention", "Transformer", "深度学习"] category: "llm" icon: "🧠"
注意力机制深入解析
注意力机制的直觉
注意力机制的核心思想可以用一个简单的比喻来理解:当你阅读一篇文章时,你的大脑并不会对每个词给予相同的关注,而是会根据当前的任务,将注意力集中在相关的词上。
例如,当理解句子"小明把书放在桌子上"时:
- 如果问"谁放了书?",注意力会集中在"小明"上
- 如果问"书放在了哪里?",注意力会集中在"桌子上"
注意力机制的数学表达
基本形式
注意力函数可以看作是一个查询(Query)到一组键值对(Key-Value pairs)的映射:
Attention(Q, K, V) = softmax(QK^T / √d_k) V
各组件的含义:
- Q(Query):当前要关注的位置的表示
- K(Key):每个位置的索引表示
- V(Value):每个位置的内容表示
- d_k:键向量的维度,用于缩放
缩放点积注意力
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output, attention_weights
# 示例
batch_size, seq_len, d_model = 2, 5, 64
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"输出形状: {output.shape}") # [2, 5, 64]
print(f"注意力权重形状: {weights.shape}") # [2, 5, 5]
注意力权重的可视化
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(tokens, attention_weights, head_idx=0):
"""可视化注意力权重"""
weights = attention_weights[0, head_idx].detach().numpy()
plt.figure(figsize=(8, 6))
sns.heatmap(
weights,
xticklabels=tokens,
yticklabels=tokens,
annot=True,
fmt='.2f',
cmap='Blues'
)
plt.title(f'Head {head_idx} Attention')
plt.xlabel('Key')
plt.ylabel('Query')
plt.tight_layout()
plt.savefig('attention可视化.png', dpi=150)
plt.show()
# 使用示例
tokens = ['我', '喜欢', '学习', '自然', '语言']
# attention_weights 来自模型计算
# visualize_attention(tokens, attention_weights)
注意力的不同类型
1. 自注意力(Self-Attention)
自注意力是 Transformer 中最常用的注意力形式,其中 Q、K、V 都来自同一个输入序列。
class SelfAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
def forward(self, x):
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
return scaled_dot_product_attention(Q, K, V)
2. 交叉注意力(Cross-Attention)
在编码器-解码器模型中,解码器通过交叉注意力关注编码器的输出。
class CrossAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.W_q = nn.Linear(d_model, d_model) # 来自解码器
self.W_k = nn.Linear(d_model, d_model) # 来自编码器
self.W_v = nn.Linear(d_model, d_model) # 来自编码器
def forward(self, decoder_input, encoder_output):
Q = self.W_q(decoder_input)
K = self.W_k(encoder_output)
V = self.W_v(encoder_output)
return scaled_dot_product_attention(Q, K, V)
3. 掩码注意力(Masked Attention)
在自回归生成时,模型不能看到未来的位置,需要使用掩码。
def create_causal_mask(seq_len):
"""创建因果掩码(下三角矩阵)"""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask.unsqueeze(0).unsqueeze(0)
# 示例
mask = create_causal_mask(5)
print(mask)
# tensor([[[[1., 0., 0., 0., 0.],
# [1., 1., 0., 0., 0.],
# [1., 1., 1., 0., 0.],
# [1., 1., 1., 1., 0.],
# [1., 1., 1., 1., 1.]]]])
注意力的计算复杂度
标准自注意力的计算复杂度为 O(n²d),其中 n 是序列长度,d 是模型维度。这意味着:
| 序列长度 | 计算量(相对) | 内存占用 |
|---|---|---|
| 128 | 1x | 1x |
| 256 | 4x | 4x |
| 512 | 16x | 16x |
| 1024 | 64x | 64x |
这解释了为什么处理长序列时需要特殊的优化技术。
注意力机制的变体
为了提高效率和性能,研究者提出了多种注意力变体:
- 稀疏注意力:只关注部分位置,降低复杂度
- 线性注意力:使用核函数近似,将复杂度降至 O(n)
- 分组查询注意力(GQA):多头共享键值,减少计算量
- Flash Attention:优化内存访问模式,提升实际速度
实际应用示例
import torch.nn as nn
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model=512, num_heads=8, d_ff=2048):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# 自注意力 + 残差连接 + 层归一化
attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
x = self.norm1(x + attn_output)
# FFN + 残差连接 + 层归一化
ffn_output = self.ffn(x)
x = self.norm2(x + ffn_output)
return x
# 使用
layer = TransformerEncoderLayer()
x = torch.randn(2, 10, 512) # batch=2, seq_len=10
output = layer(x)
print(f"输出形状: {output.shape}") # [2, 10, 512]
总结
注意力机制是 Transformer 和大语言模型的核心组件。理解其数学原理、不同类型和计算复杂度,对于优化模型性能和设计新的架构至关重要。在接下来的文章中,我们将探讨 GPT 系列模型如何利用注意力机制。