注意力变体
--- title: "注意力变体" description: "掌握注意力机制的各类变体,包括MQA、GQA、Flash Attention和线性注意力的原理与实现" tags: ["注意力变体", "MQA", "GQA", "Flash Attention", "线性注意力"] category: "llm" icon: "🧠"
注意力变体
注意力机制概述
注意力机制是Transformer架构的核心组件。标准的多头注意力(Multi-Head Attention)虽然强大,但在处理长序列时计算和内存开销巨大。因此,研究人员开发了多种注意力变体来优化性能。
标准多头注意力
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
"""标准多头注意力"""
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.head_dim)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, _ = query.shape
# 线性投影
Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 加权求和
context = torch.matmul(attn_weights, V)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
return self.out_proj(context)
MQA(Multi-Query Attention)
class MultiQueryAttention(nn.Module):
"""Multi-Query Attention
所有注意力头共享K和V的投影,只有Q有多头
"""
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim)
# K和V只有一个头
self.k_proj = nn.Linear(embed_dim, self.head_dim)
self.v_proj = nn.Linear(embed_dim, self.head_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.head_dim)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, _ = query.shape
# Q有多个头
Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# K和V只有一个头,扩展到所有头
K = self.k_proj(key).unsqueeze(1).expand(-1, self.num_heads, -1, -1)
V = self.v_proj(value).unsqueeze(1).expand(-1, self.num_heads, -1, -1)
# 注意力计算
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
context = torch.matmul(attn_weights, V)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
return self.out_proj(context)
GQA(Grouped-Query Attention)
class GroupedQueryAttention(nn.Module):
"""Grouped-Query Attention
将注意力头分组,每组共享K和V
"""
def __init__(self, embed_dim: int, num_heads: int,
num_kv_heads: int, dropout: float = 0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = embed_dim // num_heads
self.num_groups = num_heads // num_kv_heads
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
self.v_proj = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.head_dim)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, _ = query.shape
# Q投影
Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# K和V投影
K = self.k_proj(key).view(batch_size, -1, self.num_kv_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(value).view(batch_size, -1, self.num_kv_heads, self.head_dim).transpose(1, 2)
# 扩展K和V到每个组
K = K.repeat_interleave(self.num_groups, dim=1)
V = V.repeat_interleave(self.num_groups, dim=1)
# 注意力计算
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
context = torch.matmul(attn_weights, V)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
return self.out_proj(context)
Flash Attention
class FlashAttention(nn.Module):
"""Flash Attention
通过分块计算优化内存访问模式
"""
def __init__(self, embed_dim: int, num_heads: int,
block_size: int = 64):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.block_size = block_size
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.scale = math.sqrt(self.head_dim)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, _ = query.shape
Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(key).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(value).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Flash Attention分块计算
output = self._flash_attention_forward(Q, K, V, mask)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
return self.out_proj(output)
def _flash_attention_forward(self, Q, K, V, mask=None):
"""Flash Attention前向传播"""
batch_size, num_heads, seq_len, head_dim = Q.shape
# 初始化输出和缩放因子
O = torch.zeros_like(Q)
L = torch.zeros(batch_size, num_heads, seq_len, 1, device=Q.device)
M = torch.full(
(batch_size, num_heads, seq_len, 1),
float('-inf'),
device=Q.device
)
# 分块处理
for i in range(0, seq_len, self.block_size):
for j in range(0, seq_len, self.block_size):
# 提取块
Q_block = Q[:, :, i:i+self.block_size, :]
K_block = K[:, :, j:j+self.block_size, :]
V_block = V[:, :, j:j+self.block_size, :]
# 计算注意力分数
S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / self.scale
# 应用mask
if mask is not None:
mask_block = mask[:, :, i:i+self.block_size, j:j+self.block_size]
S_block = S_block.masked_fill(mask_block == 0, float('-inf'))
# 在线softmax更新
M_block = S_block.max(dim=-1, keepdim=True).values
P_block = torch.exp(S_block - M_block)
# 更新全局统计量
M_new = torch.max(M[:, :, i:i+self.block_size, :], M_block)
alpha = torch.exp(M[:, :, i:i+self.block_size, :] - M_new)
beta = torch.exp(M_block - M_new)
L_new = alpha * L[:, :, i:i+self.block_size, :] + beta * P_block.sum(dim=-1, keepdim=True)
# 更新输出
O[:, :, i:i+self.block_size, :] = (
alpha * L[:, :, i:i+self.block_size, :] * O[:, :, i:i+self.block_size, :] +
beta * torch.matmul(P_block, V_block)
) / L_new
# 更新统计量
L[:, :, i:i+self.block_size, :] = L_new
M[:, :, i:i+self.block_size, :] = M_new
return O
线性注意力
class LinearAttention(nn.Module):
"""线性注意力
使用核函数近似,将复杂度从O(n²)降低到O(n)
"""
def __init__(self, embed_dim: int, num_heads: int,
feature_map: str = "relu"):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
# 特征映射函数
if feature_map == "relu":
self.feature_map = lambda x: torch.relu(x) + 1e-6
elif feature_map == "elu":
self.feature_map = lambda x: F.elu(x) + 1
elif feature_map == "softmax":
self.feature_map = lambda x: F.softmax(x, dim=-1)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, _ = query.shape
# 投影
Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(key).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(value).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 应用特征映射
Q = self.feature_map(Q)
K = self.feature_map(K)
# 线性注意力计算
# 标准注意力: softmax(QK^T/sqrt(d))V
# 线性注意力: phi(Q)(phi(K)^T V)
KV = torch.einsum('bhsv,bhst->bhvt', K, V) # [B, H, D, D]
Z = 1.0 / (torch.einsum('bhsv,bhs->bhs', Q, K.sum(dim=2)) + 1e-6)
output = torch.einsum('bhsv,bhvt,bhs->bhst', Q, KV, Z)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
return self.out_proj(output)
性能对比
def compare_attention_variants():
"""对比不同注意力变体的性能"""
embed_dim = 1024
num_heads = 16
seq_len = 2048
batch_size = 4
# 创建输入
x = torch.randn(batch_size, seq_len, embed_dim)
# 标准注意力
mha = MultiHeadAttention(embed_dim, num_heads)
output_mha = mha(x, x, x)
# MQA
mqa = MultiQueryAttention(embed_dim, num_heads)
output_mqa = mqa(x, x, x)
# GQA
gqa = GroupedQueryAttention(embed_dim, num_heads, num_kv_heads=4)
output_gqa = gqa(x, x, x)
# 线性注意力
la = LinearAttention(embed_dim, num_heads)
output_la = la(x, x, x)
print(f"MHA output shape: {output_mha.shape}")
print(f"MQA output shape: {output_mha.shape}")
print(f"GQA output shape: {output_gqa.shape}")
print(f"Linear Attention output shape: {output_la.shape}")
# 参数量对比
def count_params(model):
return sum(p.numel() for p in model.parameters())
print(f"\n参数量对比:")
print(f"MHA: {count_params(mha):,}")
print(f"MQA: {count_params(mqa):,}")
print(f"GQA: {count_params(gqa):,}")
print(f"Linear Attention: {count_params(la):,}")
compare_attention_variants()
总结
不同的注意力变体适用于不同的场景:MQA和GQA适合大模型的推理加速,Flash Attention适合长序列处理,线性注意力适合超长序列。选择合适的注意力机制对于优化模型性能至关重要。