← 返回首页
🧠

DPO算法:直接偏好优化

📂 llm ⏱ 3 min 420 words

--- title: "DPO算法:直接偏好优化" description: "掌握DPO的原理和实现,无需奖励模型的简化对齐方法" tags: ["DPO", "偏好优化", "RLHF替代", "模型对齐"] category: "llm" icon: "🧠"

DPO算法:直接偏好优化

DPO简介

DPO(Direct Preference Optimization)是一种简化的模型对齐方法,直接使用偏好数据优化策略,无需训练独立的奖励模型。DPO由斯坦福大学研究团队提出,通过数学推导将RLHF目标转化为简单的分类损失。

DPO的核心优势:

原理推导

RLHF目标的重参数化

传统RLHF目标:

max E[reward] - β * KL(π || π_ref)

DPO的关键洞察:最优策略可以解析表达为:

π*(y|x) = π_ref(y|x) * exp(r(x,y) / β) / Z(x)

反解奖励函数:

r(x,y) = β * log(π*(y|x) / π_ref(y|x)) + β * log(Z(x))

DPO损失函数

import torch
import torch.nn.functional as F

def dpo_loss(policy_chosen_logps, policy_rejected_logps,
             reference_chosen_logps, reference_rejected_logps,
             beta=0.1):
    """DPO损失函数"""
    # 计算log比率
    chosen_log_ratios = policy_chosen_logps - reference_chosen_logps
    rejected_log_ratios = policy_rejected_logps - reference_rejected_logps
    
    # DPO损失
    logits = beta * (chosen_log_ratios - rejected_log_ratios)
    loss = -F.logsigmoid(logits).mean()
    
    # 准确率(可选)
    chosen_rewards = beta * chosen_log_ratios
    rejected_rewards = beta * rejected_log_ratios
    accuracy = (chosen_rewards > rejected_rewards).float().mean()
    
    return loss, accuracy

实现细节

数据格式

# 偏好数据格式
preference_data = [
    {
        "prompt": "什么是机器学习?",
        "chosen": "机器学习是人工智能的一个分支,它使计算机系统能够从数据中学习...",
        "rejected": "机器学习就是让电脑学习东西。"
    },
    {
        "prompt": "如何学习编程?",
        "chosen": "学习编程可以按照以下步骤:1. 选择一门编程语言...",
        "rejected": "多写代码就行了。"
    }
]

# 转换为模型输入
def format_dpo_data(sample, tokenizer):
    """格式化DPO数据"""
    # chosen序列
    chosen_text = f"Human: {sample['prompt']}\n\nAssistant: {sample['chosen']}"
    chosen_tokens = tokenizer(chosen_text, truncation=True, max_length=512)
    
    # rejected序列
    rejected_text = f"Human: {sample['prompt']}\n\nAssistant: {sample['rejected']}"
    rejected_tokens = tokenizer(rejected_text, truncation=True, max_length=512)
    
    return {
        "chosen_input_ids": chosen_tokens["input_ids"],
        "chosen_attention_mask": chosen_tokens["attention_mask"],
        "rejected_input_ids": rejected_tokens["input_ids"],
        "rejected_attention_mask": rejected_tokens["attention_mask"],
    }

完整训练代码

from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import DPOTrainer, DPOConfig
from datasets import load_dataset

# 加载模型和分词器
model = AutoModelForCausalLM.from_pretrained("sft_model")
tokenizer = AutoTokenizer.from_pretrained("sft_model")
tokenizer.pad_token = tokenizer.eos_token

# DPO配置
dpo_config = DPOConfig(
    output_dir="./dpo_output",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=5e-7,
    beta=0.1,  # KL惩罚系数
    loss_type="sigmoid",  # 损失类型
    max_length=1024,
    max_prompt_length=512,
    remove_unused_columns=False,
    logging_steps=10,
    save_steps=500,
    fp16=True,
    optim="adamw_torch"
)

# 加载偏好数据
dataset = load_dataset("json", data_files="preference_data.json")

# 创建DPO训练器
dpo_trainer = DPOTrainer(
    model=model,
    ref_model=None,  # 可选:如果设置,则使用参考模型
    args=dpo_config,
    train_dataset=dataset["train"],
    tokenizer=tokenizer
)

# 训练
dpo_trainer.train()

# 保存模型
dpo_trainer.save_model("./dpo_model")

高级技巧

参考模型处理

# 方法1:使用独立参考模型
ref_model = AutoModelForCausalLM.from_pretrained("sft_model")
dpo_trainer = DPOTrainer(
    model=model,
    ref_model=ref_model,  # 独立参考模型
    ...
)

# 方法2:不使用参考模型(假设policy ≈ reference)
dpo_trainer = DPOTrainer(
    model=model,
    ref_model=None,
    ...
)

损失函数变体

def dpo_loss_variant(policy_chosen_logps, policy_rejected_logps,
                     reference_chosen_logps, reference_rejected_logps,
                     beta=0.1, loss_type="sigmoid"):
    """多种DPO损失变体"""
    chosen_log_ratios = policy_chosen_logps - reference_chosen_logps
    rejected_log_ratios = policy_rejected_logps - reference_rejected_logps
    
    logits = beta * (chosen_log_ratios - rejected_log_ratios)
    
    if loss_type == "sigmoid":
        loss = -F.logsigmoid(logits).mean()
    elif loss_type == "hinge":
        loss = torch.relu(1 - logits).mean()
    elif loss_type == "ipo":
        loss = (logits ** 2).mean()
    elif loss_type == "kto":
        loss = (1 - torch.tanh(logits)).mean()
    
    return loss

数据质量控制

def filter_preference_data(data, min_length=50, max_length=1000):
    """过滤偏好数据"""
    filtered = []
    for item in data:
        # 长度过滤
        if len(item["chosen"]) < min_length or len(item["chosen"]) > max_length:
            continue
        if len(item["rejected"]) < min_length or len(item["rejected"]) > max_length:
            continue
        
        # 差异性过滤(chosen和rejected应该有足够差异)
        if item["chosen"] == item["rejected"]:
            continue
        
        filtered.append(item)
    
    return filtered

与RLHF对比

# RLHF流程
rlhf_steps = [
    "1. 收集偏好数据",
    "2. 训练奖励模型",
    "3. PPO训练策略模型"
]

# DPO流程
dpo_steps = [
    "1. 收集偏好数据",
    "2. 直接训练策略模型"
]

# 性能对比
comparison = {
    "训练稳定性": {"RLHF": "中等", "DPO": "高"},
    "计算成本": {"RLHF": "高", "DPO": "低"},
    "实现复杂度": {"RLHF": "高", "DPO": "低"},
    "最终效果": {"RLHF": "优秀", "DPO": "优秀"},
    "超参数敏感度": {"RLHF": "高", "DPO": "低"}
}

在LLaMA中的应用

from transformers import LlamaForCausalLM
from trl import DPOTrainer

# 加载LLaMA
model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16
)

# DPO训练
dpo_trainer = DPOTrainer(
    model=model,
    args=dpo_config,
    train_dataset=preference_dataset,
    tokenizer=tokenizer,
    max_length=2048,
    max_prompt_length=1024
)

dpo_trainer.train()

DPO通过简化RLHF流程,使得模型对齐变得更加容易实现和部署,成为当前最流行的对齐方法之一。