投机解码:加速自回归生成
--- title: "投机解码:加速自回归生成" description: "掌握投机解码的原理和实现,通过并行验证显著提升生成速度" tags: ["投机解码", "Speculative Decoding", "并行解码", "推理加速"] category: "llm" icon: "🧠"
投机解码:加速自回归生成
投机解码简介
投机解码(Speculative Decoding)是一种加速自回归文本生成的技术。它使用一个小型"草稿模型"(Draft Model)快速生成多个候选token,然后用大型"目标模型"(Target Model)并行验证这些token。如果草稿模型的预测正确,可以跳过多次前向传播,从而显著加速生成过程。
投机解码的核心优势:
- 加速生成:2-3倍加速
- 无损质量:输出与标准解码完全一致
- 并行验证:一次前向传播验证多个token
- 灵活部署:可以使用任意大小的草稿模型
工作原理
标准自回归解码
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def standard_autoregressive(model, tokenizer, prompt, max_new_tokens=100):
"""标准自回归解码"""
input_ids = tokenizer.encode(prompt, return_tensors="pt")
for _ in range(max_new_tokens):
# 每次只生成一个token
with torch.no_grad():
outputs = model(input_ids)
next_token_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=-1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0])
投机解码
def speculative_decoding(target_model, draft_model, tokenizer,
prompt, max_new_tokens=100, gamma=5):
"""投机解码实现"""
input_ids = tokenizer.encode(prompt, return_tensors="pt")
generated_tokens = 0
while generated_tokens < max_new_tokens:
# 1. 草稿模型快速生成gamma个token
draft_tokens = []
draft_input = input_ids.clone()
for _ in range(gamma):
with torch.no_grad():
draft_outputs = draft_model(draft_input)
draft_logits = draft_outputs.logits[:, -1, :]
draft_token = torch.argmax(draft_logits, dim=-1, keepdim=True)
draft_tokens.append(draft_token.item())
draft_input = torch.cat([draft_input, draft_token], dim=-1)
# 2. 目标模型并行验证
candidate_ids = torch.cat([
input_ids,
torch.tensor([draft_tokens]).unsqueeze(0)
], dim=-1)
with torch.no_grad():
target_outputs = target_model(candidate_ids)
target_logits = target_outputs.logits[:, -gamma-1:, :]
# 3. 验证草稿token
accepted_count = 0
for i, draft_token in enumerate(draft_tokens):
target_token = torch.argmax(target_logits[:, i, :], dim=-1).item()
if draft_token == target_token:
accepted_count += 1
else:
# 拒绝草稿token,使用目标模型的分布采样
break
# 4. 接受的token + 额外采样一个token
accepted_tokens = draft_tokens[:accepted_count]
if accepted_count < gamma:
# 从拒绝位置的目标分布采样
next_token = torch.argmax(target_logits[:, accepted_count, :], dim=-1)
accepted_tokens.append(next_token.item())
# 更新输入
input_ids = torch.cat([
input_ids,
torch.tensor([accepted_tokens])
], dim=-1)
generated_tokens += len(accepted_tokens)
return tokenizer.decode(input_ids[0])
实现细节
简化实现
class SpeculativeDecoder:
"""投机解码器"""
def __init__(self, target_model, draft_model, tokenizer, gamma=5):
self.target_model = target_model
self.draft_model = draft_model
self.tokenizer = tokenizer
self.gamma = gamma
self.target_model.eval()
self.draft_model.eval()
def generate(self, prompt, max_new_tokens=100):
"""生成文本"""
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
tokens_generated = 0
while tokens_generated < max_new_tokens:
# 草稿阶段
draft_tokens, draft_probs = self._draft_phase(input_ids)
# 验证阶段
accepted_tokens = self._verify_phase(input_ids, draft_tokens, draft_probs)
# 更新
input_ids = torch.cat([input_ids, accepted_tokens.unsqueeze(0)], dim=-1)
tokens_generated += len(accepted_tokens)
# 检查EOS
if self.tokenizer.eos_token_id in accepted_tokens:
break
return self.tokenizer.decode(input_ids[0])
def _draft_phase(self, input_ids):
"""草稿阶段"""
draft_tokens = []
draft_probs = []
current_input = input_ids.clone()
for _ in range(self.gamma):
with torch.no_grad():
outputs = self.draft_model(current_input)
logits = outputs.logits[:, -1, :]
probs = torch.softmax(logits, dim=-1)
# 采样
token = torch.multinomial(probs, 1).item()
draft_tokens.append(token)
draft_probs.append(probs[0, token].item())
current_input = torch.cat([
current_input,
torch.tensor([token](/notes/token))
], dim=-1)
return draft_tokens, draft_probs
def _verify_phase(self, input_ids, draft_tokens, draft_probs):
"""验证阶段"""
# 构建候选序列
candidate_ids = torch.cat([
input_ids,
torch.tensor([draft_tokens])
], dim=-1)
# 目标模型验证
with torch.no_grad():
target_outputs = self.target_model(candidate_ids)
target_logits = target_outputs.logits[:, -self.gamma-1:-1, :]
target_probs = torch.softmax(target_logits, dim=-1)
accepted_tokens = []
for i, draft_token in enumerate(draft_tokens):
target_prob = target_probs[0, i, draft_token].item()
draft_prob = draft_probs[i]
# 接受概率
accept_ratio = min(1, target_prob / draft_prob)
if torch.rand(1).item() < accept_ratio:
accepted_tokens.append(draft_token)
else:
# 从修正分布采样
adjusted_probs = torch.clamp(
target_probs[0, i] - self.draft_model(current_input).logits[:, -1, :],
min=0
)
adjusted_probs = adjusted_probs / adjusted_probs.sum()
new_token = torch.multinomial(adjusted_probs, 1).item()
accepted_tokens.append(new_token)
break
return torch.tensor(accepted_tokens)
使用Hugging Face
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
def hf_speculative_decoding(target_model_name, draft_model_name, prompt):
"""使用Hugging Face模型实现投机解码"""
# 加载模型
target_tokenizer = AutoTokenizer.from_pretrained(target_model_name)
target_model = AutoModelForCausalLM.from_pretrained(target_model_name)
draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_name)
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name)
# 创建解码器
decoder = SpeculativeDecoder(target_model, draft_model, target_tokenizer, gamma=5)
# 生成
output = decoder.generate(prompt, max_new_tokens=256)
return output
草稿模型选择
使用n-gram模型
class NgramDraftModel:
"""N-gram草稿模型"""
def __init__(self, n=3):
self.n = n
self.ngram_counts = {}
def train(self, corpus):
"""训练n-gram模型"""
for text in corpus:
tokens = text.split()
for i in range(len(tokens) - self.n + 1):
ngram = tuple(tokens[i:i+self.n-1])
next_token = tokens[i+self.n-1]
if ngram not in self.ngram_counts:
self.ngram_counts[ngram] = {}
if next_token not in self.ngram_counts[ngram]:
self.ngram_counts[ngram][next_token] = 0
self.ngram_counts[ngram][next_token] += 1
def predict(self, context):
"""预测下一个token"""
if len(context) < self.n - 1:
return None
ngram = tuple(context[-(self.n-1):])
if ngram in self.ngram_counts:
counts = self.ngram_counts[ngram]
total = sum(counts.values())
probs = {k: v/total for k, v in counts.items()}
return probs
return None
使用量化模型
from transformers import BitsAndBytesConfig
def create_quantized_draft_model(model_name):
"""创建量化草稿模型"""
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4"
)
draft_model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config
)
return draft_model
性能优化
加速比分析
def analyze_speedup(target_model, draft_model, tokenizer, prompts, gamma_values=[1, 3, 5, 7]):
"""分析不同gamma值的加速比"""
results = []
for gamma in gamma_values:
decoder = SpeculativeDecoder(target_model, draft_model, tokenizer, gamma=gamma)
start_time = time.time()
for prompt in prompts:
decoder.generate(prompt, max_new_tokens=100)
total_time = time.time() - start_time
throughput = len(prompts) * 100 / total_time
results.append({
"gamma": gamma,
"throughput": throughput,
"time_per_token": total_time / (len(prompts) * 100)
})
return results
接受率优化
def optimize_acceptance_rate(target_model, draft_model, tokenizer, eval_data):
"""优化接受率"""
total_tokens = 0
accepted_tokens = 0
for prompt in eval_data:
# 运行投机解码并记录接受率
decoder = SpeculativeDecoder(target_model, draft_model, tokenizer)
input_ids = tokenizer.encode(prompt, return_tensors="pt")
draft_tokens, draft_probs = decoder._draft_phase(input_ids)
# 模拟验证
for i, token in enumerate(draft_tokens):
total_tokens += 1
# 简化的接受率计算
accepted_tokens += 1 # 实际需要比较
acceptance_rate = accepted_tokens / total_tokens
return acceptance_rate
与其他技术结合
# 投机解码 + KV Cache优化
def optimized_speculative_decoding(target_model, draft_model, tokenizer, prompt):
"""优化的投机解码"""
decoder = SpeculativeDecoder(target_model, draft_model, tokenizer, gamma=5)
# 预热KV Cache
input_ids = tokenizer.encode(prompt, return_tensors="pt")
with torch.no_grad():
_ = target_model(input_ids, use_cache=True)
_ = draft_model(input_ids, use_cache=True)
# 生成
output = decoder.generate(prompt, max_new_tokens=256)
return output
评估指标
def evaluate_speculative_decoding(target_model, draft_model, tokenizer, test_prompts):
"""评估投机解码性能"""
results = {
"total_tokens": 0,
"accepted_tokens": 0,
"acceptance_rate": 0,
"latency": 0,
"throughput": 0
}
decoder = SpeculativeDecoder(target_model, draft_model, tokenizer, gamma=5)
start_time = time.time()
for prompt in test_prompts:
# 记录接受率
input_ids = tokenizer.encode(prompt, return_tensors="pt")
draft_tokens, draft_probs = decoder._draft_phase(input_ids)
# 模拟验证
for token in draft_tokens:
results["total_tokens"] += 1
results["accepted_tokens"] += 1
total_time = time.time() - start_time
results["acceptance_rate"] = results["accepted_tokens"] / results["total_tokens"]
results["latency"] = total_time / len(test_prompts)
results["throughput"] = sum(len(tokenizer.encode(p)) for p in test_prompts) / total_time
return results
投机解码通过并行验证显著提升了自回归生成速度,是LLM推理优化的重要技术之一。