← 返回首页
🧠

Flash Attention:高效注意力计算

📂 llm ⏱ 4 min 718 words

--- title: "Flash Attention:高效注意力计算" description: "掌握Flash Attention的原理和实现,显著提升Transformer推理速度" tags: ["Flash Attention", "注意力优化", "GPU加速", "IO感知"] category: "llm" icon: "🧠"

Flash Attention:高效注意力计算

Flash Attention简介

Flash Attention是一种IO感知(IO-aware)的精确注意力算法,通过优化GPU内存访问模式,显著提升了注意力计算的速度和内存效率。由Tri Dao等人提出,已成为现代LLM推理的标准组件。

Flash Attention的核心优势:

工作原理

标准注意力的问题

import torch
import torch.nn.functional as F

def standard_attention(Q, K, V):
    """标准注意力计算"""
    # Q, K, V: [batch, heads, seq_len, head_dim]
    d_k = Q.size(-1)
    
    # 计算注意力分数
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
    
    # Softmax
    attn_weights = F.softmax(scores, dim=-1)
    
    # 加权求和
    output = torch.matmul(attn_weights, V)
    
    return output

# 问题:
# 1. 需要存储完整的N×N注意力矩阵
# 2. 多次读写HBM(高带宽内存)
# 3. 内存占用O(N²)

Flash Attention的优化

# Flash Attention核心思想:
# 1. 分块计算:将Q, K, V分成小块
# 2. 在SRAM中完成计算:减少HBM访问
# 3. 在线softmax:避免存储完整注意力矩阵

def flash_attention_forward(Q, K, V, block_size=128):
    """Flash Attention简化实现"""
    batch, heads, seq_len, head_dim = Q.shape
    
    # 输出和softmax统计量
    O = torch.zeros_like(Q)
    L = torch.zeros(batch, heads, seq_len, 1, device=Q.device)
    M = torch.full((batch, heads, seq_len, 1), float('-inf'), device=Q.device)
    
    # 分块迭代K和V
    num_blocks = (seq_len + block_size - 1) // block_size
    
    for j in range(num_blocks):
        # 获取K和V的块
        k_j = K[:, :, j*block_size:(j+1)*block_size, :]
        v_j = V[:, :, j*block_size:(j+1)*block_size, :]
        
        # 计算注意力分数
        S_j = torch.matmul(Q, k_j.transpose(-2, -1)) / (head_dim ** 0.5)
        
        # 在线softmax更新
        M_j = S_j.max(dim=-1, keepdim=True).values
        M_new = torch.maximum(M, M_j)
        
        # 更新累积和
        exp_old = torch.exp(M - M_new)
        exp_new = torch.exp(S_j - M_new)
        
        L_new = L * exp_old + exp_new.sum(dim=-1, keepdim=True)
        O = O * (L * exp_old / L_new) + torch.matmul(exp_new, v_j) / L_new
        
        M = M_new
        L = L_new
    
    return O

实现使用

PyTorch原生支持

import torch
import torch.nn.functional as F

# PyTorch 2.0+原生支持Flash Attention
def flash_attention_pytorch(Q, K, V, is_causal=False):
    """使用PyTorch原生Flash Attention"""
    # 自动选择最优实现
    output = F.scaled_dot_product_attention(
        Q, K, V,
        attn_mask=None,
        dropout_p=0.0,
        is_causal=is_causal,
        scale=None
    )
    return output

# 使用示例
batch, heads, seq_len, head_dim = 2, 32, 2048, 128
Q = torch.randn(batch, heads, seq_len, head_dim, device='cuda')
K = torch.randn(batch, heads, seq_len, head_dim, device='cuda')
V = torch.randn(batch, heads, seq_len, head_dim, device='cuda')

output = flash_attention_pytorch(Q, K, V, is_causal=True)

xformers库

from xformers.ops import memory_efficient_attention

def flash_attention_xformers(Q, K, V, attn_bias=None):
    """使用xformers的Flash Attention"""
    output = memory_efficient_attention(
        Q, K, V,
        attn_bias=attn_bias,
        op=None  # 自动选择最优实现
    )
    return output

# 安装:pip install xformers

Flash Attention库

from flash_attn import flash_attn_func

def flash_attention_cuda(Q, K, V, dropout_p=0.0, causal=True):
    """使用Flash Attention CUDA实现"""
    # Q, K, V: [batch, seq_len, heads, head_dim]
    output = flash_attn_func(
        Q, K, V,
        dropout_p=dropout_p,
        causal=causal
    )
    return output

