← 返回首页
🧠

可复现性:确保LLM实验和结果的可复现性

📂 llm ⏱ 4 min 752 words

可复现性:确保LLM实验和结果的可复现性

为什么可复现性很重要

在大语言模型研究和开发中,实验的可复现性是科学严谨性的基石。可复现性确保其他人能够验证您的结果,并在此基础上进行进一步研究。缺乏可复现性会导致研究浪费和信任危机。

环境控制

完整的环境配置

import torch
import numpy as np
import random
import os
import json
from datetime import datetime

class ReproducibilityManager:
    """管理实验可复现性的工具类"""
    
    def __init__(self, seed=42):
        self.seed = seed
        self.env_info = {}
    
    def set_all_seeds(self):
        """设置所有随机种子"""
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed(self.seed)
        torch.cuda.manual_seed_all(self.seed)
        
        # cuDNN设置
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
        # Python哈希种子
        os.environ['PYTHONHASHSEED'] = str(self.seed)
        
        print(f"Random seeds set to {self.seed}")
    
    def get_environment_info(self):
        """收集环境信息"""
        self.env_info = {
            "python_version": self._get_python_version(),
            "torch_version": torch.__version__,
            "cuda_version": torch.version.cuda,
            "gpu_info": self._get_gpu_info(),
            "timestamp": datetime.now().isoformat(),
            "seed": self.seed,
        }
        return self.env_info
    
    def _get_python_version(self):
        import sys
        return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
    
    def _get_gpu_info(self):
        if torch.cuda.is_available():
            return {
                "name": torch.cuda.get_device_name(0),
                "count": torch.cuda.device_count(),
                "version": torch.version.cuda
            }
        return "No GPU available"
    
    def save_environment(self, path):
        """保存环境信息"""
        self.get_environment_info()
        with open(path, 'w') as f:
            json.dump(self.env_info, f, indent=2)
        print(f"Environment saved to {path}")

依赖锁定

# requirements.txt 固定版本
"""
torch==2.0.1
transformers==4.30.0
datasets==2.13.0
tokenizers==0.13.3
accelerate==0.20.0
numpy==1.24.3
"""

# 使用pipenv或poetry管理依赖
# pipenv:
# pipenv install torch==2.0.1 transformers==4.30.0

# poetry:
# poetry add torch@2.0.1 transformers@4.30.0

训练配置管理

统一的配置系统

from dataclasses import dataclass, field, asdict
from typing import Dict, Any, Optional
import yaml

@dataclass
class TrainingConfig:
    """训练配置数据类"""
    # 模型配置
    model_name: str = "gpt2"
    vocab_size: int = 50257
    hidden_size: int = 768
    num_layers: int = 12
    num_heads: int = 12
    
    # 训练配置
    batch_size: int = 8
    learning_rate: float = 5e-5
    num_epochs: int = 3
    max_seq_length: int = 512
    warmup_steps: int = 100
    
    # 优化器配置
    optimizer: str = "adamw"
    weight_decay: float = 0.01
    adam_epsilon: float = 1e-8
    max_grad_norm: float = 1.0
    
    # 数据配置
    train_file: str = "train.jsonl"
    val_file: str = "val.jsonl"
    test_file: str = "test.jsonl"
    
    # 输出配置
    output_dir: str = "./output"
    logging_steps: int = 10
    save_steps: int = 500
    eval_steps: int = 500
    
    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)
    
    def save(self, path: str):
        with open(path, 'w') as f:
            yaml.dump(self.to_dict(), f, default_flow_style=False)
    
    @classmethod
    def load(cls, path: str) -> 'TrainingConfig':
        with open(path, 'r') as f:
            config_dict = yaml.safe_load(f)
        return cls(**config_dict)

实验命名与组织

from pathlib import Path
import hashlib
from datetime import datetime

class ExperimentOrganizer:
    """组织和管理实验"""
    
    def __init__(self, base_dir: str = "./experiments"):
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(parents=True, exist_ok=True)
    
    def create_experiment(self, config: TrainingConfig, name: Optional[str] = None):
        """创建新的实验目录"""
        # 生成实验名称
        if name is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            config_hash = hashlib.md5(
                json.dumps(config.to_dict(), sort_keys=True).encode()
            ).hexdigest()[:8]
            name = f"exp_{timestamp}_{config_hash}"
        
        exp_dir = self.base_dir / name
        exp_dir.mkdir(parents=True, exist_ok=True)
        
        # 创建子目录
        (exp_dir / "checkpoints").mkdir(exist_ok=True)
        (exp_dir / "logs").mkdir(exist_ok=True)
        (exp_dir / "results").mkdir(exist_ok=True)
        
        # 保存配置
        config.save(exp_dir / "config.yaml")
        
        # 保存git信息
        self._save_git_info(exp_dir / "git_info.json")
        
        return exp_dir
    
    def _save_git_info(self, path):
        """保存Git信息"""
        import subprocess
        
        try:
            commit_hash = subprocess.check_output(
                ["git", "rev-parse", "HEAD"]
            ).decode().strip()
            
            branch = subprocess.check_output(
                ["git", "rev-parse", "--abbrev-ref", "HEAD"]
            ).decode().strip()
            
            # 检查是否有未提交的更改
            status = subprocess.check_output(
                ["git", "status", "--porcelain"]
            ).decode().strip()
            
            git_info = {
                "commit_hash": commit_hash,
                "branch": branch,
                "has_uncommitted_changes": bool(status),
                "status": status[:500] if status else None
            }
        except subprocess.CalledProcessError:
            git_info = {"error": "Not a git repository"}
        
        with open(path, 'w') as f:
            json.dump(git_info, f, indent=2)

