VeRL 训练中,有一个重要的 feature 是 token balancing ,这里主要总结下其机制和原理。

Balancing in VeRL

Balancing Across DP Ranks

在 single controller 的控制逻辑中,也就是 trainer.fit() 函数中,在完成 Generation 之后,会依次通过这里的 actor.compute_log_prob 将所有 batch 的数据通过 Dispatch 方法分发给实际训练的 workers。

比如这里的 actor.compute_log_prob 会通过注册的 DP 分发给各自的 worker。为了保证各个 dp rank 之间计算均衡,期望每个 rank 上要计算的 token 数尽可能接近。

1
2
3
4
5
6
7
# dispatch batch to all dp actor workers
@register(Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
    # ...
    with self.ulysses_sharding_manager
        output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
    # ...

如下所示,通过在 single controller 的 driver 对全局 batch 做均衡,再下发给各个 DP worker 的时候,每个 rank 的 batch 的 token 数会比原来均衡得多。

这对应到代码中 fit 函数的 balance)_batch

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
for prompts in dataloader:
    # Stage 1: Generation
    batch = actor.generate_sequences(prompts)
    
    # Token Balancing
    # Balance the number of valid tokens across DP ranks.
    # NOTE: This usually changes the order of data in the `batch`
    # which won't affect the advantage calculation (since it's based on uid),
    # but might affect the loss calculation (due to the change of mini-batching).
    # TODO: Decouple the DP balancing and mini-batching.
    if config.balance_batch:
        _balance_batch(batch
    
    # Stage 2: Experience Preparation
    batch = reward.compute_reward(batch)
    batch = actor.compute_log_prob(batch)
    batch = reference.compute_log_prob(batch)
    batch = critic.compute_values(batch)
    batch = compute_advantage(batch, "gae")
    
    # Stage 3: Training
    critic.update_critic(batch)
    actor.update_actor(batch)

具体的实现如下所示,主要是通过 get_seqlen_balanced_partitions 将全局 batch 按照预期要划分的 partition 份数 world_size,对原来的 batch reorder,这样在 dispatch 的时候就可以实现各个 dp rank 的数据均衡。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
    def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"):
        """Reorder the data on single controller such that each dp rank gets similar total tokens"""
        attention_mask = batch.batch["attention_mask"]
        batch_size = attention_mask.shape[0]
        global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist()  # (train_batch_size,)
        world_size = self.actor_rollout_wg.world_size
        global_partition_lst = get_seqlen_balanced_partitions(
            global_seqlen_lst, k_partitions=world_size, equal_size=True
        )
        # reorder based on index. The data will be automatically equally partitioned by dispatch function
        global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
        batch.reorder(global_idx)
        global_balance_stats = log_seqlen_unbalance(
            seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix
        )
        metrics.update(global_balance_stats)

这里 get_seqlen_balanced_partitions 如何找到智能的分区是通过经典的 Karmarkar-Karp 算法1,本质上就是将一组数 (负载/任务)划分到 k 各不同 worker,让他们之间差距最小。

1
2
3
4
5
def get_seqlen_balanced_partitions(seqlen_list: list[int], k_partitions: int, equal_size: bool):
    assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]"

    partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size)
    return _check_and_sort_partitions(partitions)

函数 get_seqlen_balanced_partitions 返回的是 List[List[int]],代表的是 k 个 partitions,每个 list 里面是本 partition 的每个 item 在原来 list 的 indices。

1
List[List[int]]: A list containing k_partitions lists. Each inner list contains the original indices of the items assigned to that partition. The indices within each partition list are sorted.

Balancing Across Micro Batches

在 stage 2 计算 old_log_prob 或者 stage 3 update_actor 的时候,都会将每个 dp worker 的 batch 拆分成 micro_batch,防止 GPU OOM。

