fsdp 详细介绍了 FSDP1 的主要设计和优化,随着模型越来越大,社区发现 FSDP1 的 FlatParameter 限制了进一步优化训练性能。为了解决这个问题,社区提出了下一代的 FSDP: Per-Parameter-Sharding FSDP,也就是 FSDP21

相比于 FSDP1,FSDP2 主要有以下不同2

  1. 相比于 FSDP1 的 FlatParameter 的 sharding 设计,FSDP2 基于 DTensor3 实现了 Per-Parameter Sharding,从而实现了更加简单的 sharding representation,并且同时实现了和 FSDP1 类似的性能。
  2. 为了解决 FSDP 中的 multi-stream memory usage 的问题,FSDP2 不再使用 FSDP1 中的 torch.Tensor.record_stream,而是基于 CUDA Event 实现了更加精细的 GPU 显存管理,从而实现了 deterministic 的 memory usage,并且不再需要通过 FSDP1 中的 limit_all_gathers=True 来 block CPU 操作。
  3. FSDP2 简化了 API 接口,比如不再直接支持 full state dicts,而是可以通过 DTensorDTensor.full_tensor() API 来 reshard 分片的 state dict 从而获得 full state dicts

FSDP2
FSDP2

FSDP1 中 FlatParameter 的限制

上述最主要的一个变化是,现在的 FSDP2 的 sharding 是基于每个 parameter 的,并且深度和目前 torch 原生系统结合,从而解决了之前 FlatParameter 设计下的限制:

  1. Flexible fp8 all-gather4

原来的 FSDP1 可以支持 fp8 compute,但是不能灵活的支持 fp8 all-gather 5。这是因为 FlatParameter 里面是将多个参数聚集到一起,并且摊平到一维做分片,这里是要求同一个 FlatParameter 里面的参数数据类型都一致。因此 FSDP1 不支持 fp8 和 non-fp8 的参数在一起做 all-gather。

而 FSDP2 因为是基于每个参数做 sharding,是可以灵活支持 fp8 all-gather 的4,后面会详细介绍。

  1. Flexible frozen parameters

Rethinking Pytorch Fully Sharded Data Parallel6 这篇文章中提到,

  1. Communication-free sharded state dict
  1. Future communication optimization in the compiler

FSDP2 API 演进与设计变化

FSDP2 为了简化构建逻辑,不再与 FSDP1 保持 API 兼容,新的 API 如下所示:

1
2
3
4
5
6
7
8
9
@contract(state_cls=FSDPState)
def fully_shard(
  module: nn.Module,
  *,
  mesh: Optional[DeviceMesh] = None,
  reshard_after_forward: Union[bool, int] = True,
  mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
  offload_policy: OffloadPolicy = OffloadPolicy(),
) -> nn.Module:  # return value only used by `contract` for checks
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
@dataclass
class MixedPrecisionPolicy:
  param_dtype: Optional[torch.dtype] = None
  reduce_dtype: Optional[torch.dtype] = None
  output_dtype: Optional[torch.dtype] = None
  cast_forward_inputs: bool = True

@dataclass
class OffloadPolicy:
  offload_type: Optional[str] = None  # "cpu"

具体使用如下:

1
2
3
4
for module in model.modules();
  if isinstance(module, TransformerBlock):
    fully_shard(module, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)

对比 FSDP1 和 FSDP2 的 API 变化:

FSDP1 FSDP2
module module
process_group/device_mesh mesh
sharding_strategy reshard_after_forward
cpu_offload offload_policy
auto_wrap_policy removed
backward_prefetch removed
mixed_precision mp_policy
ignored_modules/ignored_states not yet implemented
param_init_fn removed
device_id removed
sync_module_states removed
forward_prefetch not yet implemented
limit_all_gathers removed
use_orig_params removed

相比于 FSDP1 原来的 FlatParameterFlatParamHandle 的抽象,FSDP2 主要的设计变化如下:

  • 放弃 FlatParameter 的设计,而是每个将每个 sharded parameter 表示成在 dim-0 维度 sharded 的 DTensor,也就是实现了 Per-Parameter-Sharding
  • FSDP2 不再通过 recordStream 来实现不同 stream 之间的 tensor 同步,而是通过 event 来显式的手动完成同步,从而实现了更低的和 deterministic 的 GPU memory 占用,并且也不再需要 CPU 同步 (即 FSDP1 中的 all-gather rate limiter)

FSDP 2
FSDP 2