数据版本控制

import hashlib
import json
from pathlib import Path

class DataVersionManager:
    """数据版本管理"""
    
    def __init__(self, data_dir: str):
        self.data_dir = Path(data_dir)
        self.manifest_path = self.data_dir / "manifest.json"
    
    def create_manifest(self):
        """创建数据清单"""
        manifest = {
            "files": {},
            "created_at": datetime.now().isoformat()
        }
        
        for file_path in self.data_dir.rglob("*"):
            if file_path.is_file():
                file_hash = self._hash_file(file_path)
                manifest["files"][str(file_path.relative_to(self.data_dir))] = {
                    "hash": file_hash,
                    "size": file_path.stat().st_size,
                    "modified": datetime.fromtimestamp(
                        file_path.stat().st_mtime
                    ).isoformat()
                }
        
        with open(self.manifest_path, 'w') as f:
            json.dump(manifest, f, indent=2)
        
        return manifest
    
    def verify_integrity(self):
        """验证数据完整性"""
        if not self.manifest_path.exists():
            return False, "No manifest found"
        
        with open(self.manifest_path, 'r') as f:
            manifest = json.load(f)
        
        for file_path, info in manifest["files"].items():
            full_path = self.data_dir / file_path
            if not full_path.exists():
                return False, f"Missing file: {file_path}"
            
            current_hash = self._hash_file(full_path)
            if current_hash != info["hash"]:
                return False, f"File modified: {file_path}"
        
        return True, "All files verified"
    
    def _hash_file(self, file_path):
        """计算文件哈希"""
        sha256_hash = hashlib.sha256()
        with open(file_path, "rb") as f:
            for byte_block in iter(lambda: f.read(4096), b""):
                sha256_hash.update(byte_block)
        return sha256_hash.hexdigest()

训练过程记录

class TrainingLogger:
    """训练过程记录器"""
    
    def __init__(self, log_dir: str):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(parents=True, exist_ok=True)
        self.log_file = self.log_dir / "training_log.jsonl"
    
    def log_metrics(self, step: int, metrics: Dict[str, float]):
        """记录训练指标"""
        entry = {
            "step": step,
            "timestamp": datetime.now().isoformat(),
            **metrics
        }
        
        with open(self.log_file, 'a') as f:
            f.write(json.dumps(entry) + '\n')
    
    def log_hyperparameters(self, config: TrainingConfig):
        """记录超参数"""
        hp_path = self.log_dir / "hyperparameters.json"
        with open(hp_path, 'w') as f:
            json.dump(config.to_dict(), f, indent=2)
    
    def log_model_info(self, model):
        """记录模型信息"""
        model_info = {
            "total_parameters": sum(p.numel() for p in model.parameters()),
            "trainable_parameters": sum(
                p.numel() for p in model.parameters() if p.requires_grad
            ),
            "model_size_mb": sum(
                p.numel() * p.element_size() for p in model.parameters()
            ) / (1024 * 1024)
        }
        
        info_path = self.log_dir / "model_info.json"
        with open(info_path, 'w') as f:
            json.dump(model_info, f, indent=2)

验证可复现性

def verify_reproducibility(model_class, config, num_runs=3):
    """验证实验可复现性"""
    results = []
    
    for run in range(num_runs):
        print(f"\n=== Run {run + 1}/{num_runs} ===")
        
        # 设置相同的随机种子
        manager = ReproducibilityManager(seed=config.seed)
        manager.set_all_seeds()
        
        # 初始化模型
        model = model_class(config)
        
        # 记录初始参数
        initial_params = {
            name: param.clone() 
            for name, param in model.named_parameters()
        }
        
        results.append({
            "run": run,
            "initial_params": initial_params,
            "model_hash": hashlib.md5(
                str(list(model.parameters())).encode()
            ).hexdigest()
        })
    
    # 验证所有运行的初始状态相同
    initial_hashes = [
        hashlib.md5(str(r["initial_params"]).encode()).hexdigest()
        for r in results
    ]
    
    if len(set(initial_hashes)) == 1:
        print("✓ 所有运行的初始状态一致")
        return True
    else:
        print("✗ 初始状态不一致,可复现性有问题")
        return False

通过实施这些可复现性措施,可以确保您的LLM实验结果可靠且可验证,为科学研究和工程开发奠定坚实基础。