← 返回首页
🧠

KV Cache:自回归生成加速器

📂 llm ⏱ 4 min 754 words

--- title: "KV Cache:自回归生成加速器" description: "深入理解KV Cache的工作原理、实现方法和内存优化策略" tags: ["KV Cache", "自回归生成", "推理优化", "内存管理"] category: "llm" icon: "🧠"

KV Cache:自回归生成加速器

KV Cache简介

KV Cache(Key-Value Cache)是加速Transformer自回归生成的关键技术。它通过缓存已计算的Key和Value张量,避免在生成每个新token时重复计算之前的注意力,从而显著提升推理速度。

KV Cache的核心价值:

工作原理

无KV Cache的生成

import torch
import torch.nn as nn

def generate_without_cache(model, prompt_tokens, max_new_tokens):
    """无KV Cache的生成"""
    tokens = prompt_tokens.clone()
    
    for _ in range(max_new_tokens):
        # 每次都对整个序列做前向传播
        with torch.no_grad():
            outputs = model(tokens)
            logits = outputs.logits[:, -1, :]
        
        # 采样下一个token
        next_token = torch.argmax(logits, dim=-1, keepdim=True)
        tokens = torch.cat([tokens, next_token], dim=-1)
    
    return tokens

# 问题:计算复杂度 O(n² * d)

使用KV Cache的生成

def generate_with_cache(model, prompt_tokens, max_new_tokens):
    """使用KV Cache的生成"""
    tokens = prompt_tokens.clone()
    
    # 首次前向传播(预填充)
    with torch.no_grad():
        outputs = model(tokens, use_cache=True)
        past_key_values = outputs.past_key_values
        next_token = outputs.logits[:, -1:].argmax(dim=-1)
    
    generated_tokens = [next_token.item()]
    
    # 后续生成(仅处理新token)
    for _ in range(max_new_tokens - 1):
        with torch.no_grad():
            outputs = model(
                input_ids=next_token,
                past_key_values=past_key_values,
                use_cache=True
            )
            past_key_values = outputs.past_key_values
            next_token = outputs.logits[:, -1:].argmax(dim=-1)
            generated_tokens.append(next_token.item())
    
    return generated_tokens

# 优势:首次O(n),后续O(1)

实现细节

手动实现KV Cache

class KVCache:
    """KV Cache实现"""
    
    def __init__(self, max_seq_len, num_heads, head_dim, dtype=torch.float16):
        self.max_seq_len = max_seq_len
        self.num_heads = num_heads
        self.head_dim = head_dim
        
        # 预分配缓存
        self.k_cache = torch.zeros(1, num_heads, max_seq_len, head_dim, dtype=dtype)
        self.v_cache = torch.zeros(1, num_heads, max_seq_len, head_dim, dtype=dtype)
        self.current_len = 0
    
    def update(self, key, value):
        """更新缓存"""
        seq_len = key.shape[2]
        
        # 复制到缓存
        self.k_cache[:, :, self.current_len:self.current_len+seq_len, :] = key
        self.v_cache[:, :, self.current_len:self.current_len+seq_len, :] = value
        
        self.current_len += seq_len
        
        return self.k_cache[:, :, :self.current_len, :], \
               self.v_cache[:, :, :self.current_len, :]
    
    def get(self):
        """获取缓存"""
        return self.k_cache[:, :, :self.current_len, :], \
               self.v_cache[:, :, :self.current_len, :]

与模型集成

class CachedAttention(nn.Module):
    """带KV Cache的注意力层"""
    
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.o_proj = nn.Linear(hidden_size, hidden_size)
        
        self.cache = None
    
    def forward(self, x, use_cache=False, past_key_values=None):
        batch_size, seq_len, _ = x.shape
        
        # 投影
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 更新缓存
        if use_cache:
            if past_key_values is not None:
                past_k, past_v = past_key_values
                k = torch.cat([past_k, k], dim=2)
                v = torch.cat([past_v, v], dim=2)
            present_key_values = (k, v)
        else:
            present_key_values = None
        
        # 注意力计算
        attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        
        return self.o_proj(attn_output), present_key_values

内存优化

内存占用计算

def calculate_kv_cache_memory(num_layers, num_heads, head_dim, seq_len, dtype=torch.float16):
    """计算KV Cache内存占用"""
    # 每层的K和V
    bytes_per_element = 2 if dtype == torch.float16 else 4
    per_layer = 2 * num_heads * head_dim * seq_len * bytes_per_element
    
    # 所有层
    total = num_layers * per_layer
    
    return {
        "per_layer_mb": per_layer / 1024 / 1024,
        "total_mb": total / 1024 / 1024,
        "total_gb": total / 1024 / 1024 / 1024
    }

# LLaMA-7B示例
memory = calculate_kv_cache_memory(
    num_layers=32,
    num_heads=32,
    head_dim=128,
    seq_len=2048
)
print(f"KV Cache内存: {memory['total_gb']:.2f} GB")

量化缓存

