← 返回首页
🧠

模型蒸馏:大模型知识迁移

📂 llm ⏱ 4 min 657 words

--- title: "模型蒸馏:大模型知识迁移" description: "掌握模型蒸馏的原理和实现,将大模型能力迁移到小模型中" tags: ["模型蒸馏", "知识蒸馏", "知识迁移", "模型压缩"] category: "llm" icon: "🧠"

模型蒸馏:大模型知识迁移

模型蒸馏简介

模型蒸馏(Model Distillation)是一种将大型教师模型(Teacher Model)的知识迁移到小型学生模型(Student Model)的技术。通过蒸馏,小模型可以学习大模型的行为模式,在保持较小体积的同时获得接近大模型的性能。

模型蒸馏的核心价值:

蒸馏方法

离线蒸馏(Offline Distillation)

import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    """蒸馏损失函数"""
    
    def __init__(self, temperature=4.0, alpha=0.7):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
    
    def forward(self, student_logits, teacher_logits, labels):
        # 软标签损失(KL散度)
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / self.temperature, dim=-1),
            F.softmax(teacher_logits / self.temperature, dim=-1),
            reduction='batchmean'
        ) * (self.temperature ** 2)
        
        # 硬标签损失(交叉熵)
        hard_loss = F.cross_entropy(student_logits, labels)
        
        # 组合损失
        loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss
        
        return loss

# 使用示例
teacher_model = AutoModelForCausalLM.from_pretrained("teacher_model")
student_model = AutoModelForCausalLM.from_pretrained("student_model")

criterion = DistillationLoss(temperature=4.0, alpha=0.7)

# 训练循环
for batch in dataloader:
    # 教师模型输出
    with torch.no_grad():
        teacher_outputs = teacher_model(**batch)
        teacher_logits = teacher_outputs.logits
    
    # 学生模型输出
    student_outputs = student_model(**batch)
    student_logits = student_outputs.logits
    
    # 计算损失
    loss = criterion(student_logits, teacher_logits, batch["labels"])
    
    # 反向传播
    loss.backward()
    optimizer.step()

在线蒸馏(Online Distillation)

class OnlineDistillation:
    """在线蒸馏"""
    
    def __init__(self, teacher_model, student_model, temperature=4.0):
        self.teacher = teacher_model
        self.student = student_model
        self.temperature = temperature
        
        # 教师模型保持更新缓慢
        self.teacher_update_rate = 0.999
    
    def update(self, batch):
        """更新学生和教师"""
        # 学生前向传播
        student_outputs = self.student(**batch)
        student_logits = student_outputs.logits
        
        # 教师前向传播
        with torch.no_grad():
            teacher_outputs = self.teacher(**batch)
            teacher_logits = teacher_outputs.logits
        
        # 计算蒸馏损失
        loss = self._compute_loss(student_logits, teacher_logits, batch["labels"])
        
        # 更新学生
        loss.backward()
        
        # 更新教师(指数移动平均)
        self._update_teacher()
        
        return loss.item()
    
    def _update_teacher(self):
        """更新教师模型(EMA)"""
        for teacher_param, student_param in zip(
            self.teacher.parameters(), self.student.parameters()
        ):
            teacher_param.data = (
                self.teacher_update_rate * teacher_param.data +
                (1 - self.teacher_update_rate) * student_param.data
            )

自蒸馏(Self-Distillation)

class SelfDistillation:
    """自蒸馏"""
    
    def __init__(self, model, temperature=4.0):
        self.model = model
        self.teacher_model = copy.deepcopy(model)
        self.teacher_model.eval()
        self.temperature = temperature
    
    def distill(self, dataloader, epochs=3):
        """自蒸馏训练"""
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
        
        for epoch in range(epochs):
            for batch in dataloader:
                # 教师输出(停止梯度)
                with torch.no_grad():
                    teacher_outputs = self.teacher_model(**batch)
                    teacher_logits = teacher_outputs.logits
                
                # 学生输出
                student_outputs = self.model(**batch)
                student_logits = student_outputs.logits
                
                # 蒸馏损失
                loss = self._compute_loss(student_logits, teacher_logits)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            
            # 更新教师(EMA)
            self._update_teacher()
    
    def _update_teacher(self):
        """更新教师模型"""
        for t_param, s_param in zip(
            self.teacher_model.parameters(), self.model.parameters()
        ):
            t_param.data = 0.99 * t_param.data + 0.01 * s_param.data

LLM蒸馏实践

使用GPT-4蒸馏

from openai import OpenAI
import json

client = OpenAI()

def distill_with_gpt4(teacher_model_name, training_data, output_file):
    """使用GPT-4生成蒸馏数据"""
    distilled_data = []
    
    for sample in training_data:
        prompt = sample["prompt"]
        
        # GPT-4生成响应
        response = client.chat.completions.create(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.7,
            max_tokens=512
        )
        
        distilled_data.append({
            "prompt": prompt,
            "response": response.choices[0].message.content,
            "teacher_model": "gpt-4"
        })
    
    # 保存
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(distilled_data, f, ensure_ascii=False, indent=2)
    
    return distilled_data

