RL 工程优化: Token Balancing
VeRL 训练中,有一个重要的 feature 是 token balancing ,这里主要总结下其机制和原理。
Balancing in VeRL
Balancing Across DP Ranks
在 single controller 的控制逻辑中,也就是 trainer.fit() 函数中,在完成 Generation 之后,会依次通过这里的 actor.compute_log_prob 将所有 batch 的数据通过 Dispatch 方法分发给实际训练的 workers。
比如这里的 actor.compute_log_prob 会通过注册的 DP 分发给各自的 worker。为了保证各个 dp rank 之间计算均衡,期望每个 rank 上要计算的 token 数尽可能接近。
|
|
如下所示,通过在 single controller 的 driver 对全局 batch 做均衡,再下发给各个 DP worker 的时候,每个 rank 的 batch 的 token 数会比原来均衡得多。
这对应到代码中 fit 函数的 balance)_batch
|
|
具体的实现如下所示,主要是通过 get_seqlen_balanced_partitions 将全局 batch 按照预期要划分的 partition 份数 world_size,对原来的 batch reorder,这样在 dispatch 的时候就可以实现各个 dp rank 的数据均衡。
|
|
这里 get_seqlen_balanced_partitions 如何找到智能的分区是通过经典的 Karmarkar-Karp 算法1,本质上就是将一组数 (负载/任务)划分到 k 各不同 worker,让他们之间差距最小。
|
|
函数 get_seqlen_balanced_partitions 返回的是 List[List[int]],代表的是 k 个 partitions,每个 list 里面是本 partition 的每个 item 在原来 list 的 indices。
|
|
Balancing Across Micro Batches
在 stage 2 计算 old_log_prob 或者 stage 3 update_actor 的时候,都会将每个 dp worker 的 batch 拆分成 micro_batch,防止 GPU OOM。
这里切分方式是根据 max_token_len 拆成多个 micro_batch,就是 dynamic batching,每个 micro_batch 的样本数不确定。
|
|
在每个 micro_batch 计算中,需要在所有 dp worker 之间同步。因此,除了需要在各个 DP 之间保持 batch 的均衡,也还需要在 micro batches 之间保持 tokens 均衡。
这就对应着开启 use_dynamic_batch 之后,通过 prepare_dynamic_batch 实现均衡。其中的 rearrange_micro_batches 也调用了 get_seqlen_balanced_partitions 。
|
|
给个简化的例子,假设我们有一个 batch,里面包含 6 个 sequence。为了方便,我们只关心它们的真实长度。
- 输入
batch: 包含 6 个 sequence。 - 序列的真实长度 (
seq_len_effective):[100, 900, 50, 950, 400, 600]- 可以看到,序列长度差异很大,最短的只有 50,最长的有 950。
- 函数参数
max_token_len: 2000- 这意味着我们希望每个拆分出的微批次,其包含的 token 总数不要超过 2000。
|
|
- 计算
num_micro_batches- 总 token 数:
100 + 900 + 50 + 950 + 400 + 600 = 3000 - 所需 num_micro_batches:
ceil(3000 / 2000) = 2 - 函数计算出,为了满足
max_token_len的限制,最少需要将这 6 个序列拆分成 2 个微批次。
- 总 token 数:
- 第一次分区 (
get_seqlen_balanced_partitions)- 目标: 将 6 个序列分成 2 组,使得每组的 token 数量之和 尽可能接近。理想情况下,每组的 token 总和应该是
3000 / 2 = 1500。 - 简单的切分(错误方式): 如果我们按顺序切分,会得到:
- 组1:
[100, 900, 50](来自原始索引 0, 1, 2) -> 总长 1050 - 组2:
[950, 400, 600](来自原始索引 3, 4, 5) -> 总长 1950 - 这显然非常不均衡。
- 组1:
- 智能分区的做法: 函数会通过算法找到一个更优的组合。对于这个例子,一个完美的分区是:
- 分区 A: 包含长度为
[950, 400, 100, 50]的序列。 - 分区 B: 包含长度为
[900, 600]的序列。
- 分区 A: 包含长度为
- 我们检查一下 token 总数:
- 分区 A 总长:
950 + 400 + 100 + 50 = 1500 - 分区 B 总长:
900 + 600 = 1500
- 分区 A 总长:
- 完美!两组的 token 总数完全相等。此时,函数得到的原始索引分区
micro_bsz_idx是[[3, 4, 0, 2], [1, 5]](顺序可能不同)。
- 目标: 将 6 个序列分成 2 组,使得每组的 token 数量之和 尽可能接近。理想情况下,每组的 token 总和应该是
- 第二次均衡 (
use_dynamic_bsz_balance)2- 现在,函数考虑 Attention 的 $O(N^2)$ 复杂度,即计算负载与长度的平方和成正比。
- 它会为上一步得到的两个分区计算“计算负载”:
- 分区 A 的负载:
950² + 400² + 100² + 50² = 902500 + 160000 + 10000 + 2500 = 1,075,000 - 分区 B 的负载:
900² + 600² = 810000 + 360000 = 1,170,000
- 分区 A 的负载:
- 我们发现,虽然两个分区的 token 总数相同,但分区 B 的计算负载明显更高!
- 函数会根据这个计算负载对分区进行降序排序。所以,处理顺序被调整为:先处理分区 B,再处理分区 A。
- 最终排序后的索引列表
micro_bsz_idx变为:[[1, 5], [3, 4, 0, 2]]
- 生成最终的微批次
- 第一个微批次: 函数会从原始
batch中取出索引为1和5的序列,将它们打包成一个新的、更小的TensorDict。这个微批次包含2个序列,总长度1500,计算负载1,170,000。 - 第二个微批次: 函数会从原始
batch中取出索引为3,4,0,2的序列,打包成第二个TensorDict。这个微批次包含4个序列,总长度1500,计算负载1,075,000。
- 第一个微批次: 函数会从原始
PPO update_actor 中 micro batch balancing
在 PPO update_actor 训练的时候,提前全局划分好 micro batch,而不是先在 DP 间划分,之后再在每个 dp rank 内划分。
OminiPatch
-
No backlinks found.