DeepSeek MLA
MLA(Multi-head Latent Attention / 多头潜在注意力)的基本思想是将注意力输入 $h_t$ 压缩成一个低维的潜在向量 $c^{KV}_t$ ,维度为 $d_c$,且 $d_c$ 远小于原始的维度($h_nd_h$)。在需要计算注意力时,可将这个潜在向量 $c^{KV}_t$ 映射回高维空间。因此,只需要存储潜在向量 $c^{KV}_t$ ,就可以显著减少内存的占用。
这个过程可以通过以下公式更正式地进行描述。其中 $c^{KV}_t$ 表示潜在向量;$W^{DKV}$ 是压缩矩阵(上标 $D$ 代表"下投影",即降维操作),负责将 $h_t$ 的维度从($h_n d_h$)压缩到 $d_c$;$W^{UK}$ 和 $W^{UV}$ 是上投影矩阵,负责将共享的潜在向量 $c^{KV}_t$ 映射回高维空间。只需要存储这个潜在向量 $c^{KV}_t$ ,就能获得对应不同文本特征的 Key 和 Value,而不需要对每个文本特征都存储对应的 Key 和 Value。
类似地,我们也可以将查询向量映射到一个潜在的低维向量,然后再将其映射回原始的高维空间。而且,MLA又结合了权重吸收技术,减少了计算开销。
1.1 问题
标准 Transformer 的一大障碍就是 KV Cache 的空间占用问题:多头注意力机制需要为每个注意力头单独存储历史生成的 Key 和 Value 向量(即 KV 缓存)。随着序列长度增加,KV 缓存的存储需求呈指数级增长,导致内存占用急剧上升。而 GPU 的显存空间往往非常有限,较大的 KV Cache 会导致同时处理的 request 数量变少,也即 batch size 较小;为了减少 KV 缓存需求,研究人员提出了像 Multi-Query Attention(MQA)和 Group-Query Attention(GQA)这些方法。这些方法虽然降低了缓存要求,可模型的性能也受到影响。MQA 或 GQA 算子在计算注意力的过程中,所有 KV Cache 中的数据读取后都仅参与一次或几次计算,导致该算子的 MFU 极低,并且由于每个 request 有自己的 KV Cache,这一问题无法通过提高 batch size 的方式解决。
因此,如何减少推理过程的KV Cache,从而实现在更少的设备上推理更长的Context,或者在相同的Context长度下增大batch size,实现更快的推理速度或者更大的吞吐总量,最终降低推理成本。是一个关键问题。
1.2 当前状况
我们首先总结下当前各种方案的情况来看看有没有可以改进的空间。下图给出了MHA、GQA、MQA 以及 MLA 做法。
图上从左到右依次是 MHA、GQA、MQA 以及 MLA 。图中有阴影的长条表示会缓存到显存的结果。MHA、GQA、MQA 都需要将 KVCache 缓存到显存。几种方案特点如下。
- MHA:MHA KVCache 在注意力头这个维度和 Q 矩阵一样,属于“一对一”。MHA 把一个注意力计算拆成多个注意力头,每个注意力头使用独立的 Q、K、V 进行计算,需要把 K、V 都存储下来,KV Cache 中每个 token 需要缓存的参数量为 $2n_hd_hl$。而 GQA、MQA 在注意力头的维度比 Q 矩阵小。
- MQA:所有查询头共享相同的单一键和值头,因此只需要存储共享的 K 和 V,KV Cache 中每个 token 需要缓存的参数量为 $2d_hl$。在计算注意力时,会把共享的单一 K 头和 V 头广播给每个查询头,然后分别一一计算。
- GQA:将所有的 Q 头分成 g 组,同一组的 Q 头共享一个 K 头和一个 V 头,因此 KV Cache 中每个 token 需要缓存的参数量为 $2n_gd_hl$。在计算注意力时,会把 KV 头复制给所在组的所有 Q 头进行计算。
$n_h$ 是注意力头数量,$n_g$ 是 GQA 分组数,$d_h$ 是隐藏层维度,$l$ 为模型层数,$h_t \in R^d$ 表示第 $t$ 个 token 在一个 attention 层的输入。
1.3 改进思路
MLA是对MHA、GQA、MQA方案的改进,其思路是加强信息压缩能力(对应下图标号1)和丰富信息表达能力(对应下图上标号2),其实,两个标号也对应了从输入到Q、K、V的数据流上两个关键点,也是硬币的两面:增强了矩阵的表现能力的同时,也会使得压缩能力更大。
于是就是研究人员经常遇到的困境了:既要压缩更低(降低推理过程中的KV Cache资源开销),又要表现力更强(缓解MQA、MGA对性能的损耗),或者说新方案的表现力要尽可能接近MHA。
1.3.1 增强信息压缩能力
思路
从某个角度考虑,MQA 和 GQA 也属于低秩压缩的思想,MQA 将 $2n_h$ 压缩到2,GQA 则压缩到 $2n_h/g$。但是压缩能力和性能难以兼顾,所以 GQA 效果要好于 MQA。
因此我们要思考,是不是可以在“增强信息压缩能力且兼顾效果”之上再进一步?因为 MQA 在 KV 头上已经几乎做到了极致,因此我们没法从 KV 头数量上做减少。那就势必得从 KV 本身思考。目前,不管是 GQA 还是 MQA,都需要缓存 K、V 两个值,两个值不一样。那么,是否可以把两个值合并为一个值?有没有可能每个缓存的 KV 都比之前小?从 LoRA 那里得到启发,一个 M×N 的矩阵可以近似成两个 M×k 和 k×N 矩阵的乘积,如果我把一个 K 或者 V 矩阵拆成两个小矩阵的乘积,就可以减少 KV Cache 的显存占用。
方案
MLA 的核心是对注意力键和值进行低秩联合压缩,以减少推理期间的键值(KV)缓存大小,从而提高推理效率。与 GQA、MQA 直接压缩 KVCache 头维度不同,MLA 通过使用下投影矩阵 $W^{DKV}$ 将多个注意力头的 Key 和 Value 投影到一个低维的共享潜在向量空间中,取代了传统的逐头存储方式。
具体而言,MLA 将 KV 矩阵转换为低秩形式:将原矩阵表示为两个较小矩阵(相当于潜向量)的乘积。具体而言,
- 对输入矩阵的 HiddenState 会先做低秩转换,将一个 Shape 为
[S,H]的 HiddenState 压缩到 Shape 为[S,CH]的潜在向量 $c^{KV}_t$,其中 $CH≪H$ 。H 是 token 维度。 - 将压缩后的 KV 向量 $c^{KV}_t$ 作为 KVCache 存储到显存中,这样就达到了降低 KV 大小的目的。在 V2的论文中, $K_t$ 的表达从 $W^Kh_t$ 变为 $W^{UK}W^{DKV}h_t$ , 原来缓存的是 $W^Kh_t$,而现在缓存的是 Kt 的一部分 $W^{DKV}h_t$。
问题
但这有一个问题,如果单纯的把一个大的K/V矩阵拆成2个小矩阵进行缓存,在推理的时候,还是需要计算出完整的K矩阵,这样就失去了缓存的意义,毕竟缓存的意义就是减少计算。
问题就变成:有没有一种方法,即能减少缓存大小,又不增加推理时候的计算?
1.3.2 丰富信息表达
思路
我们可以注意到,在 MQA 和 GQA 计算注意力时,只用到了简单的广播或者复制机制把 KV 头复制给对应的 Q 头进行计算。我们以 GQA 为例,GQA 目的是减少 KV Cache 占用,存储的是 KV,即 $C^{KV}$。下面公式是如何得到 k(这里省略了 v 的操作)。
- 首先它将向量对半分为两份分别作为 K、V。
- 然后每一份又均分为 $g$ 份。
- 每一份复制 $h/g$ 次,以此来“凑”够 h 个 Attention Head 所需要的 K、V。
这里的WUK是一组简单线性变化(比如简单复制)的组合,其表现能力是有限的,所以其压缩维度不大。
$$ \begin{align} \boldsymbol{k} = W^{UK} C^{KV} = W^{UK} [\boldsymbol{k}^1, \dots, \boldsymbol{k}^g, \boldsymbol{v}^1, \dots, \boldsymbol{v}^g] \
\left[ \boldsymbol{k}^1, \dots, \boldsymbol{k}^1, \boldsymbol{k}^2, \dots, \boldsymbol{k}^2, \dots, \boldsymbol{k}^g, \dots, \boldsymbol{k}^g \right] \end{align} $$
既然 MQA 和 GQA 的信息表达能力不强,那么我们是不是可以引入一个矩阵变换来替代这些简单的线性变换操作(切分、复制)?比如通过针对每个 $q$ 都去自适应学习,这样就可以让这一层的信息表达更加丰富。
方案
我们已经得到了潜在向量 $C^{KV}$,那么就可以在推理期间使用每个头的上投影矩阵 $W^{UK}$ 和 $W^{UV}$ 这个潜在向量中 $C^{KV}$ 重建 K 和 V。
具体而言,MLA 在 Decode 阶段将:
- 加载压缩的 KVCache 潜在向量 $C^{KV}$。
- 然后通过上投影矩阵 $W^{UK}$ 和 $W^{UV}$ 做两个升秩转换,分别转换为 Shape 均为
[S,H]的 K、V 矩阵,即从潜在向量中恢复出每个头的 Key 和 Value(将这个潜在向量映射回高维空间)。上投影矩阵 $W^{UK}$ 和 $W^{UV}$ 做两个升秩转换起到的作用比 GQA 的简单线性变化(比如简单复制)的组合要大得多。 - 进行 MHA 计算。这样,MLA 在推断过程中仅缓存 latent vector,而不缓存完整的 KV。
MLA 的本质是对 KV 信息的有损压缩,但 MLA 可以通过训练学习到如何在提高存储信息密度的同时尽可能保留关键细节。这规避了分组查询注意力和多查询注意力的查询的信息损失,从而在降低 KV 缓存的前提下获得更好的性能。从 MLA 算子的计算特征来看,同时解决了这两方面的问题:
- 一方面,通过低秩压缩大幅降低了推理过程中的 KV Cache 资源开销。减少推理过程的 KV Cache,从而实现在更少的设备上推理更长的 Context,或者在相同的 Context 长度下增大 batch size,实现更快的推理速度或者更大的吞吐总量,最终降低推理成本。
- 另一方面,MLA解压缩后的多头注意力机制能够提供较高的计算强度(正比于 Head 数),有助于充分利用GPU的算力资源,缓解MQA、MGA对性能的损耗。MLA 通过低秩转换方式压缩 KVCache,从公式来看引入了额外的升秩转换计算,并且需要存储升秩转换计算的激活值结果。但可以根据矩阵乘的交换率特性,将升秩转换的矩阵乘权重和其他权重融合,然后在 attention kernel 直接完成 attention 计算,无需引入额外的计算开销以及存储开销。
1.3.2 解决位置编码冲突
然而,压缩和 RoPE 位置编码是冲突的,即矩阵吸收后的 $c^{KV}_t$ 没有了位置相关信息(原因:RoPE 对 key 和 query 都是位置敏感的)。在这种情况下,只依靠 $c^{KV}_t$ 来压缩 KV-Cache 的路是行不通的,所以需要额外的信息来表达 qk 之间位置关系。为了走出这个困境,DeepSeek 提出了一种折中的方法:使用 $W^{QR}$ 和 $W^{KR}$ 两个矩阵来表征跟 ROPE 相关的特征提取,为 q 和 k 都增加一个额外的维度 $d^R_h$ 来添加 ROPE 编码,之前的 $d_h$ 维度不使用 ROPE 编码,总长度变为 $d_h+d_r$。即,MLA 采用了 MQA 的思想,构造了所有 head 共享的 cache 变量 $c^{KV}_t$ 和 $k^R_i$,这样才大幅降低了 KV Cache。其中 $c^{KV}_t$ 是参数低秩分解中 Down 处理后 Up 处理前的低维向量,而 $k^R_i$ 可视作是 MQA 版本的 RoPE。
具体参见下图。
1.4 架构图 & 流程
作为对比,下图给出了 MHA 的数学公式,对于每个 token 需要缓存2nhdhl 个元素。如果是千问72B,则需要2×80×64。在这里 $q_{t,i}, k_{t,j}, v_{t,j}$ 都是用列向量表示。t 是第 t 个 token,j 是迭代第1到 t 个 token 的序号,i 是迭代 head 的序号。
下图给出了MLA的架构图,以及公式。
图中,黄色区域公式主要是为了计算Q(即Attention中的Q矩阵)。绿色区域主要是为了计算K的位置不敏感部分。紫色区域是计算K的位置敏感部分;灰色是把K聚合起来;红色是计算V。具体流程如下:
- Query 的降维压缩:输入序列中的 t 个 Token $h_t$ 通过一个下投影矩阵 $W^{DQ}$ 压缩为压缩潜在向量 $c^Q_t$(其维度远远小于输入 Token 的维度)。此处对应图上标号37。
- Key 和 Value 的联合压缩:输入序列中的第 t 个 Token $h_t$ 通过一个下投影矩阵 $W^{DKV}$ 压缩为压缩潜在向量 $c^{KV}_t$(其维度 dc 远远小于输入 Token 的维度 d)。在推理阶段,MLA 仅需要缓存 $c^{KV}_t$,即 KV 缓存仅 $dc×l$ 个元素,其中 $l$ 为模型层数。此处对应图上标号41。
- 解耦 RoPE 策略:为提高模型对序列中上下文信息的敏感性,MLA 中应用了解耦旋转位置编码(RoPE)技术。因 RoPE 与低秩 KV 压缩矩阵不兼容,故 MLA 引入额外的查询向量 $q^R_t$ 和共享键向量 $k^R_t$ 来携带 RoPE 信息,避免了 RoPE 与低秩压缩矩阵之间的耦合问题,解决了位置信息与推理效率之间的矛盾。此处大致对应图上标号39和标号43。
- 恢复信息:进行注意力计算时,进行注意力计算时,$c^{KV}_t$ 分别通过上投影矩阵 $W^{UK}$ 和 $W^{UV}$ 还原出键和值,此处对应图上标号42和45。每个注意力头上的键再与携带了 RoPE 信息的共享键向量 $k^R_t$ 拼接形成 MHA 的键值输入,此处对应图上标号44。$c^Q_t$ 通过上投影矩阵 $W^{UQ}$ 和 $W^{QR}$ 升维还原并生成查询向量 $q^C_t$(对应图上标号38)和携带 RoPE 信息的查询向量 $q^R_t$(对应图上标号39),二者拼接形成 MHA 的查询向量输入,此处对应图上标号40。
- 注意力计算。此处对应图上标号46。
- 最终多个头的输入拼接在一起,并经过线性映射 $W^O$ 得到最终的输出。此处对应图上标号47。
从图上可以看出MLA的特色:
从定性角度看,可以节约内存,因为:
- 在进入标准 MHA 算法之前,用压缩的向量来替代之前更长的 KV 向量。之前是缓存 K 和 V 两个向量,现在只存储压缩后的一个向量。
- 不仅仅压缩了KV,而且还能重建成K和V(不是标准MHA下面的K和V)。
如果定量来可看,每个 Transformer 层,只缓存了上述公式蓝框的向量: $c^{KV}_t$ 和 $k^R_t$ ,其它的都可以利用“矩阵吸收”,重新恢复过来。 $c^{KV}_t$ 和 $k^R_t$ 这两个向量的大小分别为:
$c^{KV}_t$ : 维度为 $d_c=4×d_h$。$d_h$ 是单个头的向量维度。 $c^{KV}_t$ 是多头共享的。
$k^R_t$ :维度为 $d^R_h=d_h/2$。$k^R_t$ 是多头共享的。
对比 MQA(每层有一个 $d_h$ 维度的 𝑘 和一个 $d_h$ 维度的 𝑣 ,共 $2d_h$ 个元素),MLA 相当于增加了2.25倍的存储。对比 MHA 的 $2n_hd_h$,则 $n_h$ 会大于2.25,所以肯定减少缓存。
1.5 代码
下图给出了DeepSeek V3源码中MLA的定义部分。
|
|
很明显,MLA 算子是针对现代 GPU 硬件特点“量体裁衣”定制的一个注意力机制,通过对存储和计算的再平衡,能够充分发挥现代 GPU 的各项优势。我们接下来就对 MLA 的几个核心实现要点进行仔细分析。
0x02 核心要点
MLA的核心要点如下:
- 通过低秩KV联合压缩(Low-Rank Key-Value Joint Compression)降低了KV Cache的资源占用。在计算注意力时,对压缩后的向量进行升维变换,进而增强模型的表达能力。
- 通过权重吸收减少了向上投影的计算量。
- 通过解耦RoPE策略(Decoupled Rotary Position Embedding)来解决RoPE和权重吸收的冲突。
2.1 低秩KV联合压缩
2.1.1 低秩分解
低秩矩阵分解(Low-Rank Matrix Factorization)是一种特别有效的矩阵分解方法,用于发现数据中的低维结构。低秩矩阵分解的核心思想是将一个大矩阵分解为两个或多个更小、更简单的矩阵的乘积,这些小矩阵通常具有更低的秩。
在神经网络层中使用低秩分解一般都是用内存成本换取计算成本,这种方法的变体在LoRA微调等场景中很受欢迎,因为这些场景受限于总内存成本,而不是计算开销或推理速度。其好处是压缩后的矩阵使用的参数更少,并且在某种程度上更具表现力(层数增多)。它们最终可以大致近似或等同于一个更大的矩阵,因此在理论上,我们可以将这些矩阵的权重相乘,以恢复原始矩阵的近似值。
其缺点是,我们现在每次使用这些矩阵时都必须执行两次操作(即,对于每个压缩和解压缩的层,我们将矩阵乘法的总数翻倍,以换取使它们变得更小)。并且因为将它们限制为秩r或更低的矩阵,显然会损失原始矩阵的一部分表示能力。
2.1.2 思路
传统的注意力机制直接将输入X映射到QKV的注意力头维度,MQA和GQA通过共享机制来变相压缩KV Cache的头维度。MLA的核心思想是采用类似LoRA的方式表示KV。具体而言,是在prefill期间构建一个压缩空间,对输入矩阵的HiddenSize 维度进行压缩。即先将输入X映射到隐向量c存储起来。简单理解就是,假设有个矩阵的维度是𝑛∗𝑛,那么可以将其分解为两个𝑛∗𝑑的矩阵相乘,而𝑑≪𝑛。这样就降低了存储量。在decode阶段计算注意力之前,会通过上投影矩阵将c恢复到QKV的原始维度,这样可以减少注意力键(keys)和值(values)在推理过程中的缓存,从而提高推理效率。
其实这里还有一个问题:按照这种低秩方案,传统意义上的WQ,WK,WV全部变成了低秩矩阵。既然存在了低秩矩阵对满秩矩阵的替换,就可能存在性能问题。既然DeepSeek做了替换且效果不错,就说明WQ,WK,WV这几个原来的满秩矩阵可能就是冗余的,具备较大的低秩特性。
在实现过程中,Q、K 和 V 的权重矩阵通常会进行融合以提高 GPU 的计算和内存效率。与分别进行独立投影不同,采用组合权重矩阵可以优化运算过程。
2.1.3 向下投影
上图给出了向下投影的具体流程,其中 ht 作为输入向量, WDKV 和WDQ为压缩矩阵,用于降维, cKVt 和 cQt 分别是压缩后的KV潜向量和Q潜向量(潜向量的维度远小于输入向量的维度与自注意力头数之积)。这个 cKVt 是和具体的哪个head(索引为i)无关的,需要被缓存,相当于说,我们不再直接缓存key/value这两个维度和ht 一样的向量,而是缓存 cKVt ,并通过计算来动态的恢复 kt和 vt 。
- 对于KV,构建一个共享的降维映射矩阵WDKV用来对模型输入进行降维。
- WDKV会将输入ht(hidden state)投射到隐向量cKVt,这是key和value的联合隐向量。即将一个 Shape 为 [S,H] 的 HiddenState 压缩到 Shape 为 [S,dc],其中 cKVt的维度dc远小于多头key和value的原始维度dh。MLA 不保留完整的隐藏维度,而是缩小了它们的尺寸。
- 将压缩后的KV向量作为 KVCache 存储到显存中。推理的过程中只需要缓存每一层的隐向量cKVt(因为每一层的注意力头共享该参数)。由于cKVt的维度远小于K、V。因此在MLA中,每一步token推理产生的KV Cache参数由之前的2nhdhl变成dcl,从而大大减少 KV 缓存的内存占用。
- 对于Q,使用降维映射矩阵WDQ用来对模型输入进行降维。这与减少KV Cache无关,主要是为了减少训练期间参数量和相应激活所占的显存。这是因为在模型训练的过程中,每个输入的token会通过多头注意力机制生成对应的query、key和value。这些中间数据的维度往往非常高,因此占用的内存量也相应很大。
2.1.4 向上投影
当 Decode 阶段需要进行 MHA 时,会将加载KVCache,然后利用WUK和WUV对cKVt向上投影以恢复更大的尺寸。这个更大的尺寸既可以与原始输入 ht的维度匹配,也可以根据注意力头的配置来调整。DeepSeek是将KV的维度扩展回d=dhnh,从图上也可知,新的kCt,vCt分别被均分为nh个向量,即每个注意力头有一个单独的 𝑘,𝑣 (跟MHA的KV数量一致)。
具体参见下图。 WUK 和 WUV WUQ均为投影矩阵,用于升维。注:此处忽略了RoPE,后续会结合RoPE再进行扩充和更新。
结合向下投影和向上投影,我们可以看到, WQ,WK,WV 的矩阵实际上分别被拆分成了两个,做成了lora的形式进行信息压缩,这个形式下MLA就是MQA加上lora形式的扩展,并且计算量从dxd的复杂度减少为 2 x d x c。这种信息压缩后再恢复原维度的方式相比于之前只有一个矩阵的形式,能很好的帮助网络进一步学习到更有效的信息。实现了同样的低秩分解下更好的效果,这就是MLA比GQA更进一步压缩KV Cache的根本原因。
下图给出了如何拆分,上方和MLA,下方是作为比对的MQA。
实际上,论文“TransMLA: Multi-Head Latent Attention Is All You Need"就对MLA的表达能力做了相关分析。论文指出传统的GQA模型在计算注意力的时候,同一组里头的头都会共享相同的键值对,这就导致它在表达能力上有点受限。而MLA就不一样啦,它通过低秩分解,再加上独特的投影矩阵设计,突破了这个限制。
具体参加下图,在MLA里,就WbK拿来说,如果这里面的向量是正交的,那么每个通道在乘以XWak之后,输出在各个通道间都不一样。可GQA呢,同一组里所有头的输出都是相同的。就因为这种结构上的差别,在KV缓存大小一样的情况下,MLA的表达能力更强。说白了,MLA通过调整网络结构,优化参数更新策略,让注意力计算过程更高效,这样就能更好地捕捉复杂的语义关系,提升模型的能力。
2.1.5 完整过程
完整的对比过程如下图。图中上方是总体思路。下方是MLA和GQA的对比,其中又分为两部分,上部分是通过公式看看MLA如何增强表现力;下半部分是完整的流程。
2.2 权重吸收
2.2.1 当前状态
我们目前已经通过向下投影将压缩的隐向量进行保存,这减少了KV Cache的内存占用。也通过向上投影矩阵增强了表达能力。然而,MLA强调的是激活值也随之减少。当前我们还看不出来怎么减少激活值的数量的。因为虽然压缩之后的KV占据内存比较少,但是在每次推理的时候,都必须通过 WUK,WUV 来根据缓存的cKVt重新计算出 kt,i,vt,i,单从KV的数量和维度上看跟MHA是一个量级,比GQA和MQA都要多,上采样后的 kv cache巨大,可能导致OOM。不但内存不少(kt,i,vt,i依然存在),还引入了新的计算量,会处于计算瓶颈。没有达到用时间和算力来换取空间的目的。
2.2.2 权重吸收
既然每次计算量太大,DeepSeek就想是否可以在保存压缩的隐向量的基础上来减少这个计算量(其实也减少了新KV的内存占用),于是他们给出了权重吸收这个法宝。即其作者利用矩阵结合律的特性对这些公式进行了优化,避免了针对每个query重新计算key与value,下面是文章中的原文:
备注:矩阵吸收计算是指利用矩阵乘法的结合律或低秩分解等线性代数技巧,改变矩阵的乘法顺序,重新组合某些矩阵因子,使原本需要独立计算的矩阵乘积合并在一起,避免生成大矩阵,从而降低计算复杂度和内存开销的过程。
比如,给定三个矩阵 A∈Rm,k, B∈Rk,p, C∈Rp,n,通过矩阵乘法的可知(A×B)×C=A×(B×C),但是二者的计算复杂度是不一样的。 (A×B)×C的计算复杂度是 2×m×k×p+2×m×p×n=2×m×p×(k+n), A×(B×C) 的计算复杂度是2×m×k×n+2×k×p×n=2×n×k×(m+p)。当 n 相比 m 和 p 都显著更小的时候,第二种计算顺序的性能会远好于第一种。假设 ,m=k=p=4096,n=1 ,那么第一种计算顺序的计算复杂度是 2×4096×4096×4097,第二种方式的计算复杂度是 2×1×4096×8192,显著低于第一种。
但是,具体要如何用矩阵吸收,如何使用矩阵乘法结合律,需要权衡计算量,memory读写量和瓶颈,可以套用典型的Roofline Model进行分析。这里的核心就是 AC x CB 矩阵的最终效果和 AB 矩阵的效果有多少差异。
2.2.3 推导
KQ合并
我们来结合Dot-Attention的具体形式,看看如何通过一个简单但不失巧妙的恒等变换来规避这个问题。首先,在训练阶段还是照常进行,此时优化空间不大;然后,在推理阶段,我们利用如下公式(不带位置编码)可以看到,在推理阶段,我们把WUQ⊤WUK 合并为一个(和位置无关的)矩阵W作为Q的投影矩阵,就可以用ct代替原本的kt。这样就避免了重复计算中间结果q和k。
其中转置 ⊤ 表示交换张量形状中最后两个维度。各个张量的形状如下,这里注意 num_heads 要拎出来成为一个维度,因为最后 attention weight 的结果是头间独立的。
-
CQ:[batch_size,1,q_len,q_lora_rank]。
-
WUQ:[num_heads,q_lora_rank,qk_nope_head_num]。
-
WUK:[num_heads,kv_lora_rank,qk_nope_head_num]。
-
CK:[batch_size,1,kv_len,kv_lora_rank]
我们每次缓存的 cKVt都可以直接参与计算,而不需要显式的计算出K。而且,W 矩阵是可以事先就通过WUQ⊤WUK计算出来的(其实就是被神经网络自动计算出来)。
代码表述如下:
|
|
VO合并
另外,传统方法需要先计算 Value 向量 vCt ,然后再进行注意力计算并投影到最终的输出层。我们可以直接将 WUV吸收到 WO里,简化最终的输出计算。吸收公式如下(此处提取剧透了rope和nope分离的模式):
(p⋅(ckv⋅WUV))⋅WO=(p⋅ckv)⋅(WUV⋅WO)=(softmax(qnope⋅ckv+qpe⋅kpe)⋅ckv)⋅WUV⋅WO
可以用代码描述为:
|
|
注意我们需要小心的通过转置等手段保证数学上的恒等。参见下面图,每个注意力头都可以消融成一个矩阵,因此,实际代码中可以使用高维矩阵将所有head消融在一个矩阵里,代码表述见下面。
代码表述:
|
|
结合
把目前的合并结合起来,我们得到如下:
O=AWO=ϕ(QKT)VWO=ϕ[HWQ(CKVWUK)T]CKVWUVWO=ϕ[H(WQWUKTCKVT]CKV(WUVWO)
这样,在推理时期WUK可以和WUQ.WDQ结合,WUV和WO结合,最终只有WQ和WO。矩阵合并以后,对KV的整个计算过程都在低维空间进行,不会出现再把CKV解压缩回高维空间的情况。 而且,上述矩阵全都是模型的权重,再推理过程重是不会变的,可以看作常量。如果是部署推理服务的话,再加载模型的时候就可以把这两个矩阵乘好,为以后的每次推理节省两次矩阵乘法。实际上并无额外的算力开销。MLA就达到了克服以往方法中KV Cache过大的问题并且保留的KV Cache该有的减少重复计算的功能。
2.2.4 讨论
训练
论文中一直提到在推理阶段使用权重吸收,这点很好理解,因为此时权重矩阵固定了。
那么什么不在训练阶段直接结合WUK和WUV,其原因大致如下:
- 从梯度更新的角度来看,不做权重吸收会使得优化更加简单,即遵从下面的方式进行训练更好∇(ϕψ)=ψ∇(ϕ)+ϕ∇(ψ)。
- 从投影的角度来看,KV共享WDKV某种意义上对于空间构成了一种约束,Weight Tying 使得模型能够更好的收敛,并且提高其泛化能力,还可以提高模型的稳定性。
所以,MLA在训练阶段和MHA类似。除了多一步低秩投影以及只在部分维度加RoPE之外,MLA与Q、K的头维度由dk 换成dk+dR的MHA一样。
MHA
其次,既然权重吸收这么好,为什么MHA没有做权重吸收?
我们先看看推理阶段的特点。
首先,MHA中的计算公式如下(为了演示方便,这里先讨论单头),在标准的MHA实现中,quey、key、value的embedding是分别计算的,然后通过query embedding和key embedding来计算self-attention的权重矩阵,之后将这个权重矩阵和value embedding进行相乘得到最终的结果。但是如果我们展开公式如下。
Z=softmax(qTtki√dk)vtWO=softmax(hTt(WQ)TWKhi√dk)hiWVWO
此处看起来,(WQ)TWK和WVWO都有吸收的可能。
其次,Decode 计算时,输入的 Q 往往只有一个 token,这就天然给我们一个简化计算的机会。即这个顺序是可以交换的,即从query的embedding出发,一直向下进行计算,得到最终的结果。因为首先将比较小的query embedding参与计算,因此看起来整体计算复杂度会明显降低。而且看起来和MLA的思路非常类似,即将
K 的 projection 放到 Q projection 之后,将 V projection 放到 attention 之后,output projection 之前。
目前看起来MHA做矩阵吸收的好处颇多。然而,事实并非如下简单。我们通过qTtki 为例来进行分析为何MHA不适合吸收,以及为何MLA可以提高效率。
qTtki=(WQht)T(WKhi)=hTt(WQ)TWKhi
对于单个头,nh=1,对应矩阵乘是[1,d]×[d,dh]×[dh,d]×[d,1]。我们来看看这个矩阵乘哪些可以计算,哪些可以存储。有以下几种可能:
- 标准KV Cache。
- 存储角度:我们把WKhi存储起来,就是存储k(v和k一致),则KV Cache大小为:2nhdhl.
- 计算角度:每个头实例化参数是WQ,WK,WV,WO,大小为4ddh。
- 把(WQ)TWK结合到一起,并把结合后的权重施加到x上
- 存储角度:存储(WQ)TWKhi作为新的cache,其大小为2dnhl,与KV Cache相比扩大了nh倍。
- 计算角度:每个头实例化参数是(WQ)TWK 和 WVWO。大小为2d2。
- 把(WQ)TWK结合到一起,但是只cache x,不cache k和v的权重。
- 存储角度,需要存储的cache大小是dl,相比标准kv cache减少了一半;
- 计算角度,每个头实例化的参数为(WQ)TWK 和 WVWO。大小为 2d2
结合上面的分析,标准的kv cache已经相对而言在空间开销上和计算上是最优的了,尽管我们可以通过只 cache x减少一半的kv cache,但是结合后的矩阵放到运行时计算也增加了计算量,权衡下并不是好的方案。
我们再来看看MLA。WK做了低秩变化后,从[dh,d]变成了[dh,r]×[r,d], $ h_tT(WQ)TWKh_i变成了 h_tT(WQ)TWW^{DKV}h_i$。
对应矩阵乘是[1,d]×[d,dh]×[dh,dc]×[dc,d]×[d,1]。我们来看看这个矩阵乘哪些可以计算,哪些可以存储。 ,那么有以下几种可能:
- 从存储的角度:此时存储的kv cache就是 WDKVhi, cache大小是 dcl ,加上旋转位置编码的部分,总的kv cache是(dc+dRh)l ,和MHA进行比较,则是(dc+dRh)/2d =(512+64)/(2∗5120) =5.58%
- 从计算的角度: WUK 可以被合并(merge)到 WQ 中,类似地,WUV 可以被合并(merge)到 WO中。这样实例化的权重就变成了原来的 d/r 分之一
- 无论是存储还是计算的角度,MLA的拆分方法都优于MHA。
所以到这里我们就明白了,MLA的好处来源于两个方面,一个是kv cache的显著降低,另一个是权重的合并和吸收。
不合并
具体实施过程中需要依据实际情况进行抉择,比如 李伟华大神 在https://developnotes.readthedocs.io/zh-cn/latest/deepseek.html#id1 有精彩论述。
考虑如下运算:Y=XAB,C=AB。其中X∈Rm×d是输入的hidden states,A∈Rd×dc和B∈Rdc×n是权重矩阵,C∈Rd×n是吸收后的矩阵。
直接计算Y=XAB的flops是 2mddc+2mndc=2mdc(d+n),合并后计算C=AB的flops是2mdn。如果dc较小,则dn>dc(d+n),计算量太大,所以不一定需要进行权重吸收。
或者我们使用MLA的实际代码来看。已知配置如下:
|
|
两种情况的计算量如下:
- cQt⊤WUQ⊤WUK的计算量是:2×(q_lora_rank×hidden_size×qk_nope_head_dim+kv_lora_rank×hidden_size×qk_nope_head_dim)=2×hidden_size×qk_nope_head_dim(q_lora_rank+kv_lora_rank)=2×5120×128(1536+512)
- cQt⊤WUQK 的计算量是:2×hiddensize×qlorarank×kvlorarank=2×5120×1536×512。
可以看到,把WUQWUK合并后计算量反而增大很多。prefill 的时候其实是不要做“吸收”的,可以按 $ ({c_tQ}{W{UQ}} )(W^{UK} c_t^{KV})或者 ({c_tQ}{W{UQ}} W^{UK} )c_t^{KV}$来计算。
因此,他认为,Absorb 的真实含义其实是矩阵乘法结合律,优先结合某些矩阵,并缓存 compressed latent vector cKVt, 并不是合并权重矩阵,用 Absorb 命名有一定误导性。如果吸收,也是WUK被吸收到QC,而非WUQ。
2.3 解耦RoPE
为提高模型对序列中上下文信息的敏感性,MLA中应用了解耦旋转位置编码(RoPE)技术。而迄今为止,我们在分析中丢失了一个非常重要的步骤,即位置编码。这是因为RoPE与低秩KV压缩矩阵不兼容(与权重吸收会冲突),此时还无法无缝切换。为了解决这个问题,MLA引入额外的查询向量qRt和共享键向量kRt来携带RoPE信息。从架构图中可以发现,DeepSeek的q和k各自都有2个部分,分别是[qRt,qCt]和[kRt,kCt]。
- 1个部分是压缩部分:[qCt]和[kCt]。
- 1个部分则加上了RoPE位置编码。即有独立一路做RoPE:[qRt]和[kRt]
最终两个部分拼接成Q,K矩阵。这样就把RoPE与低秩压缩矩阵之间做了解耦,解决了位置信息与推理效率之间的矛盾。
我们接下来仔细进行剖析。
2.3.1 RoPE背景
下面代码是Llama 3计算注意力的摘要。RoPE 旋转位置编码中Query和Key都是位置相关的。在进行注意力计算前,代码是先应用WK等矩阵得到Q和K,然后在Q和K上施加RoPE(乘以一个旋转矩阵),以此在Q和K中融入相对位置信息。
|
|
2.3.2 问题
无法直接应用到低秩压缩
我们先看看是否可以把RoPE 施加到低秩压缩向量上,即RoPE直接被低秩压缩向量K和V所吸收。
因为K和V的低秩表示已经是压缩了的状态,压缩操作可能已经丢失了某些信息,而RoPE矩阵对key和value是位置敏感的,直接在cQt 和 cKVt 上应用 Rm 和 Rn 不再等价于在完整的Q和K上应用位置编码,不能直接和有效地反映原始Q和K的相对位置关系。换言之,RoPE与低秩KV压缩不兼容(RoPE is incompatible with low-rank KV compression),只能作用到原始K和V上。即只能从低秩KV压缩先还原成原始的KV,然后在原始KV上施加RoPE。之前已经学习过,这样做对性能有损失,所以采用了权重吸收。
与权重吸收不兼容
我们仔细看看RoPE作用到原始K和V上时,是否可以被权重吸收。
在RoPE的实现中,如果我们要让Q、K带上位置信息,会分别乘以相应的位置编码矩阵。
^Q=RmQ^K=RnK
如果计算QTK时,就变成了
S=QTRTmRnK
DeepSeek-V2对Q和K都进行了压缩,则整个过程变成:
S=(WUQcQt)TRTmRnWUKcKVt=cQt⊤WUQ⊤RTmRnWUKcKVt=cQt⊤WUQ⊤Rm−nWUKcKVt
这里,WUQ 和 WUK 分别是用于从低秩表示恢复到原始维度的解压缩矩阵。目前公式中间多了一个与token位置差t-i相关的矩阵Rm−n,该矩阵随着相对位置变化而变化,并不是个固定矩阵,无法提前计算好。并且矩阵乘法不遵循交换律,没办法把Rm−n挪到公式的其它地方,因此在推理时,WUQ 和 WUK 无法直接进行交互,WUK 就无法整合到 WQ 中。即WUK 和 WQ 无法合并为一个固定的投影矩阵。如果要强行降低KV Cache,则必须将参数簇Rsm−n,s=1,2,…,headnum全部全部缓存下来。这个参数簇包含了O(sequence_length2)个参数张量,实在太大。因此,这就导致DeepSeek-V2原定的权重吸收无法实现,在推理过程中需要对所有前置tokens对应的Key进行旋转位置编码的计算,这会降低推理速度。
下图给出了更加精确的阐述,上方是NoPE,下方是RoPE。
2.3.3 解决方案
为了解决MLA中的RoPE与低秩KV联合压缩不兼容的问题,DeepSeek团队提出了解耦RoPE的策略:对于一个head,用一个高维度的向量表示其文本信息,以及一个低维度的向量来表示其旋转位置编码信息。前面的高维度向量称为nope,后面的低维度向量称为rope。具体而言是,把Query和Key进行拆分为[qRt,qCt]和[kRt,kCt],其中一部分小向量进行了旋转位置编码( qRt,kRt ),一部分大向量进行压缩( qCt,kCt)。
- 信息存储部分( qCt,kCt)。这部分存储了大部分的业务信息,是被压缩的。下图的红圈和紫圈表明,我们有nh个注意力头,因此,我们需要把qCt,kCt𝑡分别均分为nh份。下标 i 表示的是第 i 个头。
- 位置信息部分( qRt,kRt )。具体又分为两部分。
- 使用共享的键(shared keys)kRt∈RdRh 来携带RoPE信息,dRh 表示解耦的queries和key的一个head的维度。共享的kRt指的是每个头的K都用这同一个kRt。注意,此处是基于 ht(输入嵌入)而不是基于向下投影的 CKVt 来生成kRt。
- 使用额外的多头查询(multi-head queries) qRt,i∈RdRh 来携带RoPE位置信息。注意,此处是基于cQt生成qRt,而且每个头会有自己的qRt,i。
最后将这四个变量分别拼接起来进行注意力计算。从而在推理时不需要对Key进行位置编码的计算,避免了RoPE与低秩压缩矩阵之间的耦合问题,解决了位置信息与推理效率之间的矛盾,提高了推理效率。具体参见下图。
最终乘积计算如图中标号4.1,其中前一项(标号4.2)按照无RoPE的情况计算,推理时只需要缓存cKVt,后者(标号4.3)则对于所有注意力头只缓存一个共享kRt。即,在推理阶段,单个Token产生的KV Cache包含了两个部分。
- 需要缓存键值的压缩潜在向量cKVt(维度为dc)。
- 携带RoPE信息的共享键向量kRt(维度为dRh)。
一共是(dc+dRh)l 个元素,l是层数。这种折中的方法保证了KV Cache的显存空间依然很小(虽然在 𝑑𝑐 的基础上增加了64维的 𝑑𝑟 ),FLOPS上有增加但是代价不大。
经过Concat过程会增加 Q 和 K 向量的维度。为了处理增加的维度,模型可以选择:
- 增加注意力头的数量:这将保持原有的每头维度,但需要更多的计算资源。
- 调整每个头的处理维度:保持头的数量不变,但提高每个头的维度,以适应Concat向量。
下图给出了清晰的对比。进行注意力计算时,cKVt分别通过上投影矩阵WUK和WUV还原出键和值,每个注意力头上的键再与携带了RoPE信息的共享键向量kRt拼接形成MHA的键值输入。cQt通过上投影矩阵WUQ和WUR还原并生成查询向量qCt和携带RoPE信息的查询向量 qRt,二者拼接形成MHA的查询向量输入。最终多个头的输入拼接在一起,并经过线性映射WO得到最终的输出。
2.3.5 和权重吸收结合
我们再看看结合权重吸收之后如何处理,这里就需要将nope和rope也加进来,公式演变如下。
2.4 资源占用
2.4.1 参数量
MLA的思路来自LoRA,LoRA强调的是参数量的减少,而MLA也确实做到了减少参数量。按DeepSeek-V3的参数配置,两个低秩矩阵参数量: 2×dc×d=2×512×7168 ,而正常MHA的参数矩阵参数量: d×d=7168×7168 。
具体参数如下:
|
|
各个矩阵的参数量如下:
-
WDKV:dim * kv_lora_rank = 7168 * 512
-
WUK:kv_lora_rank * qk_rope_head_dim * n_heads = 512 * 128 * 128
-
WUV:kv_lora_rank * qk_nope_head_dim * n_heads = 512 * 128 * 128
-
WKR: dim * qk_rope_head_dim = 7168 * 64
-
WDQ:dim * q_lora_rank = 7168 * 1536
-
WUQ: q_lora_rank * qk_nope_head_dim * n_heads = 1536 * 128 * 128
-
WQR:q_lora_rank * qk_rope_head_dim * n_heads = 1536 * 64 * 128
-
WO:n_heads * v_head_dim * hidden_size = 128 * 128 * 7168。
2.4.2 内存占用
但MLA强调的是KV-cache的减少,也就是KV的激活值减少。我们接下来继续分析。与经典的MHA和GQA,MQA比较。MLA实际缓存的向量是:
- cKVt,维度是dc。
- kRt,维度是dh/2。
如下图所示,我们可以看出,MLA在优化kv cache和保证模型效果上有很强的优越性。图中nh是注意力头数量,ng是GQA分组数,dh是隐藏层维度(低秩压缩后的维度),dc是KV压缩维度,l为block的块数。和MHA相比,Q和K的头维度变成了dc+dr,V的头维度变成了dc,对于DeepSeek-V2,dc 被设置为4dh,而dRh被设置为dh2。KV Cache的数量以元素数量来衡量(不考虑存储精度)。
- 在MHA中,推理阶段针对每个Token,需要缓存其键向量和值向量,则每个Token的缓存参数个数为2×nh×dn×l。与MHA相比,MLA占用的token数92dhl 通常要小于2nhdhl,所以MLA能获得比 MHA 更强的性能,显著降低了KV缓存的大小。
- GQA 通过分组共享 K/V 矩阵(如 LLaMA-70B 设置 g=8)减少显存占用,但压缩率有限(仅减少到 g/h 倍)。与GQA相比,MLA相当于GQA中的组数量 𝑛𝑔 =2.25,小于大多数Model里的 group数量,由此可见,其kv cache的尺寸会大大减小。即,MLA 的 KVCache 存储成本约等于GroupNum=2.25 的 GQA 的 KVCache 存储成本。
- 与MQA相比,MLA相当于增加了2.25倍的存储,但是MLA的性能和效果显著优于MQA,甚至强于MHA和GQA,真正实现了即降低推理成本,又保证了模型性能。
2.4.3 计算量
和MHA相比,MLA的Q和K的头维度变成了dc+dRh,V的头维度变成了dc。而 DeepSeek V3的一些超参数如下:
- dk(hidden dimension/模型维度):7168。
- nh(注意力头数):128。因为MLA的KV Cache大小跟nh无关,增大nh只会增加计算量和提升模型能力,但不会增加KV Cache。
- dh(每个注意力头的维度):128。
- dc(KV的压缩维度):512,即4dh。
- dRh(RoPE头相关维度):64,即dh2。
既然MLA每个头的Q/K的head size变大了不小,所以MLA的推理计算量增加了。那为什么还能提高推理效率呢?其实,MLA可以提高效率是因为结合了LLM推理的瓶颈时访存而不是计算这一特性。我们可以将LLM的推理分两部分:第一个Token的生成(Prefill)和后续每个Token的生成(Generation),Prefill阶段涉及到对输入所有Token的并行计算,然后把对应的KV Cache存下来,这部分对于计算、带宽和显存都是瓶颈,MLA虽然增大了计算量,但KV Cache的减少也降低了显存和带宽的压力。Generation阶段由于每步只计算一个Token,实际上它更多的是带宽瓶颈和显存瓶颈,因此MLA的引入理论上能明显提高Generation的速度。另一方面,由于Compressed KV在每个head中都参与了计算,DeepSeek-V2的128个heads能够提供足够的计算强度(正比于 Head 数),这样就把 LLM 解码过程的访存密集型,转换为计算密集型的操作,因此Attention部分的MFU也得到了大幅提高。
我们假设 q 的形状是(b,nh,sq,dh),cKV的形状是(b,1,skv,dc),WUK 的形状是(dc,nh,dh)。prefill阶段,sq=skv=s。
- native的计算量是:2bsdcdhnh+2bnhssdh=2bnhdhs(dc+s)。
- 吸收后的计算量是:2bsdcdhnh+2bnhssdc=2bnhdcs(dh+s)。
两者相比是:(dh(dc+s))/(dc(dh+s))。
decode阶段,sq=1,skv=s。
- 缓存K的计算量。2bdcdhnh+2bnhsdh=2bnhdh(dc+s)。
- 缓存潜向量时候的计算量。2bsdcdhnh+2bnhsdh=2bnhdh(dcs+s)。
- 吸收后的计算量。2bdcdhnh+2bnhsdc=2bnhdc(dh+s)。
2.4.4 信息转移
有研究人员再读MLA,还有多少细节是你不知道的认为,MLA的作用其实是"信息转移“,即把KV头中独有的信息转移到对应的Q头上,而把KV头中间共享的相同信息存储到KV Cache中。具体思路如下:
- 改进目的:在尽量不压缩head上K、V信息的情况下,节省kv cache。
- 改进背景:之所以要保存token对应的所有注意头上的K、V值,是因为每个k_head附带有不同的信息,它将用这份独有的信息和对应的q_head进行注意力计算。
- 改进思路(下面以K头为例,V头类似):
- 把一个token中所有K头中的共有信息抽取出来,压缩到KV Cache中,因为这些共有信息会更少,只保存它们才能减少KV Cache的大小。这个相同信息是每个tokens的所有k_heads共享一份,同时在不同tokens间共享。
- 把K中每个头上独有的信息转移到对应的Q头上。因为Q头需要承载更多信息,所以Q和K的头维度变成了dc+dr,dc 被设置为4dh。在V3上,dc是512,相当于把缓存7168维的向量降低到了缓存512维。而q压缩之后是1536维,之所以这么大,就是因为Q要承载更多的信息。
虽然从形式上来说,MLA和MQA/GQA很像,似乎都是通过压缩k/v_heads的数量来节省KV cache大小的。但MLA是压缩num_heads,不压缩信息(把信息转移到了q_heads上);而MQA/GQA则在一定程度上对信息做了压缩。具体这些相同信息、相异信息存储在何处?是在WK矩阵中?还是存储在原始token ht中?笔者目前不能确定。所以只能用下图展示。
另外,GQA 的分组数需严格匹配硬件规模(如 8 卡对应 g=8),限制了模型部署的灵活性。而 MLA 通过潜在空间投影和解耦式权重合并,可动态适配不同硬件配置(如单卡或多机集群)。GQA 为弥补性能损失需增大 FFN 层规模(如 LLaMA3-70B 的 FFN 参数量增加 20%),导致模型复杂度上升。MLA 则通过低秩投影和动态路由,无需额外补偿即可维持性能。
2.5 并行
在大模型推理的decode阶段,MLA无法使用张量并行。故在目前的一些开源实现中,主要还是基于数据并行来对MLA进行处理,即不同请求的KVCache存储到不同的GPU中。DeepSeek-V3论文提到使用张量并行和序列并行。
- 张量并行:MHA通常对head_num维度进行切分来实现张量并行。而MLA则有自己的特点,如果采用 tp 并行时,部分权重和 kvcache 都无法按 head_num 划分到不同的卡上。
- 使用张量并行部分:kv_b_proj、o_proj等模块都包括了head维度,因此可以按照head维度切分执行张量并行,将MLA计算均匀的划分到多卡上,实现并行加速。
- 难以使用张量并行部分。
- mla存储KV Cache时,对于一个token存储的是(1, 1, kv_lora_rank+qk_rope_head_dim),而不是常规MHA下的(1, kv_head_num, head_dim)。因此KVCache中只保存一份潜空间的压缩向量,并不包含head维度,没有办法按照head进行划分。导致每张卡上都要保存所有请求的的完整kvcache,其形状是(bs, 1, seq_len, kv_lora_rank),这意味着KVCache 各个卡的存储是冗余的。
- 部分权重由于head_num=1无法切分到不同的卡上,比如q_a_proj 和kv_a_proj_with_mqa不能按 head_num 切分。只有上投影矩阵才能考虑按列切分和最后输出矩阵按行切分。
- 数据并行。即按照请求切分,不同请求的潜空间的压缩向量存储到不同的GPU中。但是因为不同GPU上的请求长度可能差异很大,这样会导致显存占用不均衡,也会导致不同GPU上计算时间差异较大,进而导致性能最差的GPU拖慢整体进度。
- 序列并行:MLA会用序列并行(Sequence Parallel)来进行辅助。即,对KVCache按照序列维度进行切分,每一张卡上都使用query来做local的attention计算,然后对结果进行规约。
0x03 计算过程
我们来梳理下MLA在推理阶段的计算流程。
3.1 公式
首先,我们给出Q、K、V的变换过程对应的公式。后续会按照这个公式来进行解析。
3.2 原始流程
我们将上述公式转换为流程图,图中细节如下:
- 从上到下分为Q、K、V三路。
- Q和K又都细分为两路,“上路”绿色的权重和激活值对应隐向量/低秩部分;“下路”灰色渐变的权重和激活值对应decoupled RoPE。
- K的下路和V路的数据流向有所交错。
- “缓存” 代表在推理阶段会进行缓存的数据,具体分为两部分:
- KV联合隐向量 cKVt。
- 单独施加了RoPE的键$k_t^R 。K路位置编码模块接受的输入还是原始的ℎ_t而不是压缩后的c_t$。
此处假设头数nh为2,矩阵大小并不是完全按照比例缩放。
3.3 吸收
3.3.1 过程
接下来第二步,将论文中所说的权重吸收过程施加进去,得到下图:
-
推理阶段要缓存的东西不变。
-
WUK 吸收进 WUQ 之后。
- Q的上路计算逻辑没有变,但是权重和激活值的形状都有相应的调整。
- K的上路则直接少掉了一处线性映射的计算逻辑,变成了重复拷贝nh份,与K下路类似。
-
WUV吸收进 WO之后。
- V路由线性映射退化为重复拷贝的逻辑。
- 最后输出映射的计算逻辑不变,但是权重和激活值的形状有相应的调整。
-
红色字体公式代表了吸收对应的公式。绿色箭头表示有进一步吸收的可能。
3.3.2 吸收结果
我们对上图进行整理,得到吸收的结果如下。
3.3.3 MQA形式
MLA推理阶段的计算逻辑其实很像一个MQA,我们进行比对下(不考虑 RoPE)。
MQA和MHA的最大区别在于 K,V 是所有 head共享的,因此能够减少KV Cache的显存占用。其中 $$ Q_iTK=HT(W_iQ)TW^KH $$。
对于MLA,单独看 Attention 计算的前一部分,其中$ Q_iTK_i=HT(W{DQ})T(W_i{UQ})TW{UK}_iWH$,令 WQi=(WUKi)TWUQiWDQ,我们有 $$ Q_iTK_i=HT(W_iQ)TW^{DKV}H $$ 。可以看到这一计算公式和 Multi-Query Attention 其实是一样的,都是使用的单独的 Q 和共享的 K(CKV),等价于将single-head的KV重复拷贝若干遍再执行正常的MHA。
区别在于,这里 WQiH,WDKVH∈Rdc×l。也就是说在进行 attention 计算的时候,向量点积的维度是 dc 而不是 d。在论文中实际设置的是 dc=4d。也就是说 Multi-Head Latent Attention 其实是 head dimension 提高到4倍的 Multi-Query Attention。在论文中也提到了在推理的时候 absorb WUK into WUQ,其实就代表了这里的结合方式。因为每个head的维度提高了,所以能够计算出更加复杂的 attention分布,从而相比起 Multi-Query Attention 取得性能提升。相比起直接提高 head dimension,其优点在于所有head的 WDQ,WUQ,WUK的总参数量是 d⋅dc+d⋅dc+d⋅dc=3d⋅dc=12d⋅dh,而所有 head 的 WQ 的参数量是 d⋅dc⋅nh=4d2,节省了参数量。也就是说对 WQ 做了一个低秩分解。
但是这个提升并不是免费午餐,因为 head dimension 提高意味着 attention 的计算量也提高,而 attention 的计算量是 O(l2) 的。为了处理长文本,现在大家一般都倾向于尽可能降低 attention 计算量的常数,而这个方法是会增加常数的。以上分析没有考虑 RoPE,如果考虑 RoPE 的话,每个 head 的维度会从 4d 变成 4.5d,其中4d是没有 positional encoding的,0.5d 是使用 RoPE encoding的。其实 ChatGLM2-6B 中已经使用过类似的做法,即只在一半的 head dimension 上使用 RoPE ,目的是为了把 attention 计算分成位置相关和位置无关的两部分,与性能提升的关系并不大。
了看得更明显,我们可以把图中的一些权重进一步吸收合并,得到下图。
-
Q的计算过程退化为普通multi-head线性映射
- 每个head一部分维度保持不动,对应绿色部分
- 每个head另一部分维度施加RoPE变换,对应红色部分
-
K的计算过程退化为single-head线性映射
- 同样只对部分维度施加RoPE变换。
- 施加后进行重复拷贝(逻辑上如此呈现以便于理解,计算上当然可以优化掉)。
-
V则直接使用K中未经施加RoPE变换的部分,同样重复拷贝。
下图与标准MQA的区别是:
- QK只有部分维度施加RoPE;
- V与未施加RoPE的K共享激活值。
0x04 代码
我们主要使用V2的代码来分析,因为条理更加清晰。也需要注意的是,DeepSeek的代码在很多地方和论文不一致。V2中的DeepseekV2Attention的实现本质上和V3中的native一样,其实并没有节省KV-Cache,V3版本的非native版本是跟论文一致,节省了显存。
4.1 配置
我们摘录一些相关配置信息如下。在 Naive 实现中,512 维的 Latent KV cKV 被映射回对应 128 个 head,每个 head 128 维的 kC 和 vC,然后再拼接上位置向量 kR ,最终形成标准的 q、k、v 输入到标准的 Multi Head Attention 进行 Attetion 计算。另外,代码中也使用了norm,在论文中也有相应提及。
具体配置信息如下。其中:
- 键和值的压缩维度 dc :设置为 512 ,原始嵌入维度 𝑑=5120,比例为 1/10。由于键和值在推理时需要缓存,因此采用较大的压缩比例以显著减少内存开销。
- 查询的压缩维度 d′c :设置为 1536 ,比例为 0.3 。查询在训练时需要频繁计算,因此采用较小的压缩比例以保留更多信息,确保模型性能。
|
|
4.2 定义
给定输入向量ht∈RB×L×5120,其中B为batch size,L为sequence length。
|
|
对应的一些信息如下。把整个计算流程拆成 q_nope, k_nope, k_pe, k_nope 这四个部分就是为了把RoPE进行解耦。两个pe结尾的变量就是用于储存旋转位置编码的信息。Deepseek-V2将kv cache压缩到了同一个小矩阵中,后面再解压缩出来。
|
|
另外,https://github.com/sgl-project/sglang/discussions/3082 这里阐释了为何使用norm。
4.3 操作Q
我们把Q相关的代码都合并在一起进行分析。总的流程是:模型处理上一层计算出的隐藏状态(hidden_size=5120)时,首先会将模型的q压缩到 q_lora_rank 这一维度(设定为1536),再扩展到 q_b_proj 的输出维度(num_heads * q_head_dim),最后切分成 q_pe 和 q_nope 两个部分。
4.3.1 变量定义
MLA 中对 Q 投影矩阵WQ做了一个低秩分解,对应生成 q_a_proj 和 q_b_proj 两个矩阵。
- q_a_proj 大小为 [hidden_size, q_lora_rank] = [5120, 1536],对应公式中的WDQ,用来降维。
- q_b_proj 大小为 [q_lora_rank, num_heads * q_head_dim] = [q_lora_rank, num_attention_heads * (qk_nope_head_dim + qk_rope_head_dim)] = [1536, 128*(128+64)] = [1536, 24576] ,用来升维,对应公式中的 WUQ 和 WQR合并后的大矩阵。因为从公式来看这两个矩阵都需要和cQt计算,所以可以合并矩阵后再进行拆分。对于一个head,用一个128维度的向量表示其文本信息,以及一个64维度的向量来表示其旋转位置编码信息。前面的128维度,称为nope,后面的64维度,称为rope。
|
|
4.3.2 变量操作
在DeepSeek-V2中,Q向量也采用了低秩压缩的方式。
- 首先,将输入向量投影到一个1536维的低维空间:cQt=WDQ,ht∈RB×L×1536。对应论文第37号公式。
- 然后,将其投影到RH×128的多头向量空间上(其中H=128是heads数),得到了Q向量的第一部分:qCt=WUQcQt∈RB×L×H×128。对应第38号公式。
- 再将其投影到RH×64上并使用RoPE嵌入位置信息,得到Q向量的第二部分:qRt=RoPE(WKRht)∈RB×L×H×64。对应第39号公式。每个head有自己的旋转位置编码,每个head之间不共享。
- 将两部分拼接的到最终的Q向量:qt=[qCt,qRt]∈RB×L×H×192。对应第40号公式。
在具体的实现过程中其输入为 hidden_states 向量,对应公式中的 ht。是一个大小为 [batch_Size, sequence_length, hidden_size] 的矩阵,其中 hidden_size 具体为 5120。后续的nope指代非rope。
|
|
4.4 操作KV
我们把KV相关的代码都合并在一起进行分析。对于kv矩阵的设计,模型使用了kv压缩矩阵设计(只有576维),在训练时进行先降维再升维。在模型推理的时候,需要缓存的量变成 compressed_kv,经过 kv_b_proj 升高维度得到 k,v 的计算结果。
4.4.1 变量定义
KV向量和Q向量类似,也做了一个低秩分解,对应生成 kv_a_proj_with_mqa和 kv_b_proj 两个矩阵。
- kv_a_proj_with_mqa 大小为 [hidden_size, kv_lora_rank + qk_rope_head_dim] = [5120, 512 + 64] = [5120, 576],对应上述公式中的WDKV和WKR的合并矩阵,用来把输入先投影到一个低维的空间(对应CKVt),同时做两种降维操作(nope,rope的前置操作)。因为因为从公式来看这两个矩阵都需要和ht计算,所以可以合并矩阵计算后再进行拆分。输出的维度则是512+64=576了。前面的512维度是给kv的,后面的64维度是给key的旋转位置编码的。
- kv_b_proj 大小为 [kv_lora_rank,num_heads * (q_head_dim - qk_rope_head_dim + v_head_dim)] = [512, 128*((128+64)-64+128)] = [512, 32768],对应上述公式中的WUK和WUV的合并矩阵。由于WUK只涉及nope 的部分,所以维度中把 qk_rope_head_dim 去掉了。192-64是把key表示向量中的64维度的旋转位置编码向量从192维度中减去;然后的128维度是留给value的,因为value不需要考虑位置信息。需要考虑位置信息的只有query和key。
或者说,通过kv_a_proj_with_mqa 来对head脱敏,即得到的张量和具体的head无关;通过kv_b_proj来重新恢复成对每个head敏感,得到的是形如[1, 16, 26, 128]这样的,和具体16个head分别相关的张量。
|
|
4.4.2 变量操作
计算KV向量时,有几个和公式中不同的地方,即把某些矩阵操作打包在一起执行(同时将K,V的向量一起产出了),后续再拆分开。
-
首先需要将输入向量投影为512维的联合压缩表示:cKVt=WDKVht∈RB×L×512,对应第41号公式。
-
与Q向量的计算过程类似,K向量的第一部分是将cKVt通过投影解压缩到RH×128的多头向量空间:kCt=WUKcKVt∈RB×L×H×128,对应第42号公式。注意:此处增加了一个头维度。
-
K的第二部分是将输入向量投影到64维向量空间并施加RoPE嵌入位置信息:kRt=RoPE(WKRht)∈RB×L×64,对应第43号公式。
-
与Q不同的是,完整的K是将K的第二部分广播到每个head后与第一部分拼接得到:
kt=⎡⎢ ⎢ ⎢⎣kCt,1kRtkCt,2kRt⋮⋮⎤⎥ ⎥ ⎥⎦∈RB×L×H×192
也就是说,每个head的RoPE部分是完全相同的。此处对应第44号公式。再强调下:对于query,每个head有自己的旋转位置编码向量;key则是所有heads共享同一个旋转位置编码向量。
-
V向量的计算较为简单,直接将cKVt解压缩到RH×128即可:vt=WUVcKVt∈RB×L×H×128,对应第45号公式。
通过维度分析可以看到 kv_lora_rank 是 qk_nope_head_dim 的 4 倍且 K 和 V 共享 latent state,qk_rope_head_dim 只有 qk_nope_head_dim 的一半,结合起来 4+1/2=9/2,是 正式下图中 MLA KVCache per Token 大小的来源。
具体的代码实现如下,可以发现除了在对q做计算时涉及到gemv之外,也就是q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))),其它地方的矩阵乘运算q_len维度都是和num_heads在一起做计算,而num_heads在Deepseek2的配置里面已经是128了,导致其它的Matmul几乎都落在了计算密集的范畴。
|
|
4.5 注意力操作
4.5.1 变量定义
o_proj对应矩阵WO,大小为[num_heads * v_head_dim, hidden_states]=[128 * 128, 5120]。
|
|
4.5.2 变量操作
生成 QKV 向量之后的流程就基本上等同于标准的 MHA 计算了。唯一的区别在于只有 q_pe, k_pe 这两个部分给加上了 rope。具体流程如下:
首先计算attention score:
a=softmax(q⊤tkt+Mask√192)=softmax⎛⎝qCt⊤kCt+qRt⊤kRt+Mask√128+64⎞⎠∈RB×L×H×L
然后计算对V的加权和,并将所有head压平,得到Attention输出:
o=a⋅vt∈RB×L×H×128≅RB×L×16384
最后经过另一个矩阵的投影,就能得到MLA的最终输出:
u=WOo∈RB×L×5120
|
|
4.6 前向传播
我们把完整的前向传播代码摘录如下,大家可以更好的理解。
|
|
对应如下图例。
4.7 V3 代码
我们也给出V3代码具体如下。V3中的 native 版本其实并没有节省KV-Cache(甚至还多了存储),V3版本的非native版本是跟论文一致,节省了显存。
native 版本的实现直观、适合学习,但是不适合Decode阶段,因为Decode阶段需要用到KV Cache。针对KV Cache,native 版本的实现有两种选择:
-
① 缓存 Latent KV。缓存规模小,矩阵运算是(b,nh,1,dc)×(b,1,s,dc),假定是bfloat16精度,内存读取量是2bnhdc+2bsdc=2bdc(nh+s)。但 Latent KV 缓存不能直接送 MHA 计算,还得经过 WUK 和 WUV 的线性映射,这是两个规模不小的矩阵计算,而且每轮都得重复计算。
-
② 缓存 KV。缓存规模大,不用重复计算,性能好。标准MHA (b,nh,1,dh)×(b,nh,s,dh)的内存读取量是2bnhdh+2bnhsdh=2bdhnh(1+s)。但 MLA 的一大好处就是 KV Cache 压缩,这样显存内能缓存更多 token,支持更大的 batch 和 prefix cache。如果缓存 KV,在显存上对比 MHA 就完全没有优势了。
native 版本最终的选择是方案2。所以,Naive 实现可能会用于 Prefill阶段,但在 Decode 计算时需要更好的计算方法,也就是非native版本。在非native版本最核心的 Attention kernel 计算中,“吸收“模式下 K/V tensor Shape 中不携带 num_attn_heads 信息,计算逻辑转换成了类 MQA 计算,“不吸收”模式下 K/V tensor 仍携带 num_attn_heads,就是MHA计算。
|
|
具体比对如下图。
0x05 优化代码
DeepSeek代码并没有给出某些功能的具体方案,比如压缩优化和权重吸收。因此,我们主要以章明星老师给出的方案 https://github.com/madsys-dev/deepseekv2-profile/tree/main DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子为例进行学习。
5.1 压缩优化
目前V2代码中,Attention中的KV Cache缓存的仍然是全量的key和value(从隐向量又解压缩出来),而并非论文中所说的压缩后的compressed_kv以及k_pe,导致其实没有减少KV Cache的缓存。
主要原因可能是:一方面复用transformers原有的Cache逻辑,方便实验和理解;另一方面这部分应该是训练代码,而推理代码会针对这部分进行优化和改进。
我们可以做如下修改,也将RoPE后的k_pe一并缓存入KV Cache中。
|
|
章明星老师给出了更详尽的方案。
|
|
5.2 权重吸收
在计算MLA的时候,仍然需要存储解压后的完整的KV Cache,这很可能引起OOM崩溃。DeepSeek-V2的论文中提出,可以将KV的解压缩矩阵吸收到Q-projection和Out-projection中,从而可以在不解压缩KV Cache的情况下直接计算最终的Attention结果。
实际上,把权重吸收理解成矩阵乘法交换律更合适。因为实际上是提前将两个参数矩阵乘起来,即把 (WUQ)TWUK 的计算结果做为新的参数矩阵,然后再跟中间张量乘,在性能上不一定比分开计算更好。
下图分别给出了MHA、MLA和权重吸收的MLA的计算示例。最右侧的两个虚线箭头,显示了在优化的计算过程中,哪些参数矩阵被交换了位置。它们能交换的原因,就是从数学上这样修改是等价的(矩阵乘法交换律)。此时,输入注意力机制的 q、k、v 形状发生了明显的变化。q 的形状由[nh×(dh+dRh)]变化成了[nh×(dc+dRh)],k 的形状由 [nh×(dh+dRh)] 变化成了 [nh×(dc+dRh)],v 的形状由 dh 变化成了 dc。这样一来,新的计算过程中只剩下 ① Latent KV 了。原来的 ② KV 就不存在了,变成可以用Latent KV表示。而且实际上 V 也不存在了,因为 V 就是 K 的前 512 维。这其实就是MQA,这实际上就是 FlashMLA 代码库解决的问题。
我们接下来依据章老师的代码和文字来继续学习。
5.2.1 absorbed_cache_compressed.py
与论文不同,此处将代码中 kv_b_proj 中属于 K 的部分权重(论文中对应WUK)吸收进 q_nope(论文中对应 qC,而且是在运行时做,非提前吸收);将代码中 kv_b_proj 中属于 V 的部分权重(论文中对应WUV)吸收进 attn_out。抽象一点的理解就是,将 Q 也映射到 KV 的低秩空间,然后在低秩空间做完整的 Attention,之后再映射回 Q 的原始空间。
WUK
对于K的吸收,在注意力分数的计算公式中,非RoPE部分可以做如下展开:
qCt⊤kCt=(WUQcQt)⊤WUKcKVt=cQt⊤WUQ⊤WUKcKVt=(cQt⊤WUQ⊤WUK)cKVt
也就是说,我们事实上不需要每次都将低维的cKVt展开为kt再计算,而是通过矩阵乘法结合律,直接将 WUK 通过结合律先和左边做乘法,改为计算,避免了解压缩出完整的K矩阵。即将前三者进行计算:
attention_weights=(cQt⊤WUQ⊤WUK)cKVt
此外,在原始版本的解压缩的过程中,由于每个token的key都需要与WUK相乘才能得到,因此计算量较大;矩阵吸收后,WUK只需要对qCt这一个向量相乘,也大大减少了浮点计算量。
|
|
除了压缩KV Cache之外,我们还可以观察到上面涉及到的2个矩阵乘法实际上都来到了计算密集的领域,例如对于 torch.einsum('bhqc,blc->bhql', q_nope, compressed_kv) 。由于不同 head 的 q_nope 部分共享了共同的 compressed_kv 部分,实际计算的是 batch_size 个 [head_num * q_len, kv_lora_rank] 和 [past_len, kv_lora_rank] 的矩阵乘法。计算等价于一个 MQA 操作,计算强度正比于 head_num 的也就是 128。因此相比 MHA,吸收后的 MLA 计算强度要大得多,可以更加充分的利用 GPU 算力。
WUV
对于V的吸收,情况稍微复杂。为表述的清楚性,我们采用Einstein求和约定描述该过程
|
|
5.2.2 Move Elision
不过,这样还不能完全发挥出MLA的威力。在原始代码中,query_states和key_states会通过拼接RoPE和非RoPE部分得到:
|
|
当我们采取了上述优化后,此处的拼接过程会产生大量无用的数据拷贝和广播,同时也会占用大量显存空间导致OOM,而且如果是concat放在框架做,但可能会增加IO,尤其是decode本就是IO瓶颈。而且,先对Latent解压缩再计算,则Attn的计算是一个实打实的Multi Head Attention,会增大计算量。
为此,我们采用MoveElision优化策略,即省略此处的拼接RoPE部分和非RoPE部分的过程,而是直接分别计算量部分的Attention Score并相加(考虑q⊤tkt=qCt⊤kCt+qRt⊤kRt)。即,将 RoPE 部分与 NoPE 部分分别做乘法,然后进行拼接的操作,改为 NoPE 部分 Attention 和 RoPE 部分 Attention 两者结果相加,这样做的好处在于节省了内存搬运操作,这种做法等效于ALiBi。我们具体推导如下。
[q⊤t,ikj,i]=[cQtWUQ⊤,qR⊤t][WUKcKVtkRt]=cQtWUQ⊤WUKcKVt+qR⊤tkRt
具体对应下面代码中的torch.matmul(q_pe, k_pe.transpose(2, 3))这行。即,分开计算了RoPE部分的q和k的注意力计算再求和。标准实现是将加上了 rope 的 q_pe/k_pe 和没加 rope 的 q_nope/k_nope 拼接起来一起。
|
|
代码比对如下:
5.2.3 Materializing Projection Matrices
DeepSeek-V2的论文中说:
不过,似乎并没有必要再改变顺序,对模型参数进行预处理,将WUK与WUQ相乘,以及将WUV与WO相乘。这是因为,WUK与WUQ相乘后的结果可以视为H个大小为1536×512的低秩(不超过128)矩阵,而WUV与WO相乘的结果可以视为H个大小为5120×512的低秩矩阵。相比用这些特别大的低秩矩阵做投影,明显不如按照低秩分解形式依次相乘来得划算。因此,章老师认为这一步的优化并不是很有必要。
因为假设有矩阵 A[m,k],B[k,n],C[n,l],B 和 C 为低秩矩阵,依次相乘 A⋅B⋅C 需要的算力: 2mkn+2mnl=2mn⋅(k+l),而提前合并 D=(B⋅C),A⋅D 需要的算力:2mkl,当 n⋅(k+l)<kl 时,提前合并低秩矩阵,反而会引入更多计算。而在 LoRA 的推理阶段,之所以能这样做,是因为本身就已经存在一个大的 pre-train weight 的矩阵,因此提前做吸收,不会增加计算量。
具体代码如下:
|
|
5.3 融合算子
另外,如果针对prefill和decode阶段进行不同处理,则在推理的时候Prefill 和Decode 走的逻辑不同。
-
推理的时候 Prefill 是不做矩阵吸收的(原因是Prefill做矩阵吸收会增加计算量),MLA计算与普通的MHA计算大致相同,唯一的区别在于需要支持q/k和v/o使用不同的head_dim。
-
Decode 是要做矩阵吸收的,矩阵吸收ops 远小于矩阵不吸收。这是因为此时Q的长度是1,原本重复在KV 上做up projection的操作转移到了Q 上,让Q 投影到kv 的latent space 上,Q的长度远小于KV的长度,不需要对KV做重复做up projection。或者说,MLA的主要思路就是通过交换矩阵计算顺序,利用decode阶段query seq_len比较小的特点,优化矩阵计算开销,进而达到只存储Multi-head attention中hidden states cache,而不是key和value两个cache,进而降低一半KVCache存储的目的。
因此Decode阶段需要单独设计高效的融合算子,以便高效地与低秩kv-cache进行attention计算。
权重吸收之后,公式如下:
(p⋅(ckv⋅WUV))⋅WO=(p⋅ckv)⋅(WUV⋅WO)=(softmax(qnope⋅ckv+qpe⋅kpe)⋅ckv)⋅WUV⋅WO
可以用代码描述如下,即可以设计一个MQA算子来实现。
|
|
FlashAttention最初设计的初衷是减少对softmax矩阵储存的开销,其大小正比于 lq⋅lkv,占整体I/O的比值为:
ratio(softmax)=11+HkvHqoDLqo+DLkv
对于推理阶段而言,lq 其实是非常小的,不融合qk和pv两阶段的计算也能取得不错的效果。但是对于MLA而言,融合是必要的,这是因为:
- MLA有较大的group ratio: Hqo/HKV=128 ,会增大softmax的占比。
- MLA复用了key和value矩阵,因此如果我们不融合两阶段的话,前后两个算子将各自访问一遍KV-Cache,如果硬件的cache不够大的话,带宽利用率将无法超过50%。
5.4 矩阵乘的重排序(增补@2025-04-19)
内容参考:DeepSeek V3推理: MLA与MOE解析 Arthur
具体特点如下:
- 方案来源:SGlang,应用于DeepSeek-V2。
- 方案特点:基于矩阵乘法结合律改变计算顺序,从而优化注意力机制计算效率。在解码阶段,能够有效减少计算量。
- 方案内容:
- 原始计算顺序:qnopeknope+qropekrope。其中qnopeknope的计算方式是qTnope(WUKc)。FLOPs为 2dc−1)hdnk+(2d−1)hnqnk。
- 改进顺序为:(qTnopeWUK)c。FLOPs为(2d−1)hnqdc+(2dc−1)hnqnk。
这种改变利用了矩阵乘法的结合律,使得计算可以在不同的维度上进行重组,在解码阶段(nq=1 ),优化后的方法可以显著减少计算量。
0x06 转换
6.1 GQA
Group Query Attention(GQA)是MHA的一种变体,旨在减少KV缓存的开销。它将查询头分成多个组,每个组共享一个键和值对。这种方法通过减少键和值头的数量来降低KV缓存的大小,但可能会牺牲模型的表达能力。可以将GQA看作是MLA的一种特例。由于GQA是通过复制产生的,而MLA不受这种限制,表达能力更强。
尽管MLA在Deepseek V2/V3/R1中已经证明了其效率和有效性,但许多主要的模型提供商仍然依赖GQA。为了促进MLA的更广泛应用,论文“TransMLA: Multi-Head Latent Attention Is All You Need"提出了TransMLA,这是一种后训练方法,可以将广泛使用的基于GQA的预训练模型(例如LLaMA、Qwen、Mixtral)转换为基于MLA的模型。转换后,模型可以进行额外的训练以增强表达能力,而不会增加KV缓存的大小。
6.1.1 思路
论文首先证明了对于相同的KV缓存开销,MLA的表达能力总是大于GQA。具体来说,任何GQA配置都可以等价地转换为MLA表示,但反之不然。这一结论为将基于GQA的模型转换为基于MLA的模型提供了理论基础。
在等价转换过程中,TransMLA方法首先将GQA中的键矩阵进行复制,以匹配查询头的数量。然后,它将这个复制后的键矩阵分解为两个较小矩阵的乘积,从而得到MLA中的低秩表示。通过这种方法,TransMLA可以在不增加KV缓存大小的情况下,将基于GQA的模型转换为基于MLA的模型。
6.1.2 方案
第一步是复制key矩阵,以匹配查询头的数量。在GQA中,为使标准多头注意力计算时,𝑄和𝐾(以及𝑉)具有相同数量的头,需要对𝐾进行扩展,从nk个头扩展到nq个头。这其实也有两种方法。
- 定义复制因子s=nqnk(nq为𝑄的头数,nk为𝐾的头数),将𝐾按列划分为nk个块K(i),通过将每个K(i)复制𝑠次并连接,得到扩展矩阵𝐾′。具体见下图(a)。
- 另一种方法是将复制操作移到参数侧(其实也是使用MHA替代GQA的方法),即在计算K之前,先复制投影矩阵WK。先将WK按列拆分为nk个部分W(i)K,然后复制每个W(i)K 𝑠次并连接,形成新的投影矩阵W′K,再应用W′K到𝑋直接得到K’=XW’K,此方法与先计算𝐾再复制其头在数学上是等效的。具体见下图(b)。
由于W′K由复制WK形成,其自由度最多为nkdh,因此它的秩最多为nkdh。为了更正式地理解这一点,通过奇异值分解(SVD)对W′K进行分解:W′K=UKSKV⊤K ,其中UK和VK是𝐷×𝐷正交矩阵,SK是包含奇异值的𝐷×𝐷对角矩阵。只有前nkdh(或更少)的奇异值可能是非零的。因此,可以截断SVD,只保留前 r 个奇异值,其中$ r \le n_kd_h。则𝑊’_𝐾=𝑊_𝐾𝑎𝑊_𝐾𝑏且𝐾′=𝑋𝑊_𝐾𝑎𝑊_𝐾𝑏$ 。这样就将GQA的“重复KV”方案解释为类似MLA的低秩分解形式,在实际缓存时,仅需存储低秩表示XWaK,在注意力计算时通过乘以WbK恢复完整维度,增强了模型的表现力。
6.2 MHA
如何使原本为 MHA 训练的 LLMs(如 Llama)快速适应 MLA 进行推理,而无需从头开始预训练,既具有意义又充满挑战。论文“Towards Economical Inference: Enabling DeepSeek’s Multi-Head Latent Attention in Any Transformer-based LLMs” 第一种数据高效的微调方法MHA2MLA,用于*从MHA转换到MLA。该方法包含两个关键组件:
-
对于partial-RoPE,论文从对注意力分数贡献较小的查询和键的维度中去除 RoPE。
-
对于低秩近似,论文基于键和值的预训练参数引入联合SVD近似。
这些精心设计的策略使 MHA2MLA 仅使用极少部分(3‰至 6‰)的数据就能恢复性能,显著降低推理成本,同时能与 KV 缓存量化等压缩技术无缝集成。
6.2.1 partial-RoPE
为实现从标准 MHA 到 MLA 的迁移,论文提出 partial-RoPE 微调策略,从目标比例的维度中去除 RoPE 并转换为 NoPE。
MHA
MHA 的 Full-RoPE 通过特定频率的旋转将位置信息编码到查询和键中,具体如下图所示。
拆解
MLA中,ki由[ki,nope;ki,rope]组成,所以我们首先需要把MHA的ki,rope也分解成这样的无RoPE编码和有RoPE两部分。
DeepSeek的MLA里面其实是在原始的每个head的不使用RoPE编码dh维度上,再增加一个使用RoPE编码的dRh维度。但是我们现在只能把全长为dh维度的ki,rope进行拆解,把里面dr,dr≪dh部分做RoPE编码。也就是r=dr2长度的2D子空间做旋转编码。
在注意力计算中,并非所有维度上的旋转位置编码(RoPE)都对结果有同等的贡献。Partial-RoPE 技术通过去除对结果贡献较小的维度上的 RoPE,减少了冗余计算。这就像是在一场考试中,抓住重点知识进行复习,避免在一些无关紧要的知识点上浪费时间。通过这种方式,Partial-RoPE 技术在不影响模型性能的前提下,有效提升了计算效率。
在从 Full-RoPE 转换到 Partial-RoPE 时,我们选择哪一部分子空间来做旋转编码呢?论文提出四种策略(主要是依据旋转的频率)来旋转 RoPE 编码的子空间。
- 高频保留:保留 r 个旋转最快(高频)的子空间,即位置最靠前的个2D子空间。
- 低频保留:保留 r 个旋转最慢(低频)的子空间。
- 均匀采样:选择间隔相等的 r 个子空间,即不管是高频还是低频,按照等距离采样,这样高低频都分别有一部分。
- 根据每个头2-norm贡献选择(Head-wise 2-norm Contribution):根据每个头中各子空间的 2-norm分数对所有子空间进行排序,选择前 r 个。第 r 个频率子空间对最终的attention logits的贡献有上界。
选择好了dh维度中的dr维度做RoPE位置编码,剩下的dh−dr部分我们就要当成当成MLA中的无位置编码部分,也就是qnope。但是要注意DeepSeek的MLA中这部分维度是dh,我们这里是dh−dr。
6.2.2 低秩近似
MHA中的ki=Wkxi,vi=Wvxi。我们已经使用上面的四种方法之一找到了需要做RoPE的部分,也就可以把Wk对应的部分取出来得到WKR。
我们也把Wk中对应非RoPE的部分参数提取出来:
ki,nope=Wk,nopexivi,nope=Wv,nopexi
我们的目标是从Wk,nope,Wv,nope中构造出MLA中的WDKV。
从 Full RoPE 转换到 Partial RoPE 后,为得到 MLA 中 KV 缓存的第二个组件ci,kv,论文提出两种基于SVD的策略:解耦 SVD和联合 SVD,具体参见下图。
- 解耦 SVD(SVDsplit):分别对Wk,nope和Wn进行截断 SVD 分解,分配dkv/2个维度给每个矩阵。
- 联合 SVD(SVDjoint):为保留Knope和V之间的交互,对连接矩阵[Wk,nope,Wv]进行联合分解。这种分解方式更加贴合MLA的标准格式。
到这里,我们就处理完了key和value部分。query部分并不像DeepSeek里面的MLA一样再做低秩分解,而是把得到的query对应key中的nope和rope部分也分解成两部分。
参考资料
-
No backlinks found.