FSDPParam 与 FSDPParamGroup 抽象

对应于第一点变化,FSDP2 重新设计了以下组件:

  • FSDPModule:经过 full_shard(module) 之后每个 module 会实现动态替换,比如假设原来的 module 是 Transformer,那么 FSDP2 会新创建一个 FSDPTransformer 的类。这个类继承自 FSDPModuleTransformer,并且设置 module.__class__FSDPTransformer
  • FSDPState:经过 full_shard(module) 之后,每个 module 会添加一个 FSDPState 的对象,并且可以通过 full_shard.state(module) 来访问。它用于维护与 FSDP 相关的状态,并且是通过 @contract 来 embed 到 module 的
  • FSDPParamGroup:一个 ParamGroup 里面的所有参数用同一个通信组
  • FSDPParam:对应于通信组里面的一个参数,包含了以 DTensor 形式存在的 sharded_param

对比下 FSDP1:

  • fully_shard + FSDPState 加起来类似于原来的 FullyShardedDataParallel
  • FSDPParamGroup 类似于原来的 FlatParamHandle
  • FSDPParam 类似于 FlatParamter,但是现在只有一个参数

以 TorchTitan 7 训练的 LLaMA3 8B 为例,在 full_shard 操作之前,模型如下所示:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
Transformer(
  (tok_embeddings): Embedding(128256, 4096)
  (layers): ModuleDict(
    (0): TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=4096, out_features=4096, bias=False)
        (wk): Linear(in_features=4096, out_features=1024, bias=False)
        (wv): Linear(in_features=4096, out_features=1024, bias=False)
        (wo): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=4096, out_features=14336, bias=False)
        (w2): Linear(in_features=14336, out_features=4096, bias=False)
        (w3): Linear(in_features=4096, out_features=14336, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
    # ...
    (31): TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=4096, out_features=4096, bias=False)
        (wk): Linear(in_features=4096, out_features=1024, bias=False)
        (wv): Linear(in_features=4096, out_features=1024, bias=False)
        (wo): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=4096, out_features=14336, bias=False)
        (w2): Linear(in_features=14336, out_features=4096, bias=False)
        (w3): Linear(in_features=4096, out_features=14336, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=4096, out_features=128256, bias=False)
)

经过 FSDP2 的 full_shard 之后,32 个 TransformerBlock 被设置为 FSDPTransformerBlock,root module 从 Transformer 变为 FSDPTransformer

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
FSDPTransformer(
  (tok_embeddings): Embedding(128256, 4096)
  (layers): ModuleDict(
    (0): FSDPTransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=4096, out_features=4096, bias=False)
        (wk): Linear(in_features=4096, out_features=1024, bias=False)
        (wv): Linear(in_features=4096, out_features=1024, bias=False)
        (wo): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=4096, out_features=14336, bias=False)
        (w2): Linear(in_features=14336, out_features=4096, bias=False)
        (w3): Linear(in_features=4096, out_features=14336, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
    # ...
    (31): FSDPTransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=4096, out_features=4096, bias=False)
        (wk): Linear(in_features=4096, out_features=1024, bias=False)
        (wv): Linear(in_features=4096, out_features=1024, bias=False)
        (wo): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=4096, out_features=14336, bias=False)
        (w2): Linear(in_features=14336, out_features=4096, bias=False)
        (w3): Linear(in_features=4096, out_features=14336, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=4096, out_features=128256, bias=False)
)

