← 返回首页
🧠

SFT训练:监督微调实战

📂 llm ⏱ 3 min 467 words

--- title: "SFT训练:监督微调实战" description: "掌握监督微调的完整流程,从数据准备到模型训练和评估" tags: ["SFT", "监督微调", "指令微调", "模型训练"] category: "llm" icon: "🧠"

SFT训练:监督微调实战

SFT简介

SFT(Supervised Fine-Tuning)是将预训练模型通过监督学习适配到特定任务的过程。它是RLHF三阶段中的第一阶段,也是构建有用AI助手的基础。通过高质量的(指令,响应)对训练,模型学会遵循指令并生成有帮助的回答。

SFT的核心价值:

数据准备

数据格式

{
    "instruction": "解释什么是机器学习",
    "input": "",
    "output": "机器学习是人工智能的一个分支,它使计算机系统能够从数据中学习和改进,而无需显式编程。通过识别数据中的模式,机器学习算法可以做出决策或预测。"
}

数据加载与预处理

from datasets import load_dataset
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token

def format_instruction(sample):
    """格式化指令数据"""
    if sample.get("input"):
        text = f"""<s>[INST] <<SYS>>
你是一个有帮助的助手。
<</SYS>>

{sample['instruction']}

{sample['input']} [/INST] {sample['output']} </s>"""
    else:
        text = f"""<s>[INST] <<SYS>>
你是一个有帮助的助手。
<</SYS>>

{sample['instruction']} [/INST] {sample['output']} </s>"""
    return text

def preprocess_function(examples):
    """预处理训练数据"""
    texts = [format_instruction(x) for x in examples]
    tokenized = tokenizer(
        texts,
        truncation=True,
        max_length=2048,
        padding="max_length"
    )
    
    # 标签:只在响应部分计算损失
    tokenized["labels"] = tokenized["input_ids"].copy()
    
    # 将指令部分的标签设为-100
    for i in range(len(texts)):
        instruction_end = texts[i].find("[/INST]") + len("[/INST]")
        instruction_tokens = tokenizer(
            texts[i][:instruction_end],
            truncation=True,
            max_length=2048
        )
        tokenized["labels"][i][:len(instruction_tokens["input_ids"])] = -100
    
    return tokenized

# 加载数据集
dataset = load_dataset("json", data_files="train.json")
tokenized_dataset = dataset.map(preprocess_function, batched=True)

模型配置

LoRA微调配置

from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    bias="none"
)

# 加载基础模型
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    device_map="auto"
)

# 应用LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 13,631,488 || all params: 6,755,571,712 || trainable%: 0.2018

QLoRA配置

from transformers import BitsAndBytesConfig

# 4-bit量化配置
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

# 加载量化模型
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=bnb_config,
    device_map="auto"
)

# 应用LoRA
model = get_peft_model(model, lora_config)

训练配置

TrainingArguments

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./sft_output",
    
    # 训练参数
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    
    # 优化器参数
    learning_rate=2e-4,
    weight_decay=0.01,
    adam_beta1=0.9,
    adam_beta2=0.999,
    adam_epsilon=1e-8,
    max_grad_norm=1.0,
    
    # 学习率调度
    lr_scheduler_type="cosine",
    warmup_steps=100,
    warmup_ratio=0.03,
    
    # 精度
    fp16=True,
    bf16=False,
    
    # 日志和保存
    logging_steps=10,
    logging_dir="./logs",
    save_strategy="steps",
    save_steps=500,
    save_total_limit=3,
    
    # 评估
    evaluation_strategy="steps",
    eval_steps=500,
    
    # 其他
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    report_to="tensorboard",
    seed=42
)

使用Trainer训练

from transformers import Trainer, DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset.get("validation"),
    data_collator=data_collator,
    tokenizer=tokenizer
)

# 开始训练
trainer.train()

# 保存模型
trainer.save_model("./sft_final")

使用TRL训练

from trl import SFTTrainer, SFTConfig

sft_config = SFTConfig(
    output_dir="./sft_output",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    max_seq_length=2048,
    dataset_text_field="text",
    packing=True,  # 打包短样本
    fp16=True,
    optim="adamw_torch"
)

trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=dataset["train"],
    tokenizer=tokenizer
)

trainer.train()

训练监控

import wandb

# 初始化WandB
wandb.init(project="sft_training", name="llama2-7b-sft")

# 自定义回调
from transformers import TrainerCallback

class LoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs:
            wandb.log({
                "train_loss": logs.get("loss"),
                "learning_rate": logs.get("learning_rate"),
                "epoch": state.epoch
            })
    
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics:
            wandb.log({
                "eval_loss": metrics.get("eval_loss"),
                "perplexity": 2 ** metrics.get("eval_loss")
            })

trainer = Trainer(
    ...,
    callbacks=[LoggingCallback()]
)

评估与测试

def evaluate_sft_model(model, tokenizer, test_data):
    """评估SFT模型"""
    model.eval()
    results = []
    
    for sample in test_data:
        prompt = f"""[INST] <<SYS>>
你是一个有帮助的助手。
<</SYS>>

{sample['instruction']} [/INST]"""
        
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=256,
                temperature=0.7,
                top_p=0.9,
                do_sample=True
            )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        results.append({
            "instruction": sample["instruction"],
            "expected": sample["output"],
            "generated": response
        })
    
    return results

# 运行评估
test_results = evaluate_sft_model(model, tokenizer, test_data)

# 打印示例
for result in test_results[:3]:
    print(f"指令: {result['instruction']}")
    print(f"期望: {result['expected'][:100]}...")
    print(f"生成: {result['generated'][:100]}...")
    print("-" * 50)

常见问题与解决方案

过拟合

# 1. 增加数据量或使用数据增强
# 2. 减少训练轮数
training_args = TrainingArguments(num_train_epochs=1)

# 3. 增加dropout
lora_config = LoraConfig(lora_dropout=0.2)

# 4. 使用早停
from transformers import EarlyStoppingCallback
trainer = Trainer(..., callbacks=[EarlyStoppingCallback(early_stopping_patience=3)])

训练不稳定

# 1. 降低学习率
training_args = TrainingArguments(learning_rate=1e-5)

# 2. 增加warmup
training_args = TrainingArguments(warmup_steps=500)

# 3. 使用梯度裁剪
training_args = TrainingArguments(max_grad_norm=0.5)

SFT是构建有用AI助手的基础步骤,高质量的监督微调能够显著提升模型的指令遵循能力。