长上下文
--- title: "长上下文" description: "长上下文处理技术,包括滑动窗口注意力、位置编码扩展和上下文窗口管理" tags: ["长上下文", "滑动窗口", "位置编码", "上下文扩展"] category: "llm" icon: "🧠"
长上下文
长上下文处理是大语言模型的核心挑战之一。标准Transformer的注意力机制计算复杂度为O(n²),直接处理超长序列会导致显存爆炸和推理缓慢。本文介绍长上下文的核心技术及其实现。
位置编码扩展
RoPE位置插值
旋转位置编码(RoPE)是目前最主流的位置编码方案。通过线性插值或NTK-aware缩放,可以将预训练时的上下文长度扩展到更长:
import torch
import math
def precompute_freqs(dim, max_seq_len, base=10000.0, scaling_factor=1.0):
"""预计算RoPE频率,支持缩放扩展"""
freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).float() * scaling_factor
freqs = torch.outer(t, freqs)
return torch.cos(freqs), torch.sin(freqs)
def apply_rope(x, cos, sin):
"""应用旋转位置编码"""
d = x.shape[-1] // 2
x1, x2 = x[..., :d], x[..., d:]
cos = cos[:x.shape[-2]].unsqueeze(0)
sin = sin[:x.shape[-2]].unsqueeze(0)
return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
# 线性插值:将4K扩展到32K上下文
scaling_factor = 32768 / 4096 # = 8
cos, sin = precompute_freqs(128, 32768, scaling_factor=scaling_factor)
YaRN(Yet another RoPE extension)
YaRN结合了NTK-aware缩放和注意力缩放,效果优于纯线性插值:
def yarn_rope(dim, max_seq_len, base=10000.0, original_max=4096):
"""YaRN位置编码扩展"""
beta_fast = 32.0
beta_slow = 1.0
low, high = yarn_find_correction_range(beta_fast, beta_slow, dim, base)
# NTK-aware缩放
new_base = base * (scaling_factor ** (dim / (dim - 2)))
# 注意力因子补偿
attn_factor = 0.1 * math.log(scaling_factor) + 1.0
return precompute_freqs(dim, max_seq_len, base=new_base), attn_factor
滑动窗口注意力
Mistral模型采用滑动窗口注意力(Sliding Window Attention),每个token只关注固定窗口内的邻居,将复杂度降低为O(n):
def sliding_window_attention(q, k, v, window_size=4096):
"""滑动窗口注意力实现"""
seq_len = q.shape[-2]
output = torch.zeros_like(q)
for i in range(seq_len):
start = max(0, i - window_size)
# 只计算窗口内的注意力
scores = torch.matmul(q[:, :, i:i+1, :],
k[:, :, start:i+1, :].transpose(-2, -1))
scores = scores / math.sqrt(q.shape[-1])
weights = torch.softmax(scores, dim=-1)
output[:, :, i:i+1, :] = torch.matmul(weights, v[:, :, start:i+1, :])
return output
稀疏注意力模式
Longformer和BigBird采用稀疏注意力模式,结合局部注意力和全局注意力:
class SparseAttention:
@staticmethod
def longformer_pattern(seq_len, window=512, num_global=64):
"""Longformer: 局部窗口 + 全局token"""
mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
# 局部窗口注意力
for i in range(seq_len):
start = max(0, i - window)
end = min(seq_len, i + window + 1)
mask[i, start:end] = True
# 全局token对所有位置可见
global_tokens = list(range(num_global))
mask[global_tokens, :] = True
mask[:, global_tokens] = True
return mask
实际应用:超长文档处理
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
def process_long_document(text, max_chunk=8192, overlap=512):
"""分块处理长文档,维护上下文一致性"""
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct", device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
# 按token分割,保留重叠
tokens = tokenizer.encode(text)
chunks = []
for i in range(0, len(tokens), max_chunk - overlap):
chunk = tokens[i:i + max_chunk]
chunks.append(chunk)
summaries = []
for chunk in chunks:
input_text = tokenizer.decode(chunk)
messages = [
{"role": "system", "content": "请对以下文本进行摘要"},
{"role": "user", "content": input_text}
]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
outputs = model.generate(inputs.to(model.device), max_new_tokens=1024)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
summaries.append(summary)
# 合并摘要
final_input = "请整合以下摘要:\n" + "\n".join(summaries)
final_outputs = model.generate(
tokenizer(final_input, return_tensors="pt").input_ids.to(model.device),
max_new_tokens=2048
)
return tokenizer.decode(final_outputs[0], skip_special_tokens=True)
长上下文技术正快速发展,128K乃至1M上下文窗口已成为现实,但在实际应用中仍需权衡成本与效果。