TorchTitan 1 中将 FSDP2 (而不是 FSDP123)作为其默认的 1D 数据并行策略45。对应的,去年 7 月的时候,在 Torch 2.4 Prototype 中也正式引入了 FSDP26。尽管在大模型训练中 Megatron 和 DeepSpeed 等框架得到了广泛的应用,作为 PyTorch 官方原生的并行方案,FSDP 的演进值得更多的关注。本文尝试去梳理 FSDP1 到 FSDP2 的方案演进,尝试去理解其进一步的发展方向。

The original Fully Sharded Data Parallel (FSDP)7 is an effective implementation of ZeRO8 that offers large model training capability in PyTorch. However, the original implementation (FSDP1) in PyTorch suffers from various limitations due to its FlatParameter implementation.

Given these limitations, TorchTitan integrates a new version of Fully Sharded Data Parallel (FSDP2), which uses the per-parameter Distributed Tensor sharding representation and thus provides better composability with model parallelism techniques and other features that require the manipulation of individual parameters…

FSDP2 + Tensor Parallel
FSDP2 + Tensor Parallel

受限于篇幅,本文将主要分析 FSDP1,文中所有代码基于 PyTorch 2.6.0,主要参考了 FSDP 论文7 和相关博客91011 与 GitHub 讨论12,关于 FSDP2 后续的演进将在后续文章进一步分析。

FSDP1 的基本使用与原理

当计算撞上通信墙:RDMA 再回首,EtherNET or EtherNOT 这篇博客中,已经基本解释了 FSDP1 的基本原理,其基本对应着 ZeRO-3 的工程实现,如下所示。

FSDP workflow
FSDP workflow

具体的:

  1. 通过 FSDP wrapper,让每个 rank 初始化模型后只保存 部分模型的参数
  2. Forward 前,通过 all-gather 获取完整的参数,完成 forward 计算后 free 掉非本 rank 的参数
  3. Backward 前,通过 all-gather 获取完整的参数,完成 backward 计算后 free 掉非本 rank 的参数,算出梯度 grad,然后通过 reduce-scatter 平均
  4. Optimizer step 中,每个 rank 根据自己的 weights 和 grad 更新本地权重
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
FSDP forward pass:
    for fsdp_instance_i in fsdp_instances:
        all-gather full weights for fsdp_instance_i
        forward pass for fsdp_instance_i
        discard full weights for fsdp_instance_i

FSDP backward pass:
    for fsdp_instance_i in fsdp_instances:
        all-gather full weights for fsdp_instance_i
        backward pass for fsdp_instance_i
        discard full weights for fsdp_instance_i
        reduce-scatter gradients for fsdp_instance_i

通过 ZeRO-3 的实现可以大幅减少 GPU 显存的占用:

下图展示了 FSDP 的主要用法13,FSDP 作为 nn.Module 的 wrapper,实现自动高效地对模型参数做 shard/unshard,并提供了包括 Mixed Precision、Wrap Policy、Activation Checkpointing 等策略进一步优化训练效率。

典型代码如下所示14

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# command to run the script:
# torchrun  --standalone --nnodes=1 --nproc-per-node=4 fsdp_demo.py
import torch
import functools
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

import torch.distributed as dist
dist.init_process_group("nccl")
local_rank = dist.get_rank()

from transformers.models.gpt2.modeling_gpt2 import GPT2Block
gpt2_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            GPT2Block,
        },
    )

# init model ...
from transformers import GPT2LMHeadModel
model = GPT2LMHeadModel.from_pretrained("gpt2")
if local_rank == 0:
    print(model)

bfSixteen = MixedPrecision(
    param_dtype=torch.bfloat16,
    # Gradient communication precision.
    reduce_dtype=torch.bfloat16,
    # Buffer precision.
    buffer_dtype=torch.bfloat16,
)

# wrap model
model = FSDP(model,
        auto_wrap_policy=gpt2_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        sharding_strategy=ShardingStrategy.FULL_SHARD)


if local_rank == 0:
    print(model)

在本地通过 torchrun 命令执行这段代码,可以看到,被 FSDP wrapper 封装之前,模型为经典的 GPT2 模型:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

总共 12 层 Transformer 的 GPT 模型
总共 12 层 Transformer 的 GPT 模型

