← 返回首页
🧠

注意力变体

📂 llm ⏱ 5 min 956 words

--- title: "注意力变体" description: "掌握注意力机制的各类变体,包括MQA、GQA、Flash Attention和线性注意力的原理与实现" tags: ["注意力变体", "MQA", "GQA", "Flash Attention", "线性注意力"] category: "llm" icon: "🧠"

注意力变体

注意力机制概述

注意力机制是Transformer架构的核心组件。标准的多头注意力(Multi-Head Attention)虽然强大,但在处理长序列时计算和内存开销巨大。因此,研究人员开发了多种注意力变体来优化性能。

标准多头注意力

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    """标准多头注意力"""
    
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert self.head_dim * num_heads == embed_dim
        
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)
    
    def forward(self, query, key, value, mask=None):
        batch_size, seq_len, _ = query.shape
        
        # 线性投影
        Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # 加权求和
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        
        return self.out_proj(context)

MQA(Multi-Query Attention)

class MultiQueryAttention(nn.Module):
    """Multi-Query Attention
    所有注意力头共享K和V的投影,只有Q有多头
    """
    
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        # K和V只有一个头
        self.k_proj = nn.Linear(embed_dim, self.head_dim)
        self.v_proj = nn.Linear(embed_dim, self.head_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)
    
    def forward(self, query, key, value, mask=None):
        batch_size, seq_len, _ = query.shape
        
        # Q有多个头
        Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # K和V只有一个头,扩展到所有头
        K = self.k_proj(key).unsqueeze(1).expand(-1, self.num_heads, -1, -1)
        V = self.v_proj(value).unsqueeze(1).expand(-1, self.num_heads, -1, -1)
        
        # 注意力计算
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        
        return self.out_proj(context)

GQA(Grouped-Query Attention)

