自从 10 月份以来 Thinking Machine Labs 的 On Poilcy Distillation 博客1 之后,OPD 引发了越来越多的关注。最近 MiMo-V2-Flash 在其技术报告2 中也提到他们使用了 Multi-Teacher On-Policy Distill 作为全新的 Post-Training 范式。

本文尝试回答几个问题:

  • 什么是 On Policy Distillation?
  • OPD 相对于 SFT 和 RL 的区别和联系是什么?有什么优势?
  • OPD 常见应用场景是什么,在 LLM 训练的哪个阶段?

Notation

符号表示 数学含义
$\mu_\theta$ he student sampling policy adopted in the inference engine
$\pi_\theta$ target student policy optimized in the training engine
$\pi_{\text{domain}_x}$ the teacher policy specialized for the domain of prompt $x$ sampled from distribution $\mathcal{D}$
$\text{sg}[\cdot]$ stop-gradient operator

On Policy Distillation 是什么

一句话总结:

On-Policy Distillation 在策略蒸馏是一类将策略蒸馏 Policy Distillationon-policy 强化学习相结合的方法,主要用于多策略融合、模型压缩、稳定训练或加速学习

Knowledge Distillation

Knowledge Distillation 是一种模型压缩知识迁移技术。

核心思想:

让一个较小的模型(称为 student)去模仿一个较大的、性能更强的模型(称为 teacher),从而在保持较高性能的同时显著减少模型复杂度。

传统训练中,student 直接学习真实标签 y,这叫做 hard target 但在蒸馏中,student 还要学习 teacher 的 soft target

  • teacher 不仅告诉你哪个类别是正确的 one-hot label
  • 还告诉你每个类别的置信度分布 soft probability
    例如,真实标签是 cat
  • teacher 输出:[cat: 0.9, dog: 0.08, fox: 0.02]
  • student 不仅学到 猫是正确的,还学到 狗和猫有点相似

Forward KL vs Reverse KL

OPD in RL

$$ \mathcal{D}{\text{reverse-KL}}(\theta) = \mathbb{E}{x \sim \mathcal{D}, y_{<t} \sim \pi_\theta(\cdot|x,y_{<t})} \log \frac{\pi_\theta(y_t|x, y_{<t})} {\pi_{\text{domain}x}(y_t|x, y{<t})}$$

$$ A = - \mathcal{D}{\text{reverse-KL}}(\theta) = - \mathbb{E}{x \sim \mathcal{D}, y_{<t} \sim \pi_\theta(\cdot|x,y_{<t})} \log \frac{\pi_\theta(y_t|x, y_{<t})} {\pi_{\text{domain}x}(y_t|x, y{<t})}$$

Technical Formulation of MOPD

在通过 SFT 奠定基础,并通过领域特定的 RL 训练出专业化教师模型之后,现在将形式化multi-teacher on-policy distillation——该机制可将这些专业化能力整合到统一的学生模型中。

学生与教师之间的reverse KL divergence loss定义为:$$ \mathcal{L}{\text{reverse-KL}}(\theta) = -\mathbb{E}{x \sim \mathcal{D}, y_{<t} \sim \pi_\theta(\cdot|x,y_{<t})} \log \frac{\pi_{\text{domain}x}(y_t|x, y{<t})}{\pi_\theta(y_t|x, y_{<t})} $$ 梯度为: $$ \nabla_\theta \mathcal{L}{\text{reverse-KL}}(\theta) = -\mathbb{E}{x \sim \mathcal{D}, y_{<t} \sim \pi_\theta(\cdot|x,y_{<t})} \left[ \log \frac{\pi_{\text{domain}x}(y_t|x, y{<t})}{\pi_\theta(y_t|x, y_{<t})} \nabla_\theta \log \pi_\theta(y_t|x, y_{<t}) \right] $$ 参照 IcePop 中的方法3,采用训练-推理重要性采样,并丢弃差异较大的 token。随后定义 MOPD 的 the surrogate loss 为: $$ \mathcal{L}{\text{MOPD}}(\theta) = -\mathbb{E}{x \sim \mathcal{D}, y \sim \mu_\theta(\cdot|x)} \left[ \frac{1}{|y|} \sum_{t=1}^{|y|} w_t \hat{A}{\text{MOPD},t} \log \pi\theta(y_t|x, y_{<t}) \right]$$ 其中 $$ w_t(\theta) = \begin{cases} \text{sg}\left[ \frac{\pi_\theta(y_t|x,y_{<t})}{\mu_\theta(y_t|x,y_{<t})} \right], & \epsilon_{\text{low}} \leq \frac{\pi_\theta(y_t|x,y_{<t})}{\mu_\theta(y_t|x,y_{<t})} \leq \epsilon_{\text{high}}, \ 0, & \text{other}, \end{cases} $$ $$\quad \hat{A}{\text{MOPD}, t} = \text{sg}\left[ \log \frac{\pi{\text{domain}x}(y_t|x, y{<t})}{\pi_\theta(y_t|x, y_{<t})} \right] $$

