Ulysses Sequence Parallel (USP) 是由 DeepSpeed 团队提出的一种高效的序列并行方案。它主要为了解决在大模型训练中,由于序列长度(Sequence Length)过长导致单张显卡显存溢出(OOM)的问题。

与传统的 Megatron-LM 序列并行不同,Ulysses 的核心思想是:在 Attention 计算前,通过 All-to-All 通信将序列维度的分布式转变为注意力头(Attention Heads)维度的分布式。

符号表示 数学含义
$N$ sequence length
$hc$ head number
$l$ transformer layer number
$b$ batch size
$V$ vocabulary size
$d$ hidden layer size
$h_s$ per head hidden size, $h_s = d // hc$

核心原理:维度转换

在 Ulysses 方案中,模型在计算 Self-Attention 之前和之后各进行一次 All-to-All 通信。

核心步骤

  1. 输入状态:按 seq 维度切分输入数据,每个 rank 输入 sequence 形状为 [N/P, d]
  2. 对应计算 $Q, K, V$ 张量的形状为 [N/P, d],对应于 [N/P, hc, hs] 。这里 $P$ 是并行度(Degree of Parallelism),序列被切分到了不同的 GPU 上。
    1. 注意,这里说明 ulysses 不切分模型,只切分数据的 sequence 维度
  3. 第一次 All-to-All:将数据从 序列维度 (Sequence) 重新排列到 头维度 (Head)。变换后,每张显卡拥有完整的序列长度,但只拥有部分注意力头,形状变为 [N, hc/P, hs],也就是 [N, d/P]
    1. 这里要求 hc 可以整除 P,也就是每个 rank 分到 hc/P 的 head
    2. 这里的 all-to-all 可以理解成一种 permutation 转置式通信
  4. 计算 Attention:每张显卡独立计算它所拥有的那部分 Heads 的 Attention,最终每张卡上产出结果 $P_h$ chunk,尺寸为 (N, d/P)
  5. 第二次 All-to-All:计算完成后,再通过 All-to-All 把数据从 头维度 还原回 序列维度,每张卡上维护的 P Chunk 形状回到 [N/P, d]
  6. 单张卡上拥有完整的  $W_O$ 矩阵,我们将 P chunk 和它相乘,得到最后的输出 O chunk,尺寸为 (N/P, d)
  7. 进入 MLP 层,由于在 MLP 层中,不涉及 token 和 token 之间的相关性计算,所以各 seq_chunk 块可以独自计算。
  8. 重复上述过程,直到算到Loss 为止。

代码实现

这里主要列出

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.distributed as dist

def ulysses_attention(query, key, value, pg):
    """
    pg: 序列并行的进程组 (Process Group)
    P: 序列并行度 (world_size)
    """
    P = dist.get_world_size(group=pg)
    
    # 假设输入形状: [N/P, hc, hs]  (hc 是总头数)
    
    # 1. 第一次 All-to-All: 从 Seq 维度转到 Head 维度
    # 转换后形状: [N, hc/P, hs]
    q = all_to_all_sequence_to_head(query, pg)
    k = all_to_all_sequence_to_head(key, pg)
    v = all_to_all_sequence_to_head(value, pg)
    
    # 2. 标准的局部 Attention 计算
    # 因为现在每张卡都有完整的 Seq,可以正常算 Attention
    attn_output = local_attention(q, k, v) # [N, hc/P, hs]
    
    # 3. 第二次 All-to-All: 从 Head 维度转回 Seq 维度
    # 转换后形状: [N/P, hc, hs]
    output = all_to_all_head_to_sequence(attn_output, pg)
    
    return output

FSDP 与 Ulysses 配合使用

下面的代码以 veRL 中代码为例,参考1

切分 sequence

在 veRL 的 prepare_model_inputs 中,会先调用 ulysses_pad_and_slice_inputs 去 sequence 维度切分

