整个过程可以用伪代码表示如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    for prompt in dataloader:
        # process 1: rollout
        # given prompt from dataloader -> get response
        batch = actor_rollout_ref_wg.generate_sequences(prompt)

        # process 2: reward score
        # given prompt+response -> evaluation reward score
        rewards = reward_wg.compute_rm_score(batch)

        # process 3: old policy sampling (no_grad)
        # given prompt+response -> get old_log_prob/values/ref_log_prob
        old_log_prob = actor_rollout_ref_wg.compute_log_prob(batch)
        ref_log_prob = actor_rollout_ref_wg.compute_ref_log_prob(batch)
        values = critic_wg.compute_values(batch)

        # process 4: kl penalty
        # compute rewards. apply_kl_penalty if available
        batch, kl_metrics = apply_kl_penalty(batch)
       
        # process 5: compute advantages
        # compute_advantages is running directly on the control process
        advantages = compute_advantages(values, rewards)
        batch = batch.union(old_log_prob)
        batch = batch.union(ref_log_prob)
        batch = batch.union(values)
        batch = batch.union(rewards)
        batch = batch.union(advantages)

        # process6-10: update actor/critic
        actor_rollout_ref_wg.update_actor(output)
        critic.update_critic(output)

        weight_sync()

Old logprobs 与 Entropy

在 generation 阶段给定 prompt 生成 response 之后,接下来对于这个 batch 的数据,需要经过 Old Policy Sampling 计算得到 old_log_prob,对应于上图中的 Process 3。

log_prob 的计算

如下代码所示,veRL 中具体计算 old_log_prob 的代码如下所示,要求输入字段为:

  1. input_ids: shape [batch_size, sequence_length]。其中 input_ids 是由 prompt 和 response 拼接而来, sequence_length = prompt_length + response_length
  2. attention_mask: shape [batch_size, sequence_length]
  3. position_ids: shape [batch_size, sequence_length]

将 input_ids, attention_mask 等传递给 transformer 模型之后,输出为 logits 张量,形状为 (batch_size, sequence_length, vocab_size)logits 是模型在每个序列位置上为词汇表中每个 token 打出的原始、未归一化的分数。

一般来说,log_probs 的计算方式可以表示如下:

1
2
3
4
def logprobs_from_logits_naive(logits, labels):
    logp = F.log_softmax(logits, dim=-1)
    logpy = gather_from_labels(logp, labels)
    return logpy
  • 输入的 logits 为 (batch_size, sequence_length, vocab_size)
  • logp = F.log_softmax(logits, dim=-1)
    • 这一行将原始的 logits 转换成对数概率。
    • F.log_softmax 在最后一个维度(dim=-1,即词汇表维度)上进行操作。它在数学上等价于先对 logits 应用 softmax 函数将其转换为概率分布,然后再取对数。
    • 执行后,logp 的形状与 logits 相同,但其值现在是规范化的对数概率。
    • 到目前为止,log_probs 这一计算只需要依赖于 logits,不需要依赖于 labels
  • 那这里的 gather_from_labels 作用是什么呢?
    • 输入 logp 为 (batch_size, sequence_length, vocab_size) 表示 batch 中每个 sequence 的每个 token 对应的位置,输出 vocab 的每个 token 的对数概率
    • 也就是说这里的 logp 是一个完整的概率分布,包含了完整的各个 state 的 probability
    • 强化学习中,算法更关心 Agent 实际采取的那个 action 的 log probability$$ \begin{align*} \nabla_{\theta} J(\pi_\theta) &= \mathbb{E}{\tau \sim \pi\theta} \left[ \sum_{t=0}^{T-1} A_\pi(s_t, a_t) \nabla_{\theta} \log \pi_\theta(a_t|s_t) \right] \end{align*} $$
    • 这里的 label 为 (batch_size, sequence_length),对应的是每个位置输出的 token id,这也就是实际采取的 action_t
    • 通过 gather_from_labels,我们获得了实际采取对应 action 的 log_prob,这里的 torch.gather 的操作就是从完整分布里面取出来输出 token_i 的 log probability
    • 最终的 logpy 就是在 old policy 下采取 action 的 log prob
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def gather_from_labels(data, label):
    """Gather the label from data. The value in label should be [0, vocab_size)

    Args:
        data: (..., vocab_size)
        label (torch.IntTensor) : (...,)

    Returns:

    """

    output = torch.gather(data, -1, label.unsqueeze(-1)).squeeze(-1)
    return output

总结下,经过 logprobs_from_logits_naive,输出了采取每个 action 的 log probability。函数返回 logpy,其形状与 labels 相同(即 (batch_size, sequence_length))。logpy 中的每个元素 logpy[i, j] 就是模型对于第 i 个样本的第 j 个位置预测出真实 token labels[i, j] 的对数概率。