class GroupedQueryAttention(nn.Module):
    """Grouped-Query Attention
    将注意力头分组,每组共享K和V
    """
    
    def __init__(self, embed_dim: int, num_heads: int, 
                 num_kv_heads: int, dropout: float = 0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = embed_dim // num_heads
        self.num_groups = num_heads // num_kv_heads
        
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
        self.v_proj = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)
    
    def forward(self, query, key, value, mask=None):
        batch_size, seq_len, _ = query.shape
        
        # Q投影
        Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # K和V投影
        K = self.k_proj(key).view(batch_size, -1, self.num_kv_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(batch_size, -1, self.num_kv_heads, self.head_dim).transpose(1, 2)
        
        # 扩展K和V到每个组
        K = K.repeat_interleave(self.num_groups, dim=1)
        V = V.repeat_interleave(self.num_groups, dim=1)
        
        # 注意力计算
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        
        return self.out_proj(context)

Flash Attention

class FlashAttention(nn.Module):
    """Flash Attention
    通过分块计算优化内存访问模式
    """
    
    def __init__(self, embed_dim: int, num_heads: int, 
                 block_size: int = 64):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.block_size = block_size
        
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        self.scale = math.sqrt(self.head_dim)
    
    def forward(self, query, key, value, mask=None):
        batch_size, seq_len, _ = query.shape
        
        Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(key).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Flash Attention分块计算
        output = self._flash_attention_forward(Q, K, V, mask)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        return self.out_proj(output)
    
    def _flash_attention_forward(self, Q, K, V, mask=None):
        """Flash Attention前向传播"""
        batch_size, num_heads, seq_len, head_dim = Q.shape
        
        # 初始化输出和缩放因子
        O = torch.zeros_like(Q)
        L = torch.zeros(batch_size, num_heads, seq_len, 1, device=Q.device)
        M = torch.full(
            (batch_size, num_heads, seq_len, 1),
            float('-inf'),
            device=Q.device
        )
        
        # 分块处理
        for i in range(0, seq_len, self.block_size):
            for j in range(0, seq_len, self.block_size):
                # 提取块
                Q_block = Q[:, :, i:i+self.block_size, :]
                K_block = K[:, :, j:j+self.block_size, :]
                V_block = V[:, :, j:j+self.block_size, :]
                
                # 计算注意力分数
                S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / self.scale
                
                # 应用mask
                if mask is not None:
                    mask_block = mask[:, :, i:i+self.block_size, j:j+self.block_size]
                    S_block = S_block.masked_fill(mask_block == 0, float('-inf'))
                
                # 在线softmax更新
                M_block = S_block.max(dim=-1, keepdim=True).values
                P_block = torch.exp(S_block - M_block)
                
                # 更新全局统计量
                M_new = torch.max(M[:, :, i:i+self.block_size, :], M_block)
                alpha = torch.exp(M[:, :, i:i+self.block_size, :] - M_new)
                beta = torch.exp(M_block - M_new)
                
                L_new = alpha * L[:, :, i:i+self.block_size, :] + beta * P_block.sum(dim=-1, keepdim=True)
                
                # 更新输出
                O[:, :, i:i+self.block_size, :] = (
                    alpha * L[:, :, i:i+self.block_size, :] * O[:, :, i:i+self.block_size, :] +
                    beta * torch.matmul(P_block, V_block)
                ) / L_new
                
                # 更新统计量
                L[:, :, i:i+self.block_size, :] = L_new
                M[:, :, i:i+self.block_size, :] = M_new
        
        return O

线性注意力

class LinearAttention(nn.Module):
    """线性注意力
    使用核函数近似,将复杂度从O(n²)降低到O(n)
    """
    
    def __init__(self, embed_dim: int, num_heads: int, 
                 feature_map: str = "relu"):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        # 特征映射函数
        if feature_map == "relu":
            self.feature_map = lambda x: torch.relu(x) + 1e-6
        elif feature_map == "elu":
            self.feature_map = lambda x: F.elu(x) + 1
        elif feature_map == "softmax":
            self.feature_map = lambda x: F.softmax(x, dim=-1)
    
    def forward(self, query, key, value, mask=None):
        batch_size, seq_len, _ = query.shape
        
        # 投影
        Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(key).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 应用特征映射
        Q = self.feature_map(Q)
        K = self.feature_map(K)
        
        # 线性注意力计算
        # 标准注意力: softmax(QK^T/sqrt(d))V
        # 线性注意力: phi(Q)(phi(K)^T V)
        
        KV = torch.einsum('bhsv,bhst->bhvt', K, V)  # [B, H, D, D]
        Z = 1.0 / (torch.einsum('bhsv,bhs->bhs', Q, K.sum(dim=2)) + 1e-6)
        
        output = torch.einsum('bhsv,bhvt,bhs->bhst', Q, KV, Z)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        return self.out_proj(output)

性能对比

def compare_attention_variants():
    """对比不同注意力变体的性能"""
    embed_dim = 1024
    num_heads = 16
    seq_len = 2048
    batch_size = 4
    
    # 创建输入
    x = torch.randn(batch_size, seq_len, embed_dim)
    
    # 标准注意力
    mha = MultiHeadAttention(embed_dim, num_heads)
    output_mha = mha(x, x, x)
    
    # MQA
    mqa = MultiQueryAttention(embed_dim, num_heads)
    output_mqa = mqa(x, x, x)
    
    # GQA
    gqa = GroupedQueryAttention(embed_dim, num_heads, num_kv_heads=4)
    output_gqa = gqa(x, x, x)
    
    # 线性注意力
    la = LinearAttention(embed_dim, num_heads)
    output_la = la(x, x, x)
    
    print(f"MHA output shape: {output_mha.shape}")
    print(f"MQA output shape: {output_mha.shape}")
    print(f"GQA output shape: {output_gqa.shape}")
    print(f"Linear Attention output shape: {output_la.shape}")
    
    # 参数量对比
    def count_params(model):
        return sum(p.numel() for p in model.parameters())
    
    print(f"\n参数量对比:")
    print(f"MHA: {count_params(mha):,}")
    print(f"MQA: {count_params(mqa):,}")
    print(f"GQA: {count_params(gqa):,}")
    print(f"Linear Attention: {count_params(la):,}")

compare_attention_variants()

总结

不同的注意力变体适用于不同的场景:MQA和GQA适合大模型的推理加速,Flash Attention适合长序列处理,线性注意力适合超长序列。选择合适的注意力机制对于优化模型性能至关重要。