实际前向和反向过程中对应的 FSDPParamGroup 中的 FSDPParam:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
fsdp params: [FSDPParam(fqn=tok_embeddings.weight, orig_size=torch.Size([128256, 4096])), FSDPParam(fqn=norm.weight, orig_size=torch.Size([4096])), FSDPParam(fqn=output.weight, orig_size=torch.Size([128256, 4096]))]
fsdp params: [FSDPParam(fqn=layers.0.attention.wq.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.0.attention.wk.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.0.attention.wv.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.0.attention.wo.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.0.feed_forward.w1.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.0.feed_forward.w2.weight, orig_size=torch.Size([4096, 14336])), FSDPParam(fqn=layers.0.feed_forward.w3.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.0.attention_norm.weight, orig_size=torch.Size([4096])), FSDPParam(fqn=layers.0.ffn_norm.weight, orig_size=torch.Size([4096]))]
fsdp params: [FSDPParam(fqn=layers.1.attention.wq.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.1.attention.wk.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.1.attention.wv.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.1.attention.wo.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.1.feed_forward.w1.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.1.feed_forward.w2.weight, orig_size=torch.Size([4096, 14336])), FSDPParam(fqn=layers.1.feed_forward.w3.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.1.attention_norm.weight, orig_size=torch.Size([4096])), FSDPParam(fqn=layers.1.ffn_norm.weight, orig_size=torch.Size([4096]))]
# ...
fsdp params: [FSDPParam(fqn=layers.30.attention.wq.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.30.attention.wk.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.30.attention.wv.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.30.attention.wo.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.30.feed_forward.w1.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.30.feed_forward.w2.weight, orig_size=torch.Size([4096, 14336])), FSDPParam(fqn=layers.30.feed_forward.w3.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.30.attention_norm.weight, orig_size=torch.Size([4096])), FSDPParam(fqn=layers.30.ffn_norm.weight, orig_size=torch.Size([4096]))]
fsdp params: [FSDPParam(fqn=layers.31.attention.wq.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.31.attention.wk.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.31.attention.wv.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.31.attention.wo.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.31.feed_forward.w1.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.31.feed_forward.w2.weight, orig_size=torch.Size([4096, 14336])), FSDPParam(fqn=layers.31.feed_forward.w3.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.31.attention_norm.weight, orig_size=torch.Size([4096])), FSDPParam(fqn=layers.31.ffn_norm.weight, orig_size=torch.Size([4096]))]
fsdp params: [FSDPParam(fqn=layers.30.attention.wq.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.30.attention.wk.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.30.attention.wv.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.30.attention.wo.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.30.feed_forward.w1.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.30.feed_forward.w2.weight, orig_size=torch.Size([4096, 14336])), FSDPParam(fqn=layers.30.feed_forward.w3.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.30.attention_norm.weight, orig_size=torch.Size([4096])), FSDPParam(fqn=layers.30.ffn_norm.weight, orig_size=torch.Size([4096]))]
fsdp params: [FSDPParam(fqn=layers.29.attention.wq.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.29.attention.wk.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.29.attention.wv.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.29.attention.wo.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.29.feed_forward.w1.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.29.feed_forward.w2.weight, orig_size=torch.Size([4096, 14336])), FSDPParam(fqn=layers.29.feed_forward.w3.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.29.attention_norm.weight, orig_size=torch.Size([4096])), FSDPParam(fqn=layers.29.ffn_norm.weight, orig_size=torch.Size([4096]))]
# ...
fsdp params: [FSDPParam(fqn=layers.1.attention.wq.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.1.attention.wk.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.1.attention.wv.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.1.attention.wo.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.1.feed_forward.w1.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.1.feed_forward.w2.weight, orig_size=torch.Size([4096, 14336])), FSDPParam(fqn=layers.1.feed_forward.w3.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.1.attention_norm.weight, orig_size=torch.Size([4096])), FSDPParam(fqn=layers.1.ffn_norm.weight, orig_size=torch.Size([4096]))]
fsdp params: [FSDPParam(fqn=layers.0.attention.wq.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.0.attention.wk.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.0.attention.wv.weight, orig_size=torch.Size([1024, 4096])), FSDPParam(fqn=layers.0.attention.wo.weight, orig_size=torch.Size([4096, 4096])), FSDPParam(fqn=layers.0.feed_forward.w1.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.0.feed_forward.w2.weight, orig_size=torch.Size([4096, 14336])), FSDPParam(fqn=layers.0.feed_forward.w3.weight, orig_size=torch.Size([14336, 4096])), FSDPParam(fqn=layers.0.attention_norm.weight, orig_size=torch.Size([4096])), FSDPParam(fqn=layers.0.ffn_norm.weight, orig_size=torch.Size([4096]))]
fsdp params: [FSDPParam(fqn=tok_embeddings.weight, orig_size=torch.Size([128256, 4096])), FSDPParam(fqn=norm.weight, orig_size=torch.Size([4096])), FSDPParam(fqn=output.weight, orig_size=torch.Size([128256, 4096]))]

Sharding 策略与 ZeRO 映射

在 FSDP1 中,通过 sharding_strategy 来实现不同的数据并行策略,包括 ZeRO-3/ZeRO-2 和 HSDP 等,在 FSDP2,通过 reshard_after_forwardmesh 来实现对不同数据并行策略的控制:

