← 返回首页
🧠

合成数据生成

📂 llm ⏱ 5 min 835 words

--- title: "合成数据生成" description: "详细介绍合成数据生成技术,包括Self-Instruct方法、数据增强和LLM生成数据的实践" tags: ["合成数据", "Self-Instruct", "数据增强", "LLM生成"] category: "llm" icon: "🧠"

合成数据生成

合成数据概述

合成数据是指通过算法或模型生成的、用于训练的数据,而非从现实世界收集的数据。对于LLM训练,合成数据具有以下优势:

Self-Instruct方法

基本原理

Self-Instruct是一种使用LLM自身生成指令数据的方法。

from transformers import AutoTokenizer, AutoModelForCausalLM

class SelfInstructGenerator:
    """Self-Instruct数据生成器"""
    def __init__(self, model_name="llama-7b"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
    
    def generate_instruction(self, seed_task):
        """从种子任务生成新指令"""
        prompt = f"""以下是一个任务的示例:
        指令:{seed_task['instruction']}
        输入:{seed_task['input']}
        输出:{seed_task['output']}
        
        请生成一个类似的新任务,格式相同:
        指令:"""
        
        inputs = self.tokenizer(prompt, return_tensors="pt")
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.7,
            do_sample=True
        )
        
        generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return self.parse_instruction(generated)
    
    def parse_instruction(self, text):
        """解析生成的指令"""
        lines = text.strip().split('\n')
        instruction = ""
        input_text = ""
        output_text = ""
        
        current_field = None
        for line in lines:
            if line.startswith("指令:"):
                current_field = "instruction"
                instruction = line[3:].strip()
            elif line.startswith("输入:"):
                current_field = "input"
                input_text = line[3:].strip()
            elif line.startswith("输出:"):
                current_field = "output"
                output_text = line[3:].strip()
            elif current_field:
                if current_field == "instruction":
                    instruction += line
                elif current_field == "input":
                    input_text += line
                elif current_field == "output":
                    output_text += line
        
        return {
            "instruction": instruction,
            "input": input_text,
            "output": output_text
        }

Self-Instruct管道

class SelfInstructPipeline:
    """完整的Self-Instruct管道"""
    def __init__(self, seed_tasks, num_instructions=1000):
        self.seed_tasks = seed_tasks
        self.num_instructions = num_instructions
        self.generator = SelfInstructGenerator()
    
    def generate_dataset(self):
        """生成数据集"""
        generated_tasks = list(self.seed_tasks)
        
        while len(generated_tasks) < self.num_instructions:
            # 随机选择种子任务
            import random
            seed_task = random.choice(generated_tasks[:100])  # 使用前100个作为种子
            
            # 生成新任务
            new_task = self.generator.generate_instruction(seed_task)
            
            # 验证质量
            if self.validate_task(new_task):
                generated_tasks.append(new_task)
        
        return generated_tasks[:self.num_instructions]
    
    def validate_task(self, task):
        """验证任务质量"""
        # 检查字段完整性
        if not task.get("instruction"):
            return False
        
        # 检查长度
        if len(task["instruction"]) < 10:
            return False
        
        # 检查是否与现有任务重复
        # 简单的重复检测
        for existing in self.seed_tasks:
            if self.calculate_similarity(task["instruction"], existing["instruction"]) > 0.8:
                return False
        
        return True
    
    def calculate_similarity(self, text1, text2):
        """计算文本相似度"""
        from sklearn.feature_extraction.text import TfidfVectorizer
        from sklearn.metrics.pairwise import cosine_similarity
        
        vectorizer = TfidfVectorizer()
        tfidf_matrix = vectorizer.fit_transform([text1, text2])
        similarity = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
        
        return similarity

数据增强技术

文本回译

class BackTranslationAugmenter:
    """文本回译增强"""
    def __init__(self, src_lang="en", mid_lang="zh"):
        self.src_lang = src_lang
        self.mid_lang = mid_lang
    
    def augment(self, text, num_variations=3):
        """通过回译生成变体"""
        variations = []
        
        for _ in range(num_variations):
            # 正向翻译
            translated = self.translate(text, self.src_lang, self.mid_lang)
            
            # 反向翻译
            back_translated = self.translate(translated, self.mid_lang, self.src_lang)
            
            if back_translated and back_translated != text:
                variations.append(back_translated)
        
        return variations
    
    def translate(self, text, src_lang, tgt_lang):
        """翻译文本(简化版)"""
        # 实际应用中使用翻译API
        # 这里只是示例
        return text  # 占位符

同义词替换

import random

class SynonymReplacer:
    """同义词替换增强"""
    def __init__(self, synonym_dict=None):
        self.synonym_dict = synonym_dict or self.load_default_synonyms()
    
    def load_default_synonyms(self):
        """加载默认同义词词典"""
        return {
            "好": ["优秀", "出色", "很棒", "不错"],
            "大": ["巨大", "庞大", "大型", "广大"],
            "快": ["迅速", "快速", "敏捷", "高速"],
            "重要": ["关键", "核心", "主要", "关键性"]
        }
    
    def augment(self, text, replacement_ratio=0.2, num_variations=3):
        """通过同义词替换生成变体"""
        words = list(text)
        variations = []
        
        for _ in range(num_variations):
            new_text = text
            replaced_count = 0
            
            for word, synonyms in self.synonym_dict.items():
                if word in new_text and random.random() < replacement_ratio:
                    synonym = random.choice(synonyms)
                    new_text = new_text.replace(word, synonym, 1)
                    replaced_count += 1
            
            if replaced_count > 0:
                variations.append(new_text)
        
        return variations

指令改写

