← 返回首页
🧠

投机解码:加速自回归生成

📂 llm ⏱ 4 min 795 words

--- title: "投机解码:加速自回归生成" description: "掌握投机解码的原理和实现,通过并行验证显著提升生成速度" tags: ["投机解码", "Speculative Decoding", "并行解码", "推理加速"] category: "llm" icon: "🧠"

投机解码:加速自回归生成

投机解码简介

投机解码(Speculative Decoding)是一种加速自回归文本生成的技术。它使用一个小型"草稿模型"(Draft Model)快速生成多个候选token,然后用大型"目标模型"(Target Model)并行验证这些token。如果草稿模型的预测正确,可以跳过多次前向传播,从而显著加速生成过程。

投机解码的核心优势:

工作原理

标准自回归解码

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def standard_autoregressive(model, tokenizer, prompt, max_new_tokens=100):
    """标准自回归解码"""
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    
    for _ in range(max_new_tokens):
        # 每次只生成一个token
        with torch.no_grad():
            outputs = model(input_ids)
            next_token_logits = outputs.logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    return tokenizer.decode(input_ids[0])

投机解码

def speculative_decoding(target_model, draft_model, tokenizer, 
                        prompt, max_new_tokens=100, gamma=5):
    """投机解码实现"""
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    
    generated_tokens = 0
    while generated_tokens < max_new_tokens:
        # 1. 草稿模型快速生成gamma个token
        draft_tokens = []
        draft_input = input_ids.clone()
        
        for _ in range(gamma):
            with torch.no_grad():
                draft_outputs = draft_model(draft_input)
                draft_logits = draft_outputs.logits[:, -1, :]
                draft_token = torch.argmax(draft_logits, dim=-1, keepdim=True)
            draft_tokens.append(draft_token.item())
            draft_input = torch.cat([draft_input, draft_token], dim=-1)
        
        # 2. 目标模型并行验证
        candidate_ids = torch.cat([
            input_ids,
            torch.tensor([draft_tokens]).unsqueeze(0)
        ], dim=-1)
        
        with torch.no_grad():
            target_outputs = target_model(candidate_ids)
            target_logits = target_outputs.logits[:, -gamma-1:, :]
        
        # 3. 验证草稿token
        accepted_count = 0
        for i, draft_token in enumerate(draft_tokens):
            target_token = torch.argmax(target_logits[:, i, :], dim=-1).item()
            
            if draft_token == target_token:
                accepted_count += 1
            else:
                # 拒绝草稿token,使用目标模型的分布采样
                break
        
        # 4. 接受的token + 额外采样一个token
        accepted_tokens = draft_tokens[:accepted_count]
        
        if accepted_count < gamma:
            # 从拒绝位置的目标分布采样
            next_token = torch.argmax(target_logits[:, accepted_count, :], dim=-1)
            accepted_tokens.append(next_token.item())
        
        # 更新输入
        input_ids = torch.cat([
            input_ids,
            torch.tensor([accepted_tokens])
        ], dim=-1)
        
        generated_tokens += len(accepted_tokens)
    
    return tokenizer.decode(input_ids[0])

实现细节

简化实现

class SpeculativeDecoder:
    """投机解码器"""
    
    def __init__(self, target_model, draft_model, tokenizer, gamma=5):
        self.target_model = target_model
        self.draft_model = draft_model
        self.tokenizer = tokenizer
        self.gamma = gamma
        
        self.target_model.eval()
        self.draft_model.eval()
    
    def generate(self, prompt, max_new_tokens=100):
        """生成文本"""
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
        
        tokens_generated = 0
        while tokens_generated < max_new_tokens:
            # 草稿阶段
            draft_tokens, draft_probs = self._draft_phase(input_ids)
            
            # 验证阶段
            accepted_tokens = self._verify_phase(input_ids, draft_tokens, draft_probs)
            
            # 更新
            input_ids = torch.cat([input_ids, accepted_tokens.unsqueeze(0)], dim=-1)
            tokens_generated += len(accepted_tokens)
            
            # 检查EOS
            if self.tokenizer.eos_token_id in accepted_tokens:
                break
        
        return self.tokenizer.decode(input_ids[0])
    
    def _draft_phase(self, input_ids):
        """草稿阶段"""
        draft_tokens = []
        draft_probs = []
        
        current_input = input_ids.clone()
        
        for _ in range(self.gamma):
            with torch.no_grad():
                outputs = self.draft_model(current_input)
                logits = outputs.logits[:, -1, :]
                probs = torch.softmax(logits, dim=-1)
                
                # 采样
                token = torch.multinomial(probs, 1).item()
                
                draft_tokens.append(token)
                draft_probs.append(probs[0, token].item())
                
                current_input = torch.cat([
                    current_input,
                    torch.tensor([token](/notes/token))
                ], dim=-1)
        
        return draft_tokens, draft_probs
    
    def _verify_phase(self, input_ids, draft_tokens, draft_probs):
        """验证阶段"""
        # 构建候选序列
        candidate_ids = torch.cat([
            input_ids,
            torch.tensor([draft_tokens])
        ], dim=-1)
        
        # 目标模型验证
        with torch.no_grad():
            target_outputs = self.target_model(candidate_ids)
            target_logits = target_outputs.logits[:, -self.gamma-1:-1, :]
            target_probs = torch.softmax(target_logits, dim=-1)
        
        accepted_tokens = []
        
        for i, draft_token in enumerate(draft_tokens):
            target_prob = target_probs[0, i, draft_token].item()
            draft_prob = draft_probs[i]
            
            # 接受概率
            accept_ratio = min(1, target_prob / draft_prob)
            
            if torch.rand(1).item() < accept_ratio:
                accepted_tokens.append(draft_token)
            else:
                # 从修正分布采样
                adjusted_probs = torch.clamp(
                    target_probs[0, i] - self.draft_model(current_input).logits[:, -1, :],
                    min=0
                )
                adjusted_probs = adjusted_probs / adjusted_probs.sum()
                new_token = torch.multinomial(adjusted_probs, 1).item()
                accepted_tokens.append(new_token)
                break
        
        return torch.tensor(accepted_tokens)