FSDP 1 FSDP 2 DeepSpeed
1 process_group + FULL_SHARD 1D mesh + reshard_after_forward=True ZeRO-3
1 process_group + SHARD_GRAD_OP 1D mesh + reshard_after_forward=False ZeRO-2
2 process_group s/2D device_mesh + HYBRID_SHARD 2D mesh + reshard_after_forward=True MiCS
2 process_group s/2D device_mesh + _HYBRID_SHARD_ZERO2 2D mesh + reshard_after_forward=False -
- 1D/2D mesh + reshard_after_forward=8 (int) ZeRO++ hpZ8

更原生兼容 torch API

Existing FSDP supports three kinds of state dicts: full, sharded, and local. We define a “clean” fully-qualified name (FQN) to be one without any nn.Module wrapper prefixes (e.g. _fsdp_wrapped_module. for FSDP or module. for DDP).

  • A full state dict maps clean FQN to unsharded torch.Tensor. Since this may use too much GPU memory, it offers rank0_only: bool and offload_to_cpu: bool options. This requires all-gathering parameters.
  • A sharded state dict maps clean FQN to sharded DTensor. For existing FSDP, this requires all-gathering parameters and resharding to per-parameter sharding.
  • A local state dict maps clean FQN to the sharded FlatParameter as a ShardedTensor. This does not require communication but cannot be loaded into a different distributed training setup.

For per-parameter-sharding FSDP, we only support sharded state dict.

  • Getting a full state dict can be implemented as a post-processing step to saving the sharded state dict.
  • Sharded and local state dicts are equivalent since the training and checkpointing representation match
FSDP1 FSDP2
model.state_dict(): full state dict model.state_dict(): sharded state dict (no communication)
optim.state_dict(): local state dict optim.state_dict(): sharded state dict (no communication)
summon_full_params() use DTensor APIs like full_tensor()
FSDP.clip_grad_norm_() nn.utils.clip_grad_norm_()
ShardedGradScaler amp.grad_scaler.GradScaler

Meta-Device Initialization

在 FSDP1 的时候,meta device 上初始化方法如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with torch.device("meta"):
    model = Transformer()
policy = ModuleWrapPolicy({TransformerBlock})

# Call `reset_parameters()` on every module
model = FSDP(model, auto_wrap_policy=policy)

# Call `param_init_fn` on every module
def param_init_fn(module: nn.Module) -> None: ...
model = FSDP(model, auto_wrap_policy=policy, param_init_fn=param_init_fn)

FSDP1 需要

FSDP1 requires either reset_parameters or param_init_fn to materialize a module onto GPU immediately before sharding. To do this correctly without re-initializing any tensors requires care and can be unwieldy. However, FSDP2 allows materializing tensors onto GPU after sharding (taking advantage of DTensor and a new swap_tensors path for nn.Module._apply methods).

FSDP2 supports a new meta-device initialization flow that does not require materializing a module on GPU before sharding it, removing the need for param_init_fn

对于 FSDP2

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
with torch.device("meta"):
    model = Transformer()
for module in model.modules():
    if isinstance(module, TransformerBlock):
        fully_shard(module)
fully_shard(model)

for tensor in itertools.chain(model.parameters(), model.buffers()):
    assert tensor.device == torch.device("meta")

# Allocate buffers and sharded parameters on GPU
model.to_empty(device="cuda")
# Run user-defined initializers
model.init_weights() # or `model.apply(init_weights)`

FSDP2 计算与通信流程

和 FSDP1 一样,FSDP2 在模型计算时需要实现:

  1. FWD/BWD 之前,需要通过 all-gather 来 unshard 参数,并且在计算之后 reshardc释放掉
  2. 在 BWD 计算出 gradient 之后,需要通过 reduce-scatter 来平均梯度

具体这一目标的是通过 Module 对应的 hook 以及具体 FSDPParamGroupFSDPState 操作来实现:

  1. Root pre-forward:通过 _fsdp_state.py 中的 register_forward_pre_hook 实现
    1. 完成 lazy initialization, stream synchronization, move inputs to GPU 等操作
  2. 对每一个 state/param group:
    1. State pre-forward (register_forward_pre_hook()):cast forward inputs if needed
    2. Parameter group pre-forward: 执行 unshard 参数
    3. Parameter group post-forward: 执行 reshard 参数
    4. State post-forward (register_forward_hook()): cast forward outputs if need
  3. Root post-forward
  4. Root pre-backward (register_hook): queue back
  5. 对每一个 state/param group:
    1. Pre-backward (register_hook):unshard 参数,prefetch parameters if needed
    2. Post-backward (autograd.Function.backward):reshard 参数,reduce-scatter 梯度
  6. Root post-backward final callback:_execution_engine.queue_callback()
    1. 完成 finalize parameter group backwards,stream/event 同步,数据结构重置等操作

