当前的 Transformer 架构中,Attention 层的输出通常直接线性投影。作者提出在 Scaled Dot-Product Attention (SDPA) 的输出之后,直接加入一个 Sigmoid 门控机制

核心收益:

  1. 性能提升:在 15B MoE 和 1.7B Dense 模型(训练数据达 3.5T token)上,PPL 和下游任务(MMLU, GSM8k 等)均有显著提升。
  2. 训练极其稳定:该机制几乎消除了训练过程中的 Loss Spikes,使得模型可以使用更大的学习率进行训练,这对于大规模模型训练至关重要。
  3. 消除 Attention Sink:这是个意外之喜。该机制引入了 Input-dependent sparsity,使得模型不再需要将注意力强行分配给首个 Token,从而天然地消除了 Attention Sink 现象。
  4. 长窗口外推能力增强:在进行长 Context 扩展(如使用 YaRN)时,Gated Attention 的表现显著优于 Baseline。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 来自transformer library: transformers/models/qwen3_next/modeling_qwen3_next.py
class Qwen3NextAttention(nn.Module):
    def __init__(self, config: Qwen3NextConfig, layer_idx: int):
        ...
        # 注意这里*2
        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias
        )
        ...
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)
        # 分出来gate和query_state
        query_states, gate = torch.chunk(
            self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
        )
        gate = gate.reshape(*input_shape, -1)
        ...
        # gate sigmod后乘以attn分数
        attn_output = attn_output * torch.sigmoid(gate)

Gated Attention Layer

标准 Attention 回顾

标准的 Multi-Head Attention 计算流程通常如下:

$$Q, K, V = XW_Q, XW_K, XW_V$$

$$SDPA(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

$$Output = \text{Concat}(\text{head}_1, … , \text{head}_h)W_O$$

加入门控机制

作者尝试了在 Attention 的不同位置(Query, Key, Value, SDPA Output, Final Output)加入门控。最终发现在 SDPA 输出之后($G_1$)加入门控效果最好。

公式如下: $$Y’ = Y \odot \sigma(XW_\theta)$$ 其中:

  • $Y$ 是 SDPA 的输出。
  • $X$ 是当前的输入(通常是 Pre-norm 后的 hidden states)。
  • $\sigma$ 是 Sigmoid 激活函数。
  • $W_\theta$ 是可学习的门控参数。

关键配置:

  • 位置:在 SDPA 输出之后。
  • 粒度:Element-wise或 Head-wise。Element-wise 效果略好,且增加参数极少。
  • 激活函数:Sigmoid 优于 SiLU,因为 Sigmoid 输出在 [0, 1] 之间,具有更强的“过滤”物理含义。
  • Head-Specific:不同 Head 应该有独立的门控参数,而不是共享。

深度分析:为什么简单的 Gating 如此有效?

论文并没有止步于刷榜,而是花了大量篇幅分析“为什么”。这是本文最精彩的部分,Gating 如此有效主要归因于两点:非线性和稀疏性。

增加低秩映射的非线性

在标准 Attention 中,Value 投影($W_V$)和输出投影($W_O$)是两个连续的线性层。对于第 $k$ 个 Head,输出可以写成: $$o^k_t = \sum_j S^k_{tj} \cdot X_j (W_V^k W_O^k)$$ 由于 $d_k < d_{model}$,这里的 $W_V^k W_O^k$ 本质上是一个低秩的线性映射,限制了模型的表达能力。引入门控后,相当于在两个线性变换之间插入了一个非线性操作: $$o^k_t = \text{Non-Linearity-Map}\left(\sum_j S^k_{tj} \cdot X_j W_V^k\right)W_O^k$$ 这显著增加了 Attention 模块的表达能力。这也解释了为什么在 $W_O$ 之后加门控($G_5$)无效,因为它没有打破中间的线性瓶颈。

引入 Input-Dependent Sparsity

这是 Gating 区别于简单 LayerNorm 或激活函数的关键。

  • Query-Dependent:门控的分数 $\sigma(XW_\theta)$ 是由当前 Token 的输入 $X$ 决定的(即 Query 端的信息)。
  • 过滤机制:实验发现,训练好的门控分数大部分集中在 0 附近,具有很高的稀疏性。这意味着门控机制充当了一个“动态过滤器”,根据当前的 Query,选择性地让某些 Attention 结果通过,或者直接抑制(置零)。

意外之喜:彻底消除 Attention Sink

什么是 Attention Sink?

由于标准 Softmax 中所有值加起来必须为1,导致在当前 Query 无法在上下文中找到相关信息(即缺乏匹配的 Key)时,模型倾向于将大量注意力分数分配给第一个 Token(作为“垃圾桶”来存放多余的注意力权重),导致第一个 Token 拥有极高的 Attention Score,这就是 Attention Sink。这不仅破坏了语义分布,也严重影响了长文本外推能力。

Gating 如何解决?

论文发现,使用了 Post-SDPA Gating 的模型,Attention Sink 现象完全消失了

  • 原因推测:因为有了 Sigmoid 门控,模型有了“拒绝”的能力。当上下文无关时,模型不需要通过 Softmax 强行找一个 Sink,而是可以通过门控将输出  乘以一个接近 0 的系数,直接阻断信息流。
  • 验证:实验显示,Gated 模型的第一 Token 注意力占比从 Baseline 的 46.7% 降到了 4.8%。

 消除 Massive Activations

与此同时,Gating 还显著减少了模型内部的异常大的激活值。之前的研究认为 Massive Activations 是导致 BF16 训练不稳定的元凶之一。Gating 通过稀疏化输出,天然抑制了这些异常值,解释了为什么训练稳定性得到了巨大提升。

总结

这篇论文给了我们几个非常实用的工程建议:

  1. 架构改进性价比极高:在 SDPA 后加一个 x * sigmoid(Linear(x)) 几乎不增加计算量(Latency < 2%),也不怎么增加显存,但能换来更好的 PPL 和更稳的训练。这应该成为未来 LLM 的标准组件。
  2. 训练稳定性的新解法:如果你在训练大模型时遇到 Loss Spike 或梯度爆炸,除了查数据和 LayerNorm,可以试试加 Gating。它从原理上减少了激活值的漂移。
  3. 对 Attention 机制的再思考:Softmax 的强制归一化(Sum=1)可能是一个过于强的归纳偏置。Gating 实际上赋予了 Attention 机制 “Unnormalized” 的能力(即输出的总能量可以根据 Query 的需求动态调整),这比单纯的 Softmax 更加合理。

参考资料