← 返回首页
🧠

DPO数据:直接偏好优化训练数据的构建方法

📂 llm ⏱ 5 min 831 words

DPO数据:直接偏好优化训练数据的构建方法

DPO 算法简介

DPO(Direct Preference Optimization)是2023年由斯坦福大学提出的新型对齐训练方法。与传统的RLHF不同,DPO 将奖励建模和强化学习优化合并为一个简单的分类问题,直接使用偏好数据优化语言模型策略。

DPO 的核心思想是:通过一个巧妙的数学变换,将奖励函数表示为策略模型和参考模型的对数概率比。这意味着我们可以直接使用偏好数据训练策略模型,而无需单独训练奖励模型。

DPO 损失函数

import torch
import torch.nn.functional as F

def dpo_loss(policy_chosen_logps, policy_rejected_logps,
             reference_chosen_logps, reference_rejected_logps,
             beta=0.1):
    """
    DPO损失函数实现
    
    参数:
        policy_chosen_logps: 策略模型在chosen响应上的对数概率
        policy_rejected_logps: 策略模型在rejected响应上的对数概率
        reference_chosen_logps: 参考模型在chosen响应上的对数概率
        reference_rejected_logps: 参考模型在rejected响应上的对数概率
        beta: 温度参数,控制偏离参考模型的程度
    """
    # 计算隐式奖励
    chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps)
    rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps)
    
    # DPO损失
    loss = -F.logsigmoid(chosen_rewards - rejected_rewards)
    
    # 计算奖励准确率
    accuracy = (chosen_rewards > rejected_rewards).float().mean()
    
    # 计算奖励margin
    reward_margin = (chosen_rewards - rejected_rewards).mean()
    
    return loss.mean(), accuracy, reward_margin

DPO 数据格式规范

基础三元组格式

{
  "prompt": "请解释什么是机器学习",
  "chosen": "机器学习是人工智能的一个分支,它使计算机系统能够从数据中自动学习和改进,而无需明确编程。主要包括监督学习、无监督学习和强化学习三种类型...",
  "rejected": "机器学习就是让机器自己学习的意思。"
}

完整训练数据格式

from dataclasses import dataclass
from typing import Optional, List

@dataclass
class DPOExample:
    """DPO训练数据格式"""
    prompt: str
    chosen: str
    rejected: str
    system: Optional[str] = None
    chosen_logps: Optional[float] = None  # 可选:预计算的对数概率
    rejected_logps: Optional[float] = None
    
    def to_dict(self):
        return {
            "prompt": self.prompt,
            "chosen": self.chosen,
            "rejected": self.rejected,
            "system": self.system
        }

# 数据集格式示例
dataset_format = {
    "type": "chatml",
    "messages": [
        {"role": "system", "content": "你是一个有帮助的AI助手"},
        {"role": "user", "content": "请解释量子计算"},
        {"role": "assistant", "content": "量子计算利用量子力学原理进行计算..."}  # chosen
    ],
    "rejected": "量子计算就是用量子的计算方式。"  # rejected
}

数据构建流程

1. Prompt 收集与多样性保证

import random
from typing import List, Dict

class DPOPromptCollector:
    """DPO提示收集器"""
    
    def __init__(self):
        self.prompt_templates = {
            "factual": [
                "解释{concept}的基本原理",
                "{event}发生在什么时候?",
                "{technology}是如何工作的?"
            ],
            "creative": [
                "写一个关于{theme}的故事",
                "创作一首{style}风格的诗",
                "为{product}写一段广告文案"
            ],
            "reasoning": [
                "分析{situation}的利弊",
                "如何解决{problem}?",
                "比较{option_a}和{option_b}"
            ],
            "code": [
                "用Python实现{algorithm}",
                "调试以下代码:{code}",
                "优化{function}的性能"
            ]
        }
    
    def collect_prompts(self, n_per_category: int = 500) -> List[Dict]:
        """收集多样化的prompts"""
        prompts = []
        
        for category, templates in self.prompt_templates.items():
            for i in range(n_per_category):
                template = random.choice(templates)
                # 实际使用中需要填充具体的参数值
                prompts.append({
                    "id": f"{category}_{i}",
                    "category": category,
                    "prompt": template
                })
        
        # 打乱顺序
        random.shuffle(prompts)
        return prompts

    def ensure_diversity(self, prompts: List[Dict], min_categories: int = 4) -> bool:
        """确保prompt多样性"""
        categories = set(p["category"] for p in prompts)
        return len(categories) >= min_categories

