← 返回首页
🧠

LLM训练数据准备

📂 llm ⏱ 3 min 519 words

--- title: "LLM训练数据准备" description: "掌握大语言模型训练数据的收集、清洗、预处理和格式化全流程" tags: ["数据准备", "数据收集", "预处理", "训练数据"] category: "llm" icon: "🧠"

LLM训练数据准备

数据准备概述

高质量的训练数据是大语言模型成功的关键。数据准备包括数据收集、清洗、去重、脱敏和格式化等多个步骤。本文将介绍LLM训练数据准备的完整流程。

数据准备的核心目标:

数据收集

开源数据集

from datasets import load_dataset

# 常用开源数据集
datasets = {
    "Common Crawl": "common-crawl/c4",  # 大规模网页数据
    "Wikipedia": "wikimedia/wikipedia",  # 维基百科
    "BookCorpus": "bookcorpus",  # 书籍语料
    "OpenWebText": "openwebtext",  # Reddit高赞内容
    "The Pile": "EleutherAI/the_pile",  # 混合语料
    "RedPajama": "togethercomputer/RedPajama-Data-1T"  # 开源LLM数据
}

# 加载示例
dataset = load_dataset("openwebtext", split="train")
print(f"数据集大小: {len(dataset)}")

网页爬取

import requests
from bs4 import BeautifulSoup
import trafilatura

def crawl_webpage(url):
    """爬取网页内容"""
    try:
        response = requests.get(url, timeout=10)
        html = response.text
        
        # 使用trafilatura提取正文
        text = trafilatura.extract(html)
        
        # 提取元数据
        soup = BeautifulSoup(html, 'html.parser')
        title = soup.title.string if soup.title else ""
        
        return {
            "url": url,
            "title": title,
            "content": text
        }
    except Exception as e:
        return None

# 批量爬取
urls = ["https://example.com/article1", "https://example.com/article2"]
documents = [crawl_webpage(url) for url in urls]
documents = [doc for doc in documents if doc is not None]

数据清洗

文本清洗

import re
import unicodedata

def clean_text(text):
    """基础文本清洗"""
    # 1. Unicode标准化
    text = unicodedata.normalize('NFKC', text)
    
    # 2. 去除多余空白
    text = re.sub(r'\s+', ' ', text)
    
    # 3. 去除HTML标签
    text = re.sub(r'<[^>]+>', '', text)
    
    # 4. 去除URL
    text = re.sub(r'https?://\S+', '', text)
    
    # 5. 去除邮箱
    text = re.sub(r'\S+@\S+', '', text)
    
    return text.strip()

# 应用清洗
dataset = dataset.map(lambda x: {"text": clean_text(x["text"])})

语言检测

from langdetect import detect

def filter_by_language(text, target_lang="en"):
    """按语言过滤"""
    try:
        lang = detect(text)
        return lang == target_lang
    except:
        return False

# 过滤非目标语言
dataset = dataset.filter(lambda x: filter_by_language(x["text"], "en"))

质量过滤

def quality_filter(text, min_length=100, max_length=100000):
    """质量过滤"""
    # 1. 长度过滤
    if len(text) < min_length or len(text) > max_length:
        return False
    
    # 2. 字母比例过滤(过滤乱码)
    alpha_ratio = sum(c.isalpha() for c in text) / len(text)
    if alpha_ratio < 0.3:
        return False
    
    # 3. 重复内容过滤
    lines = text.split('\n')
    unique_lines = set(lines)
    if len(unique_lines) / len(lines) < 0.5:
        return False
    
    # 4. 特殊字符比例过滤
    special_ratio = sum(not c.isalnum() and not c.isspace() for c in text) / len(text)
    if special_ratio > 0.3:
        return False
    
    return True

dataset = dataset.filter(lambda x: quality_filter(x["text"]))

去重处理

精确去重

from hashlib import md5

def exact_dedup(documents):
    """基于哈希的精确去重"""
    seen = set()
    unique_docs = []
    
    for doc in documents:
        doc_hash = md5(doc["text"].encode()).hexdigest()
        if doc_hash not in seen:
            seen.add(doc_hash)
            unique_docs.append(doc)
    
    return unique_docs

模糊去重

from datasketch import MinHash, MinHashLSH

def fuzzy_dedup(documents, threshold=0.8):
    """基于MinHash的模糊去重"""
    lsh = MinHashLSH(threshold=threshold, num_perm=128)
    unique_docs = []
    
    for i, doc in enumerate(documents):
        # 创建MinHash
        m = MinHash(num_perm=128)
        for word in doc["text"].split():
            m.update(word.encode('utf8'))
        
        # 查询相似文档
        result = lsh.query(m)
        if not result:
            lsh.insert(str(i), m)
            unique_docs.append(doc)
    
    return unique_docs

敏感信息脱敏

import re

def anonymize_text(text):
    """匿名化敏感信息"""
    # 1. 手机号
    text = re.sub(r'1[3-9]\d{9}', '[PHONE]', text)
    
    # 2. 身份证号
    text = re.sub(r'\d{17}[\dXx]', '[ID_CARD]', text)
    
    # 3. 邮箱
    text = re.sub(r'\S+@\S+\.\S+', '[EMAIL]', text)
    
    # 4. 银行卡号
    text = re.sub(r'\d{16,19}', '[BANK_CARD]', text)
    
    return text

数据格式化

对话格式

{
    "conversations": [
        {"role": "system", "content": "你是一个有帮助的助手"},
        {"role": "user", "content": "什么是机器学习?"},
        {"role": "assistant", "content": "机器学习是人工智能的一个分支..."}
    ]
}

指令格式

def format_instruction(instruction, input_text, output):
    """Alpaca格式"""
    return f"""### Instruction:
{instruction}

### Input:
{input_text}

### Response:
{output}"""

# 转换数据集
def convert_to_alpaca(example):
    return {
        "text": format_instruction(
            example["instruction"],
            example.get("input", ""),
            example["output"]
        )
    }

dataset = dataset.map(convert_to_alpaca)

数据验证

def validate_dataset(dataset):
    """验证数据集质量"""
    stats = {
        "total": len(dataset),
        "avg_length": 0,
        "empty_count": 0,
        "duplicate_count": 0
    }
    
    lengths = []
    seen = set()
    
    for item in dataset:
        text = item["text"]
        lengths.append(len(text))
        
        if not text.strip():
            stats["empty_count"] += 1
        
        if text in seen:
            stats["duplicate_count"] += 1
        seen.add(text)
    
    stats["avg_length"] = sum(lengths) / len(lengths)
    
    return stats

# 验证
stats = validate_dataset(dataset)
print(f"总样本数: {stats['total']}")
print(f"平均长度: {stats['avg_length']:.0f}")
print(f"空文本: {stats['empty_count']}")
print(f"重复文本: {stats['duplicate_count']}")

高质量的数据准备是LLM训练成功的基础,需要投入足够的时间和精力。