注意力机制深入解析:从Bahdanau到Self-Attention
注意力机制深入解析:从Bahdanau到Self-Attention
什么是注意力机制?
注意力机制的核心思想是:在处理信息时,模型应该学会"关注"最相关的部分,而不是平等对待所有输入。
就像人类阅读时,我们不会平均分配注意力给每个词,而是会重点关注关键信息。
注意力机制的发展历程
1. Seq2Seq的瓶颈
传统的编码器-解码器模型将整个输入序列压缩成一个固定长度的向量,这导致了信息瓶颈:
# 传统Seq2Seq的问题
encoder_output = encoder(input_sequence) # 固定长度向量
decoder_output = decoder(encoder_output) # 信息损失
2. Bahdanau注意力(2015)
Bahdanau等人提出在解码每一步时,动态关注编码器的不同位置:
class BahdanauAttention(nn.Module):
def __init__(self, encoder_dim, decoder_dim, attention_dim):
super().__init__()
self.W_encoder = nn.Linear(encoder_dim, attention_dim)
self.W_decoder = nn.Linear(decoder_dim, attention_dim)
self.v = nn.Linear(attention_dim, 1)
def forward(self, encoder_outputs, decoder_hidden):
# encoder_outputs: (batch, seq_len, encoder_dim)
# decoder_hidden: (batch, decoder_dim)
encoder_proj = self.W_encoder(encoder_outputs) # (batch, seq_len, attention_dim)
decoder_proj = self.W_decoder(decoder_hidden).unsqueeze(1) # (batch, 1, attention_dim)
attention_scores = self.v(torch.tanh(encoder_proj + decoder_proj)) # (batch, seq_len, 1)
attention_weights = torch.softmax(attention_scores, dim=1) # (batch, seq_len, 1)
context = torch.sum(attention_weights * encoder_outputs, dim=1) # (batch, encoder_dim)
return context, attention_weights
3. Luong注意力(2015)
Luong提出了更简洁的注意力计算方式:
class LuongAttention(nn.Module):
def __init__(self, encoder_dim, decoder_dim):
super().__init__()
self.W = nn.Linear(encoder_dim, decoder_dim)
def forward(self, encoder_outputs, decoder_hidden):
# 计算注意力分数
encoder_proj = self.W(encoder_outputs) # (batch, seq_len, decoder_dim)
scores = torch.bmm(encoder_proj, decoder_hidden.unsqueeze(2)) # (batch, seq_len, 1)
attention_weights = torch.softmax(scores, dim=1)
context = torch.bmm(attention_weights.transpose(1, 2), encoder_outputs)
return context.squeeze(1), attention_weights
自注意力(Self-Attention)
自注意力是Transformer的核心,它允许序列中的每个位置关注所有其他位置。
计算过程
- 线性变换:将输入映射为Q、K、V
- 注意力计算:Q和K的点积,除以根号d_k
- Softmax归一化:得到注意力权重
- 加权求和:用注意力权重对V加权
def self_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output, attention_weights
注意力可视化
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(attention_weights, tokens_x, tokens_y):
plt.figure(figsize=(10, 10))
sns.heatmap(attention_weights,
xticklabels=tokens_x,
yticklabels=tokens_y,
cmap='viridis')
plt.xlabel('Key')
plt.ylabel('Query')
plt.title('Attention Weights')
plt.show()
多头注意力
多头注意力让模型能够同时关注不同位置的不同表示子空间:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
# (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
x = x.view(batch_size, -1, self.num_heads, self.d_k)
return x.transpose(1, 2)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 线性变换并分头
Q = self.split_heads(self.W_q(Q), batch_size)
K = self.split_heads(self.W_k(K), batch_size)
V = self.split_heads(self.W_v(V), batch_size)
# 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
# 合并多头
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
return self.W_o(output), attention_weights
注意力的变体
1. 稀疏注意力
只计算部分位置的注意力,降低计算复杂度:
- 局部注意力:只关注窗口内的位置
- 全局注意力:设置特殊token关注所有位置
2. 线性注意力
将softmax注意力近似为线性运算,复杂度从O(n²)降到O(n):
- Performer:使用随机特征近似
- Linear Transformer:使用核函数近似
3. Flash Attention
通过分块计算和内存优化,提高注意力计算效率:
- 减少内存访问次数
- 支持更长的序列
- 训练速度提升2-4倍
注意力机制的应用
1. 文本生成
在GPT等模型中,自注意力用于预测下一个token:
# 文本生成时的注意力
def generate_text(model, prompt, max_length=100):
tokens = tokenize(prompt)
for _ in range(max_length):
# 获取注意力权重
logits, attention = model(tokens)
# 可以分析注意力来理解模型
next_token = sample_logits(logits[:, -1, :])
tokens = torch.cat([tokens, next_token.unsqueeze(1)], dim=1)
return detokenize(tokens)
2. 机器翻译
在翻译任务中,注意力帮助模型对齐源语言和目标语言:
# 机器翻译中的注意力对齐
def translate_with_attention(model, source):
encoder_output = model.encode(source)
translations = []
attentions = []
for step in range(max_length):
decoder_output, attention = model.decode(encoder_output)
attentions.append(attention)
next_token = sample_logits(decoder_output)
translations.append(next_token)
return translations, attentions
3. 图像理解
Vision Transformer将图像分割成patch,用自注意力处理:
class VisionTransformer(nn.Module):
def __init__(self, image_size, patch_size, d_model, num_heads):
super().__init__()
num_patches = (image_size // patch_size) ** 2
self.patch_embedding = nn.Linear(patch_size * patch_size * 3, d_model)
self.position_embedding = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
self.attention = MultiHeadAttention(d_model, num_heads)
def forward(self, x):
# 将图像分割成patch
patches = self.extract_patches(x)
# 添加位置编码
embeddings = self.patch_embedding(patches) + self.position_embedding
# 自注意力处理
attention_output, attention_weights = self.attention(embeddings, embeddings, embeddings)
return attention_output, attention_weights
总结
注意力机制是现代深度学习的核心技术之一。从最初的Bahdanau注意力到Transformer的自注意力,注意力机制不断发展,使得模型能够更好地处理序列数据,理解上下文关系。掌握注意力机制的原理对于理解大语言模型至关重要。