← 返回首页
🧠

稀疏模型技术

📂 llm ⏱ 3 min 440 words

--- title: "稀疏模型技术" description: "深入解析大语言模型中的稀疏化技术,包括稀疏注意力、MoE稀疏激活和稀疏加速等核心方法,帮助降低模型计算开销" tags: ["稀疏模型", "稀疏注意力", "MoE", "稀疏加速"] category: "llm" icon: "🧠"

稀疏模型技术

什么是稀疏模型

在大语言模型(LLM)中,稀疏化是一种通过减少模型中有效参数或激活数量来降低计算开销的核心技术。稀疏模型的核心思想是:模型中的大部分权重或激活值可以被设为零,而不会显著影响模型性能。与稠密模型中所有参数都参与计算不同,稀疏模型只激活模型的一部分神经元,从而大幅减少推理时间和内存占用。

稀疏化的分类

非结构化稀疏

非结构化稀疏是指随机将权重置为零,不考虑网络结构。这种方法可以达到很高的稀疏率,但对硬件加速不友好。

import torch
import torch.nn.utils.prune as prune

# 对线性层进行L1非结构化剪枝(稀疏率80%)
linear = torch.nn.Linear(1024, 1024)
prune.l1_unstructured(linear, name='weight', amount=0.8)

# 检查稀疏率
weight = linear.weight
sparsity = 1.0 - (weight.count_nonzero() / weight.numel())
print(f"稀疏率: {sparsity:.2%}")  # 输出约80%

结构化稀疏

结构化稀疏按行、列或块为单位进行稀疏化,更有利于硬件加速。例如,将整个注意力头或前馈网络层整体移除。

# 结构化剪枝:按通道裁剪
conv = torch.nn.Conv2d(64, 128, 3)
prune.ln_structured(conv, name='weight', amount=0.3, n=2, dim=0)

# 验证结构化稀疏
print(f"稀疏后权重形状: {conv.weight.shape}")  # 形状不变
print(f"非零元素比例: {conv.weight.count_nonzero() / conv.weight.numel():.2%}")

稀疏注意力机制

稀疏注意力是Transformer模型中最重要的稀疏化技术之一。标准注意力的复杂度为O(n²),而稀疏注意力通过限制每个token只关注部分其他token来降低复杂度。

局部窗口注意力

import torch
import torch.nn.functional as F

def sparse_local_attention(Q, K, V, window_size=256):
    """局部窗口稀疏注意力"""
    batch, heads, seq_len, dim = Q.shape
    output = torch.zeros_like(V)
    
    for i in range(seq_len):
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)
        
        # 只计算窗口内的注意力
        q = Q[:, :, i:i+1, :]
        k = K[:, :, start:end, :]
        v = V[:, :, start:end, :]
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / (dim ** 0.5)
        weights = F.softmax(scores, dim=-1)
        output[:, :, i:i+1, :] = torch.matmul(weights, v)
    
    return output

全局-局部混合注意力

class SparseAttention(torch.nn.Module):
    """BigBird风格的稀疏注意力"""
    def __init__(self, hidden_dim, window_size=64, num_global_tokens=64):
        super().__init__()
        self.window_size = window_size
        self.num_global_tokens = num_global_tokens
    
    def forward(self, x):
        batch, seq_len, dim = x.shape
        
        # 全局token关注所有位置
        global_tokens = x[:, :self.num_global_tokens, :]
        
        # 局部token只关注窗口内
        sparse_mask = self._create_sparse_mask(seq_len)
        
        return global_tokens, sparse_mask
    
    def _create_sparse_mask(self, seq_len):
        mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = True
        return mask

MoE稀疏激活

混合专家模型(MoE)是实现稀疏激活的主流方案。每个token只激活部分专家网络,而非全部。

import torch
import torch.nn as nn

class SparseMoELayer(nn.Module):
    def __init__(self, hidden_dim, num_experts=8, top_k=2):
        super().__init__()
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim * 4),
                nn.GELU(),
                nn.Linear(hidden_dim * 4, hidden_dim)
            ) for _ in range(num_experts)
        ])
        self.gate = nn.Linear(hidden_dim, num_experts)
        self.top_k = top_k
    
    def forward(self, x):
        batch, seq_len, dim = x.shape
        
        # 计算门控分数
        gate_scores = self.gate(x)  # [batch, seq_len, num_experts]
        
        # 选择top-k专家
        topk_indices = torch.topk(gate_scores, self.top_k, dim=-1).indices
        
        # 稀疏计算:只激活top-k专家
        output = torch.zeros_like(x)
        for k in range(self.top_k):
            expert_idx = topk_indices[:, :, k]
            for i in range(len(self.experts)):
                mask = (expert_idx == i)
                if mask.any():
                    selected = x[mask]
                    output[mask] += self.experts[i](selected) / self.top_k
        
        return output

稀疏加速技术

稀疏矩阵乘法

现代GPU支持稀疏矩阵乘法,可以跳过零值计算。

# 将稠密权重转换为稀疏格式
dense_weight = torch.randn(1024, 1024)
sparse_weight = dense_weight.to_sparse()

# 稀疏矩阵乘法(跳过零值)
input_tensor = torch.randn(1, 1024)
output = torch.sparse.mm(sparse_weight, input_tensor.t()).t()

动态稀疏推理

class DynamicSparseInference:
    """根据输入动态调整稀疏率"""
    def __init__(self, model, sparsity_threshold=0.5):
        self.model = model
        self.threshold = sparsity_threshold
    
    def forward(self, x):
        # 计算激活值
        activations = self.model(x)
        
        # 动态稀疏化:保留top-k%的激活
        k = int(activations.numel() * (1 - self.threshold))
        _, topk_indices = torch.topk(activations.abs().flatten(), k)
        
        sparse_output = torch.zeros_like(activations).flatten()
        sparse_output[topk_indices] = activations.flatten()[topk_indices]
        
        return sparse_output.reshape(activations.shape)

稀疏模型的实际应用

推理加速效果

稀疏率 参数量 推理速度 精度损失
50% 50% 1.5x <1%
70% 30% 2.5x 1-3%
90% 10% 5x 3-8%

稀疏模型训练策略

def train_sparse_model(model, train_loader, optimizer, epochs=10):
    """渐进式稀疏训练"""
    for epoch in range(epochs):
        # 每个epoch逐渐增加稀疏率
        target_sparsity = min(0.9, epoch * 0.1)
        
        for batch in train_loader:
            output = model(batch)
            loss = compute_loss(output, batch)
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            
            # 应用稀疏掩码
            apply_sparsity_mask(model, target_sparsity)
            
            optimizer.step()

总结

稀疏模型技术是降低LLM计算成本的关键手段。通过稀疏注意力、MoE激活和结构化剪枝等方法,可以在保持模型性能的同时大幅降低推理成本。随着硬件对稀疏计算的支持不断改善,稀疏模型将在实际部署中发挥越来越重要的作用。