1
2
3
4
5
6
7
if self.use_ulysses_sp:
    input_ids_rmpad, position_ids_rmpad, pad_size = 
        ulysses_pad_and_slice_inputs(
            input_ids_rmpad,
            position_ids_rmpad=position_ids_rmpad,
            sp_size=self.ulysses_sequence_parallel_size,
        )

对应的实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def ulysses_pad_and_slice_inputs(
    input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1
):
    """
    Pad and slice input_ids to be divisible by sp_size
    Pad position_ids to be divisible by sp_size.

    Note both input_ids_rmpad and position_ids_rmpad will be padded and sliced.

    The is the utility of pre-forward for ulysses sequence parallelism

    Args:
        input_ids_rmpad: shape of [bsz, seqlen]
        position_ids_rmpad: shape of [bsz, seqlen], where bsz must be 1
        sp_size (int): ulysses sequence parallelism size

    Returns:
        torch.Tensor: padded and sliced input_ids
        torch.Tensor: padded and sliced position_ids
        int: pad size
    """
    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(input_ids_rmpad, position_ids_rmpad, sp_size)
    input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False)
    if position_ids_rmpad is not None:
        position_ids_rmpad = slice_input_tensor(position_ids_rmpad, dim=1, padding=False)
    return input_ids_rmpad, position_ids_rmpad, pad_size

Attention 计算流程

  1. 先计算 q, k, v,对应的 shape 为 (bsz, seq_len/n, n_head, head_dim),其中 Q 和 V 的 head 数可能不一致。转置之后,shape 为 (bsz, n_head, seq_len/n, head_dim)
    1. 注意,代码里的 q_len 是已经经过切分了的 seq_len/n
  2. 第一次 all_to_all,调用 gather_seq_scatter_heads,得到 query_states 为 (bsz, n_head/n, seq_len, head_dim)
    1. 因此这里获取了 full_q_len
  3. Apply RoPE
  4. Repeat_KV,针对于 GQA/MQA 场景
  5. 每个 rank 计算 local attention
  6. 第二次 all_to_all,调用 gather_heads_scatter_seq,得到 attn_output 为 (bsz, seq_len/n, n_head, head_dim)
  7. 经过 o_proj,attn_output 为 (bsz, seq_len/n, n_head, head_dim)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def qwen2_flash_attn_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    cache_position: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
):
    """
    Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.
    """
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

    ########## AlltoAll for Ulysses ##########
    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()

    if ulysses_sp_size > 1:
        validate_ulysses_config(self.num_heads, ulysses_sp_size)

        # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)

    full_q_len = query_states.size(2)  # full seq length

    if position_embeddings is None:
        cos, sin = self.rotary_emb(value_states, position_ids)
    else:
        cos, sin = position_embeddings
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

    # repeat k/v heads if n_kv_heads < n_heads
    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)
    dropout_rate = 0.0 if not self.training else self.attention_dropout

    # Reashape to the expected shape for Flash Attention
    query_states = query_states.transpose(1, 2)
    key_states = key_states.transpose(1, 2)
    value_states = value_states.transpose(1, 2)

    attn_output = _flash_attention_forward(
        query_states,
        key_states,
        value_states,
        attention_mask,
        full_q_len,
        position_ids=position_ids,
        dropout=dropout_rate,
        sliding_window=None,
        is_causal=self.is_causal,
        use_top_left_mask=flash_attn_supports_top_left_mask(),
    )

    # use full_q_len to reshape
    attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
    ########## AlltoAll for Ulysses ##########
    if ulysses_sp_size > 1:
        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

第一次 all-to-all