经过 FSDP wrapper 封装之后,每一个 GPT2BlockFullyShardedDataParallel 封装成 _fsdp_wrapped_module,最外层还有一个 root fsdp module

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
FullyShardedDataParallel(
  (_fsdp_wrapped_module): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x FullyShardedDataParallel(
          (_fsdp_wrapped_module): GPT2Block(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): GPT2SdpaAttention(
              (c_attn): Conv1D()
              (c_proj): Conv1D()
              (attn_dropout): Dropout(p=0.1, inplace=False)
              (resid_dropout): Dropout(p=0.1, inplace=False)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): GPT2MLP(
              (c_fc): Conv1D()
              (c_proj): Conv1D()
              (act): NewGELUActivation()
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  )
)

这里的每一个 FSDP wrapped module 就是之前提到的 FSDP Instance。

FSDP1 Wrap 与 FlatParameter 抽象

如前所述,面对一个 GPU 上都放不下的大模型,FSDP 最主要的设计是将模型拆成多个 Shard,然后通过 all-gather 在每一个 Shard 真正要计算的时候再将需要的模型参数等聚集起来 Unshard。

具体的,这里的 Shard 到底是如何操作的呢?考虑到:

  • 每个模型可能有多个 layer
  • 每个 layer 都有自己对应的 (Model Parameter, Gradients, Optimizer State)

如下图所示15,显示了一个有 9 个 layer 的模型,被拆分成了 3 个 FSDP Unit,同时也被 shard 给了 3 个 worker (对应于 subset1/2/3)。

请注意,这里的拆分实际上是有 2 个维度:

  • 第一个纬度是将这个 9 层的模型拆成了 3 个 Unit。这里在实际的模型中分为几个 FSDP Unit 是由具体的 Wrap Policy 确定的,比如 transformer_auto_wrap_policy 或者 size_based_auto_wrap_policy ,本质上是为了通过合适的拆分结构实现更高效的通信效率,后续会进一步阐释。
  • 第二个维度是对于每个 Unit,会在多个 worker 之间平分参数权重。比如这里的 Unit 1 包含 3 层,每个 worker 获取了其中的部分权重。这里不一定是每个 worker 都必须是拿一层,实际可以做到更加平均,具体看后面的 FlatParameter。

对于 Unit 1 会在 FWD 前通过 all-gather 让每个 worker 都有 Unit 1 的完整参数,这里是一个 3 层的 layer
对于 Unit 1 会在 FWD 前通过 all-gather 让每个 worker 都有 Unit 1 的完整参数,这里是一个 3 层的 layer

然后对于每个 worker 各自执行自己的 FWD pass,生成对应的 activation
然后对于每个 worker 各自执行自己的 FWD pass,生成对应的 activation

FWD 算完之后,对于 Unit 1 每个 worker 清理掉不需要的权重
FWD 算完之后,对于 Unit 1 每个 worker 清理掉不需要的权重

对于其他的 FSDP Unit 的 FWD 和后续的 BWD,也都是类似,细节图示可以参考15,此处不再赘述。重要的,我们进一步理解对于一个模型,是如何 Wrap 起来,并 Shard 为合适的 FSDP Unit

以这个图为例,模型被分为 3 个 FSDP Unit,每个 worker 都有 3 个 FSDP Unit,但是每个 worker 只有每个 FSDP Unit 的一部分权重
以这个图为例,模型被分为 3 个 FSDP Unit,每个 worker 都有 3 个 FSDP Unit,但是每个 worker 只有每个 FSDP Unit 的一部分权重

NCCL 集合通信效率约束

对于 FSDP 的每个 worker,每次 FWD/BWD 前都需要通过通过 all-gatherunshard 所有参数,通过 reduce-scatter 来平均梯度,通信在其中占据了很重要的部分。对于集合通信,有两个因素十分影响通信的效率:

  1. Even Input Size
  2. Larger Input Size

Even Input Size

对于 NCCL 集合通信,如果不同 rank 的 input tensor 的 size 都是一致的,效率会高得多,如下图左图所示,在 PyTorch 中对应着 all_gather_into_tensor 的 API16,也就是图中的 All-Gather Base。相比之下,all_gather 的 API17 则不要求 input tensor 的 size 保持一致,在输入 size 不一致的时候则性能差很多。

Larger Input Size

对于 NCCL 集合通信,随着每次发起 All-Gather 时通信量 Message Size 的增加,实际传输带宽会显著提高,如下图右图所示。因此,如果能够将 FSDP 实际的通信做 batch,从而减少发起集合通信的次数,提高每次通信的带宽,则可以大幅提高通信效率。

基于这个思考,提出两个设计约束:

  • Constraint 1: FSDP should communicate even sizes across ranks to use all_gather_into_tensor / reduce_scatter_tensor .
  • Constraint 2: FSDP should batch parameters for all-gather and gradients for reduce-scatter.

这两个约束则对应着 FSDP 中 FlatParameter 的抽象。因为:

  1. 神经网络的每一层参数不可能一致保持被卡数整除,额外的 Padding 必不可少
  2. 神经网络有很多层和各种不同的 shape 的参数,很难逐个去做集合通信,合适的通信 Size 才能取得最佳的 NCCL 通信效率

这里进一步明确下 batch parameters,以 GPT 2 为例,以下为 GPT 2 模型所有的参数,FSDP 期望在设计的时候,不需要每一个 parameter 都要进行一次 all-gather/reduce-scatter 的 NCCL 通信,而是把多个 parameter 的空间放在一起,统一起来进行通信,这即是引入了 FlatParameter

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

Name: transformer.wte.weight, Shape: torch.Size([50257, 768])
Name: transformer.wpe.weight, Shape: torch.Size([1024, 768])
Name: transformer.h.0.ln_1.weight, Shape: torch.Size([768])
Name: transformer.h.0.ln_1.bias, Shape: torch.Size([768])
Name: transformer.h.0.attn.c_attn.weight, Shape: torch.Size([768, 2304])
Name: transformer.h.0.attn.c_attn.bias, Shape: torch.Size([2304])
Name: transformer.h.0.attn.c_proj.weight, Shape: torch.Size([768, 768])
Name: transformer.h.0.attn.c_proj.bias, Shape: torch.Size([768])
Name: transformer.h.0.ln_2.weight, Shape: torch.Size([768])
Name: transformer.h.0.ln_2.bias, Shape: torch.Size([768])
Name: transformer.h.0.mlp.c_fc.weight, Shape: torch.Size([768, 3072])
Name: transformer.h.0.mlp.c_fc.bias, Shape: torch.Size([3072])
Name: transformer.h.0.mlp.c_proj.weight, Shape: torch.Size([3072, 768])
Name: transformer.h.0.mlp.c_proj.bias, Shape: torch.Size([768])

# omit middle layers...

Name: transformer.h.11.ln_1.weight, Shape: torch.Size([768])
Name: transformer.h.11.ln_1.bias, Shape: torch.Size([768])
Name: transformer.h.11.attn.c_attn.weight, Shape: torch.Size([768, 2304])
Name: transformer.h.11.attn.c_attn.bias, Shape: torch.Size([2304])
Name: transformer.h.11.attn.c_proj.weight, Shape: torch.Size([768, 768])
Name: transformer.h.11.attn.c_proj.bias, Shape: torch.Size([768])
Name: transformer.h.11.ln_2.weight, Shape: torch.Size([768])
Name: transformer.h.11.ln_2.bias, Shape: torch.Size([768])
Name: transformer.h.11.mlp.c_fc.weight, Shape: torch.Size([768, 3072])
Name: transformer.h.11.mlp.c_fc.bias, Shape: torch.Size([3072])
Name: transformer.h.11.mlp.c_proj.weight, Shape: torch.Size([3072, 768])
Name: transformer.h.11.mlp.c_proj.bias, Shape: torch.Size([768])
Name: transformer.ln_f.weight, Shape: torch.Size([768])
Name: transformer.ln_f.bias, Shape: torch.Size([768])

FlatParameter 抽象

FlatParameter 通过将同一个 FSDP Unit 的所有 parameters 统一到一个 1D 的 tensor,它包含了 nflatten original parameters ,并且还通过 right-padding 保证了对齐。

同样以 GPT 2 为例,之前的每个 GPT2Block 被封装成了一个 FSDP Unit,也就是说这个时候的一个 FlatParameter 包含一个 GPT2Block 的所有参数,如下所示。这个时候可以计算 GPT 2 这个用例下 FlatParameter 的参数量为 7,087,872。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
Name: transformer.h.0.ln_1.weight, Shape: torch.Size([768])
Name: transformer.h.0.ln_1.bias, Shape: torch.Size([768])
Name: transformer.h.0.attn.c_attn.weight, Shape: torch.Size([768, 2304])
Name: transformer.h.0.attn.c_attn.bias, Shape: torch.Size([2304])
Name: transformer.h.0.attn.c_proj.weight, Shape: torch.Size([768, 768])
Name: transformer.h.0.attn.c_proj.bias, Shape: torch.Size([768])
Name: transformer.h.0.ln_2.weight, Shape: torch.Size([768])
Name: transformer.h.0.ln_2.bias, Shape: torch.Size([768])
Name: transformer.h.0.mlp.c_fc.weight, Shape: torch.Size([768, 3072])
Name: transformer.h.0.mlp.c_fc.bias, Shape: torch.Size([3072])
Name: transformer.h.0.mlp.c_proj.weight, Shape: torch.Size([3072, 768])
Name: transformer.h.0.mlp.c_proj.bias, Shape: torch.Size([768])

FlatParameter 是 FSDP 通信的最小单位,每个的 FSDP Unit 对应着一个 FlatParameter,并且它保留着 original parameters 的底层 storage。

下图对应着一个 FSDP Unit 将 4 x 3 的 nn.Linear 层在 16 个 GPU 间 shard,其中共有 12 个 weight element和 3 个 bias element,并补充了一个 padding,每个 worker 上的 local shard 占据的空间为 1 个 element。

FSDP Full Sharding
FSDP Full Sharding

高效构建 FlatParameter 与 AutoWrapPolicy

如前所述,FSDP1 只需要将原来的 model 通过 FullyShardedDataParallel 类 wrap 下,根据传入的 auto_wrap_policy 就可以自动将原有的模型拆分成多个 FSDP Unit,也就是对应的多个 FlatParameter

具体实现如下所示,这里的 _auto_wrap 会根据传入的 root module 后续遍历对每个 submodule 去拆分,碰到 auto_wrap_policy 返回 True 则意味着当前 submodule 满足成为一个 FSDP Unit 的条件,即会返回对应封装好的 FSDP Module。每个 FSDP Module 都会继续往下走,去 _init_param_handle_from_module 初始化自己对应的 FlatParamHandle 和将对应的 params flatten 成 1D 的 tensor。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class FullyShardedDataParallel(nn.Module, _FSDPState):
    def __init__(  
        self,  
        module: nn.Module,
        # ...
    ):
        # ...
        # init process group for collective communication
        _init_process_group_state(  
            self,  
            process_group,  
            sharding_strategy,  
            auto_wrap_policy,  
            device_mesh,  
        )
        # wrap 初始化
        if auto_wrap_policy is not None:
            # ...
            # 对所有的 submodule 递归的做 wrap
            _auto_wrap(
                module,
                auto_wrap_policy,
                self._ignored_modules,
                self._ignored_params,
                root_kwargs,
                FullyShardedDataParallel,
            )
        # ...
        # 初始化 FlatParamHandle, 对 parameters 做 flatten
        _init_param_handle_from_module(...)

这里对前面说的 post-order traversal 进一步阐释下,下面是一个典型的简单的模型,它继承自 nn.Module,对应的数据结构里面封装了很多其他的 nn.Module,是一个典型的 nested 的结构,对应于树这种数据结构。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight


    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb) # input for dropout of shape (b, t, n_embd)
        for block in self.transformer.h:
            x = block(x)
            x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
	# ...

