Google 的 PaLM 和 Meta 的 LLaMA 都使用了 SwiGLU 来增强 Transformer 架构中的 FFN 层 的性能。SwiGLU 是 Gated Linear Units(GLU)激活函数的一种变体,由 Noam Shazeer 在论文《GLU Variants Improve Transformer》的论文中提出。本文主要介绍不同激活函数(如 ReLU、GELU 和 Swish)在 FFN 层中的应用。

使用 ReLU 激活的 FFN

Transformer 模型通过 MHA 和 FFN 层交替工作。FFN 层存在于 Transformer 架构的编码器和解码器部分中。例如,下方的编码器块由多头注意力层和一个 FFN 层组成。

FFN 层包括两个线性变换,中间插入一个非线性激活函数。最初的 Transformer 架构采用了 ReLU 激活函数。

$$ FFN(x, W_1, W_2, b_1, b_2) = ReLU(xW_1 + b_1)W_2 + b_2 $$

其中 ReLU 的定义为 $ReLU (x) = max (0, x)$

使用 GELU 激活的 FFN

论文《Gaussian Error Linear Units(GELUs)》提出了GELU,这是ReLU的平滑版本。

1
2
3
4
5
def gelu(x):
   return x * norm.cdf(x)

def relu(x):
   return np.maximum(0, x)

这是一个处处可微的非线性函数,因此使用 GELU 的 FFN 层可以表示为 $$ FFN(x, W_1, W_2) = GeLU(xW_1)W_2 $$ 在 GELU 论文中,作者使用了标准正态分布的累积分布函数(cdf)的近似计算来提高计算速度。

使用 Swish 激活的 FFN

论文《Swish: a Self-Gated Activation Function》提出了 Swish,这也是对带有非零负值梯度的 ReLU 平滑版本。

1
2
3
4
5
6
7
8
def gelu(x):
   return x * norm.cdf(x)

def relu(x):
   return np.maximum(0, x)

def swish(x, beta=1):
   return x * (1 / (1 + np.exp(-beta * x)))

Swish 同样是个处处可微的非线性函数,且有一个参数 beta 用于控制函数的形状,带有 Swish 的 FFN 层可以表示为 $$ FFN(x, W_1, W_2) = Swish(xW_1)W_2 $$

GLU 及其变体

GLU(Gated Linear Units)其实不算是一种激活函数,而是一种神经网络层。它是一个线性变换后面接门控机制的结构。其中门控机制是一个 sigmoid 函数用来控制信息能够通过多少。

$$ GLU(w, W, V, b, c) = \sigma(xW+b)\otimes(xV+c) $$

其中 $\sigma$ 为 sigmoid 函数,$\otimes$ 为逐元素乘。通过使用其他的激活函数我们就能够得到 GLU 的各种变体了。

比如说现在 LLM 中常用的 SwiGLU 其实就是采用 Swish 作为激活函数的 GLU 变体 $$ GLU(w, W, V, b, c) = Swish_1(xW+b)\otimes(xV+c) $$ 由于引入了更多的权重矩阵,通常会对隐藏层的大小做一个缩放,从而保证整体的参数量不变。

代码的实现如下

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        self.act_fn = ACT2FN[config.hidden_act]
        self.dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        down_proj = self.dropout(down_proj)
        return down_proj

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 8192, padding_idx=0)
    (layers): ModuleList(
      (0-79): 80 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=8192, out_features=8192, bias=False)
          (k_proj): Linear4bit(in_features=8192, out_features=8192, bias=False)
          (v_proj): Linear4bit(in_features=8192, out_features=8192, bias=False)
          (o_proj): Linear4bit(in_features=8192, out_features=8192, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=8192, out_features=22016, bias=False)
          (down_proj): Linear4bit(in_features=22016, out_features=8192, bias=False)
          (up_proj): Linear4bit(in_features=8192, out_features=22016, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=8192, out_features=32000, bias=False)
)

参考资料