← 返回首页
🧠

LLM数据增强

📂 llm ⏱ 5 min 931 words

--- title: "LLM数据增强" description: "详细介绍LLM数据增强技术,包括文本回译、同义词替换和指令增强等方法" tags: ["数据增强", "文本回译", "同义词替换", "指令增强"] category: "llm" icon: "🧠"

LLM数据增强

数据增强概述

数据增强是通过变换现有数据来生成新训练样本的技术。对于LLM,数据增强可以:

文本回译

基本回译

from transformers import MarianMTModel, MarianTokenizer

class BackTranslationAugmenter:
    """文本回译增强"""
    def __init__(self, src_lang="en", mid_lang="zh"):
        self.src_lang = src_lang
        self.mid_lang = mid_lang
        
        # 加载翻译模型
        self.forward_model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{mid_lang}"
        self.backward_model_name = f"Helsinki-NLP/opus-mt-{mid_lang}-{src_lang}"
        
        self.forward_model = MarianMTModel.from_pretrained(self.forward_model_name)
        self.forward_tokenizer = MarianTokenizer.from_pretrained(self.forward_model_name)
        
        self.backward_model = MarianMTModel.from_pretrained(self.backward_model_name)
        self.backward_tokenizer = MarianTokenizer.from_pretrained(self.backward_model_name)
    
    def translate(self, text, model, tokenizer):
        """翻译文本"""
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        outputs = model.generate(**inputs)
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    def augment(self, text, num_variations=3):
        """通过回译生成变体"""
        variations = []
        
        for _ in range(num_variations):
            # 正向翻译
            translated = self.translate(text, self.forward_model, self.forward_tokenizer)
            
            # 反向翻译
            back_translated = self.translate(translated, self.backward_model, self.backward_tokenizer)
            
            if back_translated and back_translated != text:
                variations.append(back_translated)
        
        return variations

多语言回译

class MultilingualBackTranslation:
    """多语言回译增强"""
    def __init__(self, intermediate_languages=["zh", "de", "fr"]):
        self.intermediate_languages = intermediate_languages
        self.augmenters = {}
        
        for lang in intermediate_languages:
            self.augmenters[lang] = BackTranslationAugmenter("en", lang)
    
    def augment(self, text, num_variations=3):
        """使用多种中间语言进行回译"""
        variations = []
        
        for lang, augmenter in self.augmenters.items():
            lang_variations = augmenter.augment(text, num_variations=1)
            variations.extend(lang_variations)
        
        return variations[:num_variations]

同义词替换

基于词典的替换

import random
import jieba

class SynonymReplacer:
    """同义词替换增强"""
    def __init__(self, synonym_dict=None):
        self.synonym_dict = synonym_dict or self.load_default_synonyms()
    
    def load_default_synonyms(self):
        """加载默认同义词词典"""
        return {
            "好": ["优秀", "出色", "很棒", "不错", "良好"],
            "大": ["巨大", "庞大", "大型", "广大", "硕大"],
            "快": ["迅速", "快速", "敏捷", "高速", "急速"],
            "重要": ["关键", "核心", "主要", "关键性", "至关重要"],
            "学习": ["研习", "钻研", "修习", "攻读"],
            "技术": ["技艺", "技巧", "技能", "手法"],
            "方法": ["方式", "途径", "手段", "措施"],
            "问题": ["难题", "议题", "课题", "疑问"]
        }
    
    def augment(self, text, replacement_ratio=0.2, num_variations=3):
        """通过同义词替换生成变体"""
        # 使用jieba分词
        words = list(jieba.cut(text))
        variations = []
        
        for _ in range(num_variations):
            new_words = words.copy()
            replaced_count = 0
            
            for i, word in enumerate(words):
                if word in self.synonym_dict and random.random() < replacement_ratio:
                    synonym = random.choice(self.synonym_dict[word])
                    new_words[i] = synonym
                    replaced_count += 1
            
            if replaced_count > 0:
                new_text = "".join(new_words)
                variations.append(new_text)
        
        return variations
    
    def augment_with_pos(self, text, pos_tags=None, replacement_ratio=0.2):
        """基于词性进行替换"""
        # 简化版:只替换名词和动词
        words = list(jieba.cut(text))
        new_words = words.copy()
        
        for i, word in enumerate(words):
            if random.random() < replacement_ratio:
                # 简单的词性判断
                if len(word) > 1 and word in self.synonym_dict:
                    synonym = random.choice(self.synonym_dict[word])
                    new_words[i] = synonym
        
        return "".join(new_words)