使用T5蒸馏

from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

def distill_t5(teacher_name, student_name, train_dataset, eval_dataset):
    """T5模型蒸馏"""
    # 加载模型
    teacher = T5ForConditionalGeneration.from_pretrained(teacher_name)
    student = T5ForConditionalGeneration.from_pretrained(student_name)
    
    tokenizer = T5Tokenizer.from_pretrained(teacher_name)
    
    # 蒸馏训练器
    class DistillationTrainer(Seq2SeqTrainer):
        def __init__(self, teacher_model, temperature=4.0, alpha=0.7, **kwargs):
            super().__init__(**kwargs)
            self.teacher = teacher_model
            self.teacher.eval()
            self.temperature = temperature
            self.alpha = alpha
        
        def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
            # 学生输出
            student_outputs = model(**inputs)
            student_logits = student_outputs.logits
            
            # 教师输出
            with torch.no_grad():
                teacher_outputs = self.teacher(**inputs)
                teacher_logits = teacher_outputs.logits
            
            # 蒸馏损失
            soft_loss = F.kl_div(
                F.log_softmax(student_logits / self.temperature, dim=-1),
                F.softmax(teacher_logits / self.temperature, dim=-1),
                reduction='batchmean'
            ) * (self.temperature ** 2)
            
            hard_loss = student_outputs.loss
            
            loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss
            
            return (loss, student_outputs) if return_outputs else loss
    
    # 训练参数
    training_args = Seq2SeqTrainingArguments(
        output_dir="./distilled_t5",
        num_train_epochs=3,
        per_device_train_batch_size=8,
        learning_rate=1e-4,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        predict_with_generate=True,
        fp16=True
    )
    
    # 创建训练器
    trainer = DistillationTrainer(
        teacher_model=teacher,
        model=student,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer
    )
    
    # 训练
    trainer.train()
    
    return student

LoRA蒸馏

from peft import LoraConfig, get_peft_model

def distill_with_lora(teacher_model, student_base_model, train_dataset):
    """使用LoRA进行蒸馏"""
    # LoRA配置
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.1
    )
    
    # 应用LoRA到学生模型
    student = get_peft_model(student_base_model, lora_config)
    
    # 蒸馏训练
    criterion = DistillationLoss(temperature=4.0, alpha=0.7)
    
    for epoch in range(3):
        for batch in train_dataset:
            # 教师输出
            with torch.no_grad():
                teacher_logits = teacher_model(**batch).logits
            
            # 学生输出
            student_logits = student(**batch).logits
            
            # 损失
            loss = criterion(student_logits, teacher_logits, batch["labels"])
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
    
    return student

评估蒸馏效果

def evaluate_distillation(teacher_model, student_model, test_dataset, tokenizer):
    """评估蒸馏效果"""
    results = {
        "teacher_perplexity": 0,
        "student_perplexity": 0,
        "accuracy_comparison": 0,
        "size_reduction": 0,
        "speedup": 0
    }
    
    # 困惑度对比
    results["teacher_perplexity"] = calculate_perplexity(teacher_model, tokenizer, test_dataset)
    results["student_perplexity"] = calculate_perplexity(student_model, tokenizer, test_dataset)
    
    # 准确率对比
    teacher_correct = 0
    student_correct = 0
    
    for sample in test_dataset:
        # 教师预测
        teacher_output = teacher_model.generate(**sample["inputs"])
        teacher_pred = tokenizer.decode(teacher_output[0])
        
        # 学生预测
        student_output = student_model.generate(**sample["inputs"])
        student_pred = tokenizer.decode(student_output[0])
        
        if teacher_pred == sample["label"]:
            teacher_correct += 1
        if student_pred == sample["label"]:
            student_correct += 1
    
    results["teacher_accuracy"] = teacher_correct / len(test_dataset)
    results["student_accuracy"] = student_correct / len(test_dataset)
    
    # 模型大小对比
    teacher_params = sum(p.numel() for p in teacher_model.parameters())
    student_params = sum(p.numel() for p in student_model.parameters())
    results["size_reduction"] = 1 - student_params / teacher_params
    
    # 速度对比
    # ... 测量推理时间
    
    return results

最佳实践

# 蒸馏最佳实践
best_practices = {
    "温度选择": "通常在3-10之间,需要实验调优",
    "损失权重": "alpha控制软硬标签权重,通常0.5-0.8",
    "数据质量": "使用高质量的训练数据",
    "教师模型": "选择与任务相关的教师模型",
    "学生架构": "设计合适的学生模型架构"
}

模型蒸馏是实现LLM轻量化部署的重要技术,通过合理应用可以显著降低部署成本。