对上面的论述再总结下:

  1. FullyShardedDataParallel 是一个 module wrapper,它通过模型初始化时能访问到的静态模型结构,来指引如何构建 FlatParameter
  2. 训练模型的作者需要根据自己模型的情况,将原来模型的参数 params 根据 desired locality 分组到同一个 module 或者 module subtrees

Rule 1: If the user wraps fsdp_module = FullyShardedDataParallel(module) *, then every parameter in module not already flattened is flattened into a single new FlatParameter and assigned to fsdp_module .

下图是一个简单示例,module 0 为 root module,通过 FullyShardedDataParallel 封装。经过递归 wrap 之后,黄色 module 为 non-directly-wrapped module,红色 module 为 FSDP wrapped module,虚线框内表示框里面的 module 会一起组成一个 FlatParameter,分别被 assigned 给 module 0/1/3/7。

上面的 _auto_wrap 实质上调用的是 _recursive_wrap,如下所示,因为 _auto_wrap 是从下往上的,因此 wrap 的树结构也是从下往上的。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def _recursive_wrap(
    module: nn.Module, # Module to recursively wrap.
    auto_wrap_policy: Callable,
    wrapper_cls: Callable,
    ignored_modules: Set[nn.Module],
    ignored_params: Set[nn.Parameter],
    only_wrap_children: bool = False,
    **kwargs: Any,
) -> Tuple[nn.Module, int]:
    # ...

    # We count all params, assuming none of them are already wrapped.
    nonwrapped_numel = sum(
        p.numel() for p in module.parameters() if p not in ignored_params
    )

    if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel):
        total_wrapped_numel = 0
        # Iterate through the children, recursively wrap if necessary
        for name, child in module.named_children():
            if child in ignored_modules:
                continue
            wrapped_child, num_wrapped_params = _recursive_wrap(
                module=child,
                auto_wrap_policy=auto_wrap_policy,
                wrapper_cls=wrapper_cls,
                ignored_modules=ignored_modules,
                ignored_params=ignored_params,
                **kwargs,
            )
            # 这里返回的是已经被 wrapped 的 module
            # 经过 setattr 后 child_module 已经变成 FSDP Module
            setattr(module, name, wrapped_child)
            # Keep track of how many parameters have been wrapped
            total_wrapped_numel += num_wrapped_params
        # decide if we need to wrap the current module,
        # since the left over parameters exceed the number of params to wrap
        remainder = nonwrapped_numel - total_wrapped_numel
        if not only_wrap_children and auto_wrap_policy(
            module=module, recurse=False, nonwrapped_numel=remainder
        ):
            # Leaf node or final wrapping of the remainder both happen here.
            # 这里的 wrapper_cls 就是 FSDP class
            # _wrap 函数返回的就是 wrapper_cls(module, **kwargs)
            return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
        else:
            return module, total_wrapped_numel
    return module, 0