上图实际对应流程如下:

  • input_ids 和 attention_mask 经过 transformer layer 得到 hidden states,shape 为 (seqlen, hidden_size)
  • hidden states 经过 LMHead,得到 logits,shape 为 (seqlen, vocab_size)
  • logits 经过 log_softmax 得到 full_log_probs,shape 为 (seqlen, vocab_size)
  • 选择 policy 采取的 action,也就是 next token 作为 label,选取对应的 log prob,shape 为 (seqlen, )

比如示例中,prompt 为 where is shanghai ?,response 为 shanghai is in china,对应的 input_ids 就是 prompt + response。经过 transformer layer,? 这个 token 应该预测出 shanghai

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
# input_ids: (bsz, seqlen)
# label: left shift input_ids by 1
output = self.actor_module(
    input_ids=input_ids,
    attention_mask=attention_mask,
    position_ids=position_ids,
    use_cache=False,
    **extra_args,
)
logits = output.logits # (bsz, seqlen, vocab_size)
logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size)
# 这里的 logprobs_from_logits 与上面的 logprobs_from_logits_naive 相同,只是有了显存优化 
# 这里传入的 logits 就是 left shift logits by 1,labels 就是 response
log_probs = logprobs_from_logits(logits, micro_batch["responses"]) # (bsz, response_length)
if calculate_entropy:
    entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)

entropy 的计算

熵 entropy 是信息论的概念,它用于衡量一个随机变量的不确定性混乱程度

  • 高熵:表示随机变量的取值非常不确定,结果难以预测。例如,抛一枚均匀的硬币,正反面概率各为50%,此时熵值最高。
  • 低熵:表示随机变量的取值比较确定,结果容易预测。例如,一枚作弊的硬币,正面概率为99%,反面为1%,此时熵值很低。

在强化学习中,我们将这个概念应用于智能体的策略 Policy上。策略 $\pi(a|s)$ 是指在状态 s 下,智能体选择动作 a 的概率分布。

因此,策略的熵衡量的是在给定状态 s下,智能体选择动作的不确定性

  • 高熵策略:Agent 倾向于探索各种不同的动作,每个动作被选中的概率相对平均。这被称为 Exploration 强的策略。
  • 低熵策略:Agent 非常确定,只倾向于选择某一个或某几个动作,概率分布很集中。这被称为 Exploitation 强的策略。 $$ H(\pi(\cdot|s)) = -\sum_{a \in \mathcal{A}} \pi(a|s) \cdot \log \pi(a|s) $$ 因此,对于每个 token,其概率分布对应于 logits shape 为 (1, vocab_size),可以直接计算出当前 token 的 entropy。

实际计算中经常会应用 The Log-Sum-Exp Trick 来加速计算:

$$ \log(p_i) = \log(\text{softmax}(\text{logits})i) = \text{logits}i - \text{logsumexp}(\text{logits}) $$ 因此,对于当前 token $p$ 在 vocab 每一个可能的 action 概率 $p_i$ ,其 entropy 计算为: $$ \begin{align*} H(p) &= - \sum{i} p_i \log(p_i) \ &= - \sum{i} p_i (\text{logits}i - \text{logsumexp}(\text{logits})) \ &= - \left( \sum{i} p_i \cdot \text{logits}i - \text{logsumexp}(\text{logits}) \sum{i} p_i \right) \ &= \text{logsumexp}(\text{logits}) - \sum_{i} p_i \cdot \text{logits}_i \end{align*} $$

1
2
3
4
5
def entropy_from_logits(logits: torch.Tensor):
    """Calculate entropy from logits."""
    pd = torch.nn.functional.softmax(logits, dim=-1)
    entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)
    return entropy

Ref logprobs

ref_log_prob 的计算与 old_log_prob 计算基本一致,同样是给定 prompt + response,可以在 ref model 计算出 ref_log_prob

Old values

对于 critic,在 model 初始化的时候,会将其默认设置为 token clasification model,而不是 causal language model。

1
2
3
4
5
6
7
arch = hf_config.architectures[0]
# This logic assumes the critic is a token classification model.
# If the provided model is a CausalLM, we adapt it.
if "ForCausalLM" in arch:
    model_name = arch.split("ForCausalLM")[0]
    new_arch = f"{model_name}ForTokenClassification"
    hf_config.architectures[0] = new_arch

