← 返回首页
🧠

FSDP:PyTorch全分片数据并行

📂 llm ⏱ 2 min 388 words

--- title: "FSDP:PyTorch全分片数据并行" description: "掌握FSDP的核心原理和实践方法,实现高效的分布式模型训练" tags: ["FSDP", "PyTorch", "分布式训练", "全分片"] category: "llm" icon: "🧠"

FSDP:PyTorch全分片数据并行

FSDP简介

FSDP(Fully Sharded Data Parallel)是PyTorch原生的全分片数据并行方案。它借鉴了DeepSpeed ZeRO-3的思想,将模型参数、梯度和优化器状态分片存储在不同的GPU上,在需要时动态聚合。

FSDP的核心优势:

基本使用

简单包装

import torch
import torch.nn as nn
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy
)

# 定义模型
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(1024, 1024) for _ in range(12)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

model = MyModel()

# 包装FSDP
fsdp_model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    mixed_precision=MixedPrecision(
        param_dtype=torch.float16,
        reduce_dtype=torch.float16,
        buffer_dtype=torch.float16
    )
)

分片策略

from torch.distributed.fsdp import ShardingStrategy

# 1. FULL_SHARD:完全分片(类似ZeRO-3)
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD)

# 2. SHARD_GRAD_OP:分片梯度和优化器状态(类似ZeRO-2)
model = FSDP(model, sharding_strategy=ShardingStrategy.SHARD_GRAD_OP)

# 3. NO_SHARD:不分片(类似DDP)
model = FSDP(model, sharding_strategy=ShardingStrategy.NO_SHARD)

# 4. HYBRID_SHARD:混合分片(跨节点全分片,节点内不分片)
model = FSDP(model, sharding_strategy=ShardingStrategy.HYBRID_SHARD)

高级配置

自动包装策略

from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    size_based_auto_wrap_policy
)
from transformers import LlamaDecoderLayer

# 基于Transformer层的自动包装
auto_wrap_policy = transformer_auto_wrap_policy(
    transformer_layer_cls={LlamaDecoderLayer}
)

# 基于参数量的自动包装
auto_wrap_policy = size_based_auto_wrap_policy(
    min_num_params=1e6  # 超过1M参数的模块单独分片
)

model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy,
    sharding_strategy=ShardingStrategy.FULL_SHARD
)

CPU Offload

from torch.distributed.fsdp import CPUOffload

# 启用CPU卸载
cpu_offload = CPUOffload(offload_params=True)

model = FSDP(
    model,
    cpu_offload=cpu_offload,
    sharding_strategy=ShardingStrategy.FULL_SHARD
)

混合精度

from torch.distributed.fsdp import MixedPrecision

# FP16混合精度
fp16_policy = MixedPrecision(
    param_dtype=torch.float16,
    reduce_dtype=torch.float16,
    buffer_dtype=torch.float16
)

# BF16混合精度(A100推荐)
bf16_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16
)

model = FSDP(model, mixed_precision=bf16_policy)

完整训练流程

import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import LlamaForCausalLM, LlamaTokenizer

def setup():
    dist.init_process_group("nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

def cleanup():
    dist.destroy_process_group()

def train():
    setup()
    local_rank = int(os.environ["LOCAL_RANK"])
    
    # 加载模型
    model = LlamaForCausalLM.from_pretrained(
        "meta-llama/Llama-2-7b-hf",
        torch_dtype=torch.bfloat16
    )
    
    # FSDP包装
    auto_wrap_policy = transformer_auto_wrap_policy(
        transformer_layer_cls={model.model.layers[0].__class__}
    )
    
    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16
        ),
        device_id=local_rank,
        sharding_strategy=ShardingStrategy.FULL_SHARD
    )
    
    # 优化器
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    
    # 数据加载
    sampler = DistributedSampler(dataset)
    dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
    
    # 训练循环
    model.train()
    for epoch in range(3):
        for batch in dataloader:
            optimizer.zero_grad()
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
    
    cleanup()

if __name__ == "__main__":
    train()

启动训练

torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
    --master_addr="192.168.1.1" --master_port=29500 \
    train.py

检查点保存与加载

from torch.distributed.fsdp import (
    FullStateDictConfig,
    StateDictType,
    save_state_dict,
    load_state_dict
)

# 保存检查点
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)

with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
    state_dict = model.state_dict()
    if dist.get_rank() == 0:
        torch.save(state_dict, "checkpoint.pt")

# 加载检查点
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
    state_dict = torch.load("checkpoint.pt")
    model.load_state_dict(state_dict)

性能优化

# 1. 启用通信重叠
model = FSDP(
    model,
    forward_prefetch=True,  # 前向传播预取
    backward_prefetch=BackwardPrefetch.BACKWARD_PRE  # 反向传播预取
)

# 2. 激活检查点
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing
)

apply_activation_checkpointing(
    model,
    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
    check_fn=lambda submodule: isinstance(submodule, LlamaDecoderLayer)
)

# 3. 使用torch.compile
model = torch.compile(model)

与DeepSpeed对比

特性 FSDP DeepSpeed
依赖 PyTorch原生 需要安装
ZeRO Stage 3 1/2/3
CPU Offload 支持 支持
易用性
生态集成 PyTorch Hugging Face

FSDP作为PyTorch原生方案,适合追求简洁和标准化的分布式训练场景。