KV Cache:自回归生成加速器
--- title: "KV Cache:自回归生成加速器" description: "深入理解KV Cache的工作原理、实现方法和内存优化策略" tags: ["KV Cache", "自回归生成", "推理优化", "内存管理"] category: "llm" icon: "🧠"
KV Cache:自回归生成加速器
KV Cache简介
KV Cache(Key-Value Cache)是加速Transformer自回归生成的关键技术。它通过缓存已计算的Key和Value张量,避免在生成每个新token时重复计算之前的注意力,从而显著提升推理速度。
KV Cache的核心价值:
- 避免重复计算:已计算的K/V无需重新计算
- 线性复杂度:从O(n²)降低到O(n)
- 流式生成:支持逐token生成
工作原理
无KV Cache的生成
import torch
import torch.nn as nn
def generate_without_cache(model, prompt_tokens, max_new_tokens):
"""无KV Cache的生成"""
tokens = prompt_tokens.clone()
for _ in range(max_new_tokens):
# 每次都对整个序列做前向传播
with torch.no_grad():
outputs = model(tokens)
logits = outputs.logits[:, -1, :]
# 采样下一个token
next_token = torch.argmax(logits, dim=-1, keepdim=True)
tokens = torch.cat([tokens, next_token], dim=-1)
return tokens
# 问题:计算复杂度 O(n² * d)
使用KV Cache的生成
def generate_with_cache(model, prompt_tokens, max_new_tokens):
"""使用KV Cache的生成"""
tokens = prompt_tokens.clone()
# 首次前向传播(预填充)
with torch.no_grad():
outputs = model(tokens, use_cache=True)
past_key_values = outputs.past_key_values
next_token = outputs.logits[:, -1:].argmax(dim=-1)
generated_tokens = [next_token.item()]
# 后续生成(仅处理新token)
for _ in range(max_new_tokens - 1):
with torch.no_grad():
outputs = model(
input_ids=next_token,
past_key_values=past_key_values,
use_cache=True
)
past_key_values = outputs.past_key_values
next_token = outputs.logits[:, -1:].argmax(dim=-1)
generated_tokens.append(next_token.item())
return generated_tokens
# 优势:首次O(n),后续O(1)
实现细节
手动实现KV Cache
class KVCache:
"""KV Cache实现"""
def __init__(self, max_seq_len, num_heads, head_dim, dtype=torch.float16):
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.head_dim = head_dim
# 预分配缓存
self.k_cache = torch.zeros(1, num_heads, max_seq_len, head_dim, dtype=dtype)
self.v_cache = torch.zeros(1, num_heads, max_seq_len, head_dim, dtype=dtype)
self.current_len = 0
def update(self, key, value):
"""更新缓存"""
seq_len = key.shape[2]
# 复制到缓存
self.k_cache[:, :, self.current_len:self.current_len+seq_len, :] = key
self.v_cache[:, :, self.current_len:self.current_len+seq_len, :] = value
self.current_len += seq_len
return self.k_cache[:, :, :self.current_len, :], \
self.v_cache[:, :, :self.current_len, :]
def get(self):
"""获取缓存"""
return self.k_cache[:, :, :self.current_len, :], \
self.v_cache[:, :, :self.current_len, :]
与模型集成
class CachedAttention(nn.Module):
"""带KV Cache的注意力层"""
def __init__(self, hidden_size, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.o_proj = nn.Linear(hidden_size, hidden_size)
self.cache = None
def forward(self, x, use_cache=False, past_key_values=None):
batch_size, seq_len, _ = x.shape
# 投影
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 更新缓存
if use_cache:
if past_key_values is not None:
past_k, past_v = past_key_values
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
present_key_values = (k, v)
else:
present_key_values = None
# 注意力计算
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.o_proj(attn_output), present_key_values
内存优化
内存占用计算
def calculate_kv_cache_memory(num_layers, num_heads, head_dim, seq_len, dtype=torch.float16):
"""计算KV Cache内存占用"""
# 每层的K和V
bytes_per_element = 2 if dtype == torch.float16 else 4
per_layer = 2 * num_heads * head_dim * seq_len * bytes_per_element
# 所有层
total = num_layers * per_layer
return {
"per_layer_mb": per_layer / 1024 / 1024,
"total_mb": total / 1024 / 1024,
"total_gb": total / 1024 / 1024 / 1024
}
# LLaMA-7B示例
memory = calculate_kv_cache_memory(
num_layers=32,
num_heads=32,
head_dim=128,
seq_len=2048
)
print(f"KV Cache内存: {memory['total_gb']:.2f} GB")
量化缓存
def quantize_kv_cache(k_cache, v_cache, bits=8):
"""量化KV Cache"""
# INT8量化
k_min, k_max = k_cache.min(), k_cache.max()
v_min, v_max = v_cache.min(), v_cache.max()
# 量化
k_scale = (k_max - k_min) / (2**bits - 1)
v_scale = (v_max - v_min) / (2**bits - 1)
k_quantized = ((k_cache - k_min) / k_scale).to(torch.int8)
v_quantized = ((v_cache - v_min) / v_scale).to(torch.int8)
# 反量化
k_dequantized = k_quantized.float() * k_scale + k_min
v_dequantized = v_quantized.float() * v_scale + v_min
return k_dequantized, v_dequantized
分页KV Cache(PagedAttention)
class PagedKVCache:
"""分页KV Cache(类似PagedAttention)"""
def __init__(self, block_size=16, num_blocks=1000, num_heads=32, head_dim=128):
self.block_size = block_size
self.num_blocks = num_blocks
# 预分配块
self.k_blocks = torch.zeros(num_blocks, num_heads, block_size, head_dim)
self.v_blocks = torch.zeros(num_blocks, num_heads, block_size, head_dim)
# 空闲块列表
self.free_blocks = list(range(num_blocks))
# 序列到块的映射
self.block_tables = {}
def allocate_block(self, sequence_id):
"""分配新块"""
if not self.free_blocks:
raise MemoryError("No free blocks available")
block_id = self.free_blocks.pop()
if sequence_id not in self.block_tables:
self.block_tables[sequence_id] = []
self.block_tables[sequence_id].append(block_id)
return block_id
def free_sequence(self, sequence_id):
"""释放序列的块"""
if sequence_id in self.block_tables:
for block_id in self.block_tables[sequence_id]:
self.free_blocks.append(block_id)
del self.block_tables[sequence_id]
vLLM中的KV Cache
from vllm import LLM, SamplingParams
# 配置vLLM使用PagedAttention
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
max_model_len=4096,
gpu_memory_utilization=0.9,
block_size=16, # 块大小
max_num_batched_tokens=8192, # 最大批处理token数
max_num_seqs=256 # 最大并发序列数
)
# 推理时自动管理KV Cache
prompts = ["Hello!" * 100, "Hi!" * 50]
outputs = llm.generate(prompts, SamplingParams(max_tokens=256))
性能优化
预分配策略
def preallocate_kv_cache(model, max_seq_len):
"""预分配KV Cache"""
num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // num_heads
# 预分配
k_cache = torch.zeros(
1, num_heads, max_seq_len, head_dim,
dtype=torch.float16,
device=model.device
)
v_cache = torch.zeros(
1, num_heads, max_seq_len, head_dim,
dtype=torch.float16,
device=model.device
)
return k_cache, v_cache
动态扩展
class DynamicKVCache:
"""动态扩展的KV Cache"""
def __init__(self, initial_size=1024, growth_factor=2):
self.initial_size = initial_size
self.growth_factor = growth_factor
self.current_size = initial_size
self.k_cache = None
self.v_cache = None
def ensure_capacity(self, required_size):
"""确保容量足够"""
if self.current_size < required_size:
new_size = max(required_size, int(self.current_size * self.growth_factor))
self._expand(new_size)
def _expand(self, new_size):
"""扩展缓存"""
# 创建新的缓存
new_k = torch.zeros(1, self.num_heads, new_size, self.head_dim)
new_v = torch.zeros(1, self.num_heads, new_size, self.head_dim)
# 复制旧数据
if self.k_cache is not None:
new_k[:, :, :self.current_size, :] = self.k_cache
new_v[:, :, :self.current_size, :] = self.v_cache
self.k_cache = new_k
self.v_cache = new_v
self.current_size = new_size
常见问题
内存溢出
# 解决方案1:减小最大序列长度
max_seq_len = min(max_seq_len, available_memory // per_token_memory)
# 解决方案2:使用量化缓存
k_cache, v_cache = quantize_kv_cache(k_cache, v_cache, bits=8)
# 解决方案3:使用分页缓存
paged_cache = PagedKVCache(block_size=16)
性能下降
# 检查点:缓存命中率
def calculate_cache_hit_rate(cache_access_pattern):
"""计算缓存命中率"""
hits = sum(1 for access in cache_access_pattern if access in cache)
total = len(cache_access_pattern)
return hits / total if total > 0 else 0
KV Cache是Transformer推理的核心优化技术,理解其原理对于构建高效LLM系统至关重要。