2. 响应生成与质量控制

from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple

class DPOResponseGenerator:
    """DPO响应生成器"""
    
    def __init__(self, model: str = "gpt-4"):
        self.client = OpenAI()
        self.model = model
    
    def generate_pair(self, prompt: str, temperature_high: float = 1.0, 
                      temperature_low: float = 0.3) -> Tuple[str, str]:
        """为一个prompt生成chosen和rejected对"""
        
        # 高温度生成可能较差的回答(rejected)
        rejected_response = self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=temperature_high,
            max_tokens=512
        ).choices[0].message.content
        
        # 低温度生成较好的回答(chosen)
        chosen_response = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "请提供详细、准确、有帮助的回答"},
                {"role": "user", "content": prompt}
            ],
            temperature=temperature_low,
            max_tokens=512
        ).choices[0].message.content
        
        return chosen_response, rejected_response
    
    def generate_batch(self, prompts: List[Dict], max_workers: int = 4) -> List[DPOExample]:
        """批量生成响应对"""
        results = []
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = []
            for prompt_item in prompts:
                future = executor.submit(self.generate_pair, prompt_item["prompt"])
                futures.append((prompt_item, future))
            
            for prompt_item, future in futures:
                chosen, rejected = future.result()
                results.append(DPOExample(
                    prompt=prompt_item["prompt"],
                    chosen=chosen,
                    rejected=rejected
                ))
        
        return results

3. 数据质量评估

class DPOQualityAssessor:
    """DPO数据质量评估"""
    
    def __init__(self, quality_thresholds: Dict = None):
        self.thresholds = quality_thresholds or {
            "min_length": 20,
            "max_length_diff_ratio": 3.0,
            "min_semantic_diff": 0.3,
            "max_toxicity": 0.1
        }
    
    def assess_pair(self, chosen: str, rejected: str) -> Dict:
        """评估一对chosen/rejected的质量"""
        
        # 长度检查
        length_ratio = max(len(chosen), len(rejected)) / max(min(len(chosen), len(rejected)), 1)
        length_ok = length_ratio <= self.thresholds["max_length_diff_ratio"]
        
        # 最小长度检查
        min_length_ok = min(len(chosen), len(rejected)) >= self.thresholds["min_length"]
        
        # 语义差异检查
        semantic_diff = self._compute_semantic_diff(chosen, rejected)
        semantic_ok = semantic_diff >= self.thresholds["min_semantic_diff"]
        
        # 质量差异检查
        quality_diff = self._compute_quality_diff(chosen, rejected)
        
        return {
            "length_balanced": length_ok,
            "min_length_met": min_length_ok,
            "semantic_diverse": semantic_ok,
            "quality_gap": quality_diff,
            "overall_score": sum([length_ok, min_length_ok, semantic_ok]) / 3
        }
    
    def _compute_semantic_diff(self, text1: str, text2: str) -> float:
        """计算语义差异"""
        words1 = set(text1.split())
        words2 = set(text2.split())
        
        if not words1 or not words2:
            return 0.0
        
        jaccard = len(words1 & words2) / len(words1 | words2)
        return 1 - jaccard
    
    def _compute_quality_diff(self, chosen: str, rejected: str) -> float:
        """计算质量差异(简化版)"""
        # 实际应用中可以使用更复杂的质量评估模型
        chosen_score = len(chosen.split()) / 100  # 简单启发式
        rejected_score = len(rejected.split()) / 100
        return max(0, chosen_score - rejected_score)
    
    def filter_dataset(self, dataset: List[DPOExample]) -> List[DPOExample]:
        """过滤低质量数据"""
        filtered = []
        
        for item in dataset:
            assessment = self.assess_pair(item.chosen, item.rejected)
            if assessment["overall_score"] >= 0.7:
                filtered.append(item)
        
        return filtered

高级数据构建策略

1. 生成多样化负样本

