← 返回首页
🧠

剪枝完全指南

📂 llm ⏱ 4 min 731 words

--- title: "剪枝完全指南" description: "全面介绍大模型剪枝技术,包括非结构化剪枝、结构化剪枝和迭代剪枝策略,实现模型压缩与加速。" tags: ["模型剪枝", "非结构化剪枝", "结构化剪枝", "迭代剪枝"] category: "llm" icon: "🧠"

剪枝完全指南

什么是模型剪枝

模型剪枝通过移除模型中不重要的权重、神经元或层来压缩模型。与量化不同,剪枝直接删除参数,可以在保持精度的同时显著减少模型大小和计算量。

剪枝方法分类

非结构化剪枝

移除单个权重,产生稀疏矩阵:

import torch
import torch.nn.utils.prune as prune

def unstructured_pruning(model, amount=0.5):
    """非结构化L1剪枝"""
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            # L1范数剪枝:移除绝对值最小的权重
            prune.l1_unstructured(module, name='weight', amount=amount)
            
            # 打印稀疏度
            sparsity = 1 - torch.count_nonzero(module.weight) / module.weight.nelement()
            print(f"{name}: 稀疏度 {sparsity:.2%}")
    
    return model

def structured_pruning(model, amount=0.3):
    """结构化剪枝:移除整个通道"""
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            # 按通道剪枝
            prune.ln_structured(
                module, name='weight', amount=amount, n=2, dim=0
            )
        elif isinstance(module, torch.nn.Linear):
            # 按输出神经元剪枝
            prune.ln_structured(
                module, name='weight', amount=amount, n=2, dim=0
            )
    
    return model

全局剪枝

def global_pruning(model, amount=0.5):
    """全局剪枝:考虑所有层的重要性"""
    parameters_to_prune = []
    
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
            parameters_to_prune.append((module, 'weight'))
    
    # 全局L1剪枝
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=amount,
    )
    
    # 打印全局稀疏度
    total_params = sum(p.weight.nelement() for p, _ in parameters_to_prune)
    total_zeros = sum(torch.count_nonzero(p.weight) == 0 for p, _ in parameters_to_prune)
    print(f"全局稀疏度: {total_zeros / total_params:.2%}")
    
    return model

迭代剪枝

逐步剪枝策略

class IterativePruner:
    """迭代剪枝:逐步增加稀疏度"""
    
    def __init__(self, model, initial_sparsity=0.1, target_sparsity=0.9,
                 num_iterations=10, pruner='l1'):
        self.model = model
        self.initial_sparsity = initial_sparsity
        self.target_sparsity = target_sparsity
        self.num_iterations = num_iterations
        self.pruner = pruner
    
    def prune_one_iteration(self, current_sparsity: float):
        """执行一次剪枝迭代"""
        # 计算本轮要剪枝的量
        if current_sparsity == 0:
            amount = self.initial_sparsity
        else:
            amount = self.target_sparsity - current_sparsity
        
        # 应用剪枝
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.Linear):
                if self.pruner == 'l1':
                    prune.l1_unstructured(module, 'weight', amount=amount)
                elif self.pruner == 'random':
                    prune.random_unstructured(module, 'weight', amount=amount)
        
        return self._get_sparsity()
    
    def iterative_prune(self, train_fn, eval_fn):
        """迭代剪枝流程"""
        sparsity = 0
        
        for i in range(self.num_iterations):
            print(f"迭代 {i+1}/{self.num_iterations}")
            
            # 剪枝
            sparsity = self.prune_one_iteration(sparsity)
            print(f"当前稀疏度: {sparsity:.2%}")
            
            # 微调恢复精度
            train_fn(self.model, epochs=5)
            
            # 评估
            accuracy = eval_fn(self.model)
            print(f"准确率: {accuracy:.2%}")
        
        return self.model
    
    def _get_sparsity(self) -> float:
        """计算当前稀疏度"""
        total_params = 0
        zero_params = 0
        
        for name, param in self.model.named_parameters():
            total_params += param.nelement()
            zero_params += torch.count_nonzero(param).item()
        
        return 1 - zero_params / total_params

渐进式剪枝

class ProgressivePruner:
    """渐进式剪枝:按层重要性逐步剪枝"""
    
    def __init__(self, model):
        self.model = model
        self.layer_importance = {}
    
    def compute_importance(self):
        """计算每层的重要性"""
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.Linear):
                # 使用权重范数作为重要性指标
                importance = module.weight.data.norm().item()
                self.layer_importance[name] = importance
        
        # 按重要性排序
        sorted_layers = sorted(
            self.layer_importance.items(), 
            key=lambda x: x[1], 
            reverse=True
        )
        return sorted_layers
    
    def progressive_prune(self, target_sparsity: float):
        """渐进式剪枝:先剪不重要的层"""
        layers = self.compute_importance()
        num_layers = len(layers)
        
        # 从最不重要的层开始剪枝
        for i, (layer_name, importance) in enumerate(reversed(layers)):
            # 计算该层应剪枝的比例
            layer_sparsity = target_sparsity * (1 - i / num_layers)
            
            # 找到对应的模块
            for name, module in self.model.named_modules():
                if name == layer_name:
                    prune.l1_unstructured(module, 'weight', amount=layer_sparsity)
                    print(f"剪枝层 {layer_name}: {layer_sparsity:.2%}")
                    break
        
        return self.model