使用Hugging Face

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

def hf_speculative_decoding(target_model_name, draft_model_name, prompt):
    """使用Hugging Face模型实现投机解码"""
    # 加载模型
    target_tokenizer = AutoTokenizer.from_pretrained(target_model_name)
    target_model = AutoModelForCausalLM.from_pretrained(target_model_name)
    
    draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_name)
    draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name)
    
    # 创建解码器
    decoder = SpeculativeDecoder(target_model, draft_model, target_tokenizer, gamma=5)
    
    # 生成
    output = decoder.generate(prompt, max_new_tokens=256)
    return output

草稿模型选择

使用n-gram模型

class NgramDraftModel:
    """N-gram草稿模型"""
    
    def __init__(self, n=3):
        self.n = n
        self.ngram_counts = {}
    
    def train(self, corpus):
        """训练n-gram模型"""
        for text in corpus:
            tokens = text.split()
            for i in range(len(tokens) - self.n + 1):
                ngram = tuple(tokens[i:i+self.n-1])
                next_token = tokens[i+self.n-1]
                
                if ngram not in self.ngram_counts:
                    self.ngram_counts[ngram] = {}
                
                if next_token not in self.ngram_counts[ngram]:
                    self.ngram_counts[ngram][next_token] = 0
                self.ngram_counts[ngram][next_token] += 1
    
    def predict(self, context):
        """预测下一个token"""
        if len(context) < self.n - 1:
            return None
        
        ngram = tuple(context[-(self.n-1):])
        
        if ngram in self.ngram_counts:
            counts = self.ngram_counts[ngram]
            total = sum(counts.values())
            probs = {k: v/total for k, v in counts.items()}
            return probs
        return None

使用量化模型

from transformers import BitsAndBytesConfig

def create_quantized_draft_model(model_name):
    """创建量化草稿模型"""
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4"
    )
    
    draft_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config
    )
    
    return draft_model

性能优化

加速比分析

def analyze_speedup(target_model, draft_model, tokenizer, prompts, gamma_values=[1, 3, 5, 7]):
    """分析不同gamma值的加速比"""
    results = []
    
    for gamma in gamma_values:
        decoder = SpeculativeDecoder(target_model, draft_model, tokenizer, gamma=gamma)
        
        start_time = time.time()
        for prompt in prompts:
            decoder.generate(prompt, max_new_tokens=100)
        total_time = time.time() - start_time
        
        throughput = len(prompts) * 100 / total_time
        
        results.append({
            "gamma": gamma,
            "throughput": throughput,
            "time_per_token": total_time / (len(prompts) * 100)
        })
    
    return results

接受率优化

def optimize_acceptance_rate(target_model, draft_model, tokenizer, eval_data):
    """优化接受率"""
    total_tokens = 0
    accepted_tokens = 0
    
    for prompt in eval_data:
        # 运行投机解码并记录接受率
        decoder = SpeculativeDecoder(target_model, draft_model, tokenizer)
        
        input_ids = tokenizer.encode(prompt, return_tensors="pt")
        draft_tokens, draft_probs = decoder._draft_phase(input_ids)
        
        # 模拟验证
        for i, token in enumerate(draft_tokens):
            total_tokens += 1
            # 简化的接受率计算
            accepted_tokens += 1  # 实际需要比较
    
    acceptance_rate = accepted_tokens / total_tokens
    return acceptance_rate

与其他技术结合

# 投机解码 + KV Cache优化
def optimized_speculative_decoding(target_model, draft_model, tokenizer, prompt):
    """优化的投机解码"""
    decoder = SpeculativeDecoder(target_model, draft_model, tokenizer, gamma=5)
    
    # 预热KV Cache
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    with torch.no_grad():
        _ = target_model(input_ids, use_cache=True)
        _ = draft_model(input_ids, use_cache=True)
    
    # 生成
    output = decoder.generate(prompt, max_new_tokens=256)
    return output

评估指标

def evaluate_speculative_decoding(target_model, draft_model, tokenizer, test_prompts):
    """评估投机解码性能"""
    results = {
        "total_tokens": 0,
        "accepted_tokens": 0,
        "acceptance_rate": 0,
        "latency": 0,
        "throughput": 0
    }
    
    decoder = SpeculativeDecoder(target_model, draft_model, tokenizer, gamma=5)
    
    start_time = time.time()
    
    for prompt in test_prompts:
        # 记录接受率
        input_ids = tokenizer.encode(prompt, return_tensors="pt")
        draft_tokens, draft_probs = decoder._draft_phase(input_ids)
        
        # 模拟验证
        for token in draft_tokens:
            results["total_tokens"] += 1
            results["accepted_tokens"] += 1
    
    total_time = time.time() - start_time
    
    results["acceptance_rate"] = results["accepted_tokens"] / results["total_tokens"]
    results["latency"] = total_time / len(test_prompts)
    results["throughput"] = sum(len(tokenizer.encode(p)) for p in test_prompts) / total_time
    
    return results

投机解码通过并行验证显著提升了自回归生成速度,是LLM推理优化的重要技术之一。