在简单看了上面的流程之后,和之前 FSDP1 中讨论的问题一样,在 FSDPParam 的设计下,我们也需要考虑内存空间管理的问题:

  1. FSDP2 基于 DTensor 实现的 Per-Parameter Sharding 是如何实现的,在 forward 中 all-gather 的参数重建是如何实现的?
  2. 在 Backward 的时候,做完 Unshard 之后,FSDP2 的 Backward 是如何和原有的 AutoGrad 适配起来,从而计算出 Grads,这些 Grads 的空间又是如何管理的,怎么做 Grads 同步呢?

对于这几个问题,我们继续深入到代码,进一步理解 FSDP2。

FSDPParam 实现

和 FSDP1 一样,为了更好的理解计算和通信流的同步,需要先了解下有哪些 stream,对比之前的 FSDP1,这里有一个特殊的 all_gather_copy_in_stream

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class FSDPCommContext:

    def lazy_init(self, device: torch.device):  
        self.device_handle = _get_device_handle(device.type)  
        # Setting the all-gather/reduce-scatter streams to be higher priority  
        # can help avoid some issues where their copies in/out are delayed and        
        # block computation (this is different from high-pri NCCL streams)        
        high_priority = -1  
        # All-gather state and copy-in stream allow overlapping the next 
        # copy-in with the current all-gather in forward; copy-in overlaps with        
        # reduce-scatter in backward without the separate copy-in stream 
        self.all_gather_copy_in_stream = self.device_handle.Stream(  
            priority=high_priority  
        )  
        # All-gather stream allows overlapping next all-gather with current forward compute        
        self.all_gather_stream = self.device_handle.Stream(priority=high_priority)  
        # Reduce-scatter stream gives separate execution "thread" for post-  
        # backward logic like pre/post-gradient division and reduce-scatter        
        self.reduce_scatter_stream = self.device_handle.Stream(priority=high_priority)  
        # Run the HSDP all-reduces concurrently with all-gather/reduce-scatter  
        # since collectives use different network resources and can overlap        
        # in the typical intra-node sharding / inter-node replication case        
        self.all_reduce_stream = self.device_handle.Stream()  
        # All-gather/reduce-scatter states keep references to collective  
        # tensors produced in one stream and used in another and accompanying        
        # CUDA events for synchronization        
        self.all_gather_state: Optional[AllGatherState] = None  
        self.reduce_scatter_state: Optional[ReduceScatterState] = None  
        # Post-forward order for explicit backward prefetching  
        self.post_forward_order: List[FSDPParamGroup] = []  # will cause ref cycles

首先我们通过前面的 LLaMA3-8B 为例,看 FSDPParam 的定义,其中有几个关键概念:

  • Original Parameter:就是传递给 FSDP full_shard 之前的参数,比如 LLaMA3-8B 中一个 TransformerBlock 的参数 wq/wk/wv/wo/w1/w2/w3/attn_norm/ffn_norm 这些
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
    (0): FSDPTransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=4096, out_features=4096, bias=False)
        (wk): Linear(in_features=4096, out_features=1024, bias=False)
        (wv): Linear(in_features=4096, out_features=1024, bias=False)
        (wo): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=4096, out_features=14336, bias=False)
        (w2): Linear(in_features=14336, out_features=4096, bias=False)
        (w3): Linear(in_features=4096, out_features=14336, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  • Sharded Parameter:将原来的参数转换为 DTensor,并且在 dim-0 做 shard (Shard(0) Placement),对应于 FSDPParam.sharded_param
1
2
3
4
5
6
7
8
param: DTensor(local_tensor=tensor([[-0.0063, -0.0038, -0.0414, ..., -0.0324, -0.0264, -0.0044],
[ 0.0356, -0.0099, -0.0039, ..., 0.0284, 0.0378, -0.0176],
[ 0.0352, -0.0240, 0.0145, ..., 0.0285, 0.0062, 0.0391],
...,
[ 0.0022, -0.0216, -0.0166, ..., -0.0003, -0.0007, -0.0065],
[-0.0556, -0.0252, 0.0283, ..., 0.0152, -0.0019, -0.0150],
[-0.0151, 0.0231, -0.0011, ..., 0.0109, -0.0019, -0.0229]],
device='cuda:0'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3, 4, 5, 6, 7], mesh_dim_names=('dp_shard_cp',)), placements=(Shard(dim=0),))