结构化剪枝实现

class StructuredPruner:
    """结构化剪枝:移除整个神经元/通道"""
    
    def prune_linear_layer(self, module: torch.nn.Linear, 
                           amount: float, dim: int = 0):
        """剪枝Linear层的输出维度"""
        # 计算每个神经元的重要性
        weight = module.weight.data
        if dim == 0:
            importance = weight.abs().sum(dim=1)
        else:
            importance = weight.abs().sum(dim=0)
        
        # 选择要移除的神经元
        num_prune = int(importance.numel() * amount)
        _, indices_to_prune = importance.topk(num_prune, largest=False)
        
        # 创建掩码
        mask = torch.ones_like(weight)
        if dim == 0:
            mask[indices_to_prune] = 0
        else:
            mask[:, indices_to_prune] = 0
        
        # 应用掩码
        module.weight.data = weight * mask
        
        return mask
    
    def prune_attention(self, attention_module, amount=0.3):
        """剪枝注意力头"""
        num_heads = attention_module.num_heads
        head_size = attention_module.head_dim
        
        # 计算每个头的重要性
        q_proj = attention_module.q_proj.weight.data
        head_importance = []
        
        for i in range(num_heads):
            start = i * head_size
            end = (i + 1) * head_size
            head_importance.append(q_proj[start:end].abs().sum().item())
        
        # 选择要移除的头
        num_prune = int(num_heads * amount)
        _, indices_to_prune = torch.tensor(head_importance).topk(num_prune, largest=False)
        
        # 创建掩码
        mask = torch.ones(num_heads * head_size)
        for idx in indices_to_prune:
            start = idx * head_size
            end = (idx + 1) * head_size
            mask[start:end] = 0
        
        # 应用掩码
        attention_module.q_proj.weight.data *= mask.unsqueeze(1)
        attention_module.k_proj.weight.data *= mask.unsqueeze(1)
        attention_module.v_proj.weight.data *= mask.unsqueeze(1)
        
        print(f"移除 {num_prune} 个注意力头")
        return mask

剪枝效果评估

class PruningEvaluator:
    """剪枝效果评估"""
    
    def evaluate(self, original_model, pruned_model, eval_data):
        """评估剪枝效果"""
        # 计算模型大小
        original_size = self._get_model_size(original_model)
        pruned_size = self._get_model_size(pruned_model)
        
        # 计算稀疏度
        sparsity = self._get_sparsity(pruned_model)
        
        # 评估准确率
        original_acc = self._evaluate_accuracy(original_model, eval_data)
        pruned_acc = self._evaluate_accuracy(pruned_model, eval_data)
        
        results = {
            'original_size_mb': original_size,
            'pruned_size_mb': pruned_size,
            'compression_ratio': original_size / pruned_size,
            'sparsity': sparsity,
            'original_accuracy': original_acc,
            'pruned_accuracy': pruned_acc,
            'accuracy_drop': original_acc - pruned_acc,
        }
        
        return results
    
    def _get_model_size(self, model) -> float:
        """计算模型大小(MB)"""
        import sys
        total_size = 0
        for param in model.parameters():
            total_size += param.nelement() * param.element_size()
        return total_size / 1024 / 1024
    
    def _get_sparsity(self, model) -> float:
        """计算稀疏度"""
        total = 0
        zeros = 0
        for param in model.parameters():
            total += param.nelement()
            zeros += (param.data == 0).sum().item()
        return zeros / total
    
    def _evaluate_accuracy(self, model, eval_data):
        """评估模型准确率"""
        # 简化实现
        return 0.95

剪枝与量化结合

class CombinedCompression:
    """剪枝+量化的组合压缩"""
    
    def compress(self, model, sparsity=0.5, bits=4):
        """先剪枝再量化"""
        print("步骤1: 剪枝")
        pruned_model = self._prune(model, sparsity)
        
        print("步骤2: 量化")
        quantized_model = self._quantize(pruned_model, bits)
        
        return quantized_model
    
    def _prune(self, model, sparsity):
        """执行剪枝"""
        import torch.nn.utils.prune as prune
        
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear):
                prune.l1_unstructured(module, 'weight', amount=sparsity)
                prune.remove(module, 'weight')  # 永久移除掩码
        
        return model
    
    def _quantize(self, model, bits):
        """执行量化"""
        if bits == 8:
            from torch.quantization import quantize_dynamic
            return quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
        elif bits == 4:
            # 使用GPTQ或AWQ
            print(f"应用{bits}bit量化")
            return model

实践效果

compression_results = {
    'LLaMA-7B': {
        'baseline': {'size': '13GB', 'perplexity': 5.28},
        'prune_50': {'size': '7GB', 'perplexity': 5.35, 'speedup': '1.3x'},
        'quant_4bit': {'size': '3.5GB', 'perplexity': 5.40, 'speedup': '2.0x'},
        'prune+quant': {'size': '2GB', 'perplexity': 5.50, 'speedup': '2.5x'},
    },
}

print("7B模型压缩效果:")
for method, result in compression_results['LLaMA-7B'].items():
    print(f"{method}: 大小={result['size']}, 困惑度={result['perplexity']}")

最佳实践