← 返回首页
🧠

模型压缩

📂 llm ⏱ 3 min 414 words

--- title: "模型压缩" description: "大语言模型压缩技术详解,包括剪枝、量化和知识蒸馏" tags: ["模型压缩", "剪枝", "量化", "知识蒸馏"] category: "llm" icon: "🧠"

模型压缩

模型压缩(Model Compression)是在尽可能保持模型性能的前提下,减小模型体积、降低推理成本的技术集合。对于大语言模型的部署和应用至关重要。

结构化剪枝

通过移除整个神经元或注意力头来压缩模型:

import torch
import torch.nn as nn

class StructuredPruner:
    def __init__(self, model, sparsity=0.3):
        self.model = model
        self.sparsity = sparsity
    
    def prune_attention_heads(self):
        """剪枝不重要的注意力头"""
        for name, module in self.model.named_modules():
            if hasattr(module, 'num_heads'):
                # 计算每个头的重要性
                importance = self._compute_head_importance(module)
                # 移除重要性最低的头
                num_to_prune = int(module.num_heads * self.sparsity)
                prune_indices = torch.argsort(importance)[:num_to_prune]
                self._mask_heads(module, prune_indices)
    
    def _compute_head_importance(self, attention_layer):
        """基于梯度和激活计算头的重要性"""
        # 使用注意力熵作为重要性度量
        with torch.no_grad():
            attn_weights = attention_layer.attention_weights  # [batch, heads, seq, seq]
            # 熵越低说明注意力越集中,越重要
            entropy = -(attn_weights * torch.log(attn_weights + 1e-10)).sum(dim=-1)
            head_importance = entropy.mean(dim=[0, 2])  # 平均每个头的重要性
        return head_importance
    
    def prune_feedforward_neurons(self):
        """剪枝FFN层的神经元"""
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                # L1范数作为重要性指标
                importance = module.weight.abs().sum(dim=1)
                threshold = torch.quantile(importance, self.sparsity)
                mask = importance > threshold
                module.weight.data[~mask] = 0

非结构化剪枝

移除单个权重参数,产生稀疏矩阵:

class UnstructuredPruner:
    def __init__(self, model, sparsity=0.5):
        self.model = model
        self.sparsity = sparsity
    
    def magnitude_pruning(self):
        """基于幅度的剪枝"""
        for name, param in self.model.named_parameters():
            if 'weight' in name:
                threshold = torch.quantile(param.abs(), self.sparsity)
                mask = param.abs() > threshold
                param.data *= mask.float()
    
    def lottery_ticket_pruning(self):
        """彩票假设:找到可独立训练的子网络"""
        # 第1步:初始化并训练完整网络
        initial_weights = {n: p.clone() for n, p in self.model.named_parameters()}
        
        # 第2步:剪枝小幅度权重
        self.magnitude_pruning()
        
        # 第3步:将剪枝后的权重重置为初始值
        for name, param in self.model.named_parameters():
            mask = param.data != 0
            param.data = initial_weights[name] * mask.float()
        
        # 第4步:微调剪枝后的子网络
        return self.model

知识蒸馏

用大模型(教师)指导小模型(学生)学习:

class DistillationTrainer:
    def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.5):
        self.teacher = teacher_model
        self.student = student_model
        self.temperature = temperature
        self.alpha = alpha
        self.teacher.eval()
    
    def distillation_loss(self, student_logits, teacher_logits, labels):
        """蒸馏损失:结合软标签和硬标签"""
        # 软标签损失(KL散度)
        soft_loss = nn.functional.kl_div(
            nn.functional.log_softmax(student_logits / self.temperature, dim=-1),
            nn.functional.softmax(teacher_logits / self.temperature, dim=-1),
            reduction='batchmean'
        ) * (self.temperature ** 2)
        
        # 硬标签损失(交叉熵)
        hard_loss = nn.functional.cross_entropy(student_logits, labels)
        
        # 加权组合
        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss
    
    def train_step(self, input_ids, attention_mask, labels):
        """单步训练"""
        with torch.no_grad():
            teacher_outputs = self.teacher(input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits
        
        student_outputs = self.student(input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits
        
        loss = self.distillation_loss(student_logits, teacher_logits, labels)
        return loss

# 使用示例
from transformers import AutoModelForCausalLM

teacher = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b-hf")
student = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

trainer = DistillationTrainer(teacher, student)

层剪枝

直接移除Transformer的某些层:

class LayerPruner:
    def __init__(self, model, prune_ratio=0.2):
        self.model = model
        self.prune_ratio = prune_ratio
    
    def compute_layer_importance(self, eval_data):
        """基于隐藏状态变化评估层重要性"""
        layer_diffs = []
        
        for i, layer in enumerate(self.model.transformer.layers):
            # 计算输入输出差异
            def hook(module, input, output):
                diff = (output[0] - input[0]).abs().mean().item()
                layer_diffs.append(diff)
            
            handle = layer.register_forward_hook(hook)
            self.model(eval_data)
            handle.remove()
        
        return layer_diffs
    
    def prune_layers(self, eval_data):
        """移除重要性最低的层"""
        importances = self.compute_layer_importance(eval_data)
        num_to_prune = int(len(importances) * self.prune_ratio)
        
        # 移除重要性最低的层
        prune_indices = sorted(range(len(importances)), 
                             key=lambda i: importances[i])[:num_to_prune]
        
        for idx in sorted(prune_indices, reverse=True):
            del self.model.transformer.layers[idx]
        
        # 更新层数配置
        self.model.config.num_hidden_layers -= len(prune_indices)
        return self.model

量化技术

# INT8量化
def quantize_int8(model):
    """动态INT8量化"""
    import torch.quantization as quant
    model.qconfig = quant.get_default_qconfig('fbgemm')
    model_prepared = quant.prepare(model)
    # 校准
    # model_prepared(calibration_data)
    model_quantized = quant.convert(model_prepared)
    return model_quantized

# GPTQ/AWQ(参考100-int4-inference.md获取完整示例)

压缩技术对比

技术 压缩率 速度提升 精度损失 实现难度
结构化剪枝 20-40% 2-3x
非结构化剪枝 50-90% 依赖硬件
知识蒸馏 可变 取决于学生模型 低-中
层剪枝 20-50% 2-5x 中-高

模型压缩是大模型落地的关键技术,结合多种方法可以实现更优的压缩效果。