比如配置 Critic 采用 Qwen3ForCausalLM,经过 engine 的初始化,这里就变成了 Qwen3ForTokenClassification。这两者的区别是?

  • ...ForCausalLM 模型包含一个 LMHead,实际上就是 nn.Linear (hidden_size, vocab_size)
    • Input: 接收来自基础模型的 hidden_state,其维度为 [batch_size, sequence_length, hidden_size]
    • Output: 经过 LM Head 之后,输出 logits 维度为 [batch_size, sequence_length, vocab_size]
  • ...ForTokenClassification 模型包含一个 Classification Head,这通常也是一个 nn.Linear 线性层(前面可能会有一个 Dropout 层)。
    • Input: 接收来自基础模型的 hidden_state,其维度为 [batch_size, sequence_length, hidden_size]
    • Output: 经过 Classification Head 之后,输出 logits 维度为 [batch_size, sequence_length, num_labels]

Critic 的需求:在强化学习中,Critic 模型需要对一个状态(通常由整个序列表示)给出一个单一的价值评分 (Value)。这意味着它的头部只需要输出一个标量(即 num_labels = 1)。

  • ForTokenClassification 的妙用:一个 num_labels = 1ForTokenClassification 模型,其头部恰好就是一个将 hidden_state 映射到单一维度的线性层。这在功能上与我们需要的 Value Head 完全等价。

总结下:对于 compute_values,他进行的计算为:

  • Input: input_ids [prompt + response],对应的 shape (batch_size, sequence_length)
  • Output: old_values, 对应的 shape (batch_size, sequence_length)

对应到具体代码,FSDPEngineWithValueHead 和 FSDPEngineWithLMHead 区别只在于对于 Output 的处理:

  • 因为 huggingface Qwen3ForTokenClassfication 架构的模型输出 logits shape 为 [batch_size, sequence_length, num_labels],其中 num_labels=1,因此最后 squeeze(-1) 获得每个 token 对应的 value
1
2
3
4
5
output = self.model_module(
    **model_inputs,
)
values_rmpad = output.logits # (1, total_nnz / sp_size, 1)
full_output_rmpad = values_rmpad.squeeze(0).squeeze(-1) # (total_nnz / sp_size)

KL Penalty

KL 散度 Kullback-Leibler Divergence 的标准计算公式为: $$ KL(P\parallel Q) = \sum_{x} P(x) \log \frac{P(x)}{Q(x)} = \mathbb{E}_{x \sim P} \left[ \log \frac{P(x)}{Q(x)} \right] $$ 并且,KL 散度是非对称的,也满足非负,当且仅当 P 和 Q 分布完全相同时 KL 为 0。对应于 torch 的函数12

1
torch.nn.functional.kl_div(input, target, reduction='batchmean', log_target=False)

其中:

  • inputlog-probabilities(即 log Q
  • targetprobabilities(即 P
  • 计算的是:$$\sum P⋅(log⁡P−log⁡Q)\sum P \cdot (\log P - \log Q)∑P⋅(logP−logQ)$$ 然而,在实际计算中,直接计算 KL 散度可能非常困难,主要原因如下:
  • 需要对所有 x  进行求和或积分,计算成本高。
  • 计算过程中可能涉及大规模概率分布,导致内存消耗过大。

因此实际计算中,通常通过 k1/k2/k3 来估算 KL 散度,关于 kl 散度估计,可以参考 kl-divergence

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
    if kl_penalty in ("kl", "k1"):
        return logprob - ref_logprob

    if kl_penalty == "abs":
        return (logprob - ref_logprob).abs()

    if kl_penalty in ("mse", "k2"):
        return 0.5 * (logprob - ref_logprob).square()

    # J. Schulman. Approximating kl divergence, 2020.
    # # URL http://joschu.net/blog/kl-approx.html.
    if kl_penalty in ("low_var_kl", "k3"):
        kl = ref_logprob - logprob
        # For numerical stability
        kl = torch.clamp(kl, min=-20, max=20)
        ratio = torch.exp(kl)
        kld = (ratio - kl - 1).contiguous()
        return torch.clamp(kld, min=-10, max=10)

    if kl_penalty == "full":
        # so, here logprob and ref_logprob should contain the logits for every token in vocabulary
        raise NotImplementedError

    raise NotImplementedError

在 RLHF 实践中,通常会在 reward 的计算中,应用 KL Penalty 负奖励,作为 reward shaping3。加入 KL Penalty 可以避免 policy 朝着 reward model 的方向过度优化,也可以使得 policy model 不要和 ref policy 区别太大,保证训练稳定。这里的公式里面 KL 用 K1 估计。 $$ r_t = r_\varphi(q, o_{\leq t}) - \beta \mathbb{D}{KL}\left[ \pi\theta \parallel \pi_{ref} \right] = r_\varphi(q, o_{\leq t}) - \beta \log \frac{\pi_\theta(o_t | q, o_{<t})}{\pi_{ref}(o_t | q, o_{<t})}, $$

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"):
"""Apply KL penalty to the token-level rewards.
This function computes the KL divergence between the reference policy and current policy, then applies a penalty to the token-level rewards based on this divergence.
"""
    response_mask = data.batch["response_mask"]
    token_level_scores = data.batch["token_level_scores"]
    batch_size = data.batch.batch_size[0]
    # compute kl between ref_policy and current policy
    # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.
    kld = core_algos.kl_penalty(
        data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty
    ) # (batch_size, response_length)
    kld = kld * response_mask

    beta = kl_ctrl.value
    token_level_rewards = token_level_scores - beta * kld
    current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence
    current_kl = torch.mean(current_kl, dim=0).item()

    # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837
    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
    data.batch["token_level_rewards"] = token_level_rewards
    metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta}

    return data, metrics

Advantages

$$ \hat{A}t = \delta_t + \gamma \lambda \hat{A}{t+1} $$ 其中 $$ \delta_t = \hat{A}t^{TD} = r_t + \gamma V\pi(s_{t+1}) - V_\pi(s_t) $$

输入为:Token Reward、Old Values 输出为:Advantages、Returns

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def compute_gae_advantage_return(
    token_level_rewards: torch.Tensor,
    values: torch.Tensor,
    response_mask: torch.Tensor,
    gamma: torch.Tensor,
    lam: torch.Tensor,
):
    """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py

    Args:
        token_level_rewards: `(torch.Tensor)`
            shape is (bs, response_length)
        values: `(torch.Tensor)`
            shape is (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
        gamma is `(float)`
            discounted factor used in RL
        lam: `(float)`
            lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)

    """
    with torch.no_grad():
        nextvalues = 0
        lastgaelam = 0
        advantages_reversed = []
        gen_len = token_level_rewards.shape[-1]

        for t in reversed(range(gen_len)):
            delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
            lastgaelam_ = delta + gamma * lam * lastgaelam

            # skip values and TD-error on observation tokens
            nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues
            lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam

            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)

        returns = advantages + values
        advantages = verl_F.masked_whiten(advantages, response_mask)
    return advantages, returns

