FSDP2
fsdp 详细介绍了 FSDP1 的主要设计和优化,随着模型越来越大,社区发现 FSDP1 的 FlatParameter 限制了进一步优化训练性能。为了解决这个问题,社区提出了下一代的 FSDP: Per-Parameter-Sharding FSDP,也就是 FSDP21。
相比于 FSDP1,FSDP2 主要有以下不同2:
- 相比于 FSDP1 的
FlatParameter的 sharding 设计,FSDP2 基于 DTensor3 实现了Per-Parameter Sharding,从而实现了更加简单的sharding representation,并且同时实现了和 FSDP1 类似的性能。 - 为了解决 FSDP 中的
multi-stream memory usage的问题,FSDP2 不再使用 FSDP1 中的torch.Tensor.record_stream,而是基于 CUDA Event 实现了更加精细的 GPU 显存管理,从而实现了deterministic的 memory usage,并且不再需要通过 FSDP1 中的limit_all_gathers=True来 block CPU 操作。 - FSDP2 简化了 API 接口,比如不再直接支持 full state dicts,而是可以通过
DTensor的DTensor.full_tensor()API 来 reshard 分片的 state dict 从而获得 full state dicts
FSDP1 中 FlatParameter 的限制
上述最主要的一个变化是,现在的 FSDP2 的 sharding 是基于每个 parameter 的,并且深度和目前 torch 原生系统结合,从而解决了之前 FlatParameter 设计下的限制:
- Flexible fp8 all-gather4
原来的 FSDP1 可以支持 fp8 compute,但是不能灵活的支持 fp8 all-gather 5。这是因为 FlatParameter 里面是将多个参数聚集到一起,并且摊平到一维做分片,这里是要求同一个 FlatParameter 里面的参数数据类型都一致。因此 FSDP1 不支持 fp8 和 non-fp8 的参数在一起做 all-gather。
而 FSDP2 因为是基于每个参数做 sharding,是可以灵活支持 fp8 all-gather 的4,后面会详细介绍。
- Flexible frozen parameters
在 Rethinking Pytorch Fully Sharded Data Parallel6 这篇文章中提到,
- Communication-free sharded state dict
- Future communication optimization in the compiler
FSDP2 API 演进与设计变化
FSDP2 为了简化构建逻辑,不再与 FSDP1 保持 API 兼容,新的 API 如下所示:
|
|
|
|
具体使用如下:
|
|
对比 FSDP1 和 FSDP2 的 API 变化:
| FSDP1 | FSDP2 |
|---|---|
module |
module |
process_group/device_mesh |
mesh |
sharding_strategy |
reshard_after_forward |
cpu_offload |
offload_policy |
auto_wrap_policy |
removed |
backward_prefetch |
removed |
mixed_precision |
mp_policy |
ignored_modules/ignored_states |
not yet implemented |
param_init_fn |
removed |
device_id |
removed |
sync_module_states |
removed |
forward_prefetch |
not yet implemented |
limit_all_gathers |
removed |
use_orig_params |
removed |
相比于 FSDP1 原来的 FlatParameter 和 FlatParamHandle 的抽象,FSDP2 主要的设计变化如下:
- 放弃
FlatParameter的设计,而是每个将每个sharded parameter表示成在dim-0维度 sharded 的DTensor,也就是实现了Per-Parameter-Sharding - FSDP2 不再通过
recordStream来实现不同 stream 之间的 tensor 同步,而是通过 event 来显式的手动完成同步,从而实现了更低的和 deterministic 的 GPU memory 占用,并且也不再需要 CPU 同步 (即 FSDP1 中的 all-gather rate limiter)
FSDPParam 与 FSDPParamGroup 抽象
对应于第一点变化,FSDP2 重新设计了以下组件:
FSDPModule:经过full_shard(module)之后每个 module 会实现动态替换,比如假设原来的 module 是Transformer,那么 FSDP2 会新创建一个FSDPTransformer的类。这个类继承自FSDPModule和Transformer,并且设置module.__class__为FSDPTransformerFSDPState:经过full_shard(module)之后,每个 module 会添加一个FSDPState的对象,并且可以通过full_shard.state(module)来访问。它用于维护与 FSDP 相关的状态,并且是通过@contract来 embed 到 module 的FSDPParamGroup:一个 ParamGroup 里面的所有参数用同一个通信组FSDPParam:对应于通信组里面的一个参数,包含了以 DTensor 形式存在的sharded_param
对比下 FSDP1:
fully_shard+FSDPState加起来类似于原来的FullyShardedDataParallelFSDPParamGroup类似于原来的FlatParamHandleFSDPParam类似于FlatParamter,但是现在只有一个参数
以 TorchTitan 7 训练的 LLaMA3 8B 为例,在 full_shard 操作之前,模型如下所示:
|
|
经过 FSDP2 的 full_shard 之后,32 个 TransformerBlock 被设置为 FSDPTransformerBlock,root module 从 Transformer 变为 FSDPTransformer
|
|
实际前向和反向过程中对应的 FSDPParamGroup 中的 FSDPParam:
|
|
Sharding 策略与 ZeRO 映射
在 FSDP1 中,通过 sharding_strategy 来实现不同的数据并行策略,包括 ZeRO-3/ZeRO-2 和 HSDP 等,在 FSDP2,通过 reshard_after_forward 和 mesh 来实现对不同数据并行策略的控制:
| FSDP 1 | FSDP 2 | DeepSpeed |
|---|---|---|
1 process_group + FULL_SHARD |
1D mesh + reshard_after_forward=True |
ZeRO-3 |
1 process_group + SHARD_GRAD_OP |
1D mesh + reshard_after_forward=False |
ZeRO-2 |
2 process_group s/2D device_mesh + HYBRID_SHARD |
2D mesh + reshard_after_forward=True |
MiCS |
2 process_group s/2D device_mesh + _HYBRID_SHARD_ZERO2 |
2D mesh + reshard_after_forward=False |
- |
| - | 1D/2D mesh + reshard_after_forward=8 (int) |
ZeRO++ hpZ8 |
更原生兼容 torch API
Existing FSDP supports three kinds of state dicts: full, sharded, and local. We define a “clean” fully-qualified name (FQN) to be one without any nn.Module wrapper prefixes (e.g. _fsdp_wrapped_module. for FSDP or module. for DDP).
- A full state dict maps clean FQN to unsharded
torch.Tensor. Since this may use too much GPU memory, it offersrank0_only: boolandoffload_to_cpu: booloptions. This requires all-gathering parameters. - A sharded state dict maps clean FQN to sharded
DTensor. For existing FSDP, this requires all-gathering parameters and resharding to per-parameter sharding. - A local state dict maps clean FQN to the sharded
FlatParameteras aShardedTensor. This does not require communication but cannot be loaded into a different distributed training setup.
For per-parameter-sharding FSDP, we only support sharded state dict.
- Getting a full state dict can be implemented as a post-processing step to saving the sharded state dict.
- Sharded and local state dicts are equivalent since the training and checkpointing representation match
| FSDP1 | FSDP2 |
|---|---|
model.state_dict(): full state dict |
model.state_dict(): sharded state dict (no communication) |
optim.state_dict(): local state dict |
optim.state_dict(): sharded state dict (no communication) |
summon_full_params() |
use DTensor APIs like full_tensor() |
FSDP.clip_grad_norm_() |
nn.utils.clip_grad_norm_() |
ShardedGradScaler |
amp.grad_scaler.GradScaler |
Meta-Device Initialization
在 FSDP1 的时候,meta device 上初始化方法如下:
|
|
FSDP1 需要
FSDP1 requires either
reset_parametersorparam_init_fnto materialize a module onto GPU immediately before sharding. To do this correctly without re-initializing any tensors requires care and can be unwieldy. However, FSDP2 allows materializing tensors onto GPU after sharding (taking advantage ofDTensorand a newswap_tensorspath fornn.Module._applymethods).
FSDP2 supports a new meta-device initialization flow that does not require materializing a module on GPU before sharding it, removing the need for
param_init_fn
对于 FSDP2
|
|
FSDP2 计算与通信流程
和 FSDP1 一样,FSDP2 在模型计算时需要实现:
- FWD/BWD 之前,需要通过
all-gather来 unshard 参数,并且在计算之后 reshardc释放掉 - 在 BWD 计算出 gradient 之后,需要通过
reduce-scatter来平均梯度
具体这一目标的是通过 Module 对应的 hook 以及具体 FSDPParamGroup 和 FSDPState 操作来实现:
- Root pre-forward:通过
_fsdp_state.py中的register_forward_pre_hook实现- 完成 lazy initialization, stream synchronization, move inputs to GPU 等操作
- 对每一个 state/param group:
- State pre-forward (
register_forward_pre_hook()):cast forward inputs if needed - Parameter group pre-forward: 执行 unshard 参数
- Parameter group post-forward: 执行 reshard 参数
- State post-forward (
register_forward_hook()): cast forward outputs if need
- State pre-forward (
- Root post-forward
- Root pre-backward (
register_hook): queue back - 对每一个 state/param group:
- Pre-backward (
register_hook):unshard 参数,prefetch parameters if needed - Post-backward (
autograd.Function.backward):reshard 参数,reduce-scatter 梯度
- Pre-backward (
- Root post-backward final callback:
_execution_engine.queue_callback()- 完成 finalize parameter group backwards,stream/event 同步,数据结构重置等操作
在简单看了上面的流程之后,和之前 FSDP1 中讨论的问题一样,在 FSDPParam 的设计下,我们也需要考虑内存空间管理的问题:
- FSDP2 基于 DTensor 实现的 Per-Parameter Sharding 是如何实现的,在 forward 中 all-gather 的参数重建是如何实现的?
- 在 Backward 的时候,做完 Unshard 之后,FSDP2 的 Backward 是如何和原有的 AutoGrad 适配起来,从而计算出 Grads,这些 Grads 的空间又是如何管理的,怎么做 Grads 同步呢?
对于这几个问题,我们继续深入到代码,进一步理解 FSDP2。
FSDPParam 实现
和 FSDP1 一样,为了更好的理解计算和通信流的同步,需要先了解下有哪些 stream,对比之前的 FSDP1,这里有一个特殊的 all_gather_copy_in_stream
|
|
首先我们通过前面的 LLaMA3-8B 为例,看 FSDPParam 的定义,其中有几个关键概念:
- Original Parameter:就是传递给 FSDP
full_shard之前的参数,比如 LLaMA3-8B 中一个 TransformerBlock 的参数wq/wk/wv/wo/w1/w2/w3/attn_norm/ffn_norm这些
|
|
- Sharded Parameter:将原来的参数转换为 DTensor,并且在 dim-0 做 shard (
Shard(0) Placement),对应于FSDPParam.sharded_param。
|
|
这里的将 original parameter 转换为 sharded parameter 的逻辑,是现在 FSDPParam 的 _init_sharded_param() 9 中,经过初始化,每个 FSDPParam 中的 _sharded_param 和 _sharded_local_tensor 即对应着 per-parameter sharding 的 DTensor。
|
|
- All-gather input:
- 对应于传递给 all-gather 的 input 的
torch.Tensor
- 对应于传递给 all-gather 的 input 的
- All-gather output:
- 对应于 all-gather output 产生的
torch.Tensor
- 对应于 all-gather output 产生的
实际的操作中,每次执行 all_gather 通信前,会调用 all_gather_copy_in 10 来为 all-gather output 申请一段连续的 tensor 空间,对应于下面代码里的 all_gather_input_numel * world_size:
all_gather_input_numel是一个FSDPPramGroup里面的所有sharded_param参数对应的numel总和- 输出的
all_gather_input是all_gather_output的一个 view 视图,并且是将已有的本地 tensor copy-in 到对应的 offset
|
|
- Unsharded parameter:
- 对应于
FSDPParam._unsharded_param,是将 all-gather output 经过转换得到的
- 对应于
总结一下:
|
|
再看一眼 FSDPParam 的定义:
|
|
FSDP2 集合通信设计
前面提到,FSDP 通信效率优化中很重要的一点是 batched collective communication。
对于 FSDP1,是通过 FlatParameter 的设计,将原来的多个 parameter 整合起来,摊平到一维,由 FlatParameter 管理空间和通信,然后在 FWD/BWD 计算的时候通过 view 视图来对各自参数进行计算。
对于 FSDP2 的 per-parameter sharding,可以有两种选择:
copy-in/copy-outNCCL group coalescing11
目前 (torch 2.6.0 版本) 发现 copy-in/copy-out 相比于 NCCL group coalescing 的通信效率更高一些。而且尽管有 copy,它的实现和当前 FSDP1 的集合通信效率相同,因此目前的实现是前者。
以 all-gather 为例,下图中左图即为 FlatParameter,模型拥有 3 个参数,FlatParameter 将其摊平到一维,每个 rank 拥有 FlatParameter 的 shard。在 FWD/BWD 通过 all-gather 恢复全部 FlatParameter 参数,通过 view 来 unflatten 重新计算。
FlatParameter 的设计,如上图中右图所示,每个 rank 仍然是拥有 3 个参数,只是这 3 个参数已经是 DTensor,并且在 FSDP 的维度做了 sharding。在 FWD/BWD 计算的时候,同样经过 all-gather,因为知道每个 parameter 的 shape 和 size,因此可以通过 split with sizes copy 直接重建参数,稍后会介绍其实现。
对于 backward 之后的 reduce-scatter 平均梯度也是类似。FSDP1 中计算出来的 grad 在 FlatParameter 的一维空间中,经过 Flat-parameter sharding 每个 rank 平均获得其中的一部分。对于 FSDP2,计算完梯度之后,每个 rank 的每个 Parameter 拥有自己的 grad,通过 chunk-cat 后,展平到 2 rank sharding 的空间,然后通过 reduce-scatter 平均梯度后,每个 rank 的每个参数的 grad 都是平均和 sharding 后的 grad。
这里理解一下为什么需要 copy-in/copy-out。
All-Gather
回顾下 all-gather 的前后操作:
|
|
all-gather 的具体流程为:
- 为
all-gather output申请一段连续的 tensor 空间 - 将
all-gather input设置成all-gather output的 view 视图 - 将每个
sharded_param的空间按照合适的 offset 来 copy-in 到每个 rank 的all-gather input- 这里 all-gather input 就会有一个 interleaved layout,如上图所示,
(p1 shard1, p2 shard1, p3 shard1, p1 shard2, p2 shard2, p3 shard2)
- 这里 all-gather input 就会有一个 interleaved layout,如上图所示,
- 执行
all-gather通信 - 将 tensor 从
all-gather output中 copy-out 到unsharded_param的空间,并且 make them contiguous- 这里通过 viewing 每个参数就会变成
(p1 shard1, p1 shard2), …,(p3 shard1, p3 shard2),然后就可以通过torch.cat聚集起来
- 这里通过 viewing 每个参数就会变成
这里的 copy-in 和 all-gather 即对应于前面的代码 for_all_gather,而 copy-out 即对应于 foreach_all_gather_copy_out 调用 split_with_sizes_copy 将结果从前面申请的连续空间 all-gather output copy-out 到每个 FSDPParam 的 all_gather_outputs。
|
|
实际执行完 copy-out 之后,即会初始化 _unsharded_param 并且将当前 Module 的参数设置为 unsharded 之后的参数
|
|
对应的 init_unsharded_param,即是前面的 unsharded-parameter = post-all-gather-transform(all-gather-outputs) 的变换:
|
|
ReduceScatter
reduce-scatter 的流程为:
- 为
reduce-scatter input申请一段连续的 tensor 空间 - 将每个
unsharded_param.grad按照合适的 offset 来 copy-in 到reduce-scatter input- 类似的,这里 reduce-scatter input 就会有一个 interleaved layout,如上图所示,
(g1 shard1, g2 shard1, g3 shard1, g1 shard2, g2 shard2, g3 shard2) - 这一步通过
torch._foreach_copy_来实现
- 类似的,这里 reduce-scatter input 就会有一个 interleaved layout,如上图所示,
- 为
reduce-scatter output申请一段单独的 tensor - 执行
reduce-scatter - View each new sharded parameter gradient from the
reduce-scatter outputand accumulate withsharded_param.gradas needed
FSDP2 避免 recordStream
和 FSDP1 一样,FSDP2 也申请了多个 stream 分别用于计算/通信/copy 等操作。当一个 tensor 在一个 stream 中 allocate 并且在另一个 stream 中被使用,Pytorch 需要在多个 stream 之间进行同步,FSDP1 中用的是 recordStream 12。
recordStream 将多 stream 间同步的工作交给 CUDA Caching Allocator,只有当 tensor 不会被使用的时候,tensor 对应的空间才会被释放和重用(但是这一般距离在 CPU free 操作很晚)。
FSDP & CUDACachingAllocator13 这篇文章详细介绍了 FSDP 设计中如何通过 record stream 和 rate limiter 来实现多个 stream 的同步。
对于上图中,因为 record_stream 的引入,layer i 的空间什么时候释放取决于 CCA 实际的内存管理,因此 allgather i+2 的操作和 layer i 内存空间的释放之间的顺序是 non-deterministic 的。假设 allgather i+2 的时候 layer i 的空间还没有被释放,这就会导致 active memory usage 的显著增加,如下图所示。等 layer i 空间释放之后,memory usage 的 spike 就会下降。
实际上,为了解决这个问题(以及过度申请 memory),FSDP1 中引入了 rate limiter 来限制 all-gather 的多次 prefetch。如下图所示,实际上只有当 layer i 结束计算,在 _reshard 之后,才可以发起 allgather i+2 的。因此这个时候可以保证 allgather i+2 能够复用 layer i 释放的 memory。
通过 rate limiter,FSDP1 保证了 forward 过程中不会出现之前提到的 memory spikes,如下图所示。然而,backward 的显存 spikes 依然存在。
一个可能的原因是,rate limiter 主要是面向 FSDP 集合通信设计的,在反向计算中,autograd engine 可能会在 layer i 没被释放前执行一些 malloc 操作,进而导致 memory spikes。
为了解决 FSDP1 中因为 recordStream 带来的内存管理的 non-determinism,FSDP2 选择使用了 explicit synchronization 来手动管理多个 stream 之间的依赖与同步。
record_stream introduces nondeterminism, and the nondeterminism puts the onus on the CPU to deal with it. That’s where the CPU sync comes in–the CPU is verifying the del situation before allocating new memory…
these two concepts are decoupled, and the CPU could just not “care”! By replacing the record_stream calls with stream-stream syncs, the CPU can trust that the streams will wait on each other properly, so the CPU can be carefree and schedule whatever it wants without needing to sync. In a sense, the responsibility has shifted from the CPU to the streams.
In the specific case for FSDP, removing the need for a CPU sync would require addressing the nondeterminism introduced by record_stream, and the most straightforward way to do that is to remove/replace record_stream calls.
- https://github.com/pytorch/pytorch/issues/147168
- https://github.com/pytorch/torchtitan/pull/737/files
|
|
ref
DTensor
- SimpleFSDP?
- https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/
- AutoFSDP?
-
RFC: Per-Parameter-Sharding FSDP, https://github.com/pytorch/pytorch/issues/114299 ↩︎
-
https://pytorch.org/docs/stable/distributed.fsdp.fully_shard.html ↩︎
-
PyTorch DTensor (Prototype Release), https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/README.md ↩︎
-
Enabling Float8 All-Gather in FSDP2, 2024-08-08, https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359 ↩︎ ↩︎
-
https://github.com/pytorch/pytorch/pull/114733#issuecomment-1832374773 ↩︎
-
Rethinking PyTorch Fully Sharded Data Parallel (FSDP) from First Principles, 2023-01-20, https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019/9 ↩︎
-
ZeRO++: Extremely Efficient Collective Communication for Giant Model Training, https://arxiv.org/pdf/2306.10209 ↩︎
-
https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L255 ↩︎
-
https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L153 ↩︎
-
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/groups.html ↩︎
-
https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html ↩︎
-
https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486 ↩︎
-
No backlinks found.