LLM数据增强
--- title: "LLM数据增强" description: "详细介绍LLM数据增强技术,包括文本回译、同义词替换和指令增强等方法" tags: ["数据增强", "文本回译", "同义词替换", "指令增强"] category: "llm" icon: "🧠"
LLM数据增强
数据增强概述
数据增强是通过变换现有数据来生成新训练样本的技术。对于LLM,数据增强可以:
- 增加数据多样性:减少过拟合
- 提升模型泛化能力:处理更多场景
- 平衡数据分布:解决类别不平衡
- 降低标注成本:从少量数据生成更多数据
文本回译
基本回译
from transformers import MarianMTModel, MarianTokenizer
class BackTranslationAugmenter:
"""文本回译增强"""
def __init__(self, src_lang="en", mid_lang="zh"):
self.src_lang = src_lang
self.mid_lang = mid_lang
# 加载翻译模型
self.forward_model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{mid_lang}"
self.backward_model_name = f"Helsinki-NLP/opus-mt-{mid_lang}-{src_lang}"
self.forward_model = MarianMTModel.from_pretrained(self.forward_model_name)
self.forward_tokenizer = MarianTokenizer.from_pretrained(self.forward_model_name)
self.backward_model = MarianMTModel.from_pretrained(self.backward_model_name)
self.backward_tokenizer = MarianTokenizer.from_pretrained(self.backward_model_name)
def translate(self, text, model, tokenizer):
"""翻译文本"""
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
outputs = model.generate(**inputs)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def augment(self, text, num_variations=3):
"""通过回译生成变体"""
variations = []
for _ in range(num_variations):
# 正向翻译
translated = self.translate(text, self.forward_model, self.forward_tokenizer)
# 反向翻译
back_translated = self.translate(translated, self.backward_model, self.backward_tokenizer)
if back_translated and back_translated != text:
variations.append(back_translated)
return variations
多语言回译
class MultilingualBackTranslation:
"""多语言回译增强"""
def __init__(self, intermediate_languages=["zh", "de", "fr"]):
self.intermediate_languages = intermediate_languages
self.augmenters = {}
for lang in intermediate_languages:
self.augmenters[lang] = BackTranslationAugmenter("en", lang)
def augment(self, text, num_variations=3):
"""使用多种中间语言进行回译"""
variations = []
for lang, augmenter in self.augmenters.items():
lang_variations = augmenter.augment(text, num_variations=1)
variations.extend(lang_variations)
return variations[:num_variations]
同义词替换
基于词典的替换
import random
import jieba
class SynonymReplacer:
"""同义词替换增强"""
def __init__(self, synonym_dict=None):
self.synonym_dict = synonym_dict or self.load_default_synonyms()
def load_default_synonyms(self):
"""加载默认同义词词典"""
return {
"好": ["优秀", "出色", "很棒", "不错", "良好"],
"大": ["巨大", "庞大", "大型", "广大", "硕大"],
"快": ["迅速", "快速", "敏捷", "高速", "急速"],
"重要": ["关键", "核心", "主要", "关键性", "至关重要"],
"学习": ["研习", "钻研", "修习", "攻读"],
"技术": ["技艺", "技巧", "技能", "手法"],
"方法": ["方式", "途径", "手段", "措施"],
"问题": ["难题", "议题", "课题", "疑问"]
}
def augment(self, text, replacement_ratio=0.2, num_variations=3):
"""通过同义词替换生成变体"""
# 使用jieba分词
words = list(jieba.cut(text))
variations = []
for _ in range(num_variations):
new_words = words.copy()
replaced_count = 0
for i, word in enumerate(words):
if word in self.synonym_dict and random.random() < replacement_ratio:
synonym = random.choice(self.synonym_dict[word])
new_words[i] = synonym
replaced_count += 1
if replaced_count > 0:
new_text = "".join(new_words)
variations.append(new_text)
return variations
def augment_with_pos(self, text, pos_tags=None, replacement_ratio=0.2):
"""基于词性进行替换"""
# 简化版:只替换名词和动词
words = list(jieba.cut(text))
new_words = words.copy()
for i, word in enumerate(words):
if random.random() < replacement_ratio:
# 简单的词性判断
if len(word) > 1 and word in self.synonym_dict:
synonym = random.choice(self.synonym_dict[word])
new_words[i] = synonym
return "".join(new_words)
基于Word2Vec的替换
from gensim.models import Word2Vec
class Word2VecSynonymReplacer:
"""基于Word2Vec的同义词替换"""
def __init__(self, word2vec_model):
self.model = word2vec_model
def get_synonyms(self, word, topn=5):
"""获取最相似的词"""
try:
similar_words = self.model.wv.most_similar(word, topn=topn)
return [w for w, _ in similar_words]
except KeyError:
return []
def augment(self, text, replacement_ratio=0.2, num_variations=3):
"""通过Word2Vec相似词替换"""
words = list(jieba.cut(text))
variations = []
for _ in range(num_variations):
new_words = words.copy()
for i, word in enumerate(words):
if random.random() < replacement_ratio:
synonyms = self.get_synonyms(word, topn=3)
if synonyms:
new_words[i] = random.choice(synonyms)
new_text = "".join(new_words)
if new_text != text:
variations.append(new_text)
return variations
指令增强
指令改写
class InstructionRephraser:
"""指令改写增强"""
def __init__(self):
self.rephrase_templates = [
"请解释{}",
"描述一下{}",
"什么是{}",
"如何理解{}",
"请介绍{}",
"能否说明{}",
"帮我解释{}",
"详细说明{}"
]
self.paraphrase_patterns = [
("什么是", "请解释"),
("如何", "怎样"),
("为什么", "为何"),
("请", "帮我"),
("可以", "能否")
]
def augment(self, instruction, num_variations=5):
"""通过改写生成指令变体"""
variations = []
# 提取核心内容
core_content = self.extract_core_content(instruction)
# 使用模板生成变体
for template in self.rephrase_templates[:num_variations]:
variation = template.format(core_content)
if variation != instruction:
variations.append(variation)
# 使用替换模式生成变体
for pattern, replacement in self.paraphrase_patterns:
if pattern in instruction:
variation = instruction.replace(pattern, replacement, 1)
if variation not in variations:
variations.append(variation)
return variations[:num_variations]
def extract_core_content(self, instruction):
"""提取指令核心内容"""
prefixes = ["请", "帮我", "能否", "可以", "我想", "我要"]
for prefix in prefixes:
if instruction.startswith(prefix):
return instruction[len(prefix):]
return instruction
指令扩展
class InstructionExpander:
"""指令扩展增强"""
def __init__(self):
self.expansion_templates = [
"在{}的背景下,{}",
"对于{}领域,{}",
"从{}的角度来看,{}",
"考虑到{},{}",
"在{}的情况下,{}"
]
self.contexts = [
"实际应用",
"理论研究",
"教学场景",
"工业实践",
"学术研究"
]
def augment(self, instruction, num_variations=3):
"""通过添加上下文扩展指令"""
variations = []
for _ in range(num_variations):
context = random.choice(self.contexts)
template = random.choice(self.expansion_templates)
# 提取指令核心
core = instruction
# 生成扩展指令
variation = template.format(context, core)
variations.append(variation)
return variations
回响式增强
输入-输出对生成
class EchoAugmenter:
"""回响式增强:生成输入-输出对"""
def __init__(self, llm_client):
self.llm = llm_client
def augment_qa(self, question, answer, num_variations=3):
"""增强问答对"""
variations = []
for _ in range(num_variations):
# 生成相似问题
similar_question = self.generate_similar_question(question)
# 生成相似答案
similar_answer = self.generate_similar_answer(answer)
variations.append({
"question": similar_question,
"answer": similar_answer
})
return variations
def generate_similar_question(self, question):
"""生成相似问题"""
prompt = f"""请生成一个与以下问题相似但表述不同的问题:
原问题:{question}
相似问题:"""
response = self.llm.generate(prompt)
return response.strip()
def generate_similar_answer(self, answer):
"""生成相似答案"""
prompt = f"""请用不同的方式表达以下答案的含义:
原答案:{answer}
重新表述:"""
response = self.llm.generate(prompt)
return response.strip()
数据增强管道
完整增强管道
class DataAugmentationPipeline:
"""完整的数据增强管道"""
def __init__(self):
self.back_translator = BackTranslationAugmenter()
self.synonym_replacer = SynonymReplacer()
self.instruction_rephraser = InstructionRephraser()
def augment_dataset(self, dataset, augmentation_factor=3):
"""增强整个数据集"""
augmented_dataset = list(dataset)
for sample in dataset:
# 文本回译
back_translated = self.back_translator.augment(
sample["text"], num_variations=1
)
# 同义词替换
synonym_variations = self.synonym_replacer.augment(
sample["text"], num_variations=1
)
# 指令改写(如果有指令字段)
instruction_variations = []
if "instruction" in sample:
instruction_variations = self.instruction_rephraser.augment(
sample["instruction"], num_variations=1
)
# 添加增强样本
for text in back_translated + synonym_variations:
augmented_sample = sample.copy()
augmented_sample["text"] = text
augmented_sample["augmentation_method"] = "back_translation"
augmented_dataset.append(augmented_sample)
for instruction in instruction_variations:
augmented_sample = sample.copy()
augmented_sample["instruction"] = instruction
augmented_sample["augmentation_method"] = "instruction_rephrasing"
augmented_dataset.append(augmented_sample)
return augmented_dataset
def balance_classes(self, dataset, target_count=1000):
"""平衡类别分布"""
from collections import Counter
# 统计类别分布
class_counts = Counter(sample.get("label", "unknown") for sample in dataset)
# 确定需要增强的类别
augmented_dataset = list(dataset)
for class_label, count in class_counts.items():
if count < target_count:
# 获取该类别的样本
class_samples = [s for s in dataset if s.get("label") == class_label]
# 计算需要生成的数量
num_to_generate = target_count - count
# 生成增强样本
for i in range(num_to_generate):
original_sample = class_samples[i % len(class_samples)]
# 使用多种增强方法
augmented_samples = self.augment_sample(original_sample)
for aug_sample in augmented_samples:
aug_sample["label"] = class_label
augmented_dataset.append(aug_sample)
return augmented_dataset
质量控制
增强质量验证
class AugmentationQualityValidator:
"""增强质量验证器"""
def __init__(self):
self.quality_checks = [
self.check_diversity,
self.check_coherence,
self.check_similarity
]
def validate_augmentation(self, original, augmented):
"""验证增强样本质量"""
results = {}
for check in self.quality_checks:
check_name = check.__name__
results[check_name] = check(original, augmented)
# 计算验证分数
validation_score = sum(results.values()) / len(results)
return {
"results": results,
"validation_score": validation_score,
"passed": validation_score > 0.7
}
def check_diversity(self, original, augmented):
"""检查多样性"""
# 确保增强样本与原始样本不同
return original != augmented
def check_coherence(self, original, augmented):
"""检查连贯性"""
# 简单的连贯性检查
original_words = set(original.split())
augmented_words = set(augmented.split())
# 计算词汇重叠度
overlap = original_words & augmented_words
overlap_ratio = len(overlap) / max(len(original_words), 1)
# 重叠度应该在合理范围内
return 0.3 < overlap_ratio < 0.9
def check_similarity(self, original, augmented):
"""检查相似度"""
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform([original, augmented])
similarity = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
# 相似度应该在合理范围内
return 0.3 < similarity < 0.9
实践配置
# 数据增强配置
augmentation_config = {
"back_translation": {
"enabled": True,
"intermediate_languages": ["zh", "de", "fr"],
"variations_per_sample": 2
},
"synonym_replacement": {
"enabled": True,
"replacement_ratio": 0.2,
"variations_per_sample": 2
},
"instruction_rephrasing": {
"enabled": True,
"variations_per_sample": 3
},
"quality_validation": {
"enabled": True,
"min_quality_score": 0.7
}
}
总结
LLM数据增强是扩展训练数据的有效方法。文本回译、同义词替换和指令增强提供了多种增强途径。关键是建立质量控制机制,确保增强数据的质量和多样性。