模型压缩
--- title: "模型压缩" description: "大语言模型压缩技术详解,包括剪枝、量化和知识蒸馏" tags: ["模型压缩", "剪枝", "量化", "知识蒸馏"] category: "llm" icon: "🧠"
模型压缩
模型压缩(Model Compression)是在尽可能保持模型性能的前提下,减小模型体积、降低推理成本的技术集合。对于大语言模型的部署和应用至关重要。
结构化剪枝
通过移除整个神经元或注意力头来压缩模型:
import torch
import torch.nn as nn
class StructuredPruner:
def __init__(self, model, sparsity=0.3):
self.model = model
self.sparsity = sparsity
def prune_attention_heads(self):
"""剪枝不重要的注意力头"""
for name, module in self.model.named_modules():
if hasattr(module, 'num_heads'):
# 计算每个头的重要性
importance = self._compute_head_importance(module)
# 移除重要性最低的头
num_to_prune = int(module.num_heads * self.sparsity)
prune_indices = torch.argsort(importance)[:num_to_prune]
self._mask_heads(module, prune_indices)
def _compute_head_importance(self, attention_layer):
"""基于梯度和激活计算头的重要性"""
# 使用注意力熵作为重要性度量
with torch.no_grad():
attn_weights = attention_layer.attention_weights # [batch, heads, seq, seq]
# 熵越低说明注意力越集中,越重要
entropy = -(attn_weights * torch.log(attn_weights + 1e-10)).sum(dim=-1)
head_importance = entropy.mean(dim=[0, 2]) # 平均每个头的重要性
return head_importance
def prune_feedforward_neurons(self):
"""剪枝FFN层的神经元"""
for name, module in self.model.named_modules():
if isinstance(module, nn.Linear):
# L1范数作为重要性指标
importance = module.weight.abs().sum(dim=1)
threshold = torch.quantile(importance, self.sparsity)
mask = importance > threshold
module.weight.data[~mask] = 0
非结构化剪枝
移除单个权重参数,产生稀疏矩阵:
class UnstructuredPruner:
def __init__(self, model, sparsity=0.5):
self.model = model
self.sparsity = sparsity
def magnitude_pruning(self):
"""基于幅度的剪枝"""
for name, param in self.model.named_parameters():
if 'weight' in name:
threshold = torch.quantile(param.abs(), self.sparsity)
mask = param.abs() > threshold
param.data *= mask.float()
def lottery_ticket_pruning(self):
"""彩票假设:找到可独立训练的子网络"""
# 第1步:初始化并训练完整网络
initial_weights = {n: p.clone() for n, p in self.model.named_parameters()}
# 第2步:剪枝小幅度权重
self.magnitude_pruning()
# 第3步:将剪枝后的权重重置为初始值
for name, param in self.model.named_parameters():
mask = param.data != 0
param.data = initial_weights[name] * mask.float()
# 第4步:微调剪枝后的子网络
return self.model
知识蒸馏
用大模型(教师)指导小模型(学生)学习:
class DistillationTrainer:
def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.5):
self.teacher = teacher_model
self.student = student_model
self.temperature = temperature
self.alpha = alpha
self.teacher.eval()
def distillation_loss(self, student_logits, teacher_logits, labels):
"""蒸馏损失:结合软标签和硬标签"""
# 软标签损失(KL散度)
soft_loss = nn.functional.kl_div(
nn.functional.log_softmax(student_logits / self.temperature, dim=-1),
nn.functional.softmax(teacher_logits / self.temperature, dim=-1),
reduction='batchmean'
) * (self.temperature ** 2)
# 硬标签损失(交叉熵)
hard_loss = nn.functional.cross_entropy(student_logits, labels)
# 加权组合
return self.alpha * soft_loss + (1 - self.alpha) * hard_loss
def train_step(self, input_ids, attention_mask, labels):
"""单步训练"""
with torch.no_grad():
teacher_outputs = self.teacher(input_ids, attention_mask=attention_mask)
teacher_logits = teacher_outputs.logits
student_outputs = self.student(input_ids, attention_mask=attention_mask)
student_logits = student_outputs.logits
loss = self.distillation_loss(student_logits, teacher_logits, labels)
return loss
# 使用示例
from transformers import AutoModelForCausalLM
teacher = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b-hf")
student = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
trainer = DistillationTrainer(teacher, student)
层剪枝
直接移除Transformer的某些层:
class LayerPruner:
def __init__(self, model, prune_ratio=0.2):
self.model = model
self.prune_ratio = prune_ratio
def compute_layer_importance(self, eval_data):
"""基于隐藏状态变化评估层重要性"""
layer_diffs = []
for i, layer in enumerate(self.model.transformer.layers):
# 计算输入输出差异
def hook(module, input, output):
diff = (output[0] - input[0]).abs().mean().item()
layer_diffs.append(diff)
handle = layer.register_forward_hook(hook)
self.model(eval_data)
handle.remove()
return layer_diffs
def prune_layers(self, eval_data):
"""移除重要性最低的层"""
importances = self.compute_layer_importance(eval_data)
num_to_prune = int(len(importances) * self.prune_ratio)
# 移除重要性最低的层
prune_indices = sorted(range(len(importances)),
key=lambda i: importances[i])[:num_to_prune]
for idx in sorted(prune_indices, reverse=True):
del self.model.transformer.layers[idx]
# 更新层数配置
self.model.config.num_hidden_layers -= len(prune_indices)
return self.model
量化技术
# INT8量化
def quantize_int8(model):
"""动态INT8量化"""
import torch.quantization as quant
model.qconfig = quant.get_default_qconfig('fbgemm')
model_prepared = quant.prepare(model)
# 校准
# model_prepared(calibration_data)
model_quantized = quant.convert(model_prepared)
return model_quantized
# GPTQ/AWQ(参考100-int4-inference.md获取完整示例)
压缩技术对比
| 技术 | 压缩率 | 速度提升 | 精度损失 | 实现难度 |
|---|---|---|---|---|
| 结构化剪枝 | 20-40% | 2-3x | 低 | 中 |
| 非结构化剪枝 | 50-90% | 依赖硬件 | 中 | 中 |
| 知识蒸馏 | 可变 | 取决于学生模型 | 低-中 | 高 |
| 层剪枝 | 20-50% | 2-5x | 中-高 | 低 |
模型压缩是大模型落地的关键技术,结合多种方法可以实现更优的压缩效果。