基于Word2Vec的替换

from gensim.models import Word2Vec

class Word2VecSynonymReplacer:
    """基于Word2Vec的同义词替换"""
    def __init__(self, word2vec_model):
        self.model = word2vec_model
    
    def get_synonyms(self, word, topn=5):
        """获取最相似的词"""
        try:
            similar_words = self.model.wv.most_similar(word, topn=topn)
            return [w for w, _ in similar_words]
        except KeyError:
            return []
    
    def augment(self, text, replacement_ratio=0.2, num_variations=3):
        """通过Word2Vec相似词替换"""
        words = list(jieba.cut(text))
        variations = []
        
        for _ in range(num_variations):
            new_words = words.copy()
            
            for i, word in enumerate(words):
                if random.random() < replacement_ratio:
                    synonyms = self.get_synonyms(word, topn=3)
                    if synonyms:
                        new_words[i] = random.choice(synonyms)
            
            new_text = "".join(new_words)
            if new_text != text:
                variations.append(new_text)
        
        return variations

指令增强

指令改写

class InstructionRephraser:
    """指令改写增强"""
    def __init__(self):
        self.rephrase_templates = [
            "请解释{}",
            "描述一下{}",
            "什么是{}",
            "如何理解{}",
            "请介绍{}",
            "能否说明{}",
            "帮我解释{}",
            "详细说明{}"
        ]
        
        self.paraphrase_patterns = [
            ("什么是", "请解释"),
            ("如何", "怎样"),
            ("为什么", "为何"),
            ("请", "帮我"),
            ("可以", "能否")
        ]
    
    def augment(self, instruction, num_variations=5):
        """通过改写生成指令变体"""
        variations = []
        
        # 提取核心内容
        core_content = self.extract_core_content(instruction)
        
        # 使用模板生成变体
        for template in self.rephrase_templates[:num_variations]:
            variation = template.format(core_content)
            if variation != instruction:
                variations.append(variation)
        
        # 使用替换模式生成变体
        for pattern, replacement in self.paraphrase_patterns:
            if pattern in instruction:
                variation = instruction.replace(pattern, replacement, 1)
                if variation not in variations:
                    variations.append(variation)
        
        return variations[:num_variations]
    
    def extract_core_content(self, instruction):
        """提取指令核心内容"""
        prefixes = ["请", "帮我", "能否", "可以", "我想", "我要"]
        for prefix in prefixes:
            if instruction.startswith(prefix):
                return instruction[len(prefix):]
        return instruction

指令扩展

class InstructionExpander:
    """指令扩展增强"""
    def __init__(self):
        self.expansion_templates = [
            "在{}的背景下,{}",
            "对于{}领域,{}",
            "从{}的角度来看,{}",
            "考虑到{},{}",
            "在{}的情况下,{}"
        ]
        
        self.contexts = [
            "实际应用",
            "理论研究",
            "教学场景",
            "工业实践",
            "学术研究"
        ]
    
    def augment(self, instruction, num_variations=3):
        """通过添加上下文扩展指令"""
        variations = []
        
        for _ in range(num_variations):
            context = random.choice(self.contexts)
            template = random.choice(self.expansion_templates)
            
            # 提取指令核心
            core = instruction
            
            # 生成扩展指令
            variation = template.format(context, core)
            variations.append(variation)
        
        return variations

回响式增强

输入-输出对生成

class EchoAugmenter:
    """回响式增强:生成输入-输出对"""
    def __init__(self, llm_client):
        self.llm = llm_client
    
    def augment_qa(self, question, answer, num_variations=3):
        """增强问答对"""
        variations = []
        
        for _ in range(num_variations):
            # 生成相似问题
            similar_question = self.generate_similar_question(question)
            
            # 生成相似答案
            similar_answer = self.generate_similar_answer(answer)
            
            variations.append({
                "question": similar_question,
                "answer": similar_answer
            })
        
        return variations
    
    def generate_similar_question(self, question):
        """生成相似问题"""
        prompt = f"""请生成一个与以下问题相似但表述不同的问题:
        
原问题:{question}
        
相似问题:"""
        
        response = self.llm.generate(prompt)
        return response.strip()
    
    def generate_similar_answer(self, answer):
        """生成相似答案"""
        prompt = f"""请用不同的方式表达以下答案的含义:
        
原答案:{answer}
        
重新表述:"""
        
        response = self.llm.generate(prompt)
        return response.strip()

数据增强管道

完整增强管道

