FSDP:PyTorch全分片数据并行
--- title: "FSDP:PyTorch全分片数据并行" description: "掌握FSDP的核心原理和实践方法,实现高效的分布式模型训练" tags: ["FSDP", "PyTorch", "分布式训练", "全分片"] category: "llm" icon: "🧠"
FSDP:PyTorch全分片数据并行
FSDP简介
FSDP(Fully Sharded Data Parallel)是PyTorch原生的全分片数据并行方案。它借鉴了DeepSpeed ZeRO-3的思想,将模型参数、梯度和优化器状态分片存储在不同的GPU上,在需要时动态聚合。
FSDP的核心优势:
- 原生PyTorch:无需额外依赖,与PyTorch生态无缝集成
- 灵活分片:支持多种分片粒度
- CPU Offload:支持将参数卸载到CPU内存
- 混合精度:原生支持FP16/BF16训练
基本使用
简单包装
import torch
import torch.nn as nn
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy
)
# 定义模型
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(1024, 1024) for _ in range(12)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
model = MyModel()
# 包装FSDP
fsdp_model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16
)
)
分片策略
from torch.distributed.fsdp import ShardingStrategy
# 1. FULL_SHARD:完全分片(类似ZeRO-3)
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD)
# 2. SHARD_GRAD_OP:分片梯度和优化器状态(类似ZeRO-2)
model = FSDP(model, sharding_strategy=ShardingStrategy.SHARD_GRAD_OP)
# 3. NO_SHARD:不分片(类似DDP)
model = FSDP(model, sharding_strategy=ShardingStrategy.NO_SHARD)
# 4. HYBRID_SHARD:混合分片(跨节点全分片,节点内不分片)
model = FSDP(model, sharding_strategy=ShardingStrategy.HYBRID_SHARD)
高级配置
自动包装策略
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
size_based_auto_wrap_policy
)
from transformers import LlamaDecoderLayer
# 基于Transformer层的自动包装
auto_wrap_policy = transformer_auto_wrap_policy(
transformer_layer_cls={LlamaDecoderLayer}
)
# 基于参数量的自动包装
auto_wrap_policy = size_based_auto_wrap_policy(
min_num_params=1e6 # 超过1M参数的模块单独分片
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD
)
CPU Offload
from torch.distributed.fsdp import CPUOffload
# 启用CPU卸载
cpu_offload = CPUOffload(offload_params=True)
model = FSDP(
model,
cpu_offload=cpu_offload,
sharding_strategy=ShardingStrategy.FULL_SHARD
)
混合精度
from torch.distributed.fsdp import MixedPrecision
# FP16混合精度
fp16_policy = MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16
)
# BF16混合精度(A100推荐)
bf16_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16
)
model = FSDP(model, mixed_precision=bf16_policy)
完整训练流程
import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import LlamaForCausalLM, LlamaTokenizer
def setup():
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
def cleanup():
dist.destroy_process_group()
def train():
setup()
local_rank = int(os.environ["LOCAL_RANK"])
# 加载模型
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.bfloat16
)
# FSDP包装
auto_wrap_policy = transformer_auto_wrap_policy(
transformer_layer_cls={model.model.layers[0].__class__}
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16
),
device_id=local_rank,
sharding_strategy=ShardingStrategy.FULL_SHARD
)
# 优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
# 数据加载
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
# 训练循环
model.train()
for epoch in range(3):
for batch in dataloader:
optimizer.zero_grad()
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
cleanup()
if __name__ == "__main__":
train()
启动训练
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
--master_addr="192.168.1.1" --master_port=29500 \
train.py
检查点保存与加载
from torch.distributed.fsdp import (
FullStateDictConfig,
StateDictType,
save_state_dict,
load_state_dict
)
# 保存检查点
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
state_dict = model.state_dict()
if dist.get_rank() == 0:
torch.save(state_dict, "checkpoint.pt")
# 加载检查点
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
state_dict = torch.load("checkpoint.pt")
model.load_state_dict(state_dict)
性能优化
# 1. 启用通信重叠
model = FSDP(
model,
forward_prefetch=True, # 前向传播预取
backward_prefetch=BackwardPrefetch.BACKWARD_PRE # 反向传播预取
)
# 2. 激活检查点
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing
)
apply_activation_checkpointing(
model,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
check_fn=lambda submodule: isinstance(submodule, LlamaDecoderLayer)
)
# 3. 使用torch.compile
model = torch.compile(model)
与DeepSpeed对比
| 特性 | FSDP | DeepSpeed |
|---|---|---|
| 依赖 | PyTorch原生 | 需要安装 |
| ZeRO Stage | 3 | 1/2/3 |
| CPU Offload | 支持 | 支持 |
| 易用性 | 高 | 中 |
| 生态集成 | PyTorch | Hugging Face |
FSDP作为PyTorch原生方案,适合追求简洁和标准化的分布式训练场景。