在 ulysses 的第一次 all-to-all,从 [N/P, d] 形状转换成 [N, d/P],也就是 gather sequence,scatter heads。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def gather_seq_scatter_heads(
    x: Tensor,
    seq_dim: int,
    head_dim: int,
    unpadded_dim_size: int = 0,
    group: ProcessGroup = None,
) -> Tensor:
    """
    A func to sync embedding input with alltoall in sequence parallel
    gather sequence dimension and scatter head dim:
    e.g. seq_dim: 1, head_dim: 2
    [bsz, seq/n, h, ...] -> [bsz, seq, h/n, ...]
    """
    group = get_ulysses_sequence_parallel_group() if group is None else group
    if not group:
        return x
    sp_world = get_ulysses_sequence_parallel_world_size(group)
    x = SeqAllToAll.apply(group, x, head_dim, seq_dim)
    if unpadded_dim_size and unpadded_dim_size % sp_world != 0:
        padding_size = x.size(seq_dim) - unpadded_dim_size
        x = _unpad_tensor(x, seq_dim, padding_size)
    return x

这里用到了 SeqAllToAll 这个算子,具体的对应于在 head_dim 维度 scatter,在 seq_dim 维度 gather。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class SeqAllToAll(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx: Any,
        group: dist.ProcessGroup,
        local_input: Tensor,
        scatter_dim: int,
        gather_dim: int,
        async_op: bool = False,
    ) -> Tensor:
        ctx.group = group
        ctx.scatter_dim = scatter_dim
        ctx.gather_dim = gather_dim
        ctx.async_op = async_op
        return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op)

    @staticmethod
    def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]:
        input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous() if ctx.async_op else grad_output[0]
        return (
            None,
            all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False),
            None,
            None,
            None,
            None,
        )

对应的 all_to_all_tensor 的实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
def all_to_all_tensor(
    local_input: Tensor,
    scatter_dim: int,
    gather_dim: int,
    group: Optional[dist.ProcessGroup] = None,
    async_op: bool = False,
):
    group = get_ulysses_sequence_parallel_group() if group is None else group
    seq_world_size = dist.get_world_size(group)
    input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
    output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
    comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
    if async_op:

        def wait():
            comm.wait()
            return torch.cat(output_list, dim=gather_dim).contiguous()

        return wait
    return torch.cat(output_list, dim=gather_dim).contiguous()

注意,对于 SeqAllToAll 这个算子,其 backward 对应也是一个 all_to_all,只是 gather_dim 和 scatter_dim 正好反了过来。

  • Forwardall_to_all_tensor(input, scatter_dim, gather_dim, ...)
  • Backwardall_to_all_tensor(grad, gather_dim, scatter_dim, ...)

第二次 all-to-all

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor:
    """
    A func to sync attention result with alltoall in sequence parallel
    gather head dimension and scatter seq dim:
    e.g. seq_dim: 1, head_dim: 2
    [bsz, seq, h/n, ...] -> [bsz, seq/n, h, ...]
    """
    group = get_ulysses_sequence_parallel_group() if group is None else group
    if not group:
        return x
    dim_size = x.size(seq_dim)
    sp_world = get_ulysses_sequence_parallel_world_size(group)
    if dim_size % sp_world != 0:
        padding_size = sp_world - (dim_size % sp_world)
        x = _pad_tensor(x, seq_dim, padding_size)
    return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)

gather 恢复 sequence

在 veRL 的 prepare_model_outputs 中,将计算的结果在 sequence dim 维度 gather 起来,这样每个 rank 都有完整的计算结果,就像最开始完整的输入一样。

1
2
3
4
5
6
7
8
9
# gather log_prob if sp > 1
if self.use_ulysses_sp:
    # gather and unpad for the ulysses sp
    log_probs = gather_outputs_and_unpad(
        log_probs,
        gather_dim=0,
        unpad_dim=0,
        padding_size=pad_size,
    )