class DataAugmentationPipeline:
    """完整的数据增强管道"""
    def __init__(self):
        self.back_translator = BackTranslationAugmenter()
        self.synonym_replacer = SynonymReplacer()
        self.instruction_rephraser = InstructionRephraser()
    
    def augment_dataset(self, dataset, augmentation_factor=3):
        """增强整个数据集"""
        augmented_dataset = list(dataset)
        
        for sample in dataset:
            # 文本回译
            back_translated = self.back_translator.augment(
                sample["text"], num_variations=1
            )
            
            # 同义词替换
            synonym_variations = self.synonym_replacer.augment(
                sample["text"], num_variations=1
            )
            
            # 指令改写(如果有指令字段)
            instruction_variations = []
            if "instruction" in sample:
                instruction_variations = self.instruction_rephraser.augment(
                    sample["instruction"], num_variations=1
                )
            
            # 添加增强样本
            for text in back_translated + synonym_variations:
                augmented_sample = sample.copy()
                augmented_sample["text"] = text
                augmented_sample["augmentation_method"] = "back_translation"
                augmented_dataset.append(augmented_sample)
            
            for instruction in instruction_variations:
                augmented_sample = sample.copy()
                augmented_sample["instruction"] = instruction
                augmented_sample["augmentation_method"] = "instruction_rephrasing"
                augmented_dataset.append(augmented_sample)
        
        return augmented_dataset
    
    def balance_classes(self, dataset, target_count=1000):
        """平衡类别分布"""
        from collections import Counter
        
        # 统计类别分布
        class_counts = Counter(sample.get("label", "unknown") for sample in dataset)
        
        # 确定需要增强的类别
        augmented_dataset = list(dataset)
        
        for class_label, count in class_counts.items():
            if count < target_count:
                # 获取该类别的样本
                class_samples = [s for s in dataset if s.get("label") == class_label]
                
                # 计算需要生成的数量
                num_to_generate = target_count - count
                
                # 生成增强样本
                for i in range(num_to_generate):
                    original_sample = class_samples[i % len(class_samples)]
                    
                    # 使用多种增强方法
                    augmented_samples = self.augment_sample(original_sample)
                    
                    for aug_sample in augmented_samples:
                        aug_sample["label"] = class_label
                        augmented_dataset.append(aug_sample)
        
        return augmented_dataset

质量控制

增强质量验证

class AugmentationQualityValidator:
    """增强质量验证器"""
    def __init__(self):
        self.quality_checks = [
            self.check_diversity,
            self.check_coherence,
            self.check_similarity
        ]
    
    def validate_augmentation(self, original, augmented):
        """验证增强样本质量"""
        results = {}
        
        for check in self.quality_checks:
            check_name = check.__name__
            results[check_name] = check(original, augmented)
        
        # 计算验证分数
        validation_score = sum(results.values()) / len(results)
        
        return {
            "results": results,
            "validation_score": validation_score,
            "passed": validation_score > 0.7
        }
    
    def check_diversity(self, original, augmented):
        """检查多样性"""
        # 确保增强样本与原始样本不同
        return original != augmented
    
    def check_coherence(self, original, augmented):
        """检查连贯性"""
        # 简单的连贯性检查
        original_words = set(original.split())
        augmented_words = set(augmented.split())
        
        # 计算词汇重叠度
        overlap = original_words & augmented_words
        overlap_ratio = len(overlap) / max(len(original_words), 1)
        
        # 重叠度应该在合理范围内
        return 0.3 < overlap_ratio < 0.9
    
    def check_similarity(self, original, augmented):
        """检查相似度"""
        from sklearn.feature_extraction.text import TfidfVectorizer
        from sklearn.metrics.pairwise import cosine_similarity
        
        vectorizer = TfidfVectorizer()
        tfidf_matrix = vectorizer.fit_transform([original, augmented])
        similarity = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
        
        # 相似度应该在合理范围内
        return 0.3 < similarity < 0.9

实践配置

# 数据增强配置
augmentation_config = {
    "back_translation": {
        "enabled": True,
        "intermediate_languages": ["zh", "de", "fr"],
        "variations_per_sample": 2
    },
    "synonym_replacement": {
        "enabled": True,
        "replacement_ratio": 0.2,
        "variations_per_sample": 2
    },
    "instruction_rephrasing": {
        "enabled": True,
        "variations_per_sample": 3
    },
    "quality_validation": {
        "enabled": True,
        "min_quality_score": 0.7
    }
}

总结

LLM数据增强是扩展训练数据的有效方法。文本回译、同义词替换和指令增强提供了多种增强途径。关键是建立质量控制机制,确保增强数据的质量和多样性。