模型蒸馏:大模型知识迁移
--- title: "模型蒸馏:大模型知识迁移" description: "掌握模型蒸馏的原理和实现,将大模型能力迁移到小模型中" tags: ["模型蒸馏", "知识蒸馏", "知识迁移", "模型压缩"] category: "llm" icon: "🧠"
模型蒸馏:大模型知识迁移
模型蒸馏简介
模型蒸馏(Model Distillation)是一种将大型教师模型(Teacher Model)的知识迁移到小型学生模型(Student Model)的技术。通过蒸馏,小模型可以学习大模型的行为模式,在保持较小体积的同时获得接近大模型的性能。
模型蒸馏的核心价值:
- 模型压缩:显著减小模型体积
- 推理加速:小模型推理更快
- 成本降低:减少计算资源需求
- 边缘部署:支持在资源受限设备上运行
蒸馏方法
离线蒸馏(Offline Distillation)
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
"""蒸馏损失函数"""
def __init__(self, temperature=4.0, alpha=0.7):
super().__init__()
self.temperature = temperature
self.alpha = alpha
def forward(self, student_logits, teacher_logits, labels):
# 软标签损失(KL散度)
soft_loss = F.kl_div(
F.log_softmax(student_logits / self.temperature, dim=-1),
F.softmax(teacher_logits / self.temperature, dim=-1),
reduction='batchmean'
) * (self.temperature ** 2)
# 硬标签损失(交叉熵)
hard_loss = F.cross_entropy(student_logits, labels)
# 组合损失
loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss
return loss
# 使用示例
teacher_model = AutoModelForCausalLM.from_pretrained("teacher_model")
student_model = AutoModelForCausalLM.from_pretrained("student_model")
criterion = DistillationLoss(temperature=4.0, alpha=0.7)
# 训练循环
for batch in dataloader:
# 教师模型输出
with torch.no_grad():
teacher_outputs = teacher_model(**batch)
teacher_logits = teacher_outputs.logits
# 学生模型输出
student_outputs = student_model(**batch)
student_logits = student_outputs.logits
# 计算损失
loss = criterion(student_logits, teacher_logits, batch["labels"])
# 反向传播
loss.backward()
optimizer.step()
在线蒸馏(Online Distillation)
class OnlineDistillation:
"""在线蒸馏"""
def __init__(self, teacher_model, student_model, temperature=4.0):
self.teacher = teacher_model
self.student = student_model
self.temperature = temperature
# 教师模型保持更新缓慢
self.teacher_update_rate = 0.999
def update(self, batch):
"""更新学生和教师"""
# 学生前向传播
student_outputs = self.student(**batch)
student_logits = student_outputs.logits
# 教师前向传播
with torch.no_grad():
teacher_outputs = self.teacher(**batch)
teacher_logits = teacher_outputs.logits
# 计算蒸馏损失
loss = self._compute_loss(student_logits, teacher_logits, batch["labels"])
# 更新学生
loss.backward()
# 更新教师(指数移动平均)
self._update_teacher()
return loss.item()
def _update_teacher(self):
"""更新教师模型(EMA)"""
for teacher_param, student_param in zip(
self.teacher.parameters(), self.student.parameters()
):
teacher_param.data = (
self.teacher_update_rate * teacher_param.data +
(1 - self.teacher_update_rate) * student_param.data
)
自蒸馏(Self-Distillation)
class SelfDistillation:
"""自蒸馏"""
def __init__(self, model, temperature=4.0):
self.model = model
self.teacher_model = copy.deepcopy(model)
self.teacher_model.eval()
self.temperature = temperature
def distill(self, dataloader, epochs=3):
"""自蒸馏训练"""
optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
for epoch in range(epochs):
for batch in dataloader:
# 教师输出(停止梯度)
with torch.no_grad():
teacher_outputs = self.teacher_model(**batch)
teacher_logits = teacher_outputs.logits
# 学生输出
student_outputs = self.model(**batch)
student_logits = student_outputs.logits
# 蒸馏损失
loss = self._compute_loss(student_logits, teacher_logits)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 更新教师(EMA)
self._update_teacher()
def _update_teacher(self):
"""更新教师模型"""
for t_param, s_param in zip(
self.teacher_model.parameters(), self.model.parameters()
):
t_param.data = 0.99 * t_param.data + 0.01 * s_param.data
LLM蒸馏实践
使用GPT-4蒸馏
from openai import OpenAI
import json
client = OpenAI()
def distill_with_gpt4(teacher_model_name, training_data, output_file):
"""使用GPT-4生成蒸馏数据"""
distilled_data = []
for sample in training_data:
prompt = sample["prompt"]
# GPT-4生成响应
response = client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
max_tokens=512
)
distilled_data.append({
"prompt": prompt,
"response": response.choices[0].message.content,
"teacher_model": "gpt-4"
})
# 保存
with open(output_file, "w", encoding="utf-8") as f:
json.dump(distilled_data, f, ensure_ascii=False, indent=2)
return distilled_data
使用T5蒸馏
from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
def distill_t5(teacher_name, student_name, train_dataset, eval_dataset):
"""T5模型蒸馏"""
# 加载模型
teacher = T5ForConditionalGeneration.from_pretrained(teacher_name)
student = T5ForConditionalGeneration.from_pretrained(student_name)
tokenizer = T5Tokenizer.from_pretrained(teacher_name)
# 蒸馏训练器
class DistillationTrainer(Seq2SeqTrainer):
def __init__(self, teacher_model, temperature=4.0, alpha=0.7, **kwargs):
super().__init__(**kwargs)
self.teacher = teacher_model
self.teacher.eval()
self.temperature = temperature
self.alpha = alpha
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
# 学生输出
student_outputs = model(**inputs)
student_logits = student_outputs.logits
# 教师输出
with torch.no_grad():
teacher_outputs = self.teacher(**inputs)
teacher_logits = teacher_outputs.logits
# 蒸馏损失
soft_loss = F.kl_div(
F.log_softmax(student_logits / self.temperature, dim=-1),
F.softmax(teacher_logits / self.temperature, dim=-1),
reduction='batchmean'
) * (self.temperature ** 2)
hard_loss = student_outputs.loss
loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss
return (loss, student_outputs) if return_outputs else loss
# 训练参数
training_args = Seq2SeqTrainingArguments(
output_dir="./distilled_t5",
num_train_epochs=3,
per_device_train_batch_size=8,
learning_rate=1e-4,
evaluation_strategy="epoch",
save_strategy="epoch",
predict_with_generate=True,
fp16=True
)
# 创建训练器
trainer = DistillationTrainer(
teacher_model=teacher,
model=student,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer
)
# 训练
trainer.train()
return student
LoRA蒸馏
from peft import LoraConfig, get_peft_model
def distill_with_lora(teacher_model, student_base_model, train_dataset):
"""使用LoRA进行蒸馏"""
# LoRA配置
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.1
)
# 应用LoRA到学生模型
student = get_peft_model(student_base_model, lora_config)
# 蒸馏训练
criterion = DistillationLoss(temperature=4.0, alpha=0.7)
for epoch in range(3):
for batch in train_dataset:
# 教师输出
with torch.no_grad():
teacher_logits = teacher_model(**batch).logits
# 学生输出
student_logits = student(**batch).logits
# 损失
loss = criterion(student_logits, teacher_logits, batch["labels"])
loss.backward()
optimizer.step()
optimizer.zero_grad()
return student
评估蒸馏效果
def evaluate_distillation(teacher_model, student_model, test_dataset, tokenizer):
"""评估蒸馏效果"""
results = {
"teacher_perplexity": 0,
"student_perplexity": 0,
"accuracy_comparison": 0,
"size_reduction": 0,
"speedup": 0
}
# 困惑度对比
results["teacher_perplexity"] = calculate_perplexity(teacher_model, tokenizer, test_dataset)
results["student_perplexity"] = calculate_perplexity(student_model, tokenizer, test_dataset)
# 准确率对比
teacher_correct = 0
student_correct = 0
for sample in test_dataset:
# 教师预测
teacher_output = teacher_model.generate(**sample["inputs"])
teacher_pred = tokenizer.decode(teacher_output[0])
# 学生预测
student_output = student_model.generate(**sample["inputs"])
student_pred = tokenizer.decode(student_output[0])
if teacher_pred == sample["label"]:
teacher_correct += 1
if student_pred == sample["label"]:
student_correct += 1
results["teacher_accuracy"] = teacher_correct / len(test_dataset)
results["student_accuracy"] = student_correct / len(test_dataset)
# 模型大小对比
teacher_params = sum(p.numel() for p in teacher_model.parameters())
student_params = sum(p.numel() for p in student_model.parameters())
results["size_reduction"] = 1 - student_params / teacher_params
# 速度对比
# ... 测量推理时间
return results
最佳实践
# 蒸馏最佳实践
best_practices = {
"温度选择": "通常在3-10之间,需要实验调优",
"损失权重": "alpha控制软硬标签权重,通常0.5-0.8",
"数据质量": "使用高质量的训练数据",
"教师模型": "选择与任务相关的教师模型",
"学生架构": "设计合适的学生模型架构"
}
模型蒸馏是实现LLM轻量化部署的重要技术,通过合理应用可以显著降低部署成本。