这里的将 original parameter 转换为 sharded parameter 的逻辑,是现在 FSDPParam_init_sharded_param() 9 中,经过初始化,每个 FSDPParam 中的 _sharded_param_sharded_local_tensor 即对应着 per-parameter sharding 的 DTensor。

1
2
def _sharded_local_tensor(self) -> torch.Tensor:  
    return cast(DTensor, self.sharded_param)._local_tensor
  • All-gather input:
    • 对应于传递给 all-gather 的 input 的 torch.Tensor
  • All-gather output:
    • 对应于 all-gather output 产生的 torch.Tensor

实际的操作中,每次执行 all_gather 通信前,会调用 all_gather_copy_in 10 来为 all-gather output 申请一段连续的 tensor 空间,对应于下面代码里的 all_gather_input_numel * world_size

  • all_gather_input_numel 是一个 FSDPPramGroup 里面的所有 sharded_param 参数对应的 numel 总和
  • 输出的 all_gather_inputall_gather_output 的一个 view 视图,并且是将已有的本地 tensor copy-in 到对应的 offset
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
@torch.library.impl(lib, "all_gather_copy_in", "CUDA")  
@torch.library.impl(lib, "all_gather_copy_in", "CPU")  
def all_gather_copy_in_cuda(  
    all_gather_inputs: List[torch.Tensor],  
    inp_split_sizes: List[int],  
    all_gather_input_numel: int,  
    world_size: int,  
    rank: int,  
    dtype: torch.dtype,  
    device: torch.device,  
) -> Tuple[torch.Tensor, torch.Tensor]:  
    all_gather_output = torch.empty(  
        (all_gather_input_numel * world_size,), dtype=dtype, device=device  
    )  
    all_gather_input = all_gather_output.narrow(  
        0, all_gather_input_numel * rank, all_gather_input_numel  
    )  
    foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)  
    with torch.no_grad():  
        torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)  
    return all_gather_input, all_gather_output
  • Unsharded parameter:
    • 对应于 FSDPParam._unsharded_param,是将 all-gather output 经过转换得到的

总结一下:

1
2
all-gather-inputs = pre-all-gather-transform(sharded-parameter)  
unsharded-parameter = post-all-gather-transform(all-gather-outputs)

再看一眼 FSDPParam 的定义:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class FSDPParam:

    orig_dtype: torch.dtype  
    param_dtype: Optional[torch.dtype]  
    reduce_dtype: Optional[torch.dtype]  
    _orig_size: torch.Size  # ND  
    sharded_size: torch.Size  # ND  
    contiguous_sharded_stride: Tuple[int, ...]  
    padded_sharded_param_size: torch.Size  # ND  
    sharded_post_forward_size: torch.Size  # ND  
    contiguous_sharded_post_forward_stride: Tuple[int, ...]  
    _sharded_param_data: torch.Tensor  # 1D  
    sharded_param: nn.Parameter  # ND  
    _sharded_post_forward_param_data: Optional[torch.Tensor]  # 1D  
    _sharded_post_forward_param: Optional[nn.Parameter]  # ND  
    _unsharded_param: nn.Parameter  # ND  
    unsharded_accumulated_grad: Optional[torch.Tensor]  # ND  
    _sharding_spec: DTensorSpec  
    # DTensor attributes (only defined for DTensor `param`):  
    _tp_spec: DTensorSpec  
    all_gather_outputs: List[torch.Tensor]  # 1D  
    # All-gather extension attributes    _extensions_data: ExtensionsData  
    _unsharded_inner_tensors: List[torch.Tensor]

FSDP2 集合通信设计

前面提到,FSDP 通信效率优化中很重要的一点是 batched collective communication

对于 FSDP1,是通过 FlatParameter 的设计,将原来的多个 parameter 整合起来,摊平到一维,由 FlatParameter 管理空间和通信,然后在 FWD/BWD 计算的时候通过 view 视图来对各自参数进行计算。

对于 FSDP2 的 per-parameter sharding,可以有两种选择:

  1. copy-in/copy-out
  2. NCCL group coalescing 11

目前 (torch 2.6.0 版本) 发现 copy-in/copy-out 相比于 NCCL group coalescing 的通信效率更高一些。而且尽管有 copy,它的实现和当前 FSDP1 的集合通信效率相同,因此目前的实现是前者。

