← 返回首页
🧠

LLM训练数据清洗技术

📂 llm ⏱ 4 min 674 words

--- title: "LLM训练数据清洗技术" description: "深入掌握训练数据清洗的各种技术,包括文本清洗、质量过滤和去重策略" tags: ["数据清洗", "文本处理", "质量过滤", "去重"] category: "llm" icon: "🧠"

LLM训练数据清洗技术

数据清洗的重要性

训练数据质量直接影响模型性能。"Garbage in, garbage out"在LLM领域尤为明显。研究表明,高质量的小数据集往往比低质量的大数据集效果更好。

数据清洗的主要目标:

文本清洗

基础清洗

import re
import unicodedata
from html import unescape

def basic_clean(text):
    """基础文本清洗"""
    # Unicode标准化
    text = unicodedata.normalize('NFKC', text)
    
    # HTML解码
    text = unescape(text)
    
    # 去除HTML标签
    text = re.sub(r'<[^>]+>', ' ', text)
    
    # 去除多余空白
    text = re.sub(r'\s+', ' ', text)
    
    # 去除控制字符
    text = ''.join(char for char in text if unicodedata.category(char)[0] != 'C')
    
    return text.strip()

高级清洗

def advanced_clean(text):
    """高级文本清洗"""
    # 去除URL
    text = re.sub(r'https?://\S+|www\.\S+', '', text)
    
    # 去除邮箱
    text = re.sub(r'\S+@\S+\.\S+', '', text)
    
    # 去除电话号码
    text = re.sub(r'\b\d{3}[-.]?\d{3,4}[-.]?\d{4}\b', '', text)
    
    # 去除IP地址
    text = re.sub(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', '', text)
    
    # 去除多余标点
    text = re.sub(r'([.!?])\1+', r'\1', text)
    
    # 规范化引号
    text = re.sub(r'["""]', '"', text)
    text = re.sub(r"['']", "'", text)
    
    return text.strip()

语言检测与过滤

from langdetect import detect, DetectorFactory
DetectorFactory.seed = 0  # 确保结果可重复

def detect_language(text, sample_size=1000):
    """检测文本语言"""
    sample = text[:sample_size]
    try:
        return detect(sample)
    except:
        return "unknown"

def filter_by_language(dataset, target_lang="en", threshold=0.9):
    """按语言过滤数据集"""
    def is_target_lang(example):
        lang = detect_language(example["text"])
        return lang == target_lang
    
    return dataset.filter(is_target_lang)

质量过滤

规则过滤

def rule_based_filter(text):
    """基于规则的质量过滤"""
    # 长度过滤
    if len(text) < 50 or len(text) > 100000:
        return False
    
    # 字母比例(过滤乱码)
    alpha_ratio = sum(c.isalpha() for c in text) / len(text)
    if alpha_ratio < 0.5:
        return False
    
    # 数字比例(过滤数据表)
    digit_ratio = sum(c.isdigit() for c in text) / len(text)
    if digit_ratio > 0.3:
        return False
    
    # 特殊字符比例
    special_chars = sum(not c.isalnum() and not c.isspace() for c in text)
    special_ratio = special_chars / len(text)
    if special_ratio > 0.2:
        return False
    
    # 重复行比例
    lines = text.split('\n')
    unique_lines = len(set(lines)) / max(len(lines), 1)
    if unique_lines < 0.5:
        return False
    
    # 平均行长度
    avg_line_len = sum(len(line) for line in lines) / max(len(lines), 1)
    if avg_line_len < 10:
        return False
    
    return True

模型过滤

from transformers import pipeline

# 使用分类器过滤低质量内容
classifier = pipeline("text-classification", model="unitary/toxic-bert")

def model_based_filter(text, max_toxicity=0.5):
    """基于模型的质量过滤"""
    result = classifier(text[:512])
    score = result[0]["score"]
    label = result[0]["label"]
    
    # 过滤有毒内容
    if label == "toxic" and score > max_toxicity:
        return False
    
    return True

# 使用困惑度过滤
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

def perplexity_filter(text, max_perplexity=1000):
    """基于困惑度的质量过滤"""
    inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
    outputs = model(**inputs, labels=inputs["input_ids"])
    perplexity = torch.exp(outputs.loss).item()
    return perplexity < max_perplexity

去重策略

精确去重

from hashlib import md5, sha256

def exact_dedup_by_hash(documents):
    """基于哈希的精确去重"""
    seen = set()
    unique_docs = []
    
    for doc in documents:
        doc_hash = sha256(doc["text"].encode()).hexdigest()
        if doc_hash not in seen:
            seen.add(doc_hash)
            unique_docs.append(doc)
    
    return unique_docs

def exact_dedup_by_line(documents):
    """基于行的精确去重"""
    seen_lines = set()
    unique_docs = []
    
    for doc in documents:
        lines = doc["text"].split('\n')
        unique_lines = [l for l in lines if l not in seen_lines]
        
        if len(unique_lines) / max(len(lines), 1) > 0.5:
            seen_lines.update(unique_lines)
            unique_docs.append({
                "text": '\n'.join(unique_lines)
            })
    
    return unique_docs

模糊去重

from datasketch import MinHash, MinHashLSH

def fuzzy_dedup(documents, threshold=0.8, num_perm=128):
    """基于MinHash的模糊去重"""
    lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
    unique_docs = []
    
    for i, doc in enumerate(documents):
        m = MinHash(num_perm=num_perm)
        # 使用n-gram提高准确性
        words = doc["text"].split()
        for j in range(len(words) - 2):
            ngram = ' '.join(words[j:j+3])
            m.update(ngram.encode('utf8'))
        
        # 查询相似文档
        result = lsh.query(m)
        if not result:
            lsh.insert(str(i), m)
            unique_docs.append(doc)
    
    return unique_docs

Suffix Array去重

def suffix_array_dedup(text, min_match=50):
    """基于后缀数组的子串去重"""
    # 构建后缀数组
    suffixes = sorted(range(len(text)), key=lambda i: text[i:])
    
    # 查找最长重复子串
    lcp = [0] * len(text)
    for i in range(1, len(text)):
        s1 = text[suffixes[i-1]:]
        s2 = text[suffixes[i]:]
        j = 0
        while j < len(s1) and j < len(s2) and s1[j] == s2[j]:
            j += 1
        lcp[i] = j
    
    # 去除重复子串
    max_lcp = max(lcp)
    if max_lcp > min_match:
        # 找到重复位置并去重
        pass
    
    return text

特殊内容处理

代码块处理

def process_code_blocks(text):
    """处理代码块"""
    # 提取代码块
    code_pattern = r'```(\w+)?\n(.*?)```'
    codes = re.findall(code_pattern, text, re.DOTALL)
    
    # 代码块质量检查
    valid_codes = []
    for lang, code in codes:
        # 检查代码完整性
        if code.count('{') == code.count('}'):
            valid_codes.append((lang, code))
    
    return valid_codes

表格处理

def process_tables(text):
    """处理表格内容"""
    # 检测表格
    table_pattern = r'\|.*\|.*\|'
    tables = re.findall(table_pattern, text, re.MULTILINE)
    
    # 转换为结构化格式
    processed_tables = []
    for table in tables:
        rows = table.split('\n')
        headers = [cell.strip() for cell in rows[0].split('|') if cell.strip()]
        data_rows = []
        for row in rows[2:]:  # 跳过分隔行
            cells = [cell.strip() for cell in row.split('|') if cell.strip()]
            if cells:
                data_rows.append(dict(zip(headers, cells)))
        processed_tables.append(data_rows)
    
    return processed_tables

数据清洗流水线

class DataCleaningPipeline:
    """数据清洗流水线"""
    
    def __init__(self):
        self.steps = []
    
    def add_step(self, name, func):
        self.steps.append((name, func))
        return self
    
    def process(self, dataset):
        for name, func in self.steps:
            print(f"执行: {name}")
            dataset = func(dataset)
            print(f"剩余样本: {len(dataset)}")
        return dataset

# 使用示例
pipeline = DataCleaningPipeline()
pipeline.add_step("基础清洗", lambda ds: ds.map(lambda x: {"text": basic_clean(x["text"])}))
pipeline.add_step("质量过滤", lambda ds: ds.filter(lambda x: rule_based_filter(x["text"])))
pipeline.add_step("语言过滤", lambda ds: filter_by_language(ds, "en"))
pipeline.add_step("去重", lambda ds: Dataset.from_list(fuzzy_dedup(list(ds))))

clean_dataset = pipeline.process(raw_dataset)

系统化的数据清洗流程是获得高质量训练数据的关键。