← 返回首页
🤖

注意力机制深入解析:从Bahdanau到Self-Attention

📂 ai ⏱ 3 min 480 words

注意力机制深入解析:从Bahdanau到Self-Attention

什么是注意力机制?

注意力机制的核心思想是:在处理信息时,模型应该学会"关注"最相关的部分,而不是平等对待所有输入。

就像人类阅读时,我们不会平均分配注意力给每个词,而是会重点关注关键信息。

注意力机制的发展历程

1. Seq2Seq的瓶颈

传统的编码器-解码器模型将整个输入序列压缩成一个固定长度的向量,这导致了信息瓶颈:

# 传统Seq2Seq的问题
encoder_output = encoder(input_sequence)  # 固定长度向量
decoder_output = decoder(encoder_output)  # 信息损失

2. Bahdanau注意力(2015)

Bahdanau等人提出在解码每一步时,动态关注编码器的不同位置:

class BahdanauAttention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        self.W_encoder = nn.Linear(encoder_dim, attention_dim)
        self.W_decoder = nn.Linear(decoder_dim, attention_dim)
        self.v = nn.Linear(attention_dim, 1)
    
    def forward(self, encoder_outputs, decoder_hidden):
        # encoder_outputs: (batch, seq_len, encoder_dim)
        # decoder_hidden: (batch, decoder_dim)
        
        encoder_proj = self.W_encoder(encoder_outputs)  # (batch, seq_len, attention_dim)
        decoder_proj = self.W_decoder(decoder_hidden).unsqueeze(1)  # (batch, 1, attention_dim)
        
        attention_scores = self.v(torch.tanh(encoder_proj + decoder_proj))  # (batch, seq_len, 1)
        attention_weights = torch.softmax(attention_scores, dim=1)  # (batch, seq_len, 1)
        
        context = torch.sum(attention_weights * encoder_outputs, dim=1)  # (batch, encoder_dim)
        
        return context, attention_weights

3. Luong注意力(2015)

Luong提出了更简洁的注意力计算方式:

class LuongAttention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim):
        super().__init__()
        self.W = nn.Linear(encoder_dim, decoder_dim)
    
    def forward(self, encoder_outputs, decoder_hidden):
        # 计算注意力分数
        encoder_proj = self.W(encoder_outputs)  # (batch, seq_len, decoder_dim)
        scores = torch.bmm(encoder_proj, decoder_hidden.unsqueeze(2))  # (batch, seq_len, 1)
        attention_weights = torch.softmax(scores, dim=1)
        
        context = torch.bmm(attention_weights.transpose(1, 2), encoder_outputs)
        return context.squeeze(1), attention_weights

自注意力(Self-Attention)

自注意力是Transformer的核心,它允许序列中的每个位置关注所有其他位置。

计算过程

  1. 线性变换:将输入映射为Q、K、V
  2. 注意力计算:Q和K的点积,除以根号d_k
  3. Softmax归一化:得到注意力权重
  4. 加权求和:用注意力权重对V加权
def self_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    attention_weights = torch.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

注意力可视化

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attention_weights, tokens_x, tokens_y):
    plt.figure(figsize=(10, 10))
    sns.heatmap(attention_weights, 
                xticklabels=tokens_x, 
                yticklabels=tokens_y,
                cmap='viridis')
    plt.xlabel('Key')
    plt.ylabel('Query')
    plt.title('Attention Weights')
    plt.show()

多头注意力

多头注意力让模型能够同时关注不同位置的不同表示子空间:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def split_heads(self, x, batch_size):
        # (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # 线性变换并分头
        Q = self.split_heads(self.W_q(Q), batch_size)
        K = self.split_heads(self.W_k(K), batch_size)
        V = self.split_heads(self.W_v(V), batch_size)
        
        # 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        
        # 合并多头
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        return self.W_o(output), attention_weights

注意力的变体

1. 稀疏注意力

只计算部分位置的注意力,降低计算复杂度:

2. 线性注意力

将softmax注意力近似为线性运算,复杂度从O(n²)降到O(n):

3. Flash Attention

通过分块计算和内存优化,提高注意力计算效率:

注意力机制的应用

1. 文本生成

在GPT等模型中,自注意力用于预测下一个token:

# 文本生成时的注意力
def generate_text(model, prompt, max_length=100):
    tokens = tokenize(prompt)
    
    for _ in range(max_length):
        # 获取注意力权重
        logits, attention = model(tokens)
        
        # 可以分析注意力来理解模型
        next_token = sample_logits(logits[:, -1, :])
        tokens = torch.cat([tokens, next_token.unsqueeze(1)], dim=1)
    
    return detokenize(tokens)

2. 机器翻译

在翻译任务中,注意力帮助模型对齐源语言和目标语言:

# 机器翻译中的注意力对齐
def translate_with_attention(model, source):
    encoder_output = model.encode(source)
    
    translations = []
    attentions = []
    
    for step in range(max_length):
        decoder_output, attention = model.decode(encoder_output)
        attentions.append(attention)
        
        next_token = sample_logits(decoder_output)
        translations.append(next_token)
    
    return translations, attentions

3. 图像理解

Vision Transformer将图像分割成patch,用自注意力处理:

class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, d_model, num_heads):
        super().__init__()
        num_patches = (image_size // patch_size) ** 2
        self.patch_embedding = nn.Linear(patch_size * patch_size * 3, d_model)
        self.position_embedding = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
        self.attention = MultiHeadAttention(d_model, num_heads)
    
    def forward(self, x):
        # 将图像分割成patch
        patches = self.extract_patches(x)
        
        # 添加位置编码
        embeddings = self.patch_embedding(patches) + self.position_embedding
        
        # 自注意力处理
        attention_output, attention_weights = self.attention(embeddings, embeddings, embeddings)
        
        return attention_output, attention_weights

总结

注意力机制是现代深度学习的核心技术之一。从最初的Bahdanau注意力到Transformer的自注意力,注意力机制不断发展,使得模型能够更好地处理序列数据,理解上下文关系。掌握注意力机制的原理对于理解大语言模型至关重要。