class InstructionRephraser:
    """指令改写增强"""
    def __init__(self):
        self.rephrase_templates = [
            "请解释{}",
            "描述一下{}",
            "什么是{}",
            "如何理解{}",
            "请介绍{}"
        ]
    
    def augment(self, instruction, num_variations=5):
        """通过改写生成指令变体"""
        variations = []
        
        # 提取核心内容
        core_content = self.extract_core_content(instruction)
        
        # 生成变体
        for template in self.rephrase_templates[:num_variations]:
            variation = template.format(core_content)
            if variation != instruction:
                variations.append(variation)
        
        return variations
    
    def extract_core_content(self, instruction):
        """提取指令核心内容"""
        # 简单的提取逻辑
        prefixes = ["请", "帮我", "能否", "可以"]
        for prefix in prefixes:
            if instruction.startswith(prefix):
                return instruction[len(prefix):]
        return instruction

LLM生成数据

结构化数据生成

class StructuredDataGenerator:
    """结构化数据生成器"""
    def __init__(self, llm_client):
        self.llm = llm_client
    
    def generate_qa_pairs(self, context, num_pairs=10):
        """生成问答对"""
        prompt = f"""基于以下上下文,生成{num_pairs}个问答对:

上下文:
{context}

请按照以下格式生成:
问题1:...
答案1:...

问题2:...
答案2:...
"""
        
        response = self.llm.generate(prompt)
        return self.parse_qa_pairs(response)
    
    def generate_conversations(self, topic, num_turns=5):
        """生成对话数据"""
        prompt = f"""生成一段关于{topic}的{num_turns}轮对话。

用户:...
助手:...

请确保对话自然流畅。
"""
        
        response = self.llm.generate(prompt)
        return self.parse_conversation(response)
    
    def parse_qa_pairs(self, text):
        """解析问答对"""
        pairs = []
        lines = text.strip().split('\n')
        
        current_question = None
        current_answer = None
        
        for line in lines:
            if line.startswith("问题") or line.startswith("Q"):
                if current_question and current_answer:
                    pairs.append({"question": current_question, "answer": current_answer})
                current_question = line.split(":", 1)[-1].strip() if ":" in line else ""
                current_answer = None
            elif line.startswith("答案") or line.startswith("A"):
                current_answer = line.split(":", 1)[-1].strip() if ":" in line else ""
            elif current_answer is not None:
                current_answer += line
        
        if current_question and current_answer:
            pairs.append({"question": current_question, "answer": current_answer})
        
        return pairs

多样性保证

class DiversityEnsurer:
    """确保生成数据的多样性"""
    def __init__(self):
        self.generated_texts = set()
    
    def check_diversity(self, new_text, threshold=0.7):
        """检查新文本的多样性"""
        for existing_text in self.generated_texts:
            similarity = self.calculate_similarity(new_text, existing_text)
            if similarity > threshold:
                return False
        return True
    
    def calculate_similarity(self, text1, text2):
        """计算文本相似度"""
        from sklearn.feature_extraction.text import TfidfVectorizer
        from sklearn.metrics.pairwise import cosine_similarity
        
        vectorizer = TfidfVectorizer()
        tfidf_matrix = vectorizer.fit_transform([text1, text2])
        return cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
    
    def add_text(self, text):
        """添加文本到已生成集合"""
        self.generated_texts.add(text)

质量控制

合成数据验证

class SyntheticDataValidator:
    """合成数据验证器"""
    def __init__(self):
        self.quality_checks = [
            self.check_completeness,
            self.check_coherence,
            self.check_accuracy,
            self.check_safety
        ]
    
    def validate(self, sample):
        """验证单个样本"""
        results = {}
        for check in self.quality_checks:
            check_name = check.__name__
            results[check_name] = check(sample)
        
        # 计算验证分数
        validation_score = sum(results.values()) / len(results)
        
        return {
            "results": results,
            "validation_score": validation_score,
            "passed": validation_score > 0.8
        }
    
    def check_completeness(self, sample):
        """检查完整性"""
        required_fields = ["input", "output"]
        return all(field in sample and sample[field] for field in required_fields)
    
    def check_coherence(self, sample):
        """检查连贯性"""
        # 简单的连贯性检查
        input_text = sample.get("input", "")
        output_text = sample.get("output", "")
        
        # 检查输出是否包含输入的关键词
        input_words = set(input_text.split())
        output_words = set(output_text.split())
        
        overlap = input_words & output_words
        return len(overlap) / max(len(input_words), 1) > 0.3
    
    def check_accuracy(self, sample):
        """检查准确性"""
        # 简单的准确性检查
        output = sample.get("output", "")
        
        # 检查是否有明显的错误
        error_patterns = ["错误", "不对", "不正确"]
        for pattern in error_patterns:
            if pattern in output:
                return False
        
        return True
    
    def check_safety(self, sample):
        """检查安全性"""
        # 简单的安全检查
        output = sample.get("output", "")
        
        unsafe_patterns = ["暴力", "歧视", "色情", "违法"]
        for pattern in unsafe_patterns:
            if pattern in output:
                return False
        
        return True

实践案例

# 合成数据生成配置
synthetic_config = {
    "self_instruct": {
        "seed_tasks": 100,
        "target_size": 10000,
        "quality_threshold": 0.8
    },
    "data_augmentation": {
        "back_translation": True,
        "synonym_replacement": True,
        "instruction_rephrasing": True,
        "augmentation_factor": 3
    },
    "llm_generation": {
        "qa_generation": True,
        "conversation_generation": True,
        "diversity_check": True
    }
}

总结

合成数据生成是扩展LLM训练数据的有效方法。Self-Instruct、数据增强和LLM生成提供了多种数据生成途径。关键是要建立完善的质量控制机制,确保合成数据的质量和多样性。