all-gather 为例,下图中左图即为 FlatParameter,模型拥有 3 个参数,FlatParameter 将其摊平到一维,每个 rank 拥有 FlatParameter 的 shard。在 FWD/BWD 通过 all-gather 恢复全部 FlatParameter 参数,通过 viewunflatten 重新计算。

FSDP2 则抛弃了 FlatParameter 的设计,如上图中右图所示,每个 rank 仍然是拥有 3 个参数,只是这 3 个参数已经是 DTensor,并且在 FSDP 的维度做了 sharding。在 FWD/BWD 计算的时候,同样经过 all-gather,因为知道每个 parameter 的 shape 和 size,因此可以通过 split with sizes copy 直接重建参数,稍后会介绍其实现。

对于 backward 之后的 reduce-scatter 平均梯度也是类似。FSDP1 中计算出来的 grad 在 FlatParameter 的一维空间中,经过 Flat-parameter sharding 每个 rank 平均获得其中的一部分。对于 FSDP2,计算完梯度之后,每个 rank 的每个 Parameter 拥有自己的 grad,通过 chunk-cat 后,展平到 2 rank sharding 的空间,然后通过 reduce-scatter 平均梯度后,每个 rank 的每个参数的 grad 都是平均和 sharding 后的 grad。

这里理解一下为什么需要 copy-in/copy-out

All-Gather

回顾下 all-gather 的前后操作:

1
2
all-gather-inputs = pre-all-gather-transform(sharded-parameter)  
unsharded-parameter = post-all-gather-transform(all-gather-outputs)

all-gather 的具体流程为:

  1. all-gather output 申请一段连续的 tensor 空间
  2. all-gather input 设置成 all-gather output 的 view 视图
  3. 将每个 sharded_param 的空间按照合适的 offset 来 copy-in 到每个 rank 的 all-gather input
    1. 这里 all-gather input 就会有一个 interleaved layout,如上图所示,(p1 shard1, p2 shard1, p3 shard1, p1 shard2, p2 shard2, p3 shard2)
  4. 执行 all-gather 通信
  5. 将 tensor 从 all-gather output 中 copy-out 到 unsharded_param 的空间,并且 make them contiguous
    1. 这里通过 viewing 每个参数就会变成 (p1 shard1, p1 shard2), …, (p3 shard1, p3 shard2),然后就可以通过 torch.cat 聚集起来

这里的 copy-in 和 all-gather 即对应于前面的代码 for_all_gather,而 copy-out 即对应于 foreach_all_gather_copy_out 调用 split_with_sizes_copy 将结果从前面申请的连续空间 all-gather output copy-out 到每个 FSDPParam 的 all_gather_outputs

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
def init_all_gather_outputs(  
    self,  
    all_gather_input_numels: List[int],  
    all_gather_input_dtypes: List[torch.dtype],  
    world_size: int,  
    device: torch.device,  
    force_recreate: bool = False,  
):  
    if not force_recreate and len(self.all_gather_outputs) > 0:  
        return  # already initialized  
    self.all_gather_outputs = [  
        torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)  
        for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes)  
    ]

实际执行完 copy-out 之后,即会初始化 _unsharded_param 并且将当前 Module 的参数设置为 unsharded 之后的参数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def wait_for_unshard(self):     
    # ...
    with record_function(self._with_fqn("FSDP::all_gather_copy_out")):  
        foreach_all_gather_copy_out(  
            self._all_gather_result,  
            self.fsdp_params,  
            self._all_gather_process_group,  
        )  
    for fsdp_param in self.fsdp_params:  
        fsdp_param.init_unsharded_param()  
    self._to_unsharded()  
    all_gather_copy_out_event = self.device_handle.Event()
    # ...

对应的 init_unsharded_param,即是前面的 unsharded-parameter = post-all-gather-transform(all-gather-outputs) 的变换:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
def init_unsharded_param(self):
    inner_tensor = self._sharded_local_tensor
    unsharded_tensor = self.all_gather_outputs[0]
    unsharded_param = torch.as_strided(  
        unsharded_tensor,  
        self._orig_size,  
        self._contiguous_orig_stride,  
        storage_offset=0,  
    )
    # ...
    self._unsharded_param = nn.Parameter(  
        unsharded_param, requires_grad=self.sharded_param.requires_grad  
    )

def to_unsharded(self) -> None:  
    # Assume that the data has been allocated and all-gathered  
    set_requires_grad_if_needed(self.sharded_param, self._unsharded_param)  
    self._setattr_on_modules(self._unsharded_param)  
    # ...
    self.sharded_state = ShardedState.UNSHARDED

