| 英文 | 缩写 | 中文释义 |
|---|---|---|
| BLAS | Basic Linear Algebra Subroutines | 基础线性代数程序集,一个 API 标准,用以规范发布基础线性代数操作的数值库 |
| FPGA | Field Programmable Gate Arrays | 现场可编程门阵列 |
在现代 GPU 计算中,矩阵乘法是一种很重要的计算场景。无论是像气象模拟、分子动力学、生物制药、核子物理,还是目前火热的深度学习,都会看到矩阵乘法的存在。早在 1979 年,基础线性代数程序集 BLAS1 即发布,为了提高计算性能,各个软硬件厂商都会结合自身的硬件,对 BLAS 接口进行高度性能优化,典型代表包括 Intel 的基于 CPU 指令优化的 MKL 库 2(之前主要是基于 Intel CPU 优化,目前针对 Intel 的 GPU 也有专门优化),以及 NVIDIA 的基于 GPU 优化的 cuBLAS 库3。目前以 Mathematica、MATLAB、NumPy、R 等为代表的数值计算库都会采用 BLAS-compatible 的库作为其底层线性代数计算的实现。
作为 BLAS 中最重要的代表,通用矩阵乘法 GEMM4 的优化决定了一个数值计算库的性能表现。这篇博客5 非常详细的解释了在 GEMM 中可以用到的各种软件优化算法,在进一步了解硬件优化之前,我们可以大致感受下软件算法能够讲矩阵乘法性能优化到什么程度,以下几张图片都来自该博客。
首先我们定义矩阵乘法,定义
- 维度为 $m*k$ 的矩阵 $A$
- 维度为 $k*n$ 的矩阵 $B$
如下所示:
则 $A$ 和 $B$ 经过乘法之后,得到维度为 $m*n$ 的矩阵 $C$ 如下所示:
其中:
通过图片可视化的可以表示成下图所示:
对应于伪代码可以表示为,可以看到如果是两个 $N*N$ 维度的矩阵相乘,这个最朴素的矩阵乘法时间复杂度为 $O(N^3)$。
|
|
随着矩阵乘法在现代科学计算中发挥越来越重要的作用,包括 GEMM 在深度学习的广泛应用6,人们并不满足于朴素矩阵乘法 $O(N^3)$ 的时间复杂度,尝试从各种方式对矩阵乘法的算法进行优化:
- 从数学角度:基于矩阵乘法的数学特性进行优化,典型的算法包括 Strassen 算法7 和 Coppersmith-Winograd 算法8
- 从软件角度:基于内存访问局部性原理和利用向量化指令等技术进行优化。
下图展示了从数学角度将矩阵乘法算法优化的效果,数学家们从 1969 年以前的 $O (n^3)$ 复杂度硬是拉到了 $O (n^{2.4})$ 以下,并且这个优化还在持续进行中。
对这一思路的理解,可以主要参考 Strassen 算法7,其采用分治的思路,将矩阵拆分成更小的矩阵,并引入用于辅助计算的中间矩阵,经过组合得到最终期望的矩阵。后续的优化算法延续 Strassen 算法的思路,继续尝试降低复杂度,以 Coppersmith-Winograd 算法为例,矩阵乘法复杂度降低到了 $O (n^{2.376})$,其算法证明过程比较复杂,此处不再介绍,感兴趣的同学可以参考从这篇 Wiki 8出发。
基于软件角度的矩阵乘法算法优化,会考虑到内存访问局部性,将原来简单 $MNK$ 三重循环的计算进行维度展开,拆分成多个小块计算,以提高对于输入数据的重用,减少内存访问。关于这部分的优化可以参考 How to optimize gemm。
以下图为例,原来的大矩阵被拆分成若干个 $4 * 4$ 的子块,
|
|
内存访问数变化如下,效率直接提升了 8 倍:
- 原来为 $(2 + 1 + 1) MNK$
- 矩阵 $A$ 和 $B$ 都需要读一次内存,因此访存数为 $MNK$
- 矩阵 $C$ 需要先读取内存,累加完毕再存储,因此访存数为 $2 * MNK$
- 三重维度展开后为 $M*N + \frac{1}{4}MNK + \frac{1}{4}MNK \approx \frac{1}{2}MNK$
- 矩阵 $A$ 和 $B$ 的访存均可以复用 4 次,则对应访存数为 $\frac{1}{4}MNK$
- $M$ 和 $N$ 循环共执行 $4∗4$ 次,每次存储 $4∗4$ 个 $C$ 的输出元素,因此矩阵 $C$ 的访存共 $MN$ 次,相比于 $MNK$ 可忽略
除了这种将大矩阵拆分成多个小块计算,对于 CPU 常见的优化还包括 SIMD,vectorize
Cutlas
-
BLAS, wikipedia, https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms ↩︎
-
Intel MKL Library, https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html ↩︎
-
NVIDIA, cuBLAS, https://developer.nvidia.com/cublas ↩︎
-
GEMM, wikipedia, https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3 ↩︎
-
: 通用矩阵乘 GEMM 优化算法,黎明灰烬, https://zhenhuaw.me/blog/2019/gemm-optimization.html ↩︎
-
https://petewarden.com/2015/04/20/why-gemm-is-at-the-heart-of-deep-learning/ ↩︎
-
Strassen algorithm, https://en.wikipedia.org/wiki/Strassen_algorithm ↩︎ ↩︎
-
Computational complexity of matrix multiplication, https://en.wikipedia.org/wiki/Computational_complexity_of_matrix_multiplication ↩︎ ↩︎
-
No backlinks found.