SFT训练:监督微调实战
--- title: "SFT训练:监督微调实战" description: "掌握监督微调的完整流程,从数据准备到模型训练和评估" tags: ["SFT", "监督微调", "指令微调", "模型训练"] category: "llm" icon: "🧠"
SFT训练:监督微调实战
SFT简介
SFT(Supervised Fine-Tuning)是将预训练模型通过监督学习适配到特定任务的过程。它是RLHF三阶段中的第一阶段,也是构建有用AI助手的基础。通过高质量的(指令,响应)对训练,模型学会遵循指令并生成有帮助的回答。
SFT的核心价值:
- 任务适配:将通用模型适配到特定领域
- 格式学习:教会模型遵循特定的输出格式
- 知识注入:将领域知识融入模型
- 行为塑造:定义模型的基本行为模式
数据准备
数据格式
{
"instruction": "解释什么是机器学习",
"input": "",
"output": "机器学习是人工智能的一个分支,它使计算机系统能够从数据中学习和改进,而无需显式编程。通过识别数据中的模式,机器学习算法可以做出决策或预测。"
}
数据加载与预处理
from datasets import load_dataset
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
def format_instruction(sample):
"""格式化指令数据"""
if sample.get("input"):
text = f"""<s>[INST] <<SYS>>
你是一个有帮助的助手。
<</SYS>>
{sample['instruction']}
{sample['input']} [/INST] {sample['output']} </s>"""
else:
text = f"""<s>[INST] <<SYS>>
你是一个有帮助的助手。
<</SYS>>
{sample['instruction']} [/INST] {sample['output']} </s>"""
return text
def preprocess_function(examples):
"""预处理训练数据"""
texts = [format_instruction(x) for x in examples]
tokenized = tokenizer(
texts,
truncation=True,
max_length=2048,
padding="max_length"
)
# 标签:只在响应部分计算损失
tokenized["labels"] = tokenized["input_ids"].copy()
# 将指令部分的标签设为-100
for i in range(len(texts)):
instruction_end = texts[i].find("[/INST]") + len("[/INST]")
instruction_tokens = tokenizer(
texts[i][:instruction_end],
truncation=True,
max_length=2048
)
tokenized["labels"][i][:len(instruction_tokens["input_ids"])] = -100
return tokenized
# 加载数据集
dataset = load_dataset("json", data_files="train.json")
tokenized_dataset = dataset.map(preprocess_function, batched=True)
模型配置
LoRA微调配置
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
bias="none"
)
# 加载基础模型
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="auto"
)
# 应用LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 13,631,488 || all params: 6,755,571,712 || trainable%: 0.2018
QLoRA配置
from transformers import BitsAndBytesConfig
# 4-bit量化配置
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
# 加载量化模型
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantization_config=bnb_config,
device_map="auto"
)
# 应用LoRA
model = get_peft_model(model, lora_config)
训练配置
TrainingArguments
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./sft_output",
# 训练参数
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
# 优化器参数
learning_rate=2e-4,
weight_decay=0.01,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-8,
max_grad_norm=1.0,
# 学习率调度
lr_scheduler_type="cosine",
warmup_steps=100,
warmup_ratio=0.03,
# 精度
fp16=True,
bf16=False,
# 日志和保存
logging_steps=10,
logging_dir="./logs",
save_strategy="steps",
save_steps=500,
save_total_limit=3,
# 评估
evaluation_strategy="steps",
eval_steps=500,
# 其他
gradient_checkpointing=True,
optim="paged_adamw_8bit",
report_to="tensorboard",
seed=42
)
使用Trainer训练
from transformers import Trainer, DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset.get("validation"),
data_collator=data_collator,
tokenizer=tokenizer
)
# 开始训练
trainer.train()
# 保存模型
trainer.save_model("./sft_final")
使用TRL训练
from trl import SFTTrainer, SFTConfig
sft_config = SFTConfig(
output_dir="./sft_output",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
max_seq_length=2048,
dataset_text_field="text",
packing=True, # 打包短样本
fp16=True,
optim="adamw_torch"
)
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=dataset["train"],
tokenizer=tokenizer
)
trainer.train()
训练监控
import wandb
# 初始化WandB
wandb.init(project="sft_training", name="llama2-7b-sft")
# 自定义回调
from transformers import TrainerCallback
class LoggingCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs:
wandb.log({
"train_loss": logs.get("loss"),
"learning_rate": logs.get("learning_rate"),
"epoch": state.epoch
})
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if metrics:
wandb.log({
"eval_loss": metrics.get("eval_loss"),
"perplexity": 2 ** metrics.get("eval_loss")
})
trainer = Trainer(
...,
callbacks=[LoggingCallback()]
)
评估与测试
def evaluate_sft_model(model, tokenizer, test_data):
"""评估SFT模型"""
model.eval()
results = []
for sample in test_data:
prompt = f"""[INST] <<SYS>>
你是一个有帮助的助手。
<</SYS>>
{sample['instruction']} [/INST]"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
results.append({
"instruction": sample["instruction"],
"expected": sample["output"],
"generated": response
})
return results
# 运行评估
test_results = evaluate_sft_model(model, tokenizer, test_data)
# 打印示例
for result in test_results[:3]:
print(f"指令: {result['instruction']}")
print(f"期望: {result['expected'][:100]}...")
print(f"生成: {result['generated'][:100]}...")
print("-" * 50)
常见问题与解决方案
过拟合
# 1. 增加数据量或使用数据增强
# 2. 减少训练轮数
training_args = TrainingArguments(num_train_epochs=1)
# 3. 增加dropout
lora_config = LoraConfig(lora_dropout=0.2)
# 4. 使用早停
from transformers import EarlyStoppingCallback
trainer = Trainer(..., callbacks=[EarlyStoppingCallback(early_stopping_patience=3)])
训练不稳定
# 1. 降低学习率
training_args = TrainingArguments(learning_rate=1e-5)
# 2. 增加warmup
training_args = TrainingArguments(warmup_steps=500)
# 3. 使用梯度裁剪
training_args = TrainingArguments(max_grad_norm=0.5)
SFT是构建有用AI助手的基础步骤,高质量的监督微调能够显著提升模型的指令遵循能力。