Policy Loss

Policy loss 定义如下: $$ \begin{align} L^{\text{PPO}}(\pi_\theta) &= - \underset{\tau \sim \pi_{\theta_{\text{old}}}}{E_t} \left[ \min\left{ r_t(\theta) A_{\pi_{\theta_{old}}}^{\text{GAE}}(s_t, a_t), \text{clip}\left(r_t(\theta), 1 - \epsilon, 1 + \epsilon\right) A_{\pi_{\theta_{old}}}^{\text{GAE}}(s_t, a_t) \right} \right] \end{align} $$

其中 $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}$ 是策略比率 Probability Ratio 令 $$ \begin{align} \text{pgloss1} &= -r_t(\theta)A_{\pi_{\theta_{old}}}^{GAE}(s_t, a_t) \ \text{pgloss2} &= -\text{clip}\left(r_t(\theta), 1 - \epsilon, 1 + \epsilon\right) A_{\pi_{\theta_{old}}}^{\text{GAE}}(s_t, a_t) \end{align} $$ 则 $$ \begin{align} L^{\text{PPO}}(\pi_\theta) &= \underset{\tau \sim \pi_{\theta_{\text{old}}}}{E_t} \left[ \max\left{ \text{pgloss1}, \text{pgloss2}\right} \right] \end{align} $$ 引入计算 Policy Loss 的时候,输入经常是 log_prob,则对应的 Probability Ratio 计算为: $$ r_t(\theta) = \exp\left{ \log \pi_\theta(a_t \mid s_t) - \log \pi_{\theta_{old}}(a_t \mid s_t) \right} $$

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def compute_policy_loss(old_log_prob, log_prob, advantages, response_mask, cliprange):
    """
    old_log_prob: (bs, response_length)
    log_prob: (bs, response_length)
    advantages: (bs, response_length)
    response_mask: (bs, response_length)
    """
    negative_approx_kl = log_prob - old_log_prob
    ratio = torch.exp(negative_approx_kl)
    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

    pg_losses1 = -advantages * ratio
    pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)

    pg_loss = verl_F.masked_mean(torch.max(pg_losses1, pg_losses2), response_mask)
    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
    return pg_loss, pg_clipfrac, ppo_kl

如果还要计算 entropy loss,则可以根据 entropy 计算出 entropy loss(主要就是做了个聚合)。因为 policy 可能陷入局部最优,停止探索新动作。为了鼓励探索,可以加上 entropy loss,这样策略会保持一定的随机性,继续探索可能更好的动作。具体而言,因为我们是要最大化熵,因此我们将 entropy loss 作为负奖励。