def quantize_kv_cache(k_cache, v_cache, bits=8):
    """量化KV Cache"""
    # INT8量化
    k_min, k_max = k_cache.min(), k_cache.max()
    v_min, v_max = v_cache.min(), v_cache.max()
    
    # 量化
    k_scale = (k_max - k_min) / (2**bits - 1)
    v_scale = (v_max - v_min) / (2**bits - 1)
    
    k_quantized = ((k_cache - k_min) / k_scale).to(torch.int8)
    v_quantized = ((v_cache - v_min) / v_scale).to(torch.int8)
    
    # 反量化
    k_dequantized = k_quantized.float() * k_scale + k_min
    v_dequantized = v_quantized.float() * v_scale + v_min
    
    return k_dequantized, v_dequantized

分页KV Cache(PagedAttention)

class PagedKVCache:
    """分页KV Cache(类似PagedAttention)"""
    
    def __init__(self, block_size=16, num_blocks=1000, num_heads=32, head_dim=128):
        self.block_size = block_size
        self.num_blocks = num_blocks
        
        # 预分配块
        self.k_blocks = torch.zeros(num_blocks, num_heads, block_size, head_dim)
        self.v_blocks = torch.zeros(num_blocks, num_heads, block_size, head_dim)
        
        # 空闲块列表
        self.free_blocks = list(range(num_blocks))
        
        # 序列到块的映射
        self.block_tables = {}
    
    def allocate_block(self, sequence_id):
        """分配新块"""
        if not self.free_blocks:
            raise MemoryError("No free blocks available")
        
        block_id = self.free_blocks.pop()
        
        if sequence_id not in self.block_tables:
            self.block_tables[sequence_id] = []
        self.block_tables[sequence_id].append(block_id)
        
        return block_id
    
    def free_sequence(self, sequence_id):
        """释放序列的块"""
        if sequence_id in self.block_tables:
            for block_id in self.block_tables[sequence_id]:
                self.free_blocks.append(block_id)
            del self.block_tables[sequence_id]

vLLM中的KV Cache

from vllm import LLM, SamplingParams

# 配置vLLM使用PagedAttention
llm = LLM(
    model="meta-llama/Llama-2-7b-hf",
    max_model_len=4096,
    gpu_memory_utilization=0.9,
    block_size=16,  # 块大小
    max_num_batched_tokens=8192,  # 最大批处理token数
    max_num_seqs=256  # 最大并发序列数
)

# 推理时自动管理KV Cache
prompts = ["Hello!" * 100, "Hi!" * 50]
outputs = llm.generate(prompts, SamplingParams(max_tokens=256))

性能优化

预分配策略

def preallocate_kv_cache(model, max_seq_len):
    """预分配KV Cache"""
    num_layers = model.config.num_hidden_layers
    num_heads = model.config.num_attention_heads
    head_dim = model.config.hidden_size // num_heads
    
    # 预分配
    k_cache = torch.zeros(
        1, num_heads, max_seq_len, head_dim,
        dtype=torch.float16,
        device=model.device
    )
    v_cache = torch.zeros(
        1, num_heads, max_seq_len, head_dim,
        dtype=torch.float16,
        device=model.device
    )
    
    return k_cache, v_cache

动态扩展

class DynamicKVCache:
    """动态扩展的KV Cache"""
    
    def __init__(self, initial_size=1024, growth_factor=2):
        self.initial_size = initial_size
        self.growth_factor = growth_factor
        self.current_size = initial_size
        
        self.k_cache = None
        self.v_cache = None
    
    def ensure_capacity(self, required_size):
        """确保容量足够"""
        if self.current_size < required_size:
            new_size = max(required_size, int(self.current_size * self.growth_factor))
            self._expand(new_size)
    
    def _expand(self, new_size):
        """扩展缓存"""
        # 创建新的缓存
        new_k = torch.zeros(1, self.num_heads, new_size, self.head_dim)
        new_v = torch.zeros(1, self.num_heads, new_size, self.head_dim)
        
        # 复制旧数据
        if self.k_cache is not None:
            new_k[:, :, :self.current_size, :] = self.k_cache
            new_v[:, :, :self.current_size, :] = self.v_cache
        
        self.k_cache = new_k
        self.v_cache = new_v
        self.current_size = new_size

常见问题

内存溢出

# 解决方案1:减小最大序列长度
max_seq_len = min(max_seq_len, available_memory // per_token_memory)

# 解决方案2:使用量化缓存
k_cache, v_cache = quantize_kv_cache(k_cache, v_cache, bits=8)

# 解决方案3:使用分页缓存
paged_cache = PagedKVCache(block_size=16)

性能下降

# 检查点:缓存命中率
def calculate_cache_hit_rate(cache_access_pattern):
    """计算缓存命中率"""
    hits = sum(1 for access in cache_access_pattern if access in cache)
    total = len(cache_access_pattern)
    return hits / total if total > 0 else 0

KV Cache是Transformer推理的核心优化技术,理解其原理对于构建高效LLM系统至关重要。