class NegativeSampler:
    """多样化负样本生成"""
    
    def generate_negatives(self, prompt: str, chosen: str, n_negatives: int = 3) -> List[str]:
        """为每个chosen生成多种类型的负样本"""
        negatives = []
        
        # 1. 不完整回答
        incomplete = chosen[:len(chosen)//3] + "..."
        negatives.append(incomplete)
        
        # 2. 低质量回答(添加填充词)
        low_quality = chosen.replace("。", ",嗯,我觉得,").replace(",", ",嗯,")
        negatives.append(low_quality)
        
        # 3. 偏题回答
        off_topic = "这是一个很好的问题。让我想想..."
        negatives.append(off_topic)
        
        return negatives[:n_negatives]

    def create_multi_negative_dataset(self, dataset: List[DPOExample]) -> List[DPOExample]:
        """创建多负样本数据集"""
        sampler = NegativeSampler()
        augmented = []
        
        for item in dataset:
            negatives = sampler.generate_negatives(item.prompt, item.chosen)
            for neg in negatives:
                augmented.append(DPOExample(
                    prompt=item.prompt,
                    chosen=item.chosen,
                    rejected=neg
                ))
        
        return augmented

2. 数据增强技术

class DPODataAugmenter:
    """DPO数据增强"""
    
    def augment_prompt(self, original_prompt: str) -> List[str]:
        """增强prompt"""
        augmented = [
            original_prompt,
            f"请详细说明:{original_prompt}",
            f"关于{original_prompt},你能解释一下吗?",
            f"我需要了解{original_prompt}的信息"
        ]
        return augmented
    
    def augment_with_paraphrase(self, text: str, n_variants: int = 2) -> List[str]:
        """通过释义增强文本"""
        # 实际应用中使用专门的释义模型
        variants = [text]
        return variants
    
    def cross_augment(self, dataset: List[DPOExample]) -> List[DPOExample]:
        """交叉增强"""
        augmented = []
        
        for item in dataset:
            # 保持chosen不变,生成多个rejected
            negatives = NegativeSampler().generate_negatives(
                item.prompt, item.chosen, n_negatives=3
            )
            
            for neg in negatives:
                augmented.append(DPOExample(
                    prompt=item.prompt,
                    chosen=item.chosen,
                    rejected=neg
                ))
        
        return augmented

数据加载与预处理

from torch.utils.data import Dataset
from transformers import AutoTokenizer

class DPODataset(Dataset):
    """DPO训练数据集"""
    
    def __init__(self, data: List[DPOExample], tokenizer, max_length: int = 512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # 编码prompt
        prompt_tokens = self.tokenizer(
            item.prompt,
            max_length=self.max_length // 2,
            truncation=True,
            return_tensors="pt"
        )
        
        # 编码chosen
        chosen_text = f"{item.prompt}\n\n{item.chosen}"
        chosen_tokens = self.tokenizer(
            chosen_text,
            max_length=self.max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        
        # 编码rejected
        rejected_text = f"{item.prompt}\n\n{item.rejected}"
        rejected_tokens = self.tokenizer(
            rejected_text,
            max_length=self.max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        
        return {
            "prompt_input_ids": prompt_tokens["input_ids"].squeeze(),
            "prompt_attention_mask": prompt_tokens["attention_mask"].squeeze(),
            "chosen_input_ids": chosen_tokens["input_ids"].squeeze(),
            "chosen_attention_mask": chosen_tokens["attention_mask"].squeeze(),
            "rejected_input_ids": rejected_tokens["input_ids"].squeeze(),
            "rejected_attention_mask": rejected_tokens["attention_mask"].squeeze()
        }

# 数据加载示例
def load_dpo_dataset(file_path: str, tokenizer):
    """加载DPO数据集"""
    with open(file_path, 'r', encoding='utf-8') as f:
        raw_data = json.load(f)
    
    dataset = [DPOExample(**item) for item in raw_data]
    return DPODataset(dataset, tokenizer)

训练配置建议

dpo_training_config = {
    "beta": 0.1,  # 温度参数
    "learning_rate": 5e-7,
    "batch_size": 4,
    "gradient_accumulation_steps": 4,
    "num_epochs": 3,
    "max_length": 1024,
    "warmup_ratio": 0.1,
    "weight_decay": 0.01,
    "lr_scheduler_type": "cosine",
    "bf16": True
}

最佳实践总结

  1. 数据质量优先:确保chosen明显优于rejected
  2. 多样性保障:覆盖不同场景、难度和风格
  3. 平衡正负样本:避免正负样本质量差距过大或过小
  4. 适当的数据量:DPO通常比RLHF需要更少的数据
  5. 迭代优化:根据训练效果调整数据分布
  6. 避免数据泄露:确保训练集和测试集不重叠

DPO方法的简洁性使得数据准备成为关键成功因素。高质量的偏好数据是DPO训练成功的基础。