1
2
3
4
5
# add entropy loss
if entropy is not None:
    entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
    entropy_coeff = config.entropy_coeff
    policy_loss -= entropy_coeff * entropy_loss

在语言模型训练中,我们通常会计算出在一个 batch 中,每个序列(sequence)里每一个 token 的损失。这会得到一个形状为 (batch_size, sequence_length) 的损失矩阵。但为了进行反向传播和梯度更新,我们最终需要一个单一的、代表整个批次性能的标量损失值。agg_loss 函数就是用来完成这最后一步的聚合计算的。

  • token-mean:它计算所有被 loss_mask 标记为有效的 token 损失的总平均值。它将整个批次中所有有效的 token 视为一个整体,一视同仁地计算平均损失。
  • seq-mean-token-sum
    • 先对每个序列内部的所有 token 损失进行求和
    • 然后对所有序列的“总损失”进行平均
    • 这种模式下,较长的序列会因为其 token 损失总和更大,而在最终的平均值中占据更高的权重。
  • seq-mean-token-mean
    • 先计算每个序列内部的平均 token 损失
    • 然后对所有序列的“平均损失”再进行一次平均
    • 这种方式下,每个序列对最终 loss 的贡献是平等的,无论序列长短。它避免了长序列主导损失信号的问题。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):
    """
    Aggregate the loss matrix into a scalar.

    Args:
        loss_mat: `(torch.Tensor)`:
            shape: (bs, response_length)
        loss_mask: `(torch.Tensor)`:
            shape: (bs, response_length)
        loss_agg_mode: (str) choices:
            method to aggregate the loss matrix into a scalar.
    Returns:
        loss: `a scalar torch.Tensor`
            aggregated loss
    """
    if loss_agg_mode == "token-mean":
        loss = verl_F.masked_mean(loss_mat, loss_mask)
    elif loss_agg_mode == "seq-mean-token-sum":
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)  # token-sum
        loss = torch.mean(seq_losses)  # seq-mean
    elif loss_agg_mode == "seq-mean-token-mean":
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1)  # token-mean
        loss = torch.mean(seq_losses)  # seq-mean
        raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")

    return loss

如果还要计算 KL loss,比如在 GRPO 中,为了防止 policy 和 ref model 相差太远,通过 k3 estimator 计算 KL Loss。

在 GRPO 中,对于每个 question $q$,GRPO 采样出一组 ouputs ${o_1, o_2, …, o_G}$ $$ \begin{align} \mathcal{J}{\text{GRPO}}(\theta) &= \mathbb{E}\left[ q \sim P(Q), {o_i}{i=1}^G \sim \pi_{\theta_{\text{old}}}(O|q) \right] \ & \Bigg{ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left{ \min\left[ \frac{\pi_\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{\text{old}}}(o_{i,t}|q, o_{i,<t})} \hat{A}{i,t}, clip\left( \frac{\pi\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{\text{old}}}(o_{i,t}|q, o_{i,<t})}, 1-\varepsilon, 1+\varepsilon \right) \hat{A}_{i,t} \right]

  • \beta \mathbb{D}{KL}\left[ \pi\theta \parallel \pi_{ref} \right] \right}, \Bigg} \end{align} $$

$$ \begin{align} \mathbb{D}{\text{KL}} \left( \pi\theta | \pi_{\text{ref}} \right) &= \frac{\pi_{\text{ref}}(o_{i,t}|q,o_{i,<t})}{\pi_\theta(o_{i,t}|q,o_{i,<t})} - \log \frac{\pi_{\text{ref}}(o_{i,t}|q,o_{i,<t})}{\pi_\theta(o_{i,t}|q,o_{i,<t})} - 1 \end{align} $$

1
2
3
4
5
6
7
# add kl loss
if config.use_kl_loss:
    ref_log_prob = data["ref_log_prob"]
    # compute kl loss
    kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=config.kl_loss_type)
    kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=config.loss_agg_mode)
    policy_loss += kl_loss * config.kl_loss_coef

Critic Loss

Critic 的训练目标是最小化它的预测价值与实际回报之间的差距。这个差距被称为 Critic Loss。

Critic Loss 通常通过均方误差(Mean Squared Error, MSE)来计算。对于每一个状态,我们都有一个由 Critic 预测出的预期回报值 $V(s)$,以及一个真实的回报值 G(returns)。Critic Loss 就是这两个值之间差的平方。在一个批量的数据中,Critic Loss 是所有状态的这个差的平方的平均值。

$$

\begin{align*} V_t^{\text{CLIP}} &= \text{clip}(V_t^{\text{new}}, V_t^{\text{old}} - \epsilon, V_t^{\text{old}} + \epsilon) \ G_t &= A_t^{\text{GAE}} + V_t^{\text{old}} \ \arg\min_{V_\phi} L(V_\phi) &= \mathbb{E}_t\left[ \max\left{ (V_t^{\text{new}} - G_t)^2, (V_t^{\text{CLIP}} - G_t)^2 \right} \right] \end{align*}

