可复现性:确保LLM实验和结果的可复现性
可复现性:确保LLM实验和结果的可复现性
为什么可复现性很重要
在大语言模型研究和开发中,实验的可复现性是科学严谨性的基石。可复现性确保其他人能够验证您的结果,并在此基础上进行进一步研究。缺乏可复现性会导致研究浪费和信任危机。
环境控制
完整的环境配置
import torch
import numpy as np
import random
import os
import json
from datetime import datetime
class ReproducibilityManager:
"""管理实验可复现性的工具类"""
def __init__(self, seed=42):
self.seed = seed
self.env_info = {}
def set_all_seeds(self):
"""设置所有随机种子"""
random.seed(self.seed)
np.random.seed(self.seed)
torch.manual_seed(self.seed)
torch.cuda.manual_seed(self.seed)
torch.cuda.manual_seed_all(self.seed)
# cuDNN设置
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Python哈希种子
os.environ['PYTHONHASHSEED'] = str(self.seed)
print(f"Random seeds set to {self.seed}")
def get_environment_info(self):
"""收集环境信息"""
self.env_info = {
"python_version": self._get_python_version(),
"torch_version": torch.__version__,
"cuda_version": torch.version.cuda,
"gpu_info": self._get_gpu_info(),
"timestamp": datetime.now().isoformat(),
"seed": self.seed,
}
return self.env_info
def _get_python_version(self):
import sys
return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
def _get_gpu_info(self):
if torch.cuda.is_available():
return {
"name": torch.cuda.get_device_name(0),
"count": torch.cuda.device_count(),
"version": torch.version.cuda
}
return "No GPU available"
def save_environment(self, path):
"""保存环境信息"""
self.get_environment_info()
with open(path, 'w') as f:
json.dump(self.env_info, f, indent=2)
print(f"Environment saved to {path}")
依赖锁定
# requirements.txt 固定版本
"""
torch==2.0.1
transformers==4.30.0
datasets==2.13.0
tokenizers==0.13.3
accelerate==0.20.0
numpy==1.24.3
"""
# 使用pipenv或poetry管理依赖
# pipenv:
# pipenv install torch==2.0.1 transformers==4.30.0
# poetry:
# poetry add torch@2.0.1 transformers@4.30.0
训练配置管理
统一的配置系统
from dataclasses import dataclass, field, asdict
from typing import Dict, Any, Optional
import yaml
@dataclass
class TrainingConfig:
"""训练配置数据类"""
# 模型配置
model_name: str = "gpt2"
vocab_size: int = 50257
hidden_size: int = 768
num_layers: int = 12
num_heads: int = 12
# 训练配置
batch_size: int = 8
learning_rate: float = 5e-5
num_epochs: int = 3
max_seq_length: int = 512
warmup_steps: int = 100
# 优化器配置
optimizer: str = "adamw"
weight_decay: float = 0.01
adam_epsilon: float = 1e-8
max_grad_norm: float = 1.0
# 数据配置
train_file: str = "train.jsonl"
val_file: str = "val.jsonl"
test_file: str = "test.jsonl"
# 输出配置
output_dir: str = "./output"
logging_steps: int = 10
save_steps: int = 500
eval_steps: int = 500
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
def save(self, path: str):
with open(path, 'w') as f:
yaml.dump(self.to_dict(), f, default_flow_style=False)
@classmethod
def load(cls, path: str) -> 'TrainingConfig':
with open(path, 'r') as f:
config_dict = yaml.safe_load(f)
return cls(**config_dict)
实验命名与组织
from pathlib import Path
import hashlib
from datetime import datetime
class ExperimentOrganizer:
"""组织和管理实验"""
def __init__(self, base_dir: str = "./experiments"):
self.base_dir = Path(base_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
def create_experiment(self, config: TrainingConfig, name: Optional[str] = None):
"""创建新的实验目录"""
# 生成实验名称
if name is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config_hash = hashlib.md5(
json.dumps(config.to_dict(), sort_keys=True).encode()
).hexdigest()[:8]
name = f"exp_{timestamp}_{config_hash}"
exp_dir = self.base_dir / name
exp_dir.mkdir(parents=True, exist_ok=True)
# 创建子目录
(exp_dir / "checkpoints").mkdir(exist_ok=True)
(exp_dir / "logs").mkdir(exist_ok=True)
(exp_dir / "results").mkdir(exist_ok=True)
# 保存配置
config.save(exp_dir / "config.yaml")
# 保存git信息
self._save_git_info(exp_dir / "git_info.json")
return exp_dir
def _save_git_info(self, path):
"""保存Git信息"""
import subprocess
try:
commit_hash = subprocess.check_output(
["git", "rev-parse", "HEAD"]
).decode().strip()
branch = subprocess.check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"]
).decode().strip()
# 检查是否有未提交的更改
status = subprocess.check_output(
["git", "status", "--porcelain"]
).decode().strip()
git_info = {
"commit_hash": commit_hash,
"branch": branch,
"has_uncommitted_changes": bool(status),
"status": status[:500] if status else None
}
except subprocess.CalledProcessError:
git_info = {"error": "Not a git repository"}
with open(path, 'w') as f:
json.dump(git_info, f, indent=2)
数据版本控制
import hashlib
import json
from pathlib import Path
class DataVersionManager:
"""数据版本管理"""
def __init__(self, data_dir: str):
self.data_dir = Path(data_dir)
self.manifest_path = self.data_dir / "manifest.json"
def create_manifest(self):
"""创建数据清单"""
manifest = {
"files": {},
"created_at": datetime.now().isoformat()
}
for file_path in self.data_dir.rglob("*"):
if file_path.is_file():
file_hash = self._hash_file(file_path)
manifest["files"][str(file_path.relative_to(self.data_dir))] = {
"hash": file_hash,
"size": file_path.stat().st_size,
"modified": datetime.fromtimestamp(
file_path.stat().st_mtime
).isoformat()
}
with open(self.manifest_path, 'w') as f:
json.dump(manifest, f, indent=2)
return manifest
def verify_integrity(self):
"""验证数据完整性"""
if not self.manifest_path.exists():
return False, "No manifest found"
with open(self.manifest_path, 'r') as f:
manifest = json.load(f)
for file_path, info in manifest["files"].items():
full_path = self.data_dir / file_path
if not full_path.exists():
return False, f"Missing file: {file_path}"
current_hash = self._hash_file(full_path)
if current_hash != info["hash"]:
return False, f"File modified: {file_path}"
return True, "All files verified"
def _hash_file(self, file_path):
"""计算文件哈希"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
训练过程记录
class TrainingLogger:
"""训练过程记录器"""
def __init__(self, log_dir: str):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
self.log_file = self.log_dir / "training_log.jsonl"
def log_metrics(self, step: int, metrics: Dict[str, float]):
"""记录训练指标"""
entry = {
"step": step,
"timestamp": datetime.now().isoformat(),
**metrics
}
with open(self.log_file, 'a') as f:
f.write(json.dumps(entry) + '\n')
def log_hyperparameters(self, config: TrainingConfig):
"""记录超参数"""
hp_path = self.log_dir / "hyperparameters.json"
with open(hp_path, 'w') as f:
json.dump(config.to_dict(), f, indent=2)
def log_model_info(self, model):
"""记录模型信息"""
model_info = {
"total_parameters": sum(p.numel() for p in model.parameters()),
"trainable_parameters": sum(
p.numel() for p in model.parameters() if p.requires_grad
),
"model_size_mb": sum(
p.numel() * p.element_size() for p in model.parameters()
) / (1024 * 1024)
}
info_path = self.log_dir / "model_info.json"
with open(info_path, 'w') as f:
json.dump(model_info, f, indent=2)
验证可复现性
def verify_reproducibility(model_class, config, num_runs=3):
"""验证实验可复现性"""
results = []
for run in range(num_runs):
print(f"\n=== Run {run + 1}/{num_runs} ===")
# 设置相同的随机种子
manager = ReproducibilityManager(seed=config.seed)
manager.set_all_seeds()
# 初始化模型
model = model_class(config)
# 记录初始参数
initial_params = {
name: param.clone()
for name, param in model.named_parameters()
}
results.append({
"run": run,
"initial_params": initial_params,
"model_hash": hashlib.md5(
str(list(model.parameters())).encode()
).hexdigest()
})
# 验证所有运行的初始状态相同
initial_hashes = [
hashlib.md5(str(r["initial_params"]).encode()).hexdigest()
for r in results
]
if len(set(initial_hashes)) == 1:
print("✓ 所有运行的初始状态一致")
return True
else:
print("✗ 初始状态不一致,可复现性有问题")
return False
通过实施这些可复现性措施,可以确保您的LLM实验结果可靠且可验证,为科学研究和工程开发奠定坚实基础。