← 返回首页
🧠

数据集策划

📂 llm ⏱ 4 min 749 words

--- title: "数据集策划" description: "详细介绍数据集策划的方法论,包括数据去重、质量过滤和数据多样性控制技术" tags: ["数据策划", "去重", "质量过滤", "数据多样性"] category: "llm" icon: "🧠"

数据集策划

数据集策划概述

数据集策划(Dataset Curation)是指系统性地收集、处理和组织数据,以创建高质量训练数据集的过程。对于LLM而言,数据集策划直接影响模型的性能、泛化能力和安全性。

好的数据集策划应该关注:

数据去重技术

精确去重

import hashlib
from collections import defaultdict

class ExactDeduplicator:
    """精确去重:基于哈希的完全匹配"""
    def __init__(self):
        self.seen_hashes = set()
    
    def compute_hash(self, text):
        """计算文本哈希"""
        return hashlib.sha256(text.encode('utf-8')).hexdigest()
    
    def deduplicate(self, dataset):
        """去重数据集"""
        unique_samples = []
        duplicates = 0
        
        for sample in dataset:
            text_hash = self.compute_hash(sample["text"])
            
            if text_hash not in self.seen_hashes:
                self.seen_hashes.add(text_hash)
                unique_samples.append(sample)
            else:
                duplicates += 1
        
        print(f"移除 {duplicates} 个重复样本")
        return unique_samples

模糊去重

from datasketch import MinHash, MinHashLSH

class FuzzyDeduplicator:
    """模糊去重:基于MinHash的近似匹配"""
    def __init__(self, threshold=0.8, num_perm=128):
        self.threshold = threshold
        self.num_perm = num_perm
        self.lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
        self.minhashes = {}
    
    def get_minhash(self, text, sample_size=1000):
        """计算MinHash签名"""
        m = MinHash(num_perm=self.num_perm)
        words = text.lower().split()
        
        # 采样以提高效率
        if len(words) > sample_size:
            import random
            words = random.sample(words, sample_size)
        
        for word in words:
            m.update(word.encode('utf-8'))
        
        return m
    
    def deduplicate(self, dataset):
        """模糊去重"""
        unique_samples = []
        
        for idx, sample in enumerate(dataset):
            minhash = self.get_minhash(sample["text"])
            
            # 查询相似文档
            result = self.lsh.query(minhash)
            
            if not result:  # 没有相似文档
                self.lsh.insert(str(idx), minhash)
                self.minhashes[idx] = minhash
                unique_samples.append(sample)
        
        return unique_samples

N-gram去重

from collections import Counter

class NGramDeduplicator:
    """N-gram去重:基于重复n-gram的比例"""
    def __init__(self, n=5, threshold=0.8):
        self.n = n
        self.threshold = threshold
    
    def get_ngrams(self, text):
        """提取n-gram"""
        words = text.split()
        return [tuple(words[i:i+self.n]) for i in range(len(words) - self.n + 1)]
    
    def calculate_repetition_ratio(self, text):
        """计算重复n-gram比例"""
        ngrams = self.get_ngrams(text)
        if not ngrams:
            return 0
        
        ngram_counts = Counter(ngrams)
        repeated = sum(count - 1 for count in ngrams.values() if count > 1)
        
        return repeated / len(ngrams)
    
    def filter_repetitive(self, dataset):
        """过滤重复内容过多的文本"""
        filtered = []
        
        for sample in dataset:
            ratio = self.calculate_repetition_ratio(sample["text"])
            if ratio < self.threshold:
                filtered.append(sample)
        
        return filtered

质量过滤

基于规则的过滤

import re

class RuleBasedFilter:
    """基于规则的质量过滤"""
    def __init__(self):
        self.rules = [
            self.check_length,
            self.check_encoding,
            self.check_language,
            self.check_special_chars,
            self.check_repetition
        ]
    
    def check_length(self, text):
        """检查文本长度"""
        words = text.split()
        return 10 <= len(words) <= 10000
    
    def check_encoding(self, text):
        """检查编码问题"""
        try:
            text.encode('utf-8')
            return True
        except UnicodeEncodeError:
            return False
    
    def check_language(self, text):
        """检查语言一致性"""
        # 简单的语言检测
        chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
        english_chars = len(re.findall(r'[a-zA-Z]', text))
        total = chinese_chars + english_chars
        
        if total == 0:
            return False
        
        # 确保主要语言占比超过50%
        primary_ratio = max(chinese_chars, english_chars) / total
        return primary_ratio > 0.5
    
    def check_special_chars(self, text):
        """检查特殊字符比例"""
        special_chars = len(re.findall(r'[^\w\s]', text))
        total_chars = len(text)
        
        if total_chars == 0:
            return False
        
        return special_chars / total_chars < 0.3
    
    def check_repetition(self, text):
        """检查重复内容"""
        words = text.split()
        if len(words) < 10:
            return True
        
        # 检查连续重复
        consecutive_repeats = 0
        for i in range(1, len(words)):
            if words[i] == words[i-1]:
                consecutive_repeats += 1
        
        return consecutive_repeats / len(words) < 0.1
    
    def filter(self, dataset):
        """应用所有规则过滤"""
        filtered = []
        
        for sample in dataset:
            text = sample["text"]
            if all(rule(text) for rule in self.rules):
                filtered.append(sample)
        
        return filtered