以之前的 transformer_auto_wrap_policy 为例,当 wrap policy 发现当前 module 为传入的 Transformer Block,比如 T5Block,则会返回 True。也就是说 wrap policy 指定了最小的封装单元,经过这里 wrap 后会返回一个 FullyShardedDataParallel 封装的 T5Block ,如下所示,对应的 FSDP Unit 包含 7079808 个参数。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
recursive wrap module wrapped_child=FullyShardedDataParallel(
  (_fsdp_wrapped_module): T5Block(
    (layer): ModuleList(
      (0): T5LayerSelfAttention(
        (SelfAttention): T5Attention(
          (q): Linear(in_features=768, out_features=768, bias=False)
          (k): Linear(in_features=768, out_features=768, bias=False)
          (v): Linear(in_features=768, out_features=768, bias=False)
          (o): Linear(in_features=768, out_features=768, bias=False)
          (relative_attention_bias): Embedding(32, 12)
        )
        (layer_norm): T5LayerNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (1): T5LayerFF(
        (DenseReluDense): T5DenseActDense(
          (wi): Linear(in_features=768, out_features=3072, bias=False)
          (wo): Linear(in_features=3072, out_features=768, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
          (act): ReLU()
        )
        (layer_norm): T5LayerNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
), num_wrapped_params=7079808

确定好每一个 FlatParameter 之后,就可以明确在实际计算中需要进行的 unshard 和 reshard:

Rule 2: For a given FlatParameter and forward/backward pass, FSDP only unshards and reshards the FlatParameter once.

对于每个 FlatParameter 对应的 Module:

  • fwd 前会有 pre-forward Hook 来做 unshard,fwd 后会有 post-forward hook 来做 reshard
  • bwd 前有 pre-backward Hook 来做 unshard,bwd 后会有 post-backward hook 来做 reshard
  • 对于整个计算,还会有 root pre-forwardpost-backward final callback Hook

以前向代码为例18

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
    def forward(self, *args: Any, **kwargs: Any) -> Any:
        """Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic."""
        handle = self._handle
        with torch.autograd.profiler.record_function(
            "FullyShardedDataParallel.forward"
        ):
            args, kwargs = _root_pre_forward(self, self, args, kwargs)
            unused = None
            args, kwargs = _pre_forward(
                self,
                handle,
                _pre_forward_unshard,
                self._fsdp_wrapped_module,
                args,
                kwargs,
            )
            if handle:
                _p_assert(
                    handle.flat_param.device == self.compute_device,
                    "Expected `FlatParameter` to be on the compute device "
                    f"{self.compute_device} but got {handle.flat_param.device}",
                )
            # 这里已经完成 unshard,正常执行 module 的 forward
            output = self._fsdp_wrapped_module(*args, **kwargs)
            return _post_forward(
                self, handle, _post_forward_reshard, self, unused, output
            )

FlatParamHandle 抽象

在前面的 Rule 1 中提到每个 FSDP Instance 都关联到一个 FlatParameter,这里的 FSDP Instance 也可以叫做 FSDP Unit 或者 FSDP Module。为了更好的管理 FlatParameterdata management 包括 unshard/reshard 这些,torch 抽象出来了 FlatParamHandle 的数据结构,它和 FlatParameter 一对一,也就是和 FSDP Instance 一对一。

通过 FlatParamHandle 的抽象,可以将数据管理等相关的部分从 FSDP Module 抽离出来,这样 FlatParameter 也和一个普通的 nn.Parameter 差距不大了。

FlatParameter 管理原有 Parameter Storage

前面提到,FlatParameter 会将同一个 FSDP Unit 里面的所有 Parameters 的统一到一个 1D 的 Tensor,并自动 Padding。这自然引出相关的问题:

  1. 在初始化模型的时候,每个 worker 拥有一个 FSDP Unit 的 Shard,这里的空间是如何管理的,与原有的 Parameters 的 Tensor Storage 的关系是什么
  2. 在 Forward 的时候,做完 Unshard 之后,是如何从 1D Tensor 变成原来的 Parameters 的形式,从而进行前向计算
  3. 在 Backward 的时候,做完 Unshard 之后,FSDP 的 Backward 是如何和原有的 AutoGrad 适配起来,从而计算出 Grads,这些 Grads 的空间又是如何管理的,怎么做 Grads 同步呢

对于这几个问题,我们沿着 FSDP 模型的计算流程依次来看。

FSDP1 计算与通信流程

如前所述,对于每一个 FSDP Instance,其执行的流程如下图所示:

接下来本小节将以 IBM 的 FMS FSDP 19 训练 LLaMA3 8B为例查看实际的每一个步骤的详细流程。

下图为前向流程,因为 LLaMA3 8B 包含 32 层,每层都是一个 LLaMABlock,类似于上面的 GPT2 模型,整个模型被拆分成了 33 个 FSDP Unit,也就是一个 Root FSDP Unit 和 32 个 FSDP LLaMABlock Unit。

具体的看前向流程,先 Pre-Forward 执行 Unshard 来同步参数

接下来是实际的 Attention 和 MLP 等计算

后面是 _post_forward 来释放不需要的参数,每个 worker 继续只拥有一部分 shard

接下来是反向计算,因为开启了 Selective Activation Checkpointing 20,并且设置 p = 0.5 也就是每隔一个 block 开启 SAC,需要重新执行前向计算激活。

对于每一个 block 的 backward,需要先执行 Pre-BackwardUnshard,接下来是具体的 Backward 计算,最后 Post-Backward 通过 Reshard 来同步梯度。

Pre-Backward 通过 all-gather 来同步参数:

具体的 backward 计算,这里可以看到 sqrt 和 FlashAttention 等 backward 计算
Post-Backward 中通过 reduce-scatter 来同步梯度

这里通过 trace 看了具体的流程,接下来会通过代码来获得进一步的理解。

模型与通信初始化

FSDP 初始化的时候,会初始化不同的 stream21,其中 default stream 用于前向和反向的计算。此外,还分别有 _unshard_stream_pre_unshard_stream 用于 forward 过程中的逻辑,以及 _post_backward_stream 用于 backward overlap 的逻辑,以及 _all_reduce_stream 用于 HSDP 的逻辑。

1
2
3
4
5
_default_stream       # Default stream for computation
_unshard_stream       # Stream for unshard logic, including allocating the all-gather destination tensors and the all-gathers themselves
_post_backward_stream # Stream for overlapping gradient reduction with the backward pass gradient computation(i.e. reduce-scatter)
_pre_unshard_stream   # Stream for pre-unshard logic, namely allocations and writes for CPU offloading (H2D copy) and mixed precision (low precision cast)
_all_reduce_stream    # Stream to run HSDP's all-reduce as async (if using HSDP)

注意,这里除了 _default_stream 之外的几个 stream 并不是 communication stream,为什么要创建这几个 stream 后面会进一步讨论。真正的 communication stream 是用的 process group 自带的 stream:

  • 对于 FULL_SHARDall_gatherreduce_scatter 用的是同一个 stream
  • 对于 HYBRID_SHARD 会分别创建 intra-node process groupinter-node process group,其中 intra-node 用的是默认的 process group,inter-node 用的是新建的一个 process group

Process Group 初始化的逻辑可以参考这里22

Pre-Forward 与 UnShard

Pre-Forward 最关键的步骤是执行 Unshard 来重建 Parameters,实际执行的是 FlatParamHandleunshard 方法23

1
2
3
4
5
6
7
8
9
def unshard(self):
    # ...

    # 申请 padded_unsharded_flat_param
    unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
    # 执行 all_gather_into_tensor
    padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
    # rebuild 原有的 parameters
    self._use_unsharded_flat_param(padded_unsharded_flat_param)

为了更好的理解 FSDP 在 Unshard/Reshard 过程中对于内存空间的管理,我们需要进一步看下这里的 FlatParameter 和重建后的 parameters 的空间关系。

查看 FlatParameter 的定义24,对于 sharded 的参数,其数据存在 _local_shard 这个 tensor 上,在 unshard 的时候会为 _full_param_padded 这个 tensor 申请空间,从而通过 all-gather 将完整的参数存在这个地方,然后经过 fwd/bwd 计算后将对应的空间释放。

比如,FlatParameter 初始化的时候,会申请 padded_unsharded_numel 大小的空间,然后释放掉对应的内存空间:

1
2
3
4
5
6
7
8
padded_unsharded_numel = flat_param.numel() * self.world_size
flat_param._full_param_padded = torch.empty(
    padded_unsharded_numel,
    device=self.device,
    dtype=unsharded_param_dtype,
)
flat_param._padded_unsharded_size = flat_param._full_param_padded.size()
_free_storage(flat_param._full_param_padded)

在前面的 unshard 函数中的 _alloc_padded_unsharded_flat_param,又会找到这个 tensor,重新为其从 CUDA Caching Allocator 申请 GPU Memory:

1
2
3
4
5
6
def _alloc_padded_unsharded_flat_param(self):
    flat_param = self.flat_param
    unsharded_flat_param = self._get_padded_unsharded_flat_param()
    self._check_storage_freed(unsharded_flat_param)
    _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size)  # type: ignore[attr-defined]
    return unsharded_flat_param

这里的 _free_storage/_alloc_storage 实际上调用的就是 tensor.resize() 申请到对应大小的空间。

对应的 _all_gather_flat_param 主要就是调用 all-gather 去 unshard 参数了

1
2
3
4
5
6
7
8
def _all_gather_flat_param(self, padded_unsharded_flat_param: Tensor,): -> Tensor:
    sharded_flat_param = self.flat_param.data  # 对应于 _local_shard 的空间 
    expected_numel = sharded_flat_param.numel() * self.world_size
    dist.all_gather_into_tensor(  
        padded_unsharded_flat_param,  # dst
        sharded_flat_param,           # src 
        self.process_group,
    )

前面提到,经过 all-gather 之后就会通过 _use_unsharded_flat_param 来重建原来的参数。

1
2
3
4
5
6
7
8
def _use_unsharded_flat_param(  
    self,  
    padded_unsharded_flat_param: torch.Tensor,  
) -> None:
    unsharded_size = self.flat_param._unpadded_unsharded_size  
    flat_param_part = padded_unsharded_flat_param[: unsharded_size.numel()]
    self.flat_param.data = flat_param_part
    self._use_unsharded_views(as_params=False)

实际上因为之前已经通过 FlatParamShardMetadata 记录了这个 FlatParameter 管理的包括 shape、strides 等所有 parameter 的 metadata 信息,因此可以通过 FlatParameterview 变换,实现参数重建。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class FlatParamShardMetadata(NamedTuple):
    """
    This holds metadata specific to this rank's shard of the flat parameter.

    Attributes:
        param_names (Tuple[str, ...]): Prefixed parameter names of this rank's
            shard of the parameters; see :class:`FlatParameter`.
        param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's
            shard of the parameters; see :class:`FlatParameter`.
        param_strides (Tuple[torch.Size, ...]): Parameter strides of this rank's
            shard of the parameters; see :class:`FlatParameter`.
        param_contiguities (Tuple[bool, ...]): Parameter `.contiguous` call results
            of this rank's shard of the parameters; see :class:`FlatParameter`.
        param_numels (Tuple[int, ...]): Parameter numels of this rank's shard
            of the parameters; see :class:`FlatParameter`.
        param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in
            units of numels) giving this rank's part of each flattened
            original parameter.
    """

重建完之后,就可以通过让当前 module _setattr_tensor 来让 AutoGrad 系统感知和兼容,从而在实际的前向计算中用到的空间仍然是 FlatParameter_full_param_padded 空间。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
@torch.enable_grad()
def _use_unsharded_views(self, as_params: bool) -> None:  
    """  
    Unflatten the unsharded flat parameter by setting the original parameter variables to be views into it.
    """
    flat_param = self.flat_param
    views = self._get_unflat_views()
    for i, (view, (param_name, module, _)) in enumerate(  
        zip(views, flat_param._param_infos)  
    ):
        # ...
        param_var: Tensor = view
        self._setattr_tensor(module, param_name, param_var)
    # ...    

仔细展开这里的 _get_unflat_views 实际上就是根据原有参数的 metadata 信息,进行 torch.split()torch.view() 的变换。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
    def _get_unflat_views_unaligned(
        self,
        tensor: Optional[torch.Tensor] = None,
    ) -> Iterator[Tensor]:
        flat_param = self.flat_param
        if tensor is None:
            tensor = flat_param
        views = (
            _ext_post_unflatten_transform(
                subtensor.view(shape)
                if contiguous
                else subtensor.as_strided(shape, stride),
                param_extension,
                self._fsdp_extension,
            )
            for (subtensor, shape, stride, contiguous, param_extension) in zip(
                torch.split(tensor, flat_param._numels, dim=0),
                flat_param._shapes,
                flat_param._strides,
                flat_param._contiguities,
                flat_param._param_extensions,
            )
        )
        return views

FSDP 的一个最大优势在于它能够和 torch 原生的 AutoGrad 系统紧密的结合在一起。

FlatParameter_full_param_padded 这个 tensor 对应保留着实际 FWD/BWD 过程中的参数和梯度的内存空间,虽然 FlatParameter 在 FWD/BWD 经过了上述的参数重建流程,但是 torch.split()torch.view() 都是 AutoGrad 可见的,因此在反向梯度计算的时候,自动微分引擎会将对应参数的梯度自动写到 flat_param.grad 对应的 offset 上。也就是说,经过 backward 之后,flat_param.grad 就会自动保存好计算好的梯度了。

Post-Forward 与 ReShard

Forward 执行完后,在 _post_forward_reshard 中实际调用了 FlatParamHandlereshard

1
2
3
4
def reshard(self, free_unsharded_flat_param: bool):
    self._use_sharded_flat_param()
    if free_unsharded_flat_param:
        self._free_unsharded_flat_param()

这里 _use_sharded_flat_param 本质上是将 FlatParameter 的数据重新指向 _local_shard Tensor 对应的空间,而不再使用 _full_param_padded 对应的空间,同时释放 _full_param_padded 对应的内存。

1
    flat_param.data = flat_param._local_shard

Pre-Backward 与 UnShard

Pre-BackwardPre-Forward 类似,也是先对 Param 做 Unshard,此处不再赘述。

Post-Backward 与 ReShard

执行完 backward 之后,继续执行 _post_backward_hook,实际会对 param 进行 reshard,如上所述,释放 flat parameter 的空间,然后通过 _reduce_grad 来平均梯度。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def _reduce_grad(state: _FSDPState, handle: FlatParamHandle) -> None:
    unsharded_grad = flat_param.grad.data  
    flat_param.grad = None  
    padded_unsharded_grad, new_sharded_grad = _get_reduce_scatter_tensors(  
        state, unsharded_grad  
    )
    # ...
    dist.reduce_scatter_tensor(  
        new_sharded_grad,  
        padded_unsharded_grad,  
        group=pg,  
    )
    if uses_hybrid_sharded_strategy:
        dist.all_reduce(new_sharded_grad, group=state._inter_node_pg)
        # ...
        return

其中 reduce-scatter 实际同步的数据如下所示,每个 rank 根据自己的数据算出来的完整梯度,经过 reduce-scatter 后每个 rank 保存平均后的梯度的 shard。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
def _get_reduce_scatter_tensors(  
    state: _FSDPState, unsharded_grad: torch.Tensor  
) -> Tuple[torch.Tensor, torch.Tensor]:  
    chunks = list(unsharded_grad.chunk(state.world_size))  
    numel_to_pad = state.world_size * chunks[0].numel() - unsharded_grad.numel()  
    padded_unsharded_grad = (  
        F.pad(unsharded_grad, [0, numel_to_pad]) if numel_to_pad > 0 else unsharded_grad  
    )  
    new_sharded_grad = torch.empty_like(chunks[0])  # padded  
    return padded_unsharded_grad, new_sharded_grad

FSDP1 与 Hybrid Shard

注意到,这里如果用了 HYBRID_SHARD,还会在 _inter_node_pg 这个 process group 执行 all_reduce 来在多个 replicate group 里同步梯度。具体解释下这里的几个概念:

  • Sharding Group: 在这个 process group 中,每个 rank 都只拥有 shard 的参数和梯度,它的 size 对应着有多少个 rank 来对 FlatParameter 做拆分
  • Replication Group: 在这个 process group 中,每个 rank 拥有的参数是相同的,传入的数据不同,从而算出不同的梯度,需要通过 all_reduce 来平均梯度
    FSDP Hybrid Sharding
    FSDP Hybrid Sharding
    对应于上图,sharding group size 为 8,对应于一个有 8 个 GPU 的节点,8 个 rank 做 sharding,通过高速的 NVLink 来通过 reduce-scatter 同步梯度,也就是 fsdp 代码里面的 _intra_node_group

因为 world size 为 16,因此还有另外的 8 个 replication group,每个 replication group 的 size 为 2,这也就是代码里面的 _inter_node_group。在 backward 的 _intra_node_group 梯度平均之后,再通过 all_reduce_inter_node_group 之间平均就好。

对于 torch 2.2 以下的版本,可以通过以下代码更新 process group,然后在 FSDP 初始化的时候传入对应的 process_group 参数。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def update_process_group(self):
	# create fsdp shard group and replicate group,
    # if world size = 8, rank_id = [0,1,2,3,4,5,6,7],
    # for hybrid_shard_group_size=4,
    # the shard group would be[[0,1,2,3],[4,5,6,7]]
    # and the replicate group would be [[0,4],[1,5],[2,6],[3,7]]
	shard_group = None
	replicate_group = None
	# create shard_group
	hybrid_shard_group_size = self._world_size // self._num_shard_group
	for i in range(self._num_shard_group):
	    rank_ids = [
    		i * hybrid_shard_group_size + j
    		for j in range(hybrid_shard_group_size)
	    ]
	    tmp_group = dist.new_group(
    		rank_ids,
    		backend="nccl",
	    )
	    if self._rank in rank_ids:
    		shard_group = tmp_group
	# create replicate group
	for i in range(hybrid_shard_group_size):
	    rank_ids = [
    		i + j * hybrid_shard_group_size
    		for j in range(self._num_shard_group)
	    ]
	    tmp_group = dist.new_group(
    		rank_ids,
    		backend="nccl",
	    )
	    if self._rank in rank_ids:
    		replicate_group = tmp_group
	self._process_group = tuple([shard_group, replicate_group])

对于 torch 2.2 之后,Device Mesh 的引入大大简化了这个过程,可以直接传入一个 2D 的 Device Mesh,就可以自动设置当前 rank 所在的 replication group 还是 sharding group,详细使用可以参考25,对应通信的初始化在这里26

1
2
3
4
5
6
7
8
# 第一个维度是 replica group 的 size,第二个 group 是 sharding group 的 size
# 即对应着上图的拆分方法
device_mesh = init_device_mesh("cuda", (2, 8), mesh_dim_names=("replicate", "shard"))

model = FSDP(model,
    #...
    device_mesh=device_mesh,
)

另外需要注意的是,虽然这里命名 process group 的时候按照 intra node 还是 inter node,实际上不仅仅限于 sharding group 必须在一个节点内,可以根据需要来设置,比如 1024 卡训练的时候,可以设置 device mesh 为 (8, 128),也就是在 128 卡内做 shard,然后对应的每 8 个 rank 做 replicate。这样可以在规模较大的时候比较好的降低通信的瓶颈。

FSDP1 计算与通信优化

前面我们详细介绍了 FSDP 的计算与通信流程,接下来介绍 FSDP 中的计算与通信优化。

Overlapping Communication and Computation

计算与通信 overlap 的一个简单想法就是,在当前 FSDP Unit 在做 FWD 计算的时候,提前发起下一个 FSDP Unit 的 AllGather。这样当上一个 FSDP Unit 完成 FWD 计算的时候,下一个 FSDP Unit 的参数已经准备好,可以直接开始下一个 FSDP Unit 的 FWD 计算,而不再需要等待通信的完成,保证 GPU 能一直占满。

FSDP Overlap Communication and Computation
FSDP Overlap Communication and Computation

在进一步聊这里的 overlap 之前,先回顾一下之前提到的 FSDP 多个 stream 的问题。前面提到,FSDP 在初始化的时候会初始化下面这些 stream,其中 _default_stream 用于 FWD/BWD 的计算,正如前面的 trace 所示。那么为什么还要分别有 _unshard_stream_post_backward_stream 呢?

1
2
3
4
5
_default_stream       # Default stream for computation
_unshard_stream       # Stream for unshard logic, including allocating the all-gather destination tensors and the all-gathers themselves
_post_backward_stream # Stream for overlapping gradient reduction with the backward pass gradient computation(i.e. reduce-scatter)
_pre_unshard_stream   # Stream for pre-unshard logic, namely allocations and writes for CPU offloading (H2D copy) and mixed precision (low precision cast)
_all_reduce_stream    # Stream to run HSDP's all-reduce as async (if using HSDP)

如 PyTorch FSDP 论文所说:

For general correctness, ProcessGroupNCCL synchronizes the internal stream with the current stream before running the collective.

ProcessGroupNCCL 发起 all_gather 通信之前,会先和当前 (default) stream 同步。也就是说只有 default stream 里面的计算完成之后,才会真正发起 all_gather,这样就根本无法 overlap 了。

具体的实现对应到 torch 代码就是这里的 syncStream 27,而每个 collective 在真正开始之前会执行这里的 syncStream 实现同步28

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
// [Sync Streams] Helper that lets the input ncclStreams to wait for the current  
// stream. NCCL communications run on ncclStreams, but input tensors are 
// allocated on different streams (i.e., current streams). Communications on  
// ncclStreams cannot start before pending input tensor ops on current streams  
// finish. Otherwise, ops on two streams might read/write same tensors  
// concurrently.  
//  
// The synchronization above alone is not enough. We also need to make sure  
// input tensors are not freed before their usages on ncclStreams finish. This  
// can be achieved by calling c10::cuda::CUDACachingAllocator::recordStream,  
// which remembers the usage stream (ncclStream), creates an event on the usage  
// stream when GC attempts to free the input tensor, and delays GC until that  
// event is done.  
void syncStream(  
    at::Device& device,  
    at::cuda::CUDAEvent& ncclEvent,  
    at::cuda::CUDAStream& ncclStream) {  
  ncclEvent.record(at::cuda::getCurrentCUDAStream(device.index()));  
  ncclEvent.block(ncclStream);  
}

但是实际上这里下一个 FSDP Unit 的 all_gather 和当前 FSDP Unit 并没有依赖关系,只是因为目前 ProcessGroupNCCL 的实现让其必须先等当前 stream 计算完成,也就是论文里面说的 false dependency

为了解决这个问题,FSDP 将发起 all_gather 的 stream 独立出来,而不是在默认的 _default_stream 中发起,这就对应着 _unshard_stream_post_backward_stream 也是同理,不再赘述。

Forward Prefetching

对于 forward 和 all_gather 的 overlap,主要包括两种:

  1. Implicit forward prefetch: 默认开启
  2. Explicit forward prefetch:需要在 FSDP 初始化的时候设置 forward_prefetch=True

对于第一种默认开启的 prefetch,实际上也就是前面提到的用一个单独的 CUDA stream _unshard_stream 来发起 all_gather 从而实现 overlap,对应的时间线如下,这里的 layer0 forward 就可以和 layer1 all-gather 实现 overlap。

1
layer0 all-gather -> layer0 forward -> layer1 all-gather -> layer1 forward -> ...

FSDP Overlap Communication and Computation
FSDP Overlap Communication and Computation

对于第二种 forward prefetch,FSDP 文档里面指明了这种只适用于 CPU-bound 并且是 static-graph 的模型。

if forward_prefetch=True, then FSDP explicitly prefetches the next forward-pass all-gather before the current forward computation. This is only useful for CPU-bound workloads, in which case issuing the next all-gather earlier may improve overlap. This should only be used for static-graph models since the prefetching follows the first iteration’s execution order

具体的 prefetch 方式就是在当前 FSDP Unit all-gather 之后,在当前 FSDP Unit forward 之前,如下所示:

1
layer0 all-gather -> layer1 all-gather -> layer0 forward compute -> ...

对应于代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
def _pre_forward_unshard(state: _FSDPState,  handle: Optional[FlatParamHandle],) -> None:  
    if not handle._prefetched:  
        _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)  
    handle._needs_pre_forward_unshard = False  
    # Don't wait during trace  

    # ...
    with torch.profiler.record_function(  
        "FullyShardedDataParallel._pre_forward_prefetch"  
    ):  
        _prefetch_handle(state, handle, _PrefetchMode.FORWARD)

这里注明是需要限制为 static-graph 的模型,是因为在 eagar 模式下,没有办法知道需要 prefetch 哪一个 FlatParameter,因此需要模型的执行顺序是固定的,也就是静态图的模型。

这个优化只适用于 slow cpu 的场景,并且 prefetch 也会带来更多的 GPU 显存开销,对于 LLM 场景下,大多数不需要打开这个参数。

Backward Prefetching

对于 backward,继续看下面这张图。因为 FSDP 对 all-gatherreduce-scatter 只用了一个 NCCL process group,这意味着 all-gatherreduce-scatter 只能串行执行。因此,当 current FSDP Unit 反向执行完成,需要通过 reduce-scatter 来同步梯度,然后 all-gather 来重建下一个 FSDP Unit 的参数,从而继续下一个 FSDP Unit 的反向计算。

因此,尽管计算和通信能够 overlap,但是因为反向中当前 FSDP Unit 的 reduce-scatter 和下一个 FSDP Unit 的 all-gather 只能串行,会导致下一个 FSDP Unit 必须等待,从而产生 bubble,如下图的 BWD1 执行所示。

FSDP Overlap Communication and Computation
FSDP Overlap Communication and Computation

为了解决这个问题,FSDP 默认设置可以执行 backward prefetch29。这里的 backward prefetch 指的是提前 unshard 下一个 FSDP Unit 的参数。根据 unshard 的时机不同,分为两种:

  • BACKWARD_PRE:在 pre_backward_hook 执行,也就是在当前 FSDP Unit 的 backward unshard 之后,在实际计算 backward 之前执行。也就是 overlap next all-gather 和 current grad 计算,这种为 FSDP 默认值。
  • BACKWARD_POST:在 post_backward_hook 执行,也就是在当前 FSDP Unit backward 之后,在梯度 reduce-scatter 之前执行。这里的 all-gatherreduce-scatter 仍然是串行的,但是提前发起 unshard 可以让下一个 FSDP unit 提前开始计算,从而与当前 reduce-scatter overlap。

具体看代码,BACKWARD_PRE 的 prefetch 对应于 _pre_backward_hook

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def _pre_backward_hook(
    state: _FSDPState,
    module: nn.Module,
    handle: FlatParamHandle,
    grad,
    *unused: Any,
) -> Any:
    # ...

    with torch.profiler.record_function("FullyShardedDataParallel._pre_backward_hook"):
        # ...
        handle._training_state = HandleTrainingState.BACKWARD_PRE

        if handle._needs_pre_backward_unshard:
            # If the handles have been prefetched, then there is no need to
            # call `_unshard()` again
            if not handle._prefetched:
                _unshard(
                    state,
                    handle,
                    state._unshard_stream,
                    state._pre_unshard_stream,
                )
            # Don't wait during trace
            if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
                state._device_handle.current_stream().wait_stream(state._unshard_stream)

        # Set this to `False` to ensure that a mistargeted prefetch does not
        # actually unshard these handles
        handle._needs_pre_backward_unshard = False
        with torch.profiler.record_function(
            "FullyShardedDataParallel._pre_backward_prefetch"
        ):
            _prefetch_handle(state, handle, _PrefetchMode.BACKWARD)
        handle.prepare_gradient_for_backward()
        handle._ran_pre_backward_hook = True
        return grad

BACKWARD_PRE 对应的 一个实际的的 trace:

对于 BACKWARD_POST 的 prefetch,对应于代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
def _post_backward_reshard(
    state: _FSDPState,
    handle: FlatParamHandle,
    *unused: Any,
) -> None:
    free_unsharded_flat_param = _should_free_in_backward(state, handle)
    _reshard(state, handle, free_unsharded_flat_param)

    with torch.profiler.record_function(
        "FullyShardedDataParallel._post_backward_prefetch"
    ):
        _prefetch_handle(state, handle, _PrefetchMode.BACKWARD)

AllGather Rate Limit

在讨论 FSDP 的 fwd 和 bwd 的 prefetch 的时候,有一个重要的影响因素还没有提到,这就是 GPU 显存。文章29 详细分析了 FSDP 过程中显存占用,具体的,以一个每层 1.6B 的 Transformer Block 并且 8 个 GPU 之间 shard 的前向计算 fp32 为例。计算当前 FSDP Unit 的时候,需要保存当前 sharded param 和 unsharded param,对应于是 1.6 * 4 / 8 = 0.8GB 的 sharded param 和 1.6 * 4=6.4GB 的 unshareded param,总的是 7.2GB 的参数显存。默认前向的 prefetch 中,需要提前 prefetch 下个 FSDP Unit 的 param,对应总的显存位 7.2GB * 2 = 14.4GB 的参数显存。

FSDP & CUDACachingAllocator 这篇文章30 详细分析了对于 FSDP 这种多 stream 场景与 torch 的 CUDACachingAllocator 交互中可能碰到的问题。

具体的,因为 CUDACachingAllocator 工作在 CPU Thread,并且其分配的 block 都是面向于每个 stream 的。对于单 stream 场景下能够保持顺序语义,CCA 能够直接 reuse 可以复用的 memory blocks。但是在 FSDP 场景下,GPU Memory 的分配和使用是多 stream 的,比如 FWD 中会在 _unshard_stream 调用 _alloc_padded_unsharded_flat_param 去申请 GPU Memory,然后在 _default_stream 复用这段内存进行前向计算。

也就是说,_unshard_stream 作为 producer stream,_default_stream 作为 consumer stream。对于这种多 stream 的场景,无法保证 inter-stream 之间的顺序,从而只有当这个 block 上的最后一个 GPU Kernel 完成之后, CCA 才可以确定一个 block 是否可以被 reuse。这里的确认一个 tensor 没有被任何 kernel 使用,是通过 Tensor.record_stream(stream) 的 API31

因此,如果 CPU thread 一直 run far ahead 在 GPU 执行之前,CPU 会发起多个后续 FSDP Unit 的 prefetch。但是 prefetch 过程中 producer stream 需要申请 GPU memory,这样会导致在 consumer stream 中本来可以申请更多的 memory 不再能实现(比如用于 default stream 的 activation)。这种情况下 CCA 必须先执行 cudaFree 再重新执行 cudaMalloc 来重新利用可用的 GPU memory,导致性能大幅下降。

为了解决这个问题,FSDP 实现了一个 rate limiter 3233,能够故意阻塞 CPU thread 来保证合理的 CCA memory block 的使用。

rate limiter 通过一个 free event queue 来实现,在 _unshard 的时候检查当前 event queue 里面是否有 event,判断是否需要等待,在 FWD 执行结束后的 _reshard 中来向 event queue 里 enqueue event,表示 CPU 已经提前发起了计算的 kernel。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def _unshard(state: _FSDPState, handle: FlatParamHandle, unshard_stream: torch.Stream, pre_unshard_stream: torch.Stream) -> None:
    # ...
    with state._device_handle.stream(pre_unshard_stream):
        ran_pre_unshard = handle.pre_unshard()
    if ran_pre_unshard:
        unshard_stream.wait_stream(pre_unshard_stream)
    if state.limit_all_gathers:
        event = state._free_event_queue.dequeue_if_needed()
        if event:
            with torch.profiler.record_function(
                "FullyShardedDataParallel.rate_limiter"
            ):
                event.synchronize()
    with state._device_handle.stream(unshard_stream):
        handle.unshard()
        handle.post_unshard()

最开始 event queue 队列为空,所以 CPU 可以直接发起前两个 all_gather 不需要等待。然后执行 CPU 执行到 reshard 的时候,记录对应的 event,等 FWD 计算中对对应 parameter 计算结束之后,对应的 event 也就结束了。因此当 CPU 执行到第三个 FSDP Unit 来的时候,因为 GPU 还没执行完,这个时候 event queue 里面有 event,CPU 必须执行 event.synchronize() 来做 CPU 阻塞。等对应的 GPU kernel 执行完毕,synchronize 结束,CPU 可以继续下发对应的 all_gather

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def _reshard(state: _FSDPState, handle: FlatParamHandle, free_unsharded_flat_param: bool):
    handle.reshard(free_unsharded_flat_param)
    if state.limit_all_gathers and free_unsharded_flat_param:
        if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
            # We don't run a even queue for freeing under torch compile atm
            # But maybe we need to? TODO(voz): Look into this
            free_event = state._device_handle.Event()
            free_event.record()
            state._free_event_queue.enqueue(free_event)
    handle.post_reshard()
    # Flat parameter freed or not, we always have to "unshard" the parameter
    # upon next access to get its shape correct.
    handle._prefetched = False

下图是之前的 FSDP 训练 LLaMA 的 trace,详细展示了上述的 rate liiter 逻辑:

前面提到 reshard 的时候会释放 FlatParameter 的空间,这个时候就会调用 record_stream 。因为这个空间是在 _unshard_stream 上分配的,正在跨 stream 被 _default_stream 使用,如果不 record_stream,下个 FSDP Unit _unshard 的时候 CCA 可能会直接分配这段空间。但是实际上可能 _default_stream 还没计算完,如果这时候复用就会造成问题。

1
2
3
4
5
6
7
8
9
def _free_unsharded_flat_param(self):  
    self._check_sharded_strategy()  
    unsharded_flat_param = self._get_padded_unsharded_flat_param()  
    self._check_on_compute_device(unsharded_flat_param)  
    # Do not free the memory until all ops in the current stream finish  
    _no_dispatch_record_stream(  
        unsharded_flat_param, self._device_handle.current_stream()  
    )  
    _free_storage(unsharded_flat_param)

关于这部分,更详细的分析,可以参考这篇文档30

FSDP2 API 变化与演进

本文前面详细介绍了 FSDP1 的主要设计和优化,随着模型越来越大,社区发现 FSDP1 的 FlatParameter 限制了进一步优化训练性能。为了解决这个问题,社区提出了下一代的 FSDP: Per-Parameter-Sharding FSDP,也就是 FSDP212

相比于 FSDP1,FSDP2 能够实现:

  1. Flexible fp8 all-gather34
  2. Flexible frozen parameters
  3. Communication-free sharded state dict
  4. Future communication optimization in the compiler

相比于原来的 FlatParameter 设计,FSDP2 3536基于 DTensor 实现了 Per-Parameter Sharding,如下图所示,后续将会进一步分析其设计。

FSDP 2
FSDP 2


  1. https://github.com/pytorch/torchtitan ↩︎

  2. Introducing PyTorch Fully Sharded Data Parallel (FSDP) API, 2022-03-14, https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/ ↩︎

  3. Fully Sharded Data Parallel: faster AI training with fewer GPUs, 2021-07-15, https://engineering.fb.com/2021/07/15/open-source/fsdp/ ↩︎

  4. TorchTitan: One-stop PyTorch native solution for production ready LLM pre-training, https://arxiv.org/abs/2410.06511 ↩︎

  5. FSDP1 -> FSDP2, 2024-10-19, https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md ↩︎

  6. PyTorch 2.4 Release Blog, 2024-07-24, https://pytorch.org/blog/pytorch2-4/ ↩︎

  7. PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel, 2023-04-21, https://arxiv.org/abs/2304.11277 ↩︎ ↩︎

  8. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models, 2019-10-04, https://arxiv.org/abs/1910.02054 ↩︎

  9. Maximizing Training Throughput Using PyTorch FSDP and Torch.compile, 2024-05-21, https://pytorch.org/blog/maximizing-training-throughput/ ↩︎

  10. Rethinking PyTorch Fully Sharded Data Parallel (FSDP) from First Principles, 2023-01-20, https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 ↩︎

  11. https://pytorch.org/blog/maximizing-training/ ↩︎

  12. RFC: Per-Parameter-Sharding FSDP, 2023-11-23, https://github.com/pytorch/pytorch/issues/114299 ↩︎ ↩︎

  13. https://pytorch.s3.amazonaws.com/posters/ptc2022/E03.pdf ↩︎

  14. https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html ↩︎

  15. https://blog.clika.io/fsdp-1/ ↩︎ ↩︎

  16. https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_into_tensor ↩︎

  17. https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather ↩︎

  18. https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/fsdp/fully_sharded_data_parallel.py#L842-L867 ↩︎

  19. https://github.com/foundation-model-stack/fms-fsdp ↩︎

  20. https://github.com/foundation-model-stack/fms-fsdp/blob/main/fms_fsdp/policies/ac_handler.py#L16 ↩︎

  21. https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/fsdp/_runtime_utils.py#L235-L269 ↩︎

  22. https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/fsdp/_init_utils.py#L104-L153 ↩︎

  23. https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/fsdp/_flat_param.py#L1332-L1356 ↩︎

  24. https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/fsdp/_flat_param.py#L1260-L1266 ↩︎

  25. https://pytorch.org/tutorials/recipes/distributed_device_mesh.html ↩︎

  26. https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/fsdp/_init_utils.py#L161-L167 ↩︎

  27. https://github.com/pytorch/pytorch/blob/v2.6.0/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L230-L249 ↩︎

  28. https://github.com/pytorch/pytorch/blob/v2.6.0/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L3078 ↩︎

  29. https://pytorch.org/docs/stable/notes/fsdp.html ↩︎ ↩︎

  30. https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486 ↩︎ ↩︎

  31. https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html ↩︎

  32. https://pytorch.org/blog/scaling-pytorch-fsdp-for-training-foundation-models-on-ibm-cloud/ ↩︎

  33. https://github.com/pytorch/pytorch/pull/83917 ↩︎

  34. Enabling Float8 All-Gather in FSDP2, 2024-08-08, https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359 ↩︎

  35. The Past, Present, and Future of PyTorch, 2024-09-18, https://pytorch2024.sched.com/event/1iy3h/keynote-pytorch-technical-deep-dive-piotr-bialecki-nvidia-peng-wu-will-constable-kartikay-khandelwal-mengtao-martin-yuan-meta ↩︎

  36. https://pytorch2024.sched.com/event/1fHn3/torchtitan-large-scale-llm-training-using-native-pytorch-3d-parallelism-wanchao-liang-meta-linsong-chu-ibm-research ↩︎