$$

令 $$ \begin{align*} \text{vfloss1} &= (V_t^{\text{new}} - G_t)^2 \ \text{vfloss2} &= (V_t^{\text{CLIP}} - G_t)^2 \end{align*} $$ 则 $$ \begin{align*} L(V_\phi) &= \mathbb{E}_t\left[ \max\left{ \text{vf_loss1}, \text{vf_loss2} \right} \right] \end{align*} $$

代码实现:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def compute_value_loss(
    vpreds: torch.Tensor,
    returns: torch.Tensor,
    values: torch.Tensor,
    response_mask: torch.Tensor,
    cliprange_value: float,
    loss_agg_mode: str = "token-mean",
):
    """
    Compute the clipped value-function loss for PPO.

    Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151

    Args:
        vpreds (torch.FloatTensor):
            Predicted values from the value head, shape (batch_size, response_length).
        values (torch.FloatTensor):
            Old (baseline) values from the value head, shape (batch_size, response_length).
        returns (torch.FloatTensor):
            Ground-truth returns, shape (batch_size, response_length).
        response_mask (torch.Tensor):
            Mask indicating which tokens to include in the value loss calculation.
        cliprange_value (float):
            Clip range for value prediction updates.
        loss_agg_mode (str, optional):
            Aggregation mode for `agg_loss`. Defaults to "token-mean".

    Returns:
        vf_loss (torch.FloatTensor):
            A scalar tensor containing the aggregated value-function loss.
        vf_clipfrac (float):
            Fraction of elements where the clipped loss was used.
    """
    vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
    vf_losses1 = (vpreds - returns) ** 2
    vf_losses2 = (vpredclipped - returns) ** 2
    clipped_vf_losses = torch.max(vf_losses1, vf_losses2)
    vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
    vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
    return vf_loss, vf_clipfrac

GRPO

对于每个 question $q$,GRPO 采样出一组 ouputs ${o_1, o_2, …, o_G}$,

$$ \begin{align} \mathcal{J}{\text{GRPO}}(\theta) &= \mathbb{E}\left[ q \sim P(Q), {o_i}{i=1}^G \sim \pi_{\theta_{\text{old}}}(O|q) \right] \ & \Bigg{ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left{ \min\left[ \frac{\pi_\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{\text{old}}}(o_{i,t}|q, o_{i,<t})} \hat{A}{i,t}, clip\left( \frac{\pi\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{\text{old}}}(o_{i,t}|q, o_{i,<t})}, 1-\varepsilon, 1+\varepsilon \right) \hat{A}_{i,t} \right]

  • \beta \mathbb{D}{KL}\left[ \pi\theta \parallel \pi_{ref} \right] \right}, \Bigg} \end{align} $$

$$ \begin{align} \mathbb{D}{\text{KL}} \left( \pi\theta | \pi_{\text{ref}} \right) &= \frac{\pi_{\text{ref}}(o_{i,t}|q,o_{i,<t})}{\pi_\theta(o_{i,t}|q,o_{i,<t})} - \log \frac{\pi_{\text{ref}}(o_{i,t}|q,o_{i,<t})}{\pi_\theta(o_{i,t}|q,o_{i,<t})} - 1 \end{align} $$

where $\varepsilon$ and $\beta$ are hyper-parameters, and $A_i$ is the advantage, computed using a group of rewards ${r_1, r_2, \ldots, r_G}$ corresponding to the outputs within each group: $$A_i = \frac{r_i - \text{mean}({r_1, r_2, \cdots, r_G})}{\text{std}({r_1, r_2, \cdots, r_G})}. $$

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
                                   eos_mask: torch.Tensor,
                                   index: torch.Tensor,
                                   epsilon: float = 1e-6):
    """
    token_level_rewards: (bs, response_length)
    eos_mask:  (bs, response_length)
    """
    response_length = token_level_rewards.shape[-1]
    scores = token_level_rewards.sum(dim=-1)

    # id2score的value是一个列表,其包含了G个奖励值,key则是为每个输入x生成的id
    id2score = defaultdict(list)
    # id2mean和id2std分别存储每组的均值和标准差
    id2mean = {}
    id2std = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i])
        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
                id2std[idx] = torch.tensor(1.0)
            elif len(id2score[idx]) > 1:
                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
                id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
            else:
                raise ValueError(f"no score in prompt index: {idx}")

        for i in range(bsz):
            # normalize
            scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
        scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask

    return scores, scores