内存优化

内存占用对比

def compare_memory_usage(seq_len, heads, head_dim, dtype=torch.float16):
    """对比不同注意力实现的内存占用"""
    bytes_per_element = 2 if dtype == torch.float16 else 4
    
    # 标准注意力
    standard_attn_memory = seq_len * seq_len * heads * bytes_per_element
    
    # Flash Attention
    flash_attn_memory = seq_len * heads * head_dim * bytes_per_element * 3  # Q, K, V
    
    return {
        "standard_mb": standard_attn_memory / 1024 / 1024,
        "flash_mb": flash_attn_memory / 1024 / 1024,
        "savings": 1 - flash_attn_memory / standard_attn_memory
    }

# 示例
memory = compare_memory_usage(4096, 32, 128)
print(f"标准注意力: {memory['standard_mb']:.2f} MB")
print(f"Flash Attention: {memory['flash_mb']:.2f} MB")
print(f"内存节省: {memory['savings']*100:.1f}%")

长序列支持

def estimate_max_seq_len(gpu_memory_gb, heads, head_dim, dtype=torch.float16):
    """估算GPU内存支持的最大序列长度"""
    bytes_per_element = 2 if dtype == torch.float16 else 4
    
    # 可用内存(预留一些给其他计算)
    available_memory = gpu_memory_gb * 1024 * 1024 * 1024 * 0.8
    
    # 每个token的内存需求
    per_token_memory = heads * head_dim * bytes_per_element * 3  # Q, K, V
    
    max_seq_len = int(available_memory / per_token_memory)
    
    return max_seq_len

# LLaMA-7B on A100 80GB
max_len = estimate_max_seq_len(80, 32, 128)
print(f"最大序列长度: {max_len}")

与Transformer集成

修改模型配置

from transformers import AutoModelForCausalLM

def load_model_with_flash_attention(model_name):
    """加载使用Flash Attention的模型"""
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        attn_implementation="flash_attention_2",  # 启用Flash Attention
        device_map="auto"
    )
    return model

# 使用
model = load_model_with_flash_attention("meta-llama/Llama-2-7b-hf")

自定义注意力层

import torch.nn as nn

class FlashAttentionLayer(nn.Module):
    """带Flash Attention的注意力层"""
    
    def __init__(self, hidden_size, num_heads, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.o_proj = nn.Linear(hidden_size, hidden_size)
        
        self.dropout = dropout
    
    def forward(self, x, attention_mask=None, is_causal=False):
        batch_size, seq_len, _ = x.shape
        
        # 投影
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        
        # 转置
        q = q.transpose(1, 2)  # [batch, heads, seq_len, head_dim]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # Flash Attention
        output = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=attention_mask,
            dropout_p=self.dropout if self.training else 0,
            is_causal=is_causal
        )
        
        # 重塑和投影
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        output = self.o_proj(output)
        
        return output

性能基准测试

import time

def benchmark_attention(seq_len, heads, head_dim, num_iterations=100):
    """基准测试不同注意力实现"""
    Q = torch.randn(1, heads, seq_len, head_dim, device='cuda')
    K = torch.randn(1, heads, seq_len, head_dim, device='cuda')
    V = torch.randn(1, heads, seq_len, head_dim, device='cuda')
    
    # 标准注意力
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(num_iterations):
        _ = standard_attention(Q, K, V)
    torch.cuda.synchronize()
    standard_time = time.time() - start
    
    # Flash Attention
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(num_iterations):
        _ = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
    torch.cuda.synchronize()
    flash_time = time.time() - start
    
    return {
        "standard_ms": standard_time / num_iterations * 1000,
        "flash_ms": flash_time / num_iterations * 1000,
        "speedup": standard_time / flash_time
    }

# 测试
results = benchmark_attention(2048, 32, 128)
print(f"标准注意力: {results['standard_ms']:.2f} ms")
print(f"Flash Attention: {results['flash_ms']:.2f} ms")
print(f"加速比: {results['speedup']:.2f}x")

最佳实践

  1. 优先使用:Flash Attention应在所有Transformer模型中启用
  2. 检查兼容性:确保GPU支持(A100、H100最佳)
  3. 内存监控:监控显存使用情况
  4. 结合其他优化:与KV Cache、量化等技术结合使用

Flash Attention是现代LLM推理的必备技术,能显著提升性能和内存效率。