ReduceScatter

reduce-scatter 的流程为:

  1. reduce-scatter input 申请一段连续的 tensor 空间
  2. 将每个 unsharded_param.grad 按照合适的 offset 来 copy-in 到 reduce-scatter input
    1. 类似的,这里 reduce-scatter input 就会有一个 interleaved layout,如上图所示,(g1 shard1, g2 shard1, g3 shard1, g1 shard2, g2 shard2, g3 shard2)
    2. 这一步通过 torch._foreach_copy_ 来实现
  3. reduce-scatter output 申请一段单独的 tensor
  4. 执行 reduce-scatter
  5. View each new sharded parameter gradient from the reduce-scatter output and accumulate with sharded_param.grad as needed

FSDP2 避免 recordStream

和 FSDP1 一样,FSDP2 也申请了多个 stream 分别用于计算/通信/copy 等操作。当一个 tensor 在一个 stream 中 allocate 并且在另一个 stream 中被使用,Pytorch 需要在多个 stream 之间进行同步,FSDP1 中用的是 recordStream 12

recordStream 将多 stream 间同步的工作交给 CUDA Caching Allocator,只有当 tensor 不会被使用的时候,tensor 对应的空间才会被释放和重用(但是这一般距离在 CPU free 操作很晚)。

FSDP & CUDACachingAllocator13 这篇文章详细介绍了 FSDP 设计中如何通过 record streamrate limiter 来实现多个 stream 的同步。

没有 rate limiter 的情况下,record_stream 引入 nondeterminism
没有 rate limiter 的情况下,record_stream 引入 nondeterminism

对于上图中,因为 record_stream 的引入,layer i 的空间什么时候释放取决于 CCA 实际的内存管理,因此 allgather i+2 的操作和 layer i 内存空间的释放之间的顺序是 non-deterministic 的。假设 allgather i+2 的时候 layer i 的空间还没有被释放,这就会导致 active memory usage 的显著增加,如下图所示。等 layer i 空间释放之后,memory usage 的 spike 就会下降。

实际上,为了解决这个问题(以及过度申请 memory),FSDP1 中引入了 rate limiter 来限制 all-gather 的多次 prefetch。如下图所示,实际上只有当 layer i 结束计算,在 _reshard 之后,才可以发起 allgather i+2 的。因此这个时候可以保证 allgather i+2 能够复用 layer i 释放的 memory。

通过 rate limiter,FSDP1 保证了 forward 过程中不会出现之前提到的 memory spikes,如下图所示。然而,backward 的显存 spikes 依然存在。

一个可能的原因是,rate limiter 主要是面向 FSDP 集合通信设计的,在反向计算中,autograd engine 可能会在 layer i 没被释放前执行一些 malloc 操作,进而导致 memory spikes。

为了解决 FSDP1 中因为 recordStream 带来的内存管理的 non-determinism,FSDP2 选择使用了 explicit synchronization 来手动管理多个 stream 之间的依赖与同步。

record_stream introduces nondeterminism, and the nondeterminism puts the onus on the CPU to deal with it. That’s where the CPU sync comes in–the CPU is verifying the del situation before allocating new memory…

these two concepts are decoupled, and the CPU could just not “care”! By replacing the record_stream calls with stream-stream syncs, the CPU can trust that the streams will wait on each other properly, so the CPU can be carefree and schedule whatever it wants without needing to sync. In a sense, the responsibility has shifted from the CPU to the streams.

In the specific case for FSDP, removing the need for a CPU sync would require addressing the nondeterminism introduced by record_stream, and the most straightforward way to do that is to remove/replace record_stream calls.

delay the del i to be after layer i+1’s compute has been scheduled
delay the del i to be after layer i+1’s compute has been scheduled

1
2
3
4
5
default_stream            # default compute stream
all_gather_copy_in_stream # overlap copy-in with forward compute and reduce-scatter
all_gather_stream         # overlap all-gather with forward/backward computation
reduce_scatter_stream     # overlap with backward computation; convenient to queue additional work like division
all_reduce_stream         # (for HSDP) overlap with all-gather/reduce-scatter/backward computation

implicit prefetch
implicit prefetch

explicit prefetch
explicit prefetch

ref

DTensor

https://ml.byteintl.net/development/profiling/preview/1fac3c49513c2293a4305d7373e052ac8a3ba243029350e94cb2b5c52db6cc51