​ 除了优势计算需要重写外,整个 loss 的计算直接复用 PPO 的即可,只要指定每个输入的重复采样次数 actor_rollout_ref.rollout.n 即可。

DAPO

移除 KL 散度

上面的 PPO 及 GRPO 目标函数中都存在 Actor 模型与参考模型的 KL 散度,KL 散度的意义也说过了,就是不想让训练的模型与最初始模型分布差距太大。

但 DAPO 的训练方案应对场景是有长思维链输出(带思考过程)情况,长输出也就代表着对于输出 token 分布调整更大,那么训练后的模型就必然会与原始模型存在很大差异,因为目标就是让他们有差异,因此KL散度的约束反而不是必需的了,所以可以移除。

$$ \begin{gather*} \mathcal{J}{\text{DAPO}}(\theta) = \mathbb{E}{(q,a) \sim \mathcal{D}, {oi}{i=1}^G \sim \pi*{\theta*{\text{old}}}(\cdot \mid q)} \left[ \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \min\left( r_{i,t}(\theta) \hat{A}{i,t}, \text{clip}\left( r{i,t}(\theta), 1 - \varepsilon_{\text{low}}, 1 + \varepsilon_{\text{high}} \right) \hat{A}_{i,t} \right) \right] \ \text{s.t.} \quad 0 < \left| { oi \mid \text{is_equivalent}(a, o_i) } \right| < G, \ r{i,t}(\theta) = \frac{\pi*\theta(o*{i,t} \mid q, o*{i,<t})}{\pi*{\theta*{\text{old}}}(o*{i,t} \mid q, o*{i,<t})}, \ \hat{A}{i,t} = \frac{Ri - \text{mean}({ R_i }{i=1}^G)}{\text{std}({ Ri }{i=1}^G)}. \end{gather} $$

CLIP Higher

首先改动是 clip 函数中的 $\epsilon$ ,在 PPO 及 GRPO 中都使用一个固定的超参数(一般是 0.2),但 DAPO 中分化成了和,DAPO 论文中叫提高上限——Clip-Higher

裁剪的对象是重要性采样比值,这个比值是新策略模型生成某个 token 与旧策略模型生成某个 token 的比值,当旧策略模型生成某个 token 的概率本身比较高时,其被裁剪的概率就会变低;而如果旧策略模型生成某个 token 的概率本身比较低时,其被裁剪的概率就会变高。

举个例子,比如旧策略采样到某个 token 的概率是 0.9,按照裁剪上限 1.2 计算(1+ $\epsilon$),则新策略采样到该 token 的概率是接近于 1(0.9 * 1.2 但最大为 1)

新策略是我们训练目标函数期望的概率分布,也就是说旧策略中本身高概率的那些 token 是不容易被上限裁剪的,哪怕新策略下这个 token 采样概率很高了也容易被裁剪。

反之当旧策略采样到某个 token 概率是 0.1 时(一般情况下低概率 token 不容易被采样到,但强化训练 Rollout 会具备一定随机性),如果同样的现在裁剪上限是 1.2,那么新策略下这个 token 最高的采样概率也就是 0.1 * 1.2=0.12

也就是说对于旧策略概率低的 token,即使训练后这个 token 的采样概率也不会有很大提升,因为提升上限被裁剪限制了。这样一对比,(0.99-0.9)»(0.12-0.1)是不是就很明显的看到差距了。

这也是为什么 DAPO 要提升裁剪上限,因为不这样做的话,本来旧策略模型采样概率高的 token 会随着训练变得采样概率越来越高,而低的 token 只会有很小的提升,那么结果就是模型输出的分布越来越尖锐,也就使得分布的熵变低,造成熵坍塌现象。

Dynamic Sampling

前面对 GRPO 思考中提到过,如果 GRPO 对某一组输出的结果全是错误,或全是正确的情况下,这样组内的每个样本序列计算后的优势都是 0,因为本身就是每个样本序列奖励值与组内均值的差值最标准化,均值就等于样本奖励值的时候那么就没有优势了。

那就会造成这一组的训练不会对模型梯度变化有任何贡献,就代表本组训练没意义了。但一个问题就是当我们训练到后期就会面临一个组内很可能全正确的情况,这种情况出现的很自然,因为我们训练目标就是让模型输出序列的奖励值更高。这就使得后期的训练中有很多的组是没有意义的,白耗费训练资源。

除了耗费训练训练资源之外还会带来一个问题,假设我们有每个批次有 N 个指令来进行 GRPO 训练,这 N 个指令有的容易一些,有的难一些,模型在训练后期很可能在这 N 个指令中有 50% 的指令输出的组内序列全部正确,也就代表有一半的指令训练是没有意义。

