数据集策划
--- title: "数据集策划" description: "详细介绍数据集策划的方法论,包括数据去重、质量过滤和数据多样性控制技术" tags: ["数据策划", "去重", "质量过滤", "数据多样性"] category: "llm" icon: "🧠"
数据集策划
数据集策划概述
数据集策划(Dataset Curation)是指系统性地收集、处理和组织数据,以创建高质量训练数据集的过程。对于LLM而言,数据集策划直接影响模型的性能、泛化能力和安全性。
好的数据集策划应该关注:
- 数据质量:确保数据准确、一致、无错误
- 数据多样性:覆盖足够的领域、语言和风格
- 数据平衡:避免某些类别过度代表
- 数据时效性:保持数据的更新和相关性
数据去重技术
精确去重
import hashlib
from collections import defaultdict
class ExactDeduplicator:
"""精确去重:基于哈希的完全匹配"""
def __init__(self):
self.seen_hashes = set()
def compute_hash(self, text):
"""计算文本哈希"""
return hashlib.sha256(text.encode('utf-8')).hexdigest()
def deduplicate(self, dataset):
"""去重数据集"""
unique_samples = []
duplicates = 0
for sample in dataset:
text_hash = self.compute_hash(sample["text"])
if text_hash not in self.seen_hashes:
self.seen_hashes.add(text_hash)
unique_samples.append(sample)
else:
duplicates += 1
print(f"移除 {duplicates} 个重复样本")
return unique_samples
模糊去重
from datasketch import MinHash, MinHashLSH
class FuzzyDeduplicator:
"""模糊去重:基于MinHash的近似匹配"""
def __init__(self, threshold=0.8, num_perm=128):
self.threshold = threshold
self.num_perm = num_perm
self.lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
self.minhashes = {}
def get_minhash(self, text, sample_size=1000):
"""计算MinHash签名"""
m = MinHash(num_perm=self.num_perm)
words = text.lower().split()
# 采样以提高效率
if len(words) > sample_size:
import random
words = random.sample(words, sample_size)
for word in words:
m.update(word.encode('utf-8'))
return m
def deduplicate(self, dataset):
"""模糊去重"""
unique_samples = []
for idx, sample in enumerate(dataset):
minhash = self.get_minhash(sample["text"])
# 查询相似文档
result = self.lsh.query(minhash)
if not result: # 没有相似文档
self.lsh.insert(str(idx), minhash)
self.minhashes[idx] = minhash
unique_samples.append(sample)
return unique_samples
N-gram去重
from collections import Counter
class NGramDeduplicator:
"""N-gram去重:基于重复n-gram的比例"""
def __init__(self, n=5, threshold=0.8):
self.n = n
self.threshold = threshold
def get_ngrams(self, text):
"""提取n-gram"""
words = text.split()
return [tuple(words[i:i+self.n]) for i in range(len(words) - self.n + 1)]
def calculate_repetition_ratio(self, text):
"""计算重复n-gram比例"""
ngrams = self.get_ngrams(text)
if not ngrams:
return 0
ngram_counts = Counter(ngrams)
repeated = sum(count - 1 for count in ngrams.values() if count > 1)
return repeated / len(ngrams)
def filter_repetitive(self, dataset):
"""过滤重复内容过多的文本"""
filtered = []
for sample in dataset:
ratio = self.calculate_repetition_ratio(sample["text"])
if ratio < self.threshold:
filtered.append(sample)
return filtered
质量过滤
基于规则的过滤
import re
class RuleBasedFilter:
"""基于规则的质量过滤"""
def __init__(self):
self.rules = [
self.check_length,
self.check_encoding,
self.check_language,
self.check_special_chars,
self.check_repetition
]
def check_length(self, text):
"""检查文本长度"""
words = text.split()
return 10 <= len(words) <= 10000
def check_encoding(self, text):
"""检查编码问题"""
try:
text.encode('utf-8')
return True
except UnicodeEncodeError:
return False
def check_language(self, text):
"""检查语言一致性"""
# 简单的语言检测
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
english_chars = len(re.findall(r'[a-zA-Z]', text))
total = chinese_chars + english_chars
if total == 0:
return False
# 确保主要语言占比超过50%
primary_ratio = max(chinese_chars, english_chars) / total
return primary_ratio > 0.5
def check_special_chars(self, text):
"""检查特殊字符比例"""
special_chars = len(re.findall(r'[^\w\s]', text))
total_chars = len(text)
if total_chars == 0:
return False
return special_chars / total_chars < 0.3
def check_repetition(self, text):
"""检查重复内容"""
words = text.split()
if len(words) < 10:
return True
# 检查连续重复
consecutive_repeats = 0
for i in range(1, len(words)):
if words[i] == words[i-1]:
consecutive_repeats += 1
return consecutive_repeats / len(words) < 0.1
def filter(self, dataset):
"""应用所有规则过滤"""
filtered = []
for sample in dataset:
text = sample["text"]
if all(rule(text) for rule in self.rules):
filtered.append(sample)
return filtered
基于模型的过滤
from transformers import pipeline
class ModelBasedFilter:
"""基于模型的质量过滤"""
def __init__(self, quality_threshold=0.5):
self.quality_threshold = quality_threshold
# 使用预训练的质量评估模型
self.quality_classifier = pipeline(
"text-classification",
model="unitary/toxic-bert"
)
def calculate_quality_score(self, text):
"""计算文本质量分数"""
# 使用质量评估模型
result = self.quality_classifier(text[:512])
return result[0]["score"]
def filter_low_quality(self, dataset):
"""过滤低质量文本"""
filtered = []
for sample in dataset:
score = self.calculate_quality_score(sample["text"])
if score >= self.quality_threshold:
filtered.append(sample)
return filtered
基于困惑度的过滤
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
class PerplexityFilter:
"""基于困惑度的质量过滤"""
def __init__(self, max_perplexity=100):
self.max_perplexity = max_perplexity
self.model = GPT2LMHeadModel.from_pretrained('gpt2')
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.model.eval()
def calculate_perplexity(self, text):
"""计算文本困惑度"""
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
with torch.no_grad():
outputs = self.model(**inputs, labels=inputs['input_ids'])
loss = outputs.loss
return torch.exp(loss).item()
def filter_by_perplexity(self, dataset):
"""基于困惑度过滤"""
filtered = []
for sample in dataset:
perplexity = self.calculate_perplexity(sample["text"])
if perplexity <= self.max_perplexity:
filtered.append(sample)
return filtered
数据多样性控制
领域多样性
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
class DomainDiversityController:
"""领域多样性控制"""
def __init__(self, num_clusters=10):
self.num_clusters = num_clusters
self.vectorizer = TfidfVectorizer(max_features=10000)
def cluster_dataset(self, dataset):
"""聚类数据集以确保领域覆盖"""
texts = [sample["text"][:1000] for sample in dataset]
tfidf_matrix = self.vectorizer.fit_transform(texts)
kmeans = KMeans(n_clusters=self.num_clusters, random_state=42)
clusters = kmeans.fit_predict(tfidf_matrix)
return clusters
def balance_domains(self, dataset, target_samples_per_cluster=1000):
"""平衡各领域样本数"""
clusters = self.cluster_dataset(dataset)
balanced_dataset = []
cluster_counts = {}
for sample, cluster_id in zip(dataset, clusters):
cluster_counts[cluster_id] = cluster_counts.get(cluster_id, 0) + 1
if cluster_counts[cluster_id] <= target_samples_per_cluster:
balanced_dataset.append(sample)
return balanced_dataset
语言多样性
class LanguageDiversityController:
"""语言多样性控制"""
def __init__(self, target_distribution=None):
self.target_distribution = target_distribution or {
"zh": 0.4,
"en": 0.4,
"other": 0.2
}
def detect_language(self, text):
"""简单语言检测"""
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
english_chars = len(re.findall(r'[a-zA-Z]', text))
total = chinese_chars + english_chars
if total == 0:
return "other"
if chinese_chars / total > 0.5:
return "zh"
elif english_chars / total > 0.5:
return "en"
else:
return "other"
def balance_languages(self, dataset, total_samples=10000):
"""平衡语言分布"""
language_buckets = {"zh": [], "en": [], "other": []}
for sample in dataset:
lang = self.detect_language(sample["text"])
language_buckets[lang].append(sample)
balanced_dataset = []
for lang, target_ratio in self.target_distribution.items():
target_count = int(total_samples * target_ratio)
balanced_dataset.extend(language_buckets[lang][:target_count])
return balanced_dataset
数据集评估
def evaluate_dataset_quality(dataset):
"""综合评估数据集质量"""
metrics = {
"size": len(dataset),
"avg_length": sum(len(s["text"].split()) for s in dataset) / len(dataset),
"diversity": calculate_diversity_score(dataset),
"quality": calculate_quality_score(dataset),
"balance": calculate_balance_score(dataset)
}
return metrics
def calculate_diversity_score(dataset):
"""计算多样性分数"""
# 使用n-gram多样性
all_ngrams = set()
for sample in dataset:
words = sample["text"].split()
ngrams = set(tuple(words[i:i+3]) for i in range(len(words)-2))
all_ngrams.update(ngrams)
return len(all_ngrams) / (len(dataset) * 100)
总结
数据集策划是构建高质量LLM的关键步骤。通过精确和模糊去重消除冗余,使用规则和模型过滤低质量数据,以及控制领域和语言多样性,可以创建出高质量的训练数据集。