默认情况下,我们将 MOPD 的优势函数与其他类型的优势函数结合(例如基于结果奖励模型(ORM)计算的优势,包括 GRPO。设 $\hat{A}{\text{ORM}}$ 为 ORM 计算的优势函数,则最终优势函数为: $$ \hat{A}{\text{MOPD}, t} = \text{sg}\left[ \log \frac{\pi_{\text{domain}x}(y_t|x, y{<t})}{\pi_\theta(y_t|x, y_{<t})} \right] + \alpha \hat{A}_{\text{ORM}} $$

图6展示了 MOPD 相对于传统后训练方法的有效性:在数学推理(AIME 2025)和代码(LiveCodeBench)基准测试中,MOPD 成功保留并融合了多教师的专业化能力,在多数领域中达到或超过了最强教师的性能。

代码实现

1
2
3
4
5
6
7
teacher_logits = teacher(seq)[:-1].detach() # [seq_len, vocab_size] 左移一位
student_logits = student(seq)[:-1]  # [seq_len, vocab_size]
teacher_logprob = logsoftmax(teacher_logits)
student_logprob = logsoftmax(student_logits)

kl_loss = teacher_logprob.exp() * (teacher_logprob - student_logprob) # [seq_len, vocab_size]
kl_loss_token = kl_loss.sum(-1) # [seq_len,]

TML

TML 根本就没有想词表的事情,kl compute 全程只使用被选中的 token

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
# Initialize teacher client (main):
teacher_client = service_client.create_sampling_client(
    base_model=teacher_config.base_model,
    model_path=teacher_config.load_checkpoint_path,
)

# Sample trajectories (main):
trajectories = do_group_rollout(student_client, env_group_builder)
sampled_logprobs = trajectories.loss_fn_inputs["logprobs"]

# Compute reward (compute_teacher_reverse_kl):
teacher_logprobs = teacher_client.compute_logprobs(trajectories)
reverse_kl = sampled_logprobs - teacher_logprobs
trajectories["advantages"] = -reverse_kl

# Train with RL (train_step):
training_client.forward_backward(trajectories, loss_fn="importance_sampling")

其实这里相当于对完整的 kl 做了一次蒙特卡洛模拟。因为直接使用 student 做 rollout 时,选中的每个 token 的概率正比于 student_prob,那么在这个 token 上估计一次 kl,它的期望和全词表估计是一样的。也正是由于选中概率和 student 一样,所以期望和正版 reverse_kl loss 一致,所以即使根本没有加权这个过程,作者仍然把这个 loss 叫做 reverse_kl loss。

应用场景

参考资料


  1. https://thinkingmachines.ai/blog/on-policy-distillation/ ↩︎

  2. MiMo-V2-Flash Technical Report, https://github.com/XiaomiMiMo/MiMo-V2-Flash/blob/main/paper.pdf ↩︎

  3. Small leak can sink a great ship–boost rl training on moe with icepop!, Sep 2025. URL https://ringtech.notion.site/icepop ↩︎