随着模型训练到后期,每个批次中全为 1 的指令占比会更高。这会使得强化学习训练方差变大,因为我们输入指令让模型产生组内输出的过程实际是生成旧策略模型输出分布的过程。

只有指令足够多,旧策略模型输出的序列足够多才能更准确的表示旧策略的输出分布,当有效指令变少的时候,旧策略模型输出的分布也就存在一定的偏移,换句话说就是存在方差,也就是说 GRPO 越到训练后期训练的方差偏移越大。

关于解决这个问题 DAPO 也是简单粗暴,对于批次内生成的组内序列全部正确或错误的指令直接剔除掉,使用新的输出组中不全是错误或正确的指令来补充上,直到补全这个批次。

这个方法粗看会影响训练效率,因为你需要让每条指令去生成一个组,再去使用奖励函数判断才能知道输出的组中序列是否全部正确或错误,但是作者实验发现这种方法可以更快的让模型收敛,也就是说可以平衡掉 GRPO 耗费的资源,甚至更优。

Token -level loss

GRPO loss 计算是 sample/trajectory level:每个 sample 内先对 token 的 loss 做 mean,再 sample 间 mean,所有 sample share the same weight Token-level:一个 batch 里面全部 token 的 loss 算 mean

Soft Overlong Reward Shaping

VAPO

GSPO

GSPO 提到的一个 GRPO 关键性缺陷在于重要性采样修正项使用的粒度不对,GRPO 中是对序列中每个 token 进行的重要性采样,因为动作的粒度是 token,但是奖励却是对整个序列的奖励,这样会造成一种逻辑冲突的问题。

我们训练优化的目标是 token 的采样概率,这一点使用 token 作为动作粒度可以理解。但 GRPO 却是对一个序列整体奖惩,优化单元与奖惩单元粒度上不一致时的模型训练就容易出现偏差。

其实上面这个观点如果站在未简化的强化学习算法的角度上来讲(PPO 或者 A2C),优势应该是动作级别的,但是对于 LLM 的自回归输出场景来说动作级别的优势是不太容易计算的,上面介绍 RLHF 的时候也提到了,需要有奖励模型和价值模型。

但 GRPO 的简化方案使得每个 token 共享该序列整体的优势值,那粒度上对不齐也是必然的。从理论上来说带来的结果就是无法从 token 粒度上更快速的让模型提升正确 token 的采样概率。

$$

\begin{equation*} \begin{split} \mathcal{J}{\text{GSPO}}(\theta) &= \mathbb{E}{x \sim \mathcal{D}, {y_i}{i=1}^G \sim \pi{\theta_{\text{old}}}(\cdot \mid x)} \left[ \frac{1}{G} \sum_{i=1}^G \min\left( s_i(\theta) \widehat{A}_i, \text{clip}\left( s_i(\theta), 1 - \varepsilon, 1 + \varepsilon \right) \widehat{A}_i \right) \right] \end{split} \end{equation*}

$$ 其中 $$

\begin{gather*} \widehat{A}i = \frac{r(x, y_i) - \text{mean}\left( { r(x, y_i) }{i=1}^G \right)}{\text{std}\left( { r(x, y_i) }_{i=1}^G \right)} \ \end{gather*}

$$ 基于 sequence likelyhood 定义 importance ratio $s_i(\theta)$ $$ s_i(\theta) = \left( \frac{\pi_\theta(y_i \mid x)}{\pi_{\theta_{\text{old}}}(y_i \mid x)} \right)^{\frac{1}{|y_i|}} = \exp\left( \frac{1}{|y_i|} \sum_{t=1}^{|y_i|} \log \frac{\pi_\theta(y_{i,t} \mid x, y_{i,<t})}{\pi_{\theta_{\text{old}}}(y_{i,t} \mid x, y_{i,<t})} \right) $$

1
2
3
4
5
log_ratio = per_token_logps - old_per_token_logps  # 每个token的log概率差
# 按句子平均:总log概率差 / 有效token数(避免padding影响)
log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
log_importance_weights = log_importance_weights.unsqueeze(-1)  # 扩展为(batch_size, 1)
coef_1 = torch.exp(log_importance_weights)  # 整个句子的平均概率比

Token-Level vs Sequence Level

Reference

https://mp.weixin.qq.com/s/w-DBnIcgHfl7F5vSFsAbxQ https://mp.weixin.qq.com/s/lCeIxUfEh4U8cMCM09TZsw https://bytedance.larkoffice.com/docx/KCsLd3SxXo0lxpxliQucIA52nTf https://limoncc.com/post/c0a3be9c86b2b4cd/ https://zhuanlan.zhihu.com/p/1891244454612009240 https://zhuanlan.zhihu.com/p/1932791271363154917 https://huggingface.co/blog/NormalUhr/rlhf-pipeline