合成数据生成
--- title: "合成数据生成" description: "详细介绍合成数据生成技术,包括Self-Instruct方法、数据增强和LLM生成数据的实践" tags: ["合成数据", "Self-Instruct", "数据增强", "LLM生成"] category: "llm" icon: "🧠"
合成数据生成
合成数据概述
合成数据是指通过算法或模型生成的、用于训练的数据,而非从现实世界收集的数据。对于LLM训练,合成数据具有以下优势:
- 可控性:可以精确控制数据分布和特性
- 可扩展性:可以生成大量训练数据
- 隐私保护:不涉及真实用户数据
- 成本效益:比人工标注更便宜
Self-Instruct方法
基本原理
Self-Instruct是一种使用LLM自身生成指令数据的方法。
from transformers import AutoTokenizer, AutoModelForCausalLM
class SelfInstructGenerator:
"""Self-Instruct数据生成器"""
def __init__(self, model_name="llama-7b"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
def generate_instruction(self, seed_task):
"""从种子任务生成新指令"""
prompt = f"""以下是一个任务的示例:
指令:{seed_task['instruction']}
输入:{seed_task['input']}
输出:{seed_task['output']}
请生成一个类似的新任务,格式相同:
指令:"""
inputs = self.tokenizer(prompt, return_tensors="pt")
outputs = self.model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
do_sample=True
)
generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return self.parse_instruction(generated)
def parse_instruction(self, text):
"""解析生成的指令"""
lines = text.strip().split('\n')
instruction = ""
input_text = ""
output_text = ""
current_field = None
for line in lines:
if line.startswith("指令:"):
current_field = "instruction"
instruction = line[3:].strip()
elif line.startswith("输入:"):
current_field = "input"
input_text = line[3:].strip()
elif line.startswith("输出:"):
current_field = "output"
output_text = line[3:].strip()
elif current_field:
if current_field == "instruction":
instruction += line
elif current_field == "input":
input_text += line
elif current_field == "output":
output_text += line
return {
"instruction": instruction,
"input": input_text,
"output": output_text
}
Self-Instruct管道
class SelfInstructPipeline:
"""完整的Self-Instruct管道"""
def __init__(self, seed_tasks, num_instructions=1000):
self.seed_tasks = seed_tasks
self.num_instructions = num_instructions
self.generator = SelfInstructGenerator()
def generate_dataset(self):
"""生成数据集"""
generated_tasks = list(self.seed_tasks)
while len(generated_tasks) < self.num_instructions:
# 随机选择种子任务
import random
seed_task = random.choice(generated_tasks[:100]) # 使用前100个作为种子
# 生成新任务
new_task = self.generator.generate_instruction(seed_task)
# 验证质量
if self.validate_task(new_task):
generated_tasks.append(new_task)
return generated_tasks[:self.num_instructions]
def validate_task(self, task):
"""验证任务质量"""
# 检查字段完整性
if not task.get("instruction"):
return False
# 检查长度
if len(task["instruction"]) < 10:
return False
# 检查是否与现有任务重复
# 简单的重复检测
for existing in self.seed_tasks:
if self.calculate_similarity(task["instruction"], existing["instruction"]) > 0.8:
return False
return True
def calculate_similarity(self, text1, text2):
"""计算文本相似度"""
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform([text1, text2])
similarity = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
return similarity
数据增强技术
文本回译
class BackTranslationAugmenter:
"""文本回译增强"""
def __init__(self, src_lang="en", mid_lang="zh"):
self.src_lang = src_lang
self.mid_lang = mid_lang
def augment(self, text, num_variations=3):
"""通过回译生成变体"""
variations = []
for _ in range(num_variations):
# 正向翻译
translated = self.translate(text, self.src_lang, self.mid_lang)
# 反向翻译
back_translated = self.translate(translated, self.mid_lang, self.src_lang)
if back_translated and back_translated != text:
variations.append(back_translated)
return variations
def translate(self, text, src_lang, tgt_lang):
"""翻译文本(简化版)"""
# 实际应用中使用翻译API
# 这里只是示例
return text # 占位符
同义词替换
import random
class SynonymReplacer:
"""同义词替换增强"""
def __init__(self, synonym_dict=None):
self.synonym_dict = synonym_dict or self.load_default_synonyms()
def load_default_synonyms(self):
"""加载默认同义词词典"""
return {
"好": ["优秀", "出色", "很棒", "不错"],
"大": ["巨大", "庞大", "大型", "广大"],
"快": ["迅速", "快速", "敏捷", "高速"],
"重要": ["关键", "核心", "主要", "关键性"]
}
def augment(self, text, replacement_ratio=0.2, num_variations=3):
"""通过同义词替换生成变体"""
words = list(text)
variations = []
for _ in range(num_variations):
new_text = text
replaced_count = 0
for word, synonyms in self.synonym_dict.items():
if word in new_text and random.random() < replacement_ratio:
synonym = random.choice(synonyms)
new_text = new_text.replace(word, synonym, 1)
replaced_count += 1
if replaced_count > 0:
variations.append(new_text)
return variations
指令改写
class InstructionRephraser:
"""指令改写增强"""
def __init__(self):
self.rephrase_templates = [
"请解释{}",
"描述一下{}",
"什么是{}",
"如何理解{}",
"请介绍{}"
]
def augment(self, instruction, num_variations=5):
"""通过改写生成指令变体"""
variations = []
# 提取核心内容
core_content = self.extract_core_content(instruction)
# 生成变体
for template in self.rephrase_templates[:num_variations]:
variation = template.format(core_content)
if variation != instruction:
variations.append(variation)
return variations
def extract_core_content(self, instruction):
"""提取指令核心内容"""
# 简单的提取逻辑
prefixes = ["请", "帮我", "能否", "可以"]
for prefix in prefixes:
if instruction.startswith(prefix):
return instruction[len(prefix):]
return instruction
LLM生成数据
结构化数据生成
class StructuredDataGenerator:
"""结构化数据生成器"""
def __init__(self, llm_client):
self.llm = llm_client
def generate_qa_pairs(self, context, num_pairs=10):
"""生成问答对"""
prompt = f"""基于以下上下文,生成{num_pairs}个问答对:
上下文:
{context}
请按照以下格式生成:
问题1:...
答案1:...
问题2:...
答案2:...
"""
response = self.llm.generate(prompt)
return self.parse_qa_pairs(response)
def generate_conversations(self, topic, num_turns=5):
"""生成对话数据"""
prompt = f"""生成一段关于{topic}的{num_turns}轮对话。
用户:...
助手:...
请确保对话自然流畅。
"""
response = self.llm.generate(prompt)
return self.parse_conversation(response)
def parse_qa_pairs(self, text):
"""解析问答对"""
pairs = []
lines = text.strip().split('\n')
current_question = None
current_answer = None
for line in lines:
if line.startswith("问题") or line.startswith("Q"):
if current_question and current_answer:
pairs.append({"question": current_question, "answer": current_answer})
current_question = line.split(":", 1)[-1].strip() if ":" in line else ""
current_answer = None
elif line.startswith("答案") or line.startswith("A"):
current_answer = line.split(":", 1)[-1].strip() if ":" in line else ""
elif current_answer is not None:
current_answer += line
if current_question and current_answer:
pairs.append({"question": current_question, "answer": current_answer})
return pairs
多样性保证
class DiversityEnsurer:
"""确保生成数据的多样性"""
def __init__(self):
self.generated_texts = set()
def check_diversity(self, new_text, threshold=0.7):
"""检查新文本的多样性"""
for existing_text in self.generated_texts:
similarity = self.calculate_similarity(new_text, existing_text)
if similarity > threshold:
return False
return True
def calculate_similarity(self, text1, text2):
"""计算文本相似度"""
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform([text1, text2])
return cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
def add_text(self, text):
"""添加文本到已生成集合"""
self.generated_texts.add(text)
质量控制
合成数据验证
class SyntheticDataValidator:
"""合成数据验证器"""
def __init__(self):
self.quality_checks = [
self.check_completeness,
self.check_coherence,
self.check_accuracy,
self.check_safety
]
def validate(self, sample):
"""验证单个样本"""
results = {}
for check in self.quality_checks:
check_name = check.__name__
results[check_name] = check(sample)
# 计算验证分数
validation_score = sum(results.values()) / len(results)
return {
"results": results,
"validation_score": validation_score,
"passed": validation_score > 0.8
}
def check_completeness(self, sample):
"""检查完整性"""
required_fields = ["input", "output"]
return all(field in sample and sample[field] for field in required_fields)
def check_coherence(self, sample):
"""检查连贯性"""
# 简单的连贯性检查
input_text = sample.get("input", "")
output_text = sample.get("output", "")
# 检查输出是否包含输入的关键词
input_words = set(input_text.split())
output_words = set(output_text.split())
overlap = input_words & output_words
return len(overlap) / max(len(input_words), 1) > 0.3
def check_accuracy(self, sample):
"""检查准确性"""
# 简单的准确性检查
output = sample.get("output", "")
# 检查是否有明显的错误
error_patterns = ["错误", "不对", "不正确"]
for pattern in error_patterns:
if pattern in output:
return False
return True
def check_safety(self, sample):
"""检查安全性"""
# 简单的安全检查
output = sample.get("output", "")
unsafe_patterns = ["暴力", "歧视", "色情", "违法"]
for pattern in unsafe_patterns:
if pattern in output:
return False
return True
实践案例
# 合成数据生成配置
synthetic_config = {
"self_instruct": {
"seed_tasks": 100,
"target_size": 10000,
"quality_threshold": 0.8
},
"data_augmentation": {
"back_translation": True,
"synonym_replacement": True,
"instruction_rephrasing": True,
"augmentation_factor": 3
},
"llm_generation": {
"qa_generation": True,
"conversation_generation": True,
"diversity_check": True
}
}
总结
合成数据生成是扩展LLM训练数据的有效方法。Self-Instruct、数据增强和LLM生成提供了多种数据生成途径。关键是要建立完善的质量控制机制,确保合成数据的质量和多样性。