英文 缩写 中文释义
BLAS Basic Linear Algebra Subroutines 基础线性代数程序集,一个 API 标准,用以规范发布基础线性代数操作的数值库
FPGA Field Programmable Gate Arrays 现场可编程门阵列

在现代 GPU 计算中,矩阵乘法是一种很重要的计算场景。无论是像气象模拟、分子动力学、生物制药、核子物理,还是目前火热的深度学习,都会看到矩阵乘法的存在。早在 1979 年,基础线性代数程序集 BLAS1 即发布,为了提高计算性能,各个软硬件厂商都会结合自身的硬件,对 BLAS 接口进行高度性能优化,典型代表包括 Intel 的基于 CPU 指令优化的 MKL 库 2(之前主要是基于 Intel CPU 优化,目前针对 Intel 的 GPU 也有专门优化),以及 NVIDIA 的基于 GPU 优化的 cuBLAS 库3。目前以 Mathematica、MATLAB、NumPy、R 等为代表的数值计算库都会采用 BLAS-compatible 的库作为其底层线性代数计算的实现。

作为 BLAS 中最重要的代表,通用矩阵乘法 GEMM4 的优化决定了一个数值计算库的性能表现。这篇博客5 非常详细的解释了在 GEMM 中可以用到的各种软件优化算法,在进一步了解硬件优化之前,我们可以大致感受下软件算法能够讲矩阵乘法性能优化到什么程度,以下几张图片都来自该博客。

首先我们定义矩阵乘法,定义

  • 维度为 $m*k$ 的矩阵 $A$
  • 维度为 $k*n$ 的矩阵 $B$

如下所示:

$$A = \begin{pmatrix} a_{11} & a_{12} & \dots & a_{1k} \\ a_{21} & a_{22} & \dots & a_{2k} \\ \vdots & \vdots & \ddots & \vdots \\ a_{m1} & a_{m2} & \dots & a_{mk} \\ \end{pmatrix} , B = \begin{pmatrix} b_{11} & b_{12} & \dots & b_{1n} \\ b_{21} & b_{22} & \dots & b_{2n} \\ \vdots & \vdots & \ddots & \vdots \\ b_{k1} & b_{k2} & \dots & b_{kn} \\ \end{pmatrix}$$

则 $A$ 和 $B$ 经过乘法之后,得到维度为 $m*n$ 的矩阵 $C$ 如下所示:

$$C = \begin{pmatrix} c_{11} & c_{12} & \dots & c_{1n} \\ c_{21} & c_{22} & \dots & c_{2n} \\ \vdots & \vdots & \ddots & \vdots \\ c_{m1} & c_{m2} & \dots & c_{mn} \\ \end{pmatrix}$$

其中:

$$c_{ij} = a_{i1}b_{1j}+a_{i2}b_{2j}+\dots +a_{ik}b_{kj}=\sum_{t=1}^{k} a_{it}b_{tj}$$

通过图片可视化的可以表示成下图所示:

Source:
Source:

对应于伪代码可以表示为,可以看到如果是两个 $N*N$ 维度的矩阵相乘,这个最朴素的矩阵乘法时间复杂度为 $O(N^3)$。

1
2
3
4
5
6
7
8
for (int m = 0; m < M; m++) {
  for (int n = 0; n < N; n++) {
    C[m][n] = 0;
    for (int k = 0; k < K; k++) {
      C[m][n] += A[m][k] * B[k][n];
    }
  }
}

随着矩阵乘法在现代科学计算中发挥越来越重要的作用,包括 GEMM 在深度学习的广泛应用6,人们并不满足于朴素矩阵乘法 $O(N^3)$ 的时间复杂度,尝试从各种方式对矩阵乘法的算法进行优化:

  • 从数学角度:基于矩阵乘法的数学特性进行优化,典型的算法包括 Strassen 算法7 和 Coppersmith-Winograd 算法8
  • 从软件角度:基于内存访问局部性原理和利用向量化指令等技术进行优化。

下图展示了从数学角度将矩阵乘法算法优化的效果,数学家们从 1969 年以前的 $O (n^3)$ 复杂度硬是拉到了 $O (n^{2.4})$ 以下,并且这个优化还在持续进行中。

对这一思路的理解,可以主要参考 Strassen 算法7,其采用分治的思路,将矩阵拆分成更小的矩阵,并引入用于辅助计算的中间矩阵,经过组合得到最终期望的矩阵。后续的优化算法延续 Strassen 算法的思路,继续尝试降低复杂度,以 Coppersmith-Winograd 算法为例,矩阵乘法复杂度降低到了 $O (n^{2.376})$,其算法证明过程比较复杂,此处不再介绍,感兴趣的同学可以参考从这篇 Wiki 8出发。

Strassen 算法
Strassen 算法

基于软件角度的矩阵乘法算法优化,会考虑到内存访问局部性,将原来简单 $MNK$ 三重循环的计算进行维度展开,拆分成多个小块计算,以提高对于输入数据的重用,减少内存访问。关于这部分的优化可以参考 How to optimize gemm

以下图为例,原来的大矩阵被拆分成若干个 $4 * 4$ 的子块,

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
for (int m = 0; m < M; m += 4) {
  for (int n = 0; n < N; n += 4) {
    C[m + 0..3][n + 0..3] = 0;
    C[m + 0..3][n + 0..3] = 0;
    C[m + 0..3][n + 0..3] = 0;
    C[m + 0..3][n + 0..3] = 0;
    for (int k = 0; k < K; k += 4) {
      C[m + 0..3][n + 0..3] += A[m + 0..3][k + 0] * B[k + 0][n + 0..3];
      C[m + 0..3][n + 0..3] += A[m + 0..3][k + 1] * B[k + 1][n + 0..3];
      C[m + 0..3][n + 0..3] += A[m + 0..3][k + 2] * B[k + 2][n + 0..3];
      C[m + 0..3][n + 0..3] += A[m + 0..3][k + 3] * B[k + 3][n + 0..3];
    }
  }
}

内存访问数变化如下,效率直接提升了 8 倍:

  • 原来为 $(2 + 1 + 1) MNK$
    • 矩阵 $A$ 和 $B$ 都需要读一次内存,因此访存数为 $MNK$
    • 矩阵 $C$ 需要先读取内存,累加完毕再存储,因此访存数为 $2 * MNK$
  • 三重维度展开后为 $M*N + \frac{1}{4}MNK + \frac{1}{4}MNK \approx \frac{1}{2}MNK$
    • 矩阵 $A$ 和 $B$ 的访存均可以复用 4 次,则对应访存数为 $\frac{1}{4}MNK$
    • $M$ 和 $N$ 循环共执行 $4∗4$ 次,每次存储 $4∗4$ 个 $C$ 的输出元素,因此矩阵 $C$ 的访存共 $MN$ 次,相比于 $MNK$ 可忽略

除了这种将大矩阵拆分成多个小块计算,对于 CPU 常见的优化还包括 SIMD,vectorize

两个 4×4 矩阵相乘的向量化
两个 4×4 矩阵相乘的向量化

Cutlas