这里切分方式是根据 max_token_len 拆成多个 micro_batch,就是 dynamic batching,每个 micro_batch 的样本数不确定。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def compute_log_prob(self, data: DataProto):
    select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
    non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []
    data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)


    if use_dynamic_bsz:
        max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
        micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len)
    else:
        micro_batches = data.split(micro_batch_size)

    log_probs_lst = []
    entropy_lst = []

    for micro_batch in micro_batches:
        micro_batch = micro_batch.to(get_device_id())
        model_inputs = {**micro_batch.batch,**micro_batch.non_tensor_batch}

        with torch.no_grad():
            entropy, log_probs = self._forward_micro_batch(model_inputs, temperature=temperature, calculate_entropy=calculate_entropy
)
        log_probs_lst.append(log_probs)
        if calculate_entropy:
            entropy_lst.append(entropy)

    log_probs = torch.concat(log_probs_lst, dim=0)
    entropys = None

    if calculate_entropy:
        entropys = torch.concat(entropy_lst, dim=0)

    if use_dynamic_bsz:
        log_probs = restore_dynamic_batch(log_probs, batch_idx_list)
        if calculate_entropy:
            entropys = restore_dynamic_batch(entropys, batch_idx_list)

    return log_probs, entropys

在每个 micro_batch 计算中,需要在所有 dp worker 之间同步。因此,除了需要在各个 DP 之间保持 batch 的均衡,也还需要在 micro batches 之间保持 tokens 均衡。

这就对应着开启 use_dynamic_batch 之后,通过 prepare_dynamic_batch 实现均衡。其中的 rearrange_micro_batches 也调用了 get_seqlen_balanced_partitions

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def prepare_dynamic_batch(
    data: DataProto,
    max_token_len: int,
    dp_group=None,
    num_batches_divided_by=None,
    same_micro_num_in_dp=True,
    min_num_micro_batch=None,
    use_dynamic_bsz_balance=True,
) -> tuple[list[DataProto], list[list[int]]]:
    batch, batch_idx_list = rearrange_micro_batches(
        data.batch,
        max_token_len=max_token_len,
        dp_group=dp_group,
        num_batches_divided_by=num_batches_divided_by,
        same_micro_num_in_dp=same_micro_num_in_dp,
        min_num_micro_batch=min_num_micro_batch,
        use_dynamic_bsz_balance=use_dynamic_bsz_balance,
    )
    micro_batches = []
    for i, batch_idx in enumerate(batch_idx_list):
        tensors = dict(batch[i])
        non_tensors = {key: value[batch_idx] for key, value in data.non_tensor_batch.items()}
        meta_info = copy.deepcopy(data.meta_info)
        micro_batches.append(DataProto.from_dict(tensors, non_tensors, meta_info=meta_info))

    return micro_batches, batch_idx_list

