← 返回首页
🧠

LLM训练管道:自动化模型训练流程

📂 llm ⏱ 3 min 530 words

--- title: "LLM训练管道:自动化模型训练流程" description: "构建端到端的LLM训练管道,从数据准备到模型评估的完整流程" tags: ["LLM", "训练管道", "模型训练", "自动化", "MLOps"] category: "llm" icon: "🏋️"

LLM训练管道:自动化模型训练流程

训练管道概述

LLM训练管道将数据预处理、模型训练、评估和部署等环节自动化,确保训练过程可重复、可追踪、可扩展。

训练配置管理

1. 统一配置系统

import json
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional

@dataclass
class TrainingConfig:
    # 模型配置
    model_name: str = "meta-llama/Llama-2-7b-hf"
    model_max_length: int = 2048
    
    # 训练配置
    num_epochs: int = 3
    batch_size: int = 8
    learning_rate: float = 2e-5
    weight_decay: float = 0.01
    warmup_ratio: float = 0.1
    
    # 数据配置
    train_data: str = "data/train.json"
    val_data: str = "data/val.json"
    
    # 输出配置
    output_dir: str = "output/checkpoints"
    logging_dir: str = "output/logs"
    
    # 其他
    seed: int = 42
    fp16: bool = True
    gradient_accumulation_steps: int = 2
    
    def save(self, path):
        Path(path).parent.mkdir(parents=True, exist_ok=True)
        with open(path, "w") as f:
            json.dump(asdict(self), f, indent=2)
    
    @classmethod
    def load(cls, path):
        with open(path) as f:
            return cls(**json.load(f))

# 使用示例
config = TrainingConfig(
    model_name="meta-llama/Llama-2-7b-hf",
    num_epochs=5,
    batch_size=16
)
config.save("configs/training_config.json")

2. 超参数搜索

import itertools

class HyperparameterSearch:
    def __init__(self, base_config):
        self.base_config = base_config
        self.param_grid = {}
    
    def add_param(self, name, values):
        self.param_grid[name] = values
        return self
    
    def generate_configs(self):
        keys = list(self.param_grid.keys())
        values = list(self.param_grid.values())
        
        configs = []
        for combo in itertools.product(*values):
            config_dict = asdict(self.base_config)
            for key, value in zip(keys, combo):
                config_dict[key] = value
            configs.append(TrainingConfig(**config_dict))
        
        return configs

# 使用示例
base = TrainingConfig()
search = HyperparameterSearch(base)
search.add_param("learning_rate", [1e-5, 2e-5, 5e-5])
search.add_param("batch_size", [4, 8, 16])

configs = search.generate_configs()
print(f"生成 {len(configs)} 个配置")

训练流程

1. 训练监控器

import time
from datetime import datetime

class TrainingMonitor:
    def __init__(self, log_dir="output/logs"):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(parents=True, exist_ok=True)
        self.metrics = []
        self.start_time = None
    
    def start(self):
        self.start_time = time.time()
        print(f"训练开始: {datetime.now()}")
    
    def log_metrics(self, step, metrics):
        metrics["step"] = step
        metrics["timestamp"] = datetime.now().isoformat()
        metrics["elapsed"] = time.time() - self.start_time
        self.metrics.append(metrics)
        
        # 实时打印
        metric_str = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}"
                              for k, v in metrics.items() if k != "timestamp")
        print(f"Step {step}: {metric_str}")
    
    def save_metrics(self):
        log_file = self.log_dir / f"metrics_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(log_file, "w") as f:
            json.dump(self.metrics, f, indent=2)
        print(f"指标已保存: {log_file}")
    
    def get_best_step(self, metric="val_loss", mode="min"):
        if not self.metrics:
            return None
        
        if mode == "min":
            best = min(self.metrics, key=lambda x: x.get(metric, float("inf")))
        else:
            best = max(self.metrics, key=lambda x: x.get(metric, 0))
        
        return best

2. 检查点管理

class CheckpointManager:
    def __init__(self, output_dir, max_checkpoints=3):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.max_checkpoints = max_checkpoints
        self.checkpoints = []
    
    def save_checkpoint(self, model, step, metrics):
        ckpt_dir = self.output_dir / f"checkpoint-{step}"
        ckpt_dir.mkdir(exist_ok=True)
        
        # 保存模型
        model.save_pretrained(ckpt_dir)
        
        # 保存元数据
        metadata = {
            "step": step,
            "metrics": metrics,
            "timestamp": datetime.now().isoformat()
        }
        with open(ckpt_dir / "metadata.json", "w") as f:
            json.dump(metadata, f, indent=2)
        
        self.checkpoints.append({"step": step, "path": ckpt_dir, "metrics": metrics})
        self._cleanup_old_checkpoints()
        
        print(f"检查点已保存: {ckpt_dir}")
        return ckpt_dir
    
    def _cleanup_old_checkpoints(self):
        if len(self.checkpoints) > self.max_checkpoints:
            old = self.checkpoints.pop(0)
            import shutil
            shutil.rmtree(old["path"])
            print(f"删除旧检查点: {old['path']}")
    
    def load_best_checkpoint(self, metric="val_loss"):
        best = min(self.checkpoints, key=lambda x: x["metrics"].get(metric, float("inf")))
        return best["path"]

3. 自动化训练管道

class TrainingPipeline:
    def __init__(self, config):
        self.config = config
        self.monitor = TrainingMonitor(config.logging_dir)
        self.ckpt_manager = CheckpointManager(config.output_dir)
    
    def prepare_data(self):
        """准备训练数据"""
        print("加载训练数据...")
        # 实际实现中会加载tokenizer和数据集
        return {"train": None, "val": None}
    
    def setup_model(self):
        """设置模型"""
        print(f"加载模型: {self.config.model_name}")
        # 实际实现中会加载模型
        return None
    
    def train(self):
        """执行训练"""
        self.monitor.start()
        
        # 准备数据
        datasets = self.prepare_data()
        
        # 设置模型
        model = self.setup_model()
        
        # 训练循环
        global_step = 0
        for epoch in range(self.config.num_epochs):
            print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")
            
            # 模拟训练步骤
            for step in range(100):  # 简化示例
                global_step += 1
                
                # 模拟指标
                metrics = {
                    "train_loss": 2.5 - epoch * 0.3 - step * 0.001,
                    "val_loss": 2.3 - epoch * 0.2 - step * 0.0008,
                    "learning_rate": 2e-5 * (1 - step / 100)
                }
                
                self.monitor.log_metrics(global_step, metrics)
                
                # 保存检查点
                if global_step % 50 == 0:
                    self.ckpt_manager.save_checkpoint(model, global_step, metrics)
        
        # 保存最终指标
        self.monitor.save_metrics()
        print("\n训练完成!")

最佳实践

  1. 配置管理:使用版本控制管理训练配置
  2. 实验追踪:记录每次训练的完整指标
  3. 资源管理:监控GPU使用率和内存占用
  4. 失败恢复:支持从检查点恢复训练

总结

训练管道是LLM开发的核心组件。通过自动化的训练流程、完善的监控机制和灵活的配置管理,我们可以高效地训练出高质量的语言模型。