Training Inference Mismatch
训推不一致问题,在 RL 训练上实际上是不可忽视的,主要由于训练侧(FSDP、Megatron 等)和推理侧(vllm、sglang、trt 等)这类 kernel 差异、计算实现路径差异、硬件侧针对两边不同优化这类问题导致的系统性偏差。而且这类偏差,在数学上可能会导致:
Bias: 训练优化器会主动走向一个错误的结果。
Variance:优化器会完全停滞,让训练中止。
在后面算法的章节,在理论上也会对不一致问题对 RL 影响进行数学上的分析。
有研究1 指出 train / inference engine 之间的不一致会隐性导致 on-policy 假设的 RL 实际变成 off-policy。所以当我们追求”真正的” on-policy RL 训练时,需要知道:
如果不能从两个完全一致的推理请求中获取 bitwise 相等的结果,那么当然也无法保障训推之间的 bitwise 一致性。所以基于之前我们对确定性推理实现讨论,直觉上可以知道如果保证了确定性推理,那么通过修改训练这部分 stack,也能够实现在 bitwise 上训推的一致性,从而实现真正的 on-policy RL 训练。
而业界对这个问题的解决思路上主要分为两种:
- 在训练引擎侧,基于推理引擎(vllm/sglang)确定性推理内核前向实现,进行反向传递的实现,通过对齐 kernel 的实现,做到训练和采样部分的 bitwise 一致性(i.e. 0 KL divergence)。
- 拥抱训推分布的不一致(考虑到训练 bitwise 实现在工程上的工作量,和不同模型适配的工作量),在算法上为 off-policy 做 off-policy correction,进行训推 KL 散度的偏差抑制,在大多数场景也能实现 RL 训练的平滑和目标效果。
后续会分别着重分析这两种解决思路。
数学基础
三策略 TRPO 视角下的最小统一理解
上面列的这些工作,看上去各自解决的是:
- 算法层:PPO / GRPO 的目标怎么写,token-level 还是 sequence-level,用 clip 还是 mask;
- 系统层:推理框架和训练框架怎样对齐;
- 模型层:MoE 模型路由问题如何放大训练不稳定,等等。
但如果我们把“行为策略 vs 参考策略”这条线拉直,会发现相当一部分问题,其实都可以放到一个相对简单的理论框架里理解:三策略 TRPO。
下面这节我会用尽量简单的数学,把这个三策略版 TRPO 摊开——它可以被看作是“TRPO + 三角不等式”的一个小扩展,但在分析大模型 RL 里的训推不一致时非常好用:
- 一方面让我们重新理解“训推不一致”和“异步训练框架”到底在影响什么;
- 另一方面,也帮我们统一理解 TIS、IcePop、sequence-level MIS 等,在本文的视角下,它们其实都是在实施下文的“约束 2”。
三个策略
沿用前文的记号,我们在一个折扣 MDP 上工作,折扣因子为 $\gamma \in (0,1)$:
- 状态 $s \in \mathcal{S}$,动作 $a \in \mathcal{A}$;
- 策略 $\pi(a|s)$;
- 折扣状态分布:$\boldsymbol{d}\pi(s) := (1-\gamma)\sum{t=0}^\infty \gamma^t \Pr_\pi(s_t = s)$。
- 回报(episode 视角):$\mathcal{J}(\pi) := \mathbb{E}\pi\left[\sum{t=0}^\infty \gamma^t r_t\right]$。
- 值函数 / 优势函数:$V_\pi(s)$,$Q_\pi(s,a)$,$\mathcal{A}\pi(s,a) := Q\pi(s,a) - V_\pi(s)$。
稍微赘述一下,在“三策略”设定里,我们有:
- 行为策略(behavior policy):$\boldsymbol{\mu}$,真正用来 rollout 的策略;数据 $(s,a,r,\dots)$ 都是从它来的。
- 参考策略(reference policy):$\pi_{\boldsymbol{\theta}_{\text{old}}}$,优化目标里拿来做 ratio、clip 或 KL 约束的那一份“旧策略”。
- 目标策略(target policy):$\pi_{\boldsymbol{\theta}}$,我们这一步想要优化的策略。
在理想设定里我们默认 $\boldsymbol{\mu} = \pi_{\boldsymbol{\theta}_{\text{old}}}$;现实系统里二者往往不等,这就是“训推不一致”的数学影子。
两策略 TRPO
TRPO 的所有理论保证,都是建立在某个“基准策略”的优势函数之上的。既然实际能算清楚的只有 $\mathcal{A}{\boldsymbol{\mu}}$(数据是按 $\boldsymbol{\mu}$ 采的),那我们就直接把 $\boldsymbol{\mu}$ 当成基准。一个经典的结论是 性能差分引理(Performance Difference Lemma): 对任意两策略 $\boldsymbol{\mu}$ 和 $\pi{\boldsymbol{\theta}}$,有 $$ \mathcal{J}(\pi_{\boldsymbol{\theta}}) - \mathcal{J}(\boldsymbol{\mu}) = \frac{1}{1-\gamma} \mathbb{E}{s \sim d{\pi_{\boldsymbol{\theta}}}, a \sim \pi_{\boldsymbol{\theta}}} \left[ \mathcal{A}_{\boldsymbol{\mu}}(s,a) \right]。 $$ 直觉非常简单:
- $\mathcal{A}_{\boldsymbol{\mu}}(s,a)$ 就是在说“如果在 $s$ 里本来按 $\boldsymbol{\mu}$ 行动,现在换成动作 $a$,长期回报会多或少多少”;
- 把所有时刻、所有状态、所有动作的“增益”累积起来,就得到新策略比行为策略总共赚了多少。
TRPO 的问题在于,我们没法准确算 $$ \mathbb{E}{s \sim d{\pi_{\boldsymbol{\theta}}}, a \sim \pi_{\boldsymbol{\theta}}} \left[ \mathcal{A}{\boldsymbol{\mu}}(s,a) \right], $$ 因为 $d{\pi_{\boldsymbol{\theta}}}$ 是“新策略”的状态分布,我们没有在它下面采样过。
于是 TRPO 引入了一个替代目标:把状态分布换成行为策略的: $$ \mathcal{L}{\boldsymbol{\mu}}(\pi{\boldsymbol{\theta}}) := \mathcal{J}(\boldsymbol{\mu}) + \frac{1}{1-\gamma} \mathbb{E}{s \sim d{\boldsymbol{\mu}}, a \sim \pi_{\boldsymbol{\theta}}} \left[ \mathcal{A}{\boldsymbol{\mu}}(s,a) \right]。 $$ $\mathcal{L}{\boldsymbol{\mu}}$ 的直觉解释是:在行为策略的状态分布下,让新策略试着去选动作,看优势有多大。 从性能差分引理出发,两者之差是: $$ \mathcal{J}(\pi_{\boldsymbol{\theta}}) - \mathcal{L}{\boldsymbol{\mu}}(\pi{\boldsymbol{\theta}}) = \frac{1}{1-\gamma} \sum_{s} \left( d_{\pi_{\boldsymbol{\theta}}}(s) - d_{\boldsymbol{\mu}}(s) \right) \mathbb{E}{a \sim \pi{\boldsymbol{\theta}}(\cdot|s)} \left[ \mathcal{A}{\boldsymbol{\mu}}(s,a) \right]。 $$ 如果我们定义 $$ \epsilon{\boldsymbol{\mu}} := \max_{s,a} |\mathcal{A}_{\boldsymbol{\mu}}(s,a)|, $$ 那么有一个直接的上界:
Lemma 1 $$ \left| \mathcal{J}(\pi_{\boldsymbol{\theta}}) - \mathcal{L}{\boldsymbol{\mu}}(\pi{\boldsymbol{\theta}}) \right| \leq \frac{\epsilon_{\boldsymbol{\mu}}}{1-\gamma} \left| d_{\pi_{\boldsymbol{\theta}}} - d_{\boldsymbol{\mu}} \right|_1。 $$
这里出现了第一个关键量:
状态分布偏移 $|d_{\pi_{\boldsymbol{\theta}}} - d_{\boldsymbol{\mu}}|_1$,也就是“新策略和行为策略看到的世界,到底差了多少”。
我们通常不会直接对 $|d_{\pi_{\boldsymbol{\theta}}} - d_{\boldsymbol{\mu}}|1$ 施加约束,反而是对“每一步 action 分布”的差异施加约束,比如 trust region、KL、clip 等。记总变差距离(total variation): $$ D{\text{TV}}(p,q) := \frac{1}{2} | p - q |_1。 $$ 假设存在常数 $\beta$,使得
对所有 $s$,行为策略和目标策略之间的 TV 被 $\beta$ 上界: $$ D_{\text{TV}}(\boldsymbol{\mu}(\cdot|s), \pi_{\boldsymbol{\theta}}(\cdot|s)) \leq \beta。 $$
直观含义:在任意状态里,“新策略”和“生成数据的策略”选动作的分布都不会离太远。一个经典结果(可以用 coupling 证明)是:
Lemma 2:在上述条件下有 $$ | d_{\pi_{\boldsymbol{\theta}}} - d_{\boldsymbol{\mu}} |_1 \leq \frac{2\gamma}{1-\gamma} \beta。 $$
把它和 Lemma 1 结合: $$ \left| \mathcal{J}(\pi_{\boldsymbol{\theta}}) - \mathcal{L}{\boldsymbol{\mu}}(\pi{\boldsymbol{\theta}}) \right| \leq \frac{\epsilon_{\boldsymbol{\mu}}}{1-\gamma} \cdot \frac{2\gamma}{1-\gamma} \beta = \frac{2\epsilon_{\boldsymbol{\mu}} \gamma}{(1-\gamma)^2} \beta。 $$ 于是我们得到一个形式上相当简洁的两策略 TRPO 下界(基准为行为策略):
Theorem 1(两策略 TRPO) $$ \mathcal{J}(\pi_{\boldsymbol{\theta}}) \geq \mathcal{L}{\boldsymbol{\mu}}(\pi{\boldsymbol{\theta}}) - \frac{2\epsilon_{\boldsymbol{\mu}} \gamma}{(1-\gamma)^2} \beta。 $$
这说明:
- 真正决定“替代目标 $\mathcal{L}{\boldsymbol{\mu}}$ 靠不靠谱”的,是行为策略 $\boldsymbol{\mu}$ 和目标策略 $\pi{\boldsymbol{\theta}}$ 的差异:$\beta = \max_s D_{\text{TV}}(\boldsymbol{\mu}(\cdot|s), \pi_{\boldsymbol{\theta}}(\cdot|s))$。
如果你能直接约束住这个 $\beta$,就能直接把 TRPO 的单调性保证搬到行为策略视角下。
三策略 TRPO
现实问题在于:大模型强化学习训练里我们可能无法直接控制 $\beta$ 本身。
在大部分 PPO / GRPO / GSPO / 现有 RLHF 框架里,实际发生的是:
- rollout 数据是由某个行为策略 $\boldsymbol{\mu}$ 产生的(推理引擎里的“那一版参数”+ 若干系统细节);
- 更新时,我们希望利用参考策略 $\pi_{\boldsymbol{\theta}{\text{old}}}$ 来限制目标策略 $\pi{\boldsymbol{\theta}}$ 的更新幅度。
也就是说,实际可以“动手”的是两个量:
- 参考 vs 目标:我们可以通过 KL / clip 等手段控制 $D_{\text{TV}}(\pi_{\boldsymbol{\theta}{\text{old}}}(\cdot|s), \pi{\boldsymbol{\theta}}(\cdot|s))$;
- 行为 vs 参考:我们希望间接控制 $D_{\text{TV}}(\boldsymbol{\mu}(\cdot|s), \pi_{\boldsymbol{\theta}_{\text{old}}}(\cdot|s))$。
于是自然就定义两个“proxy 差异”:
- 约束 1:参考 vs 目标 $\alpha_0 := \max_s D_{\text{TV}}(\pi_{\boldsymbol{\theta}{\text{old}}}(\cdot|s), \pi{\boldsymbol{\theta}}(\cdot|s))$;
- 约束 2:行为 vs 参考 $\alpha_1 := \max_s D_{\text{TV}}(\boldsymbol{\mu}(\cdot|s), \pi_{\boldsymbol{\theta}_{\text{old}}}(\cdot|s))$。
直觉上:
- $\alpha_0$:新策略到底离“你宣称的那份旧策略”有多远——这就是 trust region 控制的那部分;
- $\alpha_1$:你用来训练的参考策略,到底跟真实采样时的行为策略差了多少——这就是训推不一致或异步的影子。
现在,可以把这两个量塞回 TRPO 的下界里。
对任意状态 $s$,有
$$
D_{\text{TV}}(\boldsymbol{\mu}(\cdot|s), \pi_{\boldsymbol{\theta}}(\cdot|s)) \leq D_{\text{TV}}(\boldsymbol{\mu}(\cdot|s), \pi_{\boldsymbol{\theta}{\text{old}}}(\cdot|s)) + D{\text{TV}}(\pi_{\boldsymbol{\theta}{\text{old}}}(\cdot|s), \pi{\boldsymbol{\theta}}(\cdot|s))。
$$
对 $s$ 取上确界:$$
\beta := \max_s D_{\text{TV}}(\boldsymbol{\mu}(\cdot|s), \pi_{\boldsymbol{\theta}}(\cdot|s)) \leq \alpha_1 + \alpha_0。
$$ 把这个不等式塞回两策略 TRPO 的结论(Theorem 1)里,记 $$
C := \frac{2\epsilon_{\boldsymbol{\mu}} \gamma}{(1-\gamma)^2},
$$
即得到: $$
\mathcal{J}(\pi_{\boldsymbol{\theta}}) \geq \mathcal{L}{\boldsymbol{\mu}}(\pi{\boldsymbol{\theta}}) - C \beta \geq \mathcal{L}{\boldsymbol{\mu}}(\pi{\boldsymbol{\theta}}) - C (\alpha_0 + \alpha_1)。
$$
于是,我们得到一个非常直接的三策略 TRPO 下界:
Theorem 2(三策略 TRPO)
记 $$
\epsilon_{\boldsymbol{\mu}} := \max_{s,a} |\mathcal{A}{\boldsymbol{\mu}}(s,a)|, \quad C := \frac{2\epsilon{\boldsymbol{\mu}} \gamma}{(1-\gamma)^2},
$$
以及 $$
\alpha_0 := \max_s D_{\text{TV}}(\pi_{\boldsymbol{\theta}{\text{old}}}(\cdot|s), \pi{\boldsymbol{\theta}}(\cdot|s))), \quad \alpha_1 := \max_s D_{\text{TV}}(\boldsymbol{\mu}(\cdot|s), \pi_{\boldsymbol{\theta}_{\text{old}}}(\cdot|s))。
$$ 则对任意目标策略 $\pi_{\boldsymbol{\theta}}$ 有 $$
\mathcal{J}(\pi_{\boldsymbol{\theta}}) \geq \mathcal{L}{\boldsymbol{\mu}}(\pi{\boldsymbol{\theta}}) - C (\alpha_0 + \alpha_1)
$$
其中
$$
\mathcal{L}{\boldsymbol{\mu}}(\pi{\boldsymbol{\theta}}) := \mathcal{J}(\boldsymbol{\mu}) + \frac{1}{1-\gamma} \mathbb{E}{s \sim d{\boldsymbol{\mu}}, a \sim \pi_{\boldsymbol{\theta}}} \left[ \mathcal{A}_{\boldsymbol{\mu}}(s,a) \right]。
$$ 这个结论的含义其实很直接:这个结论的含义其实很直接:
- 替代目标 $\mathcal{L}{\boldsymbol{\mu}}(\pi{\boldsymbol{\theta}})$ 与真实性能 $\mathcal{J}(\pi_{\boldsymbol{\theta}})$ 之间的 gap,可以拆成两部分:
- 参考 vs 目标的偏移 $\alpha_0$;
- 行为 vs 参考的偏移 $\alpha_1$。
只要这两个量都小,优化 $\mathcal{L}_{\boldsymbol{\mu}}$ 就有希望有效提升 $\mathcal{J}$。
这两个差异各自怎么约束?
现在,我们可以从 Theorem 2 回头看各种实际方法:
- 绝大多数 “PPO / GRPO / GSPO” 类工作,其实是在控制 约束 1:$\alpha_0$;
- 绝大多数 “TIS / IcePop / MIS” 类工作,在本文的统一视角下,可以理解为主要是在控制 约束 2:$\alpha_1$。
本文下面只讨论 约束 2。
约束 2 的目标是:保证用来训练的数据,尽可能来自“接近参考策略”的行为策略。
这里通常既有系统层的机制,也有**算法层(importance sampling)**的机制。
-
系统层:让行为策略别飘太远
- 异步框架:给每个样本打上策略版本号,只能用与 $\pi_{\theta_{old}}$ 相差不大的参数版本采样的数据;
- 训推对齐:强调训练框架和推理框架用相同精度、相同算子、相近的内核 / kernel 行为。 这些机制的目标是:从“算法外部”让 $\mu$ 和 $\pi_{\theta_{old}}$ 靠近,从而压缩 $\alpha_1$。
-
算法层:样本修正 在算法层,我们不再试图“纠正整个行为策略”,而是用重要性采样比率在样本层面做筛选和重加权,让“真正参与训练的样本子集”上的行为策略尽量接近参考策略,或者减小差异较大的样本在训练上的权重。
具体来说,就是下面这些方法,它们本质上都可以看作是“实现约束 2 的不同方式”。
重要性采样与掩码:四种约束2实现
下面延续前文的记号体系来写这三种方法的目标函数,只聚焦在“行为策略 vs 参考策略”这一维的设计。token级的 PPO / GRPO 风格更新项为
$$
g_\theta(t) = \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right),
$$
其中 $$ r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)}, \quad (s_t, a_t) \sim \boldsymbol{\mu}, \quad A_t := \mathcal{A}_{\boldsymbol{\mu}}(s_t, a_t)。 $$ 也就是说:
- $r_t(\theta)$ 是目标 vs 参考的比率(对应约束1);
- $A_t$ 基于行为策略采样的数据,是我们能估到的优势函数。
为了把token级的 $(s_t, a_t)$ 与序列级的 $(x, y)$ 记号打通,在以 RLHF(reinforcement learning from human feedback,人类反馈强化学习)为代表的 LLM-RL 设定中,我们约定:
- prompt 记为 $x$;回复记为 $y = (y_1, \dots, y_{|y|})$;
- token级状态 $s_t := (x, y_{<t})$,动作 $a_t := y_t$;
- 因此行为策略和参考策略在序列上的分布可写成 $\boldsymbol{\mu}(y | x) = \prod_{t=1}^{|y|} \boldsymbol{\mu}(a_t = y_t | s_t)$,$\pi_{\theta_{\text{old}}}(y | x) = \prod_{t=1}^{|y|} \pi_{\theta_{\text{old}}}(a_t = y_t | s_t)$。
此外,为了描述“参考 vs 行为”的偏移,统一定义 token级重要性比率 $$
\rho_t^{(\text{ref-beh})} := \frac{\pi_{\theta_{\text{old}}}(a_t | s_t)}{\boldsymbol{\mu}(a_t | s_t)},
$$
以及其对应的序列级版本 $$
\rho(y | x) := \frac{\pi_{\theta_{\text{old}}}(y | x)}{\boldsymbol{\mu}(y | x)} = \prod_{t=1}^{|y|} \rho_t^{(\text{ref-beh})}。
$$
接下来,TIS / IcePop / MIS 的区别,就体现在“如何利用这些 $\rho$ 来实现约束2”。
1. TIS:token-level 截断 IS
TIS 直接对上述 $\rho_t^{(\text{ref-beh})}$ 做截断,记
$$ \color{blue}{w_t = \min\big(\rho_t^{(\text{ref}\leftarrow\text{beh})},\ C_{\text{IS}}\big)}。 $$
更新目标写成 $$ L_{\text{TIS}}(\theta) = - \mathbb{E}{(s_t,a_t)\sim\mu}\big[,\color{blue}{w_t}; g\theta(t)\big]。 $$
- 蓝色的 $w_t$ 是被截断的 IS 权重:极端大的比率被压到常数 $C_{\text{IS}}$。
- 从三策略 TRPO 的角度看,这相当于在 token 分布上“软削弱”行为策略和参考策略严重不一致的样本,从而在梯度中有效减小那部分样本对 $\alpha_1$ 的贡献。
标准的重要性采样(IS),通过引入重要性权重 $\frac{\pi}{\boldsymbol{\mu}}$ 来校正分布差异,使梯度重新变得无偏:
$$
\nabla_\theta J_{\text{pg-is}}(x) = \mathbb{E}{y \sim \boldsymbol{\mu}(\cdot|x,\theta’)} \left[ \frac{\pi(y|x,\theta)}{\boldsymbol{\mu}(y|x,\theta’)} \nabla\theta \log \pi(y|x,\theta) \cdot A(x,y) \right] \tag{5}
$$
其中 $A(x,y)$ 是优势函数。缺点:对于长文本,重要性权重方差极大,训练极不稳定。
为了降低方差,研究者提出了两种牺牲少量偏差换取稳定性的变体:
- 截断重要性采样(Truncated IS, TIS):设置上限 $C$。
$$
\nabla_\theta J_{\text{pg-tis}}(x) = \mathbb{E}{y \sim \boldsymbol{\mu}} \left[ \min\left( \frac{\pi}{\boldsymbol{\mu}}, C \right) \cdot \nabla\theta \log \pi \cdot A \right] \tag{6}
$$
- 掩码重要性采样(Masked IS, MIS):如果权重过大直接丢弃。
$$
\nabla_\theta J_{\text{pg-mis}}(x) = \mathbb{E}{y \sim \boldsymbol{\mu}} \left[ \frac{\pi}{\boldsymbol{\mu}} \cdot \boldsymbol{1}\left{ \frac{\pi}{\boldsymbol{\mu}} \leq C \right} \cdot \nabla\theta \log \pi \cdot A \right] \tag{7}
$$
2. IcePop:MoE 场景下的 token-level 双侧 Mask
IcePop 同样以 $\rho_t^{(\text{ref-beh})}$ 为度量,但采用双侧掩码:
$$ \color{blue}{m_t = \mathbf{1}\big[C_{\text{low}} \le \rho_t^{(\text{ref}\leftarrow\text{beh})} \le C_{\text{high}}\big]}。 $$
更新目标写成
$$ L_{\text{IcePop}}(\theta) = - \mathbb{E}{(s_t,a_t)\sim\mu}\big[,\color{blue}{m_t}; g\theta(t)\big]。 $$
- 蓝色的 $m_t$ 决定某个 token 是否参与更新:比率太大或太小的 token 直接被丢弃。
- 这相当于硬性裁掉行为策略和参考策略极度不一致的 token,只在 $\rho_t$ 适中的区域上优化,从样本集合层面实施更强的“约束2”。
3. sequence-level MIS:按整条序列 Mask 的重要性采样
MIS 的核心操作是:只保留 IS 比率不超过阈值 $C$ 的序列,其余序列的损失直接置零。写成
$$ \color{blue}{ \rho(y\mid x) \leftarrow \rho(y\mid x),\mathbf{1}{\rho(y\mid x)\le C} } $$
在统一的损失形式下,可以写成 $$ L_{\text{MIS}}(\theta) =-,\mathbb{E}{(x,y)\sim\mu} \Big[ \color{blue}{\rho(y\mid x),\mathbf{1}{\rho(y\mid x)\le C}} ;\cdot; \sum{t=1}^{|y|}g_\theta(t) \Big], $$
简而言之:
- 对于 IS 比率较小的序列:保留完整的 $\rho(y|x)$ 权重,正常做 off-policy 修正;
- 对于 IS 比率超过阈值 $C$ 的序列:整个序列的 policy loss 被 mask 掉(权重变成 0)。
从三策略 TRPO 的角度看,MIS 不再在 token 上做截断,而是直接在序列级筛选“行为策略和参考策略严重不一致”的轨迹,只在 $\rho(y|x) \leq C$ 的子分布上优化,从而在 trajectory 粒度上实现对“约束2”($\boldsymbol{\mu}$ vs $\pi_{\theta_{\text{old}}}$ 偏移)的控制。
4. Worst Token Reject Sampling:按最差 token 拒绝整条序列
verI 中的 veto 机制与 INTELLECT-3 分别在各自的训练框架中采用了一种可统称为 Worst Token Reject Sampling(WTRS) 的拒绝采样策略:
- verI Token Veto:在其 rollout correction 模块中,若轨迹中存在任意 token 使得 $\min_t \rho_t < \tau_{\text{veto}}$,则通过
response_mask将整条序列剔除。阈值 $\tau_{\text{veto}}$ 可由用户配置。 - INTELLECT-3 Token Masking:在其异步分布式 RL 框架中,若任意 token 的比率低于 $10^{-5}$,则对整条轨迹进行 masking。
二者的核心操作一致:若轨迹中存在任意 token 的 IS 比率低于阈值 $\tau$,则将整条序列从训练中剔除。写成 $$ \color{blue}{ m(y\mid x) = \mathbf{1}\Big{\min_{t=1}^{|y|} \rho_t^{(\text{ref}\leftarrow\text{beh})} \ge \tau\Big} } $$
在统一的损失形式下,可以写成 $$ L_{\text{WTRS}}(\theta) =-,\mathbb{E}{(x,y)\sim\mu} \Big[ \color{blue}{m(y\mid x)} ;\cdot; \sum{t=1}^{|y|}g_\theta(t) \Big], $$
简而言之:
- 对于所有 token 的 IS 比率均不低于 $\tau$ 的序列:正常参与训练;
- 对于存在任意 token 的 IS 比率低于 $\tau$ 的序列:整条序列的 policy loss 被 mask 掉。
从三策略 TRPO 的角度看,WTRS 采用了“token 粒度检测、sequence 粒度否决”的混合策略:在 token-level 检测极端不一致的信号,一旦发现则在 sequence-level 执行拒绝。这种“一票否决”的设计体现了一种保守思路——当轨迹中存在“行为策略生成但参考策略几乎不可能生成”的 token 时,整条轨迹的可信度都将受到质疑,从而在 trajectory 粒度上实现对“约束2”($\boldsymbol{\mu}$ vs $\pi_{\theta_{\text{old}}}$ 偏移)的控制。
训练推理不一致问题分析
1 中 Rollout-Training Mismatch Analysis 章节从实验的角度来对 rollout-training 不一致问题进行了分析,主要得出的结论是,不同的并行策略以及更长的响应长度会增大二者之间的 mismatch,而选择不同的推理后端的影响比较小。
这些消融实验可以带来一些经验的归纳,也就是说明现象。但是笔者认为并不能让我们完全理解 mismatch 产生的原因。笔者认为,不一致性的主要来源,还是因为训练(FSDP、Megatron 等)和推理(vLLM、SGLang 等)是针对不同计算 pattern 进行了各自侧重的优化,不论是前向的 kernel 算子差异带来的数值精度误差累积,还是切分策略带来的通信算子规约顺序带来的精度误差累计,都是 mismatch 的原因一部分。
而像 MoE 模型的稀疏以及动态路由特性,会带来比 Dense 模型更大的 mismatch,因为路由机制本身就是数值精度敏感的,一些微小的数值差异,会带来差异巨大的专家激活。除此之外,MoE 模型本身的稀疏特性,和 Dense 模型相比一般规模会更大,而现代的推理引擎,通常针对 MoE 模型有独特的优化手段(计算、通信),也会放大训推引擎之间的不一致性。
而字节团队的这篇文章2,对训推不一致问题进行了更加深入的理论、实验分析,针对不一致的现象,也提出了更 genearal 的叙述:
To achieve the massive throughput required, modern inference engines (e.g., vLLM, SGLang, TensorRT-LLM) employ aggressive optimization strategies like speculative decoding, low-precision computation (INT8/FP8), and specialized, batch-variant CUDA kernels. While maintaining sampling fidelity, the primary objective of modern inference engines is to maximize throughput, often measured in tokens per second. Conversely, training frameworks (e.g., FSDP, DeepSpeed, Megatron-LM) must strike a different balance, prioritizing numerical stability and precision for gradient computation, often using higher-precision formats like FP32 for master weights and optimizer states. This divergence in optimization priorities and constraints creates an inevitable training-inference mismatch.
因此,我们可以回到一开始提到的业界解决 on-policy RL 训推不一致问题的两个思路,实际上是在性能和一致性上 trade-off 的取舍,如果希望对齐训推计算(例如之前讨论的 batch invariant),势必会带来性能上的劣化。
从这篇文档,能得到很多有用的 takeaways,比如实验中衡量不一致性的用的是下面的 vllm-kl metric:
$$
\small{\mathbb{E}{s\sim d{\textcolor{red}{\pi^\text{vllm}\theta}}}\left[\text{KL}\left(\textcolor{red}{\pi^\text{vllm}\theta}\left(\cdot|s\right),\textcolor{blue}{\pi^\text{fsdp}\theta}\left(\cdot|s\right)\right)\right] = \mathbb{E}{s\sim d_{\textcolor{red}{\pi^\text{vllm}\theta}},a\sim {\textcolor{red}{\pi^\text{vllm}\theta}\left(\cdot|s\right)}} \left[\log\left(\frac{\textcolor{red}{\pi^\text{vllm}\theta}(a|s)}{\textcolor{blue}{\pi^\text{fsdp}\theta}(a|s)}\right)\right],}
$$
而 vllm-kl 的 spike 同时会导致 fsdp-ppl 和 gradient norm 的爆炸性波动,这表示 FSDP engine 给推理引擎采样的得到的 tokens 设置特别小的概率,导致梯度爆炸,从而让 RL 训练崩溃。
以及 mismatch 不是均匀分布的,如果推理引擎得到的 token 概率越接近 0,那在训练侧这个 token 的概率会更严重地被压小,让 mismatch 更大。
所以综上所述,在当前的 RL 框架中,训推引擎之间的不一致,是一个不可避免的问题,如果不一致问题非常严重,容易导致训练崩溃这样的严重后果(特别在长稳训练下)。
算法解决
https://zhuanlan.zhihu.com/p/1973802307188717016
TIS
$$ \theta \gets \theta + \mu \cdot \mathbb{E}{\underbrace{a \sim{\pi}(\theta)}{rollout}} [R(a)\cdot \underbrace{\nabla_\theta \log {\pi}(a, \theta)}_{\tiny{training}}]. $$
作者在算法层面提出使用截断重要性采样(truncated importance sampleing, TIS)的方法,矫正采样模型和训练模型之间的偏差 $$ \mathbb{E}{a \sim \pi{\text{sampler}}(\theta)} \left[ \underbrace{\min\left( \frac{\color{blue}\pi_{\text{learner}}(a, \theta)}{\color{red}\pi_{\text{sampler}}(a, \theta)}, C \right)}{\text{truncated importance ratio}} \cdot R(a) \cdot \nabla\theta \log \color{blue}\pi_{\text{learner}}(a, \theta) \right], $$
如果是 PPO 算法,则可以实现为
$$ \mathbb{E}{a \sim \pi{\text{sampler}}(\theta_{\text{old}})} \left[ \underbrace{\min\left( \frac{\pi_{\text{learner}}(a, \theta_{\text{old}})}{\pi_{\text{sampler}}(a, \theta_{\text{old}})}, C \right)}{\text{truncated importance ratio}} \cdot \nabla\theta \min\left( \frac{\pi_{\text{learner}}(a, \theta)}{\pi_{\text{learner}}(a, \theta_{\text{old}})} \hat{A}, \text{clip}\left( \frac{\pi_{\text{learner}}(a, \theta)}{\pi_{\text{learner}}(a, \theta_{\text{old}})}, 1-\epsilon, 1+\epsilon \right) \hat{A} \right) \right] $$ 通过对 Qwen 等模型实验发现,TIS 能有效解决不匹配问题,提升性能。此外,还对比了 TIS 与 PPO - IS、vanilla - IS 的效果,TIS 表现更优。同时研究了 rollout 量化对训练稳定性影响,发现量化 rollout 会导致训练不稳定,TIS 可使其稳定。也探讨了引入 TIS 虽可能使奖励记录变差,但能提升下游性能。最后通过实验确定并行策略差异和长响应长度会加剧 rollout 生成与梯度计算的差异,而采样器后端选择影响较小。
IcePop
作者发现训推不一致的情况在MoE模型中更为显著,这是因为expert是使用top-k策略选择,微小的专家选择差异也会带来很大的prob差异
$$ \small{\begin{align*}\mathcal{J}{{\text{IcePop}}}(\theta) &= \mathbb{E}{x \sim \mathcal{D}, {y_i}{i=1}^G \sim \pi{\textcolor{red}{\text{infer}}}(\cdot \mid x; \theta_{\rm old})} \left[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|y_i|} \sum_{t=1}^{|y_i|} \Big[\mathcal{M}\Bigl(\frac{\pi_{\textcolor{blue}{\text{train}}}(y_{i,t} \mid x, y_{i,<t};\theta_{\text{old}})}{\pi_{\textcolor{red}{\text{infer}}}(y_{i,t} \mid x, y_{i,<t}; \theta_{\mathrm{old}})}; \alpha, \beta\Bigr) \right. \ &\left. \qquad \qquad \qquad \qquad \quad \qquad \cdot \min \left( r_{i,t}\widehat{A}{i,t}, \text{clip} \left( r{i,t}, 1 - \varepsilon, 1 + \varepsilon \right) \widehat{A}{i,t} \right) \right]\Bigg] &\end{align*}} $$ $\text{where } r{i,t} = \frac{\pi_{\text{train}}(y_t | x_i, y_{<t}; \theta)}{\pi_{\text{train}}(y_t | x_i, y_{<t}; \theta_{\text{old}})}, \text{ with the masking function below,}$
$$ \begin{equation} \mathcal{M}(k) =\begin{cases} k & \text{if \ } k \in [\alpha, \beta] \ 0 & \text{otherwise}\end{cases} \end{equation} $$
The gradient of IcePop: $$ \small{\nabla_\theta \mathcal{J}{\text{IcePop}}(\theta) \sim \small{\begin{equation}\mathbb{E}{a \sim \textcolor{red}{\pi_{\text{infer}}}(\theta_{\text{old}})} \Bigg[\mathcal{M}\Bigg(\frac{\textcolor{blue}{\pi_{\text{train}}}(a;\theta_{\text{old}})}{\textcolor{red}{\pi_{\text{infer}}}(a;\theta_{\text{old}})}\Bigg ) \cdot \nabla_\theta \log \textcolor{blue}{\pi_{\text{train}}}(a;\theta) \cdot \hat{A} \cdot r(a)\Bigg)\Bigg].\end{equation}}} $$
Routing Replay
BF16 切换为 FP16
来自 Sea Lab 的论文声称3,我们不需要上面这些复杂的算法修正,只需要把精度从 BF16 切回 FP16,不一致带来的问题就会解决。
研究认为,虽然 FP16 和 BF16 都使用总计 16 位,单比特分布不一样,尾数项和指数项的差异,导致 FP16 的精度为 BF16 的 8 倍。但带来的 trade-off 就是 FP16 能表征的数值范围就小的多,容易会 overflow,可能需要额外的技巧比如 loss scaling(反向之前把计算得到的 loss 乘大的缩放因子,反向时用放大了的 loss 计算梯度,落入 fp16 可表征的范围,利用梯度更新权重之前把梯度再还原回去)等稳定性技术进行缓解。
文章作者认为,在 RL 后训练的阶段,模型权重基本已经经过了 pre-training,数值分布相对稳定,配合 loss scalling 等技巧 fp16 可以保证训练稳定。实际上这个是一个非常重要的场景预设,也是这个研究的出发点。因为 BF16 格式在 LLM 训练中,被广泛使用的原因,就是其既能够表征 FP32 的数值范围(指数位都是 8),从而在预训练的时候,甚至不需要 loss scaling 这种增加复杂度的工程实现,也可能做到训练的稳定性,而且显存占用和全精度比也省一半。这种流行,是建立在 pre-training 对精度问题的“宽容度”比较高这个前提上的:
- 在 pre-training 阶段,模型在海量的数据上学习通用的统计规律,pre-training 过程有很强的鲁棒性,对数值噪声不敏感。 和 pre-training 阶段只关心“预测的 next token 是否准确”相对的,RL post-training 更关注的是策略的更新幅度,例如 PPO 就依赖新旧策略的的 ratio,通过限制 ratio 不要偏离 1 太远来保证训练稳定,这对训练时的精度需求就会更为严苛,如果由于精度表征问题,导致本来有差异的新旧策略,被截断成了 ratio=1,那么策略就不更新了。而且针对 KL 散度约束这种设计到 log 计算的,精度差异带来的影响会在对数域被放大,那 KL 算不准,reward model 给出奖励信号可能导致模型迅速过拟合到某个错误的 pattern 上,或者训练非常不稳定。
因此这里的 takeaway 简单来说,就是在 RL 场景下,梯度的数值范围通常没有 pre-training 方差那么大,对梯度的精度要求更高,所谓为了获得比 bf16 高 8 倍的精度表征,用 fp16 + loss scaling 增加的工程成本是值得的。
确定性推理
总结
许多“大模型 RL 训推不一致”和“异步训练”问题,在本文的视角下,其实都可以理解为:在 TRPO 框架下,当行为策略 $\boldsymbol{\mu}$ 和参考策略 $\pi_{\boldsymbol{\theta}_{\text{old}}}$ 不一致时,二者之间的偏移($\alpha_1$)被严重低估了。
从两策略到三策略,我们做的事情其实很简单:
- 把 TRPO 的下界从“旧策略 vs 新策略”的叙述,改写成“行为策略 - 参考策略 - 目标策略”三者的关系;
- 显式地拆出了两个 TV 距离:
- 约束 1:参考 vs 目标 $\alpha_0$,对应 PPO / GRPO / GSPO 等工作里最常见的 KL / clip / trust region;
- 约束 2:行为 vs 参考 $\alpha_1$,对应异步框架、训推差异、MoE 路由、kernel 非确定性等现实因素;
- 得到了一个非常直接的结论:替代目标 $\mathcal{L}{\boldsymbol{\mu}}(\pi{\boldsymbol{\theta}})$ 和真实性能 $\mathcal{J}(\pi_{\boldsymbol{\theta}})$ 的 gap 正比于 $\alpha_0 + \alpha_1$。
在这个视角下(当然这只是众多可能视角之一):
- Decoupled PPO / Areal 可以被看作是在形式上承认“三策略存在”,并尝试在目标函数上将“行为分布”和“参考策略”解耦;
- TIS、IcePop、MIS、WTRS 则是通过 IS 或者掩码机制在样本层面实施“约束 2”:
- TIS:用 token-level 截断权重削弱比率过大样本的影响;
- IcePop:在 MoE 场景下用 token-level 双侧掩码硬性丢弃“极端不一致”的 token;
- MIS:在 sequence-level 直接屏蔽整条“比率过大”的轨迹;
- WTRS:在 token-level 检测比率过小的信号,一旦发现则在 sequence-level 拒绝整条轨迹;
- routing replay(路由回放)在三策略 TRPO 的视角下更像是“改写 surrogate objective”而非“直接实现约束”:无论回放行为路由(R3 类)还是回放参考路由,它们都把原本的 $\mathcal{L}{\boldsymbol{\mu}}(\pi{\boldsymbol{\theta}})$ 改成了一个路由被条件化/替换后的 surrogate objective,用一定的目标偏差与路由学习自由度的收缩换取降低方差与提升稳定性。因此它并不会真正收缩 $\alpha_0$ 或 $\alpha_1$,而是让路由不一致在 loss 中“不可见”;
- 《RL 老训崩?训推差异是基石》、以及前文提到的 Defeating Nondeterminism in LLM Inference 等工程经验,则可以理解为在系统侧和数值实现侧,尽可能把 $\alpha_1$ 压低,让算法层的假设不至于完全失效。
从这个统一视角出发,也许有助于回答几个实际问题(这里只是抛几个开放性问题):
- 在什么条件下,我们还能把“大模型 RL 训练”理解成某种意义上的“近似 TRPO / PPO”?
- 对一个具体的 RL 系统,我们究竟应该把主要精力花在:
- 收紧 $\alpha_0$(更强的 KL / 更稳的 sequence-level 目标),还是
- 压低 $\alpha_1$(更一致的训推框架、更激进的 MIS / TIS / IcePop)?
- 在 MoE、异步采样、复杂 agent workflow 这些现实设定下,我们还能安全地假装“$\boldsymbol{\mu} \approx \pi_{\boldsymbol{\theta}_{\text{old}}}$”多久?
本文只是在 TRPO 这个老框架上做了一个非常“最小化”的延展,把“三策略”显式写出来,并用它来整理现有的一些工作。难免有理解偏差或遗漏之处,如果你也关注实际大模型 RL 训练的情况,欢迎把自己的设定抽象成“$\boldsymbol{\mu}, \pi_{\boldsymbol{\theta}{\text{old}}}, \pi{\boldsymbol{\theta}}$ 三者的关系”,再回头看看 Theorem 2 里的那条不等式,或许会有不一样的直观感受。
参考资料
- 思路: https://zhuanlan.zhihu.com/p/1980312064305169845
- 近期关于解决 LLM RL 中训推不一致的工作总结 https://zhuanlan.zhihu.com/p/1973802307188717016
- RL 训练总是崩?别怪算法了,罪魁祸首可能是 BF16! https://zhuanlan.zhihu.com/p/1976720631467877312
- https://zhuanlan.zhihu.com/p/1973206684907365344
- https://verl.readthedocs.io/en/latest/algo/rollout_corr_math.html
-
Your Efficient RL Framework Secretly Brings You Off-Policy RL Training, https://fengyao.notion.site/off-policy-rl ↩︎ ↩︎
-
https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda ↩︎
-
Defeating the Training-Inference Mismatch via FP16, https://arxiv.org/pdf/2510.26788 ↩︎
-
No backlinks found.