gather_outputs_and_unpad 的实现实际上是调用了 Gather 算子,并且去除掉 pad

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def gather_outputs_and_unpad(
    x: Tensor,
    gather_dim: int,
    unpad_dim: int = None,
    padding_size: int = 0,
    grad_scaler: bool = True,
    group: Optional[dist.ProcessGroup] = None,
):
    """
    Gather a tensor across a process group and optionally unpad its padded elements.

    Args:
        x (Tensor): Input tensor to gather.
        gather_dim (int): Dimension along which to gather across ranks.
        unpad_dim (int, optional): Dimension from which to remove padding. If None, no unpadding.
        padding_size (int): Number of padding elements to remove on `unpad_dim`. Defaults to 0.
        grad_scaler (bool): Whether to apply gradient scaling during gather. Defaults to True.
        group (ProcessGroup, optional): Process group for gathering. If None, uses
            `get_ulysses_sequence_parallel_group()`. If still None, returns `x` unchanged.

    Returns:
        Tensor: The gathered tensor, with padding removed if requested.
    """
    group = get_ulysses_sequence_parallel_group() if group is None else group
    if group is None:
        return x
    x = Gather.apply(group, x, gather_dim, grad_scaler)
    if unpad_dim is not None:
        assert isinstance(padding_size, int), "padding size is not given or is not an integer"
        if padding_size == 0:
            return x
        x = _unpad_tensor(x, unpad_dim, padding_size)
    return x

对应的算子实现为

  • forward 就是将各个 rank 被切分的 logprob 按照 sequence 维度 gather 在一起
  • backward 就是将算出来的 grad 切片,分到每个 rank
    • 注意这里 backward 需要执行 grad_scale
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Gather(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx: Any,
        group: dist.ProcessGroup,
        local_tensor: Tensor,
        gather_dim: int,
        grad_scaler: bool = True,
        async_op=False,
    ) -> Tensor:
        ctx.group = group
        ctx.gather_dim = gather_dim
        ctx.grad_scaler = grad_scaler
        ctx.async_op = async_op

        sp_world_size = dist.get_world_size(group=group)
        ctx.sp_world_size = sp_world_size

        sp_rank = dist.get_rank(group=group)
        ctx.sp_rank = sp_rank

        local_shape = list(local_tensor.size())
        split_size = local_shape[0]
        part_size = local_shape[gather_dim]  # store original size
        ctx.part_size = part_size

        output = all_gather_tensor(local_tensor, group, async_op)
        return torch.cat(output.split(split_size, dim=0), dim=gather_dim)

    @staticmethod
    def backward(ctx: Any, grad_output: Tensor) -> Any:
        if ctx.grad_scaler:
            grad_output = grad_output * ctx.sp_world_size
        return (
            None,
            grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ctx.sp_rank].contiguous(),
            None,
            None,
            None,
            None,
        )

Ulysses Sharding Manager

在 veRL 最开始的 ulysses 实现中1,引入了 FSDPUlyssesShardingManager,

1
2
3
4
5
6
7
8
9
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
     with self.ulysses_sharding_manager:
        data = self.ulysses_sharding_manager.preprocess_data(data=data)
        # perform training
        metrics = self.actor.update_policy(data=data)
        output = DataProto(meta_info={'metrics': metrics})
        output = self.ulysses_sharding_manager.postprocess_data(data=output)
        output = output.to('cpu')

这里是因为将 update_actor 通过 DP_COMPUTE_PROTO 下发给不同 worker 的时候会沿着 FSDP 维度 worker_group.world_size 对数据平均切分,但是在 Ulysses 中,我们需要保证同一个 sp group 中的数据是相同的。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
    from verl.single_controller.base.worker_group import WorkerGroup

    assert isinstance(worker_group, WorkerGroup)
    # Note: enable auto padding for dp compute DatapProto
    splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding(
        worker_group.world_size,
        *args,
        **kwargs,
    )
    return splitted_args, splitted_kwargs

比如说,假设原来 bsz=1024, world_size=64, sp_size=4,那么实际的 dp_size=16,每个 DP group 切分到的数据 bsz 应该为 1024/16=64,但是这里是直接切 64 份,每个 worker 上的 bsz=16。因此要通过 preprocess_data 将其 all_gather 起来。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
 def preprocess_data(self, data: DataProto) -> DataProto:
    """
    AllGather data from sp region
    This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE
    In Ulysses, we need to make sure the same data is used across a SP group
    """
    if self.device_mesh is not None:
        group = self.device_mesh["sp"].get_group()

        all_gather_data_proto(data=data, process_group=group)
    return data