基于模型的过滤

from transformers import pipeline

class ModelBasedFilter:
    """基于模型的质量过滤"""
    def __init__(self, quality_threshold=0.5):
        self.quality_threshold = quality_threshold
        # 使用预训练的质量评估模型
        self.quality_classifier = pipeline(
            "text-classification",
            model="unitary/toxic-bert"
        )
    
    def calculate_quality_score(self, text):
        """计算文本质量分数"""
        # 使用质量评估模型
        result = self.quality_classifier(text[:512])
        return result[0]["score"]
    
    def filter_low_quality(self, dataset):
        """过滤低质量文本"""
        filtered = []
        
        for sample in dataset:
            score = self.calculate_quality_score(sample["text"])
            if score >= self.quality_threshold:
                filtered.append(sample)
        
        return filtered

基于困惑度的过滤

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

class PerplexityFilter:
    """基于困惑度的质量过滤"""
    def __init__(self, max_perplexity=100):
        self.max_perplexity = max_perplexity
        self.model = GPT2LMHeadModel.from_pretrained('gpt2')
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.model.eval()
    
    def calculate_perplexity(self, text):
        """计算文本困惑度"""
        inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
        
        with torch.no_grad():
            outputs = self.model(**inputs, labels=inputs['input_ids'])
            loss = outputs.loss
        
        return torch.exp(loss).item()
    
    def filter_by_perplexity(self, dataset):
        """基于困惑度过滤"""
        filtered = []
        
        for sample in dataset:
            perplexity = self.calculate_perplexity(sample["text"])
            if perplexity <= self.max_perplexity:
                filtered.append(sample)
        
        return filtered

数据多样性控制

领域多样性

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans

class DomainDiversityController:
    """领域多样性控制"""
    def __init__(self, num_clusters=10):
        self.num_clusters = num_clusters
        self.vectorizer = TfidfVectorizer(max_features=10000)
    
    def cluster_dataset(self, dataset):
        """聚类数据集以确保领域覆盖"""
        texts = [sample["text"][:1000] for sample in dataset]
        tfidf_matrix = self.vectorizer.fit_transform(texts)
        
        kmeans = KMeans(n_clusters=self.num_clusters, random_state=42)
        clusters = kmeans.fit_predict(tfidf_matrix)
        
        return clusters
    
    def balance_domains(self, dataset, target_samples_per_cluster=1000):
        """平衡各领域样本数"""
        clusters = self.cluster_dataset(dataset)
        
        balanced_dataset = []
        cluster_counts = {}
        
        for sample, cluster_id in zip(dataset, clusters):
            cluster_counts[cluster_id] = cluster_counts.get(cluster_id, 0) + 1
            
            if cluster_counts[cluster_id] <= target_samples_per_cluster:
                balanced_dataset.append(sample)
        
        return balanced_dataset

语言多样性

class LanguageDiversityController:
    """语言多样性控制"""
    def __init__(self, target_distribution=None):
        self.target_distribution = target_distribution or {
            "zh": 0.4,
            "en": 0.4,
            "other": 0.2
        }
    
    def detect_language(self, text):
        """简单语言检测"""
        chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
        english_chars = len(re.findall(r'[a-zA-Z]', text))
        total = chinese_chars + english_chars
        
        if total == 0:
            return "other"
        
        if chinese_chars / total > 0.5:
            return "zh"
        elif english_chars / total > 0.5:
            return "en"
        else:
            return "other"
    
    def balance_languages(self, dataset, total_samples=10000):
        """平衡语言分布"""
        language_buckets = {"zh": [], "en": [], "other": []}
        
        for sample in dataset:
            lang = self.detect_language(sample["text"])
            language_buckets[lang].append(sample)
        
        balanced_dataset = []
        for lang, target_ratio in self.target_distribution.items():
            target_count = int(total_samples * target_ratio)
            balanced_dataset.extend(language_buckets[lang][:target_count])
        
        return balanced_dataset

数据集评估

def evaluate_dataset_quality(dataset):
    """综合评估数据集质量"""
    metrics = {
        "size": len(dataset),
        "avg_length": sum(len(s["text"].split()) for s in dataset) / len(dataset),
        "diversity": calculate_diversity_score(dataset),
        "quality": calculate_quality_score(dataset),
        "balance": calculate_balance_score(dataset)
    }
    return metrics

def calculate_diversity_score(dataset):
    """计算多样性分数"""
    # 使用n-gram多样性
    all_ngrams = set()
    for sample in dataset:
        words = sample["text"].split()
        ngrams = set(tuple(words[i:i+3]) for i in range(len(words)-2))
        all_ngrams.update(ngrams)
    
    return len(all_ngrams) / (len(dataset) * 100)

总结

数据集策划是构建高质量LLM的关键步骤。通过精确和模糊去重消除冗余,使用规则和模型过滤低质量数据,以及控制领域和语言多样性,可以创建出高质量的训练数据集。