给个简化的例子,假设我们有一个 batch,里面包含 6 个 sequence。为了方便,我们只关心它们的真实长度。

  • 输入 batch: 包含 6 个 sequence。
  • 序列的真实长度 (seq_len_effective)[100, 900, 50, 950, 400, 600]
    • 可以看到,序列长度差异很大,最短的只有 50,最长的有 950。
  • 函数参数 max_token_len: 2000
    • 这意味着我们希望每个拆分出的微批次,其包含的 token 总数不要超过 2000。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def rearrange_micro_batches(
    batch,
    max_token_len,
    dp_group=None,
    num_batches_divided_by=None,
    same_micro_num_in_dp=True,
    min_num_micro_batch=None,
    use_dynamic_bsz_balance=True,
):
    """
    Split a batch into micro-batches by total token count, with optional DP sync and padding.

    Args:
        batch (TensorDict): must include "attention_mask" (B*S); other fields are sliced similarly.
        max_token_len (int): max sum of attention_mask per micro-batch.
        dp_group (optional): torch.distributed group for data-parallel sync.
        num_batches_divided_by (optional): virtual pipeline parallel size, for megatron.
        same_micro_num_in_dp (bool): if True and dp_group set, pad all ranks to the same count.
        min_num_micro_batch (int, optional): force at least this many splits (pads empty ones).
        use_dynamic_bsz_balance (bool, optional): balance the computational workload between micro-batches

    Returns:
        List[TensorDict]: the micro-batches.
        List[List[int]]: index lists mapping each micro-batch back to original positions.
    """
    # this is per local micro_bsz
    max_seq_len = batch["attention_mask"].shape[-1]
    assert max_token_len >= max_seq_len, (
        f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}"
    )
    seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1)
    total_seqlen = seq_len_effective.sum().item()
    # NOTE: num_microbatches <= batch_size, so take the min of this two.
    num_micro_batches = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len))
    if min_num_micro_batch is not None:
        # used to support pp
        num_micro_batches = max(min_num_micro_batch, num_micro_batches)
    if dist.is_initialized() and same_micro_num_in_dp:
        num_micro_batches = torch.tensor([num_micro_batches], device=get_device_name())
        dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group)
        num_micro_batches = num_micro_batches.cpu().item()
    if num_batches_divided_by is not None:
        num_micro_batches = roundup_divisible(num_micro_batches, num_batches_divided_by)

    seq_len_effective = seq_len_effective.tolist()
    assert num_micro_batches <= len(seq_len_effective)

    micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False)

    if use_dynamic_bsz_balance:
        # Use the sum of squared sequence lengths to approximate attention computation workload
        micro_bsz_idx.sort(
            key=lambda partition: (
                sum(seq_len_effective[idx] ** 2 for idx in partition),
                min(partition) if partition else 0,
            ),
            reverse=True,
        )

    micro_batches = []

    for partition in micro_bsz_idx:
        curr_micro_batch = []
        for idx in partition:
            curr_micro_batch.append(batch[idx : idx + 1])
        curr_micro_batch = torch.cat(curr_micro_batch)

        micro_batches.append(curr_micro_batch)

    return micro_batches, micro_bsz_idx
  1. 计算 num_micro_batches
    • 总 token 数100 + 900 + 50 + 950 + 400 + 600 = 3000
    • 所需 num_micro_batchesceil(3000 / 2000) = 2
    • 函数计算出,为了满足 max_token_len 的限制,最少需要将这 6 个序列拆分成 2 个微批次。
  2. 第一次分区 (get_seqlen_balanced_partitions)
    • 目标: 将 6 个序列分成 2 组,使得每组的 token 数量之和 尽可能接近。理想情况下,每组的 token 总和应该是 3000 / 2 = 1500
    • 简单的切分(错误方式): 如果我们按顺序切分,会得到:
      • 组1: [100, 900, 50] (来自原始索引 0, 1, 2) -> 总长 1050
      • 组2: [950, 400, 600] (来自原始索引 3, 4, 5) -> 总长 1950
      • 这显然非常不均衡。
    • 智能分区的做法: 函数会通过算法找到一个更优的组合。对于这个例子,一个完美的分区是:
      • 分区 A: 包含长度为 [950, 400, 100, 50] 的序列。
      • 分区 B: 包含长度为 [900, 600] 的序列。
    • 我们检查一下 token 总数:
      • 分区 A 总长: 950 + 400 + 100 + 50 = 1500
      • 分区 B 总长: 900 + 600 = 1500
    • 完美!两组的 token 总数完全相等。此时,函数得到的原始索引分区 micro_bsz_idx 是 [[3, 4, 0, 2], [1, 5]] (顺序可能不同)。
  3. 第二次均衡 (use_dynamic_bsz_balance)2
    • 现在,函数考虑 Attention 的 $O(N^2)$ 复杂度,即计算负载与长度的平方和成正比
    • 它会为上一步得到的两个分区计算“计算负载”:
      • 分区 A 的负载950² + 400² + 100² + 50² = 902500 + 160000 + 10000 + 2500 = 1,075,000
      • 分区 B 的负载900² + 600² = 810000 + 360000 = 1,170,000
    • 我们发现,虽然两个分区的 token 总数相同,但分区 B 的计算负载明显更高
    • 函数会根据这个计算负载对分区进行降序排序。所以,处理顺序被调整为:先处理分区 B,再处理分区 A。
    • 最终排序后的索引列表 micro_bsz_idx 变为: [[1, 5], [3, 4, 0, 2]]
  4. 生成最终的微批次
    • 第一个微批次: 函数会从原始 batch 中取出索引为 1 和 5 的序列,将它们打包成一个新的、更小的 TensorDict。这个微批次包含2个序列,总长度1500,计算负载 1,170,000
    • 第二个微批次: 函数会从原始 batch 中取出索引为 3402 的序列,打包成第二个 TensorDict。这个微批次包含4个序列,总长度1500,计算负载 1,075,000

PPO update_actor 中 micro batch balancing

在 PPO update_actor 训练的时候,提前全局划分好 micro batch,而不是先在 DP 间划分,之后再在每个 dp rank 内划分。

OminiPatch