← 返回首页
🧠

长上下文

📂 llm ⏱ 2 min 347 words

--- title: "长上下文" description: "长上下文处理技术,包括滑动窗口注意力、位置编码扩展和上下文窗口管理" tags: ["长上下文", "滑动窗口", "位置编码", "上下文扩展"] category: "llm" icon: "🧠"

长上下文

长上下文处理是大语言模型的核心挑战之一。标准Transformer的注意力机制计算复杂度为O(n²),直接处理超长序列会导致显存爆炸和推理缓慢。本文介绍长上下文的核心技术及其实现。

位置编码扩展

RoPE位置插值

旋转位置编码(RoPE)是目前最主流的位置编码方案。通过线性插值或NTK-aware缩放,可以将预训练时的上下文长度扩展到更长:

import torch
import math

def precompute_freqs(dim, max_seq_len, base=10000.0, scaling_factor=1.0):
    """预计算RoPE频率,支持缩放扩展"""
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_seq_len).float() * scaling_factor
    freqs = torch.outer(t, freqs)
    return torch.cos(freqs), torch.sin(freqs)

def apply_rope(x, cos, sin):
    """应用旋转位置编码"""
    d = x.shape[-1] // 2
    x1, x2 = x[..., :d], x[..., d:]
    cos = cos[:x.shape[-2]].unsqueeze(0)
    sin = sin[:x.shape[-2]].unsqueeze(0)
    return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)

# 线性插值:将4K扩展到32K上下文
scaling_factor = 32768 / 4096  # = 8
cos, sin = precompute_freqs(128, 32768, scaling_factor=scaling_factor)

YaRN(Yet another RoPE extension)

YaRN结合了NTK-aware缩放和注意力缩放,效果优于纯线性插值:

def yarn_rope(dim, max_seq_len, base=10000.0, original_max=4096):
    """YaRN位置编码扩展"""
    beta_fast = 32.0
    beta_slow = 1.0
    low, high = yarn_find_correction_range(beta_fast, beta_slow, dim, base)
    
    # NTK-aware缩放
    new_base = base * (scaling_factor ** (dim / (dim - 2)))
    
    # 注意力因子补偿
    attn_factor = 0.1 * math.log(scaling_factor) + 1.0
    
    return precompute_freqs(dim, max_seq_len, base=new_base), attn_factor

滑动窗口注意力

Mistral模型采用滑动窗口注意力(Sliding Window Attention),每个token只关注固定窗口内的邻居,将复杂度降低为O(n):

def sliding_window_attention(q, k, v, window_size=4096):
    """滑动窗口注意力实现"""
    seq_len = q.shape[-2]
    output = torch.zeros_like(q)
    
    for i in range(seq_len):
        start = max(0, i - window_size)
        # 只计算窗口内的注意力
        scores = torch.matmul(q[:, :, i:i+1, :], 
                             k[:, :, start:i+1, :].transpose(-2, -1))
        scores = scores / math.sqrt(q.shape[-1])
        weights = torch.softmax(scores, dim=-1)
        output[:, :, i:i+1, :] = torch.matmul(weights, v[:, :, start:i+1, :])
    
    return output

稀疏注意力模式

Longformer和BigBird采用稀疏注意力模式,结合局部注意力和全局注意力:

class SparseAttention:
    @staticmethod
    def longformer_pattern(seq_len, window=512, num_global=64):
        """Longformer: 局部窗口 + 全局token"""
        mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
        
        # 局部窗口注意力
        for i in range(seq_len):
            start = max(0, i - window)
            end = min(seq_len, i + window + 1)
            mask[i, start:end] = True
        
        # 全局token对所有位置可见
        global_tokens = list(range(num_global))
        mask[global_tokens, :] = True
        mask[:, global_tokens] = True
        
        return mask

实际应用:超长文档处理

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

def process_long_document(text, max_chunk=8192, overlap=512):
    """分块处理长文档,维护上下文一致性"""
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-3.1-8B-Instruct", device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
    
    # 按token分割,保留重叠
    tokens = tokenizer.encode(text)
    chunks = []
    for i in range(0, len(tokens), max_chunk - overlap):
        chunk = tokens[i:i + max_chunk]
        chunks.append(chunk)
    
    summaries = []
    for chunk in chunks:
        input_text = tokenizer.decode(chunk)
        messages = [
            {"role": "system", "content": "请对以下文本进行摘要"},
            {"role": "user", "content": input_text}
        ]
        inputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
        outputs = model.generate(inputs.to(model.device), max_new_tokens=1024)
        summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
        summaries.append(summary)
    
    # 合并摘要
    final_input = "请整合以下摘要:\n" + "\n".join(summaries)
    final_outputs = model.generate(
        tokenizer(final_input, return_tensors="pt").input_ids.to(model.device),
        max_new_tokens=2048
    )
    return tokenizer.decode(final_outputs[0], skip_special_tokens=True)

长上下文技术正快速发展,128K乃至1M上下文窗口已成为现实,但在实际应用中仍需权衡成本与效果。