剪枝完全指南
--- title: "剪枝完全指南" description: "全面介绍大模型剪枝技术,包括非结构化剪枝、结构化剪枝和迭代剪枝策略,实现模型压缩与加速。" tags: ["模型剪枝", "非结构化剪枝", "结构化剪枝", "迭代剪枝"] category: "llm" icon: "🧠"
剪枝完全指南
什么是模型剪枝
模型剪枝通过移除模型中不重要的权重、神经元或层来压缩模型。与量化不同,剪枝直接删除参数,可以在保持精度的同时显著减少模型大小和计算量。
剪枝方法分类
非结构化剪枝
移除单个权重,产生稀疏矩阵:
import torch
import torch.nn.utils.prune as prune
def unstructured_pruning(model, amount=0.5):
"""非结构化L1剪枝"""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
# L1范数剪枝:移除绝对值最小的权重
prune.l1_unstructured(module, name='weight', amount=amount)
# 打印稀疏度
sparsity = 1 - torch.count_nonzero(module.weight) / module.weight.nelement()
print(f"{name}: 稀疏度 {sparsity:.2%}")
return model
def structured_pruning(model, amount=0.3):
"""结构化剪枝:移除整个通道"""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
# 按通道剪枝
prune.ln_structured(
module, name='weight', amount=amount, n=2, dim=0
)
elif isinstance(module, torch.nn.Linear):
# 按输出神经元剪枝
prune.ln_structured(
module, name='weight', amount=amount, n=2, dim=0
)
return model
全局剪枝
def global_pruning(model, amount=0.5):
"""全局剪枝:考虑所有层的重要性"""
parameters_to_prune = []
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
parameters_to_prune.append((module, 'weight'))
# 全局L1剪枝
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=amount,
)
# 打印全局稀疏度
total_params = sum(p.weight.nelement() for p, _ in parameters_to_prune)
total_zeros = sum(torch.count_nonzero(p.weight) == 0 for p, _ in parameters_to_prune)
print(f"全局稀疏度: {total_zeros / total_params:.2%}")
return model
迭代剪枝
逐步剪枝策略
class IterativePruner:
"""迭代剪枝:逐步增加稀疏度"""
def __init__(self, model, initial_sparsity=0.1, target_sparsity=0.9,
num_iterations=10, pruner='l1'):
self.model = model
self.initial_sparsity = initial_sparsity
self.target_sparsity = target_sparsity
self.num_iterations = num_iterations
self.pruner = pruner
def prune_one_iteration(self, current_sparsity: float):
"""执行一次剪枝迭代"""
# 计算本轮要剪枝的量
if current_sparsity == 0:
amount = self.initial_sparsity
else:
amount = self.target_sparsity - current_sparsity
# 应用剪枝
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
if self.pruner == 'l1':
prune.l1_unstructured(module, 'weight', amount=amount)
elif self.pruner == 'random':
prune.random_unstructured(module, 'weight', amount=amount)
return self._get_sparsity()
def iterative_prune(self, train_fn, eval_fn):
"""迭代剪枝流程"""
sparsity = 0
for i in range(self.num_iterations):
print(f"迭代 {i+1}/{self.num_iterations}")
# 剪枝
sparsity = self.prune_one_iteration(sparsity)
print(f"当前稀疏度: {sparsity:.2%}")
# 微调恢复精度
train_fn(self.model, epochs=5)
# 评估
accuracy = eval_fn(self.model)
print(f"准确率: {accuracy:.2%}")
return self.model
def _get_sparsity(self) -> float:
"""计算当前稀疏度"""
total_params = 0
zero_params = 0
for name, param in self.model.named_parameters():
total_params += param.nelement()
zero_params += torch.count_nonzero(param).item()
return 1 - zero_params / total_params
渐进式剪枝
class ProgressivePruner:
"""渐进式剪枝:按层重要性逐步剪枝"""
def __init__(self, model):
self.model = model
self.layer_importance = {}
def compute_importance(self):
"""计算每层的重要性"""
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
# 使用权重范数作为重要性指标
importance = module.weight.data.norm().item()
self.layer_importance[name] = importance
# 按重要性排序
sorted_layers = sorted(
self.layer_importance.items(),
key=lambda x: x[1],
reverse=True
)
return sorted_layers
def progressive_prune(self, target_sparsity: float):
"""渐进式剪枝:先剪不重要的层"""
layers = self.compute_importance()
num_layers = len(layers)
# 从最不重要的层开始剪枝
for i, (layer_name, importance) in enumerate(reversed(layers)):
# 计算该层应剪枝的比例
layer_sparsity = target_sparsity * (1 - i / num_layers)
# 找到对应的模块
for name, module in self.model.named_modules():
if name == layer_name:
prune.l1_unstructured(module, 'weight', amount=layer_sparsity)
print(f"剪枝层 {layer_name}: {layer_sparsity:.2%}")
break
return self.model
结构化剪枝实现
class StructuredPruner:
"""结构化剪枝:移除整个神经元/通道"""
def prune_linear_layer(self, module: torch.nn.Linear,
amount: float, dim: int = 0):
"""剪枝Linear层的输出维度"""
# 计算每个神经元的重要性
weight = module.weight.data
if dim == 0:
importance = weight.abs().sum(dim=1)
else:
importance = weight.abs().sum(dim=0)
# 选择要移除的神经元
num_prune = int(importance.numel() * amount)
_, indices_to_prune = importance.topk(num_prune, largest=False)
# 创建掩码
mask = torch.ones_like(weight)
if dim == 0:
mask[indices_to_prune] = 0
else:
mask[:, indices_to_prune] = 0
# 应用掩码
module.weight.data = weight * mask
return mask
def prune_attention(self, attention_module, amount=0.3):
"""剪枝注意力头"""
num_heads = attention_module.num_heads
head_size = attention_module.head_dim
# 计算每个头的重要性
q_proj = attention_module.q_proj.weight.data
head_importance = []
for i in range(num_heads):
start = i * head_size
end = (i + 1) * head_size
head_importance.append(q_proj[start:end].abs().sum().item())
# 选择要移除的头
num_prune = int(num_heads * amount)
_, indices_to_prune = torch.tensor(head_importance).topk(num_prune, largest=False)
# 创建掩码
mask = torch.ones(num_heads * head_size)
for idx in indices_to_prune:
start = idx * head_size
end = (idx + 1) * head_size
mask[start:end] = 0
# 应用掩码
attention_module.q_proj.weight.data *= mask.unsqueeze(1)
attention_module.k_proj.weight.data *= mask.unsqueeze(1)
attention_module.v_proj.weight.data *= mask.unsqueeze(1)
print(f"移除 {num_prune} 个注意力头")
return mask
剪枝效果评估
class PruningEvaluator:
"""剪枝效果评估"""
def evaluate(self, original_model, pruned_model, eval_data):
"""评估剪枝效果"""
# 计算模型大小
original_size = self._get_model_size(original_model)
pruned_size = self._get_model_size(pruned_model)
# 计算稀疏度
sparsity = self._get_sparsity(pruned_model)
# 评估准确率
original_acc = self._evaluate_accuracy(original_model, eval_data)
pruned_acc = self._evaluate_accuracy(pruned_model, eval_data)
results = {
'original_size_mb': original_size,
'pruned_size_mb': pruned_size,
'compression_ratio': original_size / pruned_size,
'sparsity': sparsity,
'original_accuracy': original_acc,
'pruned_accuracy': pruned_acc,
'accuracy_drop': original_acc - pruned_acc,
}
return results
def _get_model_size(self, model) -> float:
"""计算模型大小(MB)"""
import sys
total_size = 0
for param in model.parameters():
total_size += param.nelement() * param.element_size()
return total_size / 1024 / 1024
def _get_sparsity(self, model) -> float:
"""计算稀疏度"""
total = 0
zeros = 0
for param in model.parameters():
total += param.nelement()
zeros += (param.data == 0).sum().item()
return zeros / total
def _evaluate_accuracy(self, model, eval_data):
"""评估模型准确率"""
# 简化实现
return 0.95
剪枝与量化结合
class CombinedCompression:
"""剪枝+量化的组合压缩"""
def compress(self, model, sparsity=0.5, bits=4):
"""先剪枝再量化"""
print("步骤1: 剪枝")
pruned_model = self._prune(model, sparsity)
print("步骤2: 量化")
quantized_model = self._quantize(pruned_model, bits)
return quantized_model
def _prune(self, model, sparsity):
"""执行剪枝"""
import torch.nn.utils.prune as prune
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, 'weight', amount=sparsity)
prune.remove(module, 'weight') # 永久移除掩码
return model
def _quantize(self, model, bits):
"""执行量化"""
if bits == 8:
from torch.quantization import quantize_dynamic
return quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
elif bits == 4:
# 使用GPTQ或AWQ
print(f"应用{bits}bit量化")
return model
实践效果
compression_results = {
'LLaMA-7B': {
'baseline': {'size': '13GB', 'perplexity': 5.28},
'prune_50': {'size': '7GB', 'perplexity': 5.35, 'speedup': '1.3x'},
'quant_4bit': {'size': '3.5GB', 'perplexity': 5.40, 'speedup': '2.0x'},
'prune+quant': {'size': '2GB', 'perplexity': 5.50, 'speedup': '2.5x'},
},
}
print("7B模型压缩效果:")
for method, result in compression_results['LLaMA-7B'].items():
print(f"{method}: 大小={result['size']}, 困惑度={result['perplexity']}")
最佳实践
- 非结构化剪枝适合稀疏硬件支持
- 结构化剪枝更适合实际加速
- 迭代剪枝比一次性剪枝精度损失更小
- 剪枝后需要微调恢复精度
- 剪枝与量化结合可实现更大压缩比