长文的挑战

当 sequence length 增加时,模型计算量和存储占用复杂度随着序列长度呈二次方增长,具体计算在前面已经推导过了,此处不再详述,可以参考这篇文章[^24]。本节采用的数学符号表示如下:

符号表示 数学含义
$N$ sequence length, could be 4 K or larger
$d$ head dimension, typically 128
$Q \in R^{N\times d}$ Query Matrix
$K \in R^{N\times d}$ Key Matrix
$V \in R ^{N\times d}$ Value Matrix
$S = QK^T \in R^{N\times N}$ Attention Score
$P=softmax(S) \in R^{N\times N}$ Row-wise Softmax by S
$O = PV \in R^{N\times d}$ Output Value

Attention 的 $N^2$ 计算复杂度

计算复杂度: 参考 https://www.harmdevries.com/post/context-length/ ,随着 sequence length 增加,Attention 计算成为主要部分。

Attention 的 $N^2$ 存储复杂度

内存占用 $N^2$ 的关键,就在于 Attention Score $S = Q^KT$ 的计算,当 sequence length 增加时,显存占用以 $O(N^2)$ 比例增加。FlashAttention 优化的思路,就是通过 Tiling 切块,将大矩阵切成小矩阵运算,减少中间内存占用。Tiling 是矩阵计算中常见的优化思路,Attention 的主要问题是 Softmax 的计算限制了必需要等待整行的计算结果才能够继续计算,这大大限制了 Attention 的分块并行计算。

长文训练加速技术

训练阶段

  • longCT

Sequence Parallel

算子层面

Attention 优化