对应的,计算结束之后,通过 postprocess_data 再将数据平均切分下去。

1
2
3
4
5
6
7
8
9
def postprocess_data(self, data: DataProto) -> DataProto:
    """
    Split the data to follow FSDP partition
    """
    if self.device_mesh is not None:
        sp_size = self.device_mesh["sp"].size()
        sp_rank = self.device_mesh["sp"].get_local_rank()
        data = data.chunk(chunks=sp_size)[sp_rank]
    return data

这个看起来有点麻烦,因为我们已经有 device_mesh 的信息了,可以不需要这么处理,因此在这个 PR2 移除了 preprocess_data 和 postprocess_data 这部分逻辑。主要是通过 make_nd_compute_dataproto_dispatch_fn 来实现的。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
def update_actor(self, data: DataProto):
     with self.ulysses_sharding_manager:
        # perform training
        metrics = self.actor.update_policy(data=data)
        output = DataProto(meta_info={'metrics': metrics})
        output = output.to('cpu')```

具体来说在一个 worker group 的每个 worker 初始化的时候根据 SP 的配置获取对应的 dp mesh然后通过 `_register_dispatch_collect_info` 注册对应的 mesh 信息
```python
    # build device mesh for Ulysses Sequence Parallel
    self.ulysses_device_mesh = None
    self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1)
    dp = world_size // self.ulysses_sequence_parallel_size
    if self.ulysses_sequence_parallel_size > 1:
        self.ulysses_device_mesh = init_device_mesh(
            device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]
        )
            
    # create training dispatch
    if self.ulysses_device_mesh is not None:
        is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0
        self._register_dispatch_collect_info(
            "actor", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect
        )
    else:
        self._register_dispatch_collect_info("actor", dp_rank=self.rank, is_collect=True)

注册信息如下

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
def _register_dispatch_collect_info(self, mesh_name: str, dp_rank: int, is_collect: bool):
        """Register the dp_rank for a given mesh name. This function is meant to be called by the worker

        Args:
            mesh_name (str):
                Name of the mesh to register dp_rank for.
            dp_rank (int):
                dp_rank to register for the given mesh name.
            is_collect (bool):
                Whether the dp_rank is used for collect.
        """
        if mesh_name in self.__dispatch_dp_rank or mesh_name in self.__collect_dp_rank:
            raise ValueError(f"mesh_name {mesh_name} has been registered")
        self.__dispatch_dp_rank[mesh_name] = dp_rank
        self.__collect_dp_rank[mesh_name] = is_collect

然后在 make_nd_compute_dataproto_dispatch_fn 会通过注册的信息。计算出正确的 dp_size 去做分发 dispatch。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
def make_nd_compute_dataproto_dispatch_fn(mesh_name):
    return {
        "dispatch_fn": partial(dispatch_lazy_compute_data_proto, mesh_name),
        "collect_fn": partial(collect_lazy_compute_data_proto, mesh_name),
    }

def dispatch_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs):
    from verl.single_controller.base.worker_group import WorkerGroup

    assert isinstance(worker_group, WorkerGroup)

    # query dispatch info of the worker group
    if mesh_name not in worker_group._dispatch_info:
        worker_group._dispatch_info[mesh_name] = worker_group._query_dispatch_info(mesh_name)
        assert len(worker_group._dispatch_info[mesh_name]) == worker_group.world_size

    dp_rank_mapping = worker_group._dispatch_info[mesh_name]
    # perform dispatch
    dp_size = max(dp_rank_mapping) + 1
    return dispatch_nd_compute_dataproto(dp_rank_mapping, dp_size, worker_group, *args, **kwargs)

对于 collect 是类似的,只是只会在 dp_group 的 sp rank 0 上执行 collect,因为同一个 dp group 其他的 sp rank 数据都是相同的。