为什么要 混合精度训练

  • 传统深度学习一般默认使用单精度 (FP32)训练,但模型规模一大,FP32 的内存和计算成本大

  • 只使用 FP16 会有精度损失的问题,所以需要 混合精度训练:本文主要介绍半精度(FP16)和单精度(FP32)混合的情况

  • 混合精度训练的三大技术

    • 维护 FP32 主副本的权重:前向、反向传播时,使用 FP16,而只在更新权重 的时候,累加到 FP32 的权重副本
    • 损失缩放:如果梯度太小,FP16 可能下溢,简单的应对方法是让梯度先放大一个倍数,保证计算过程中不溢出,更新权重时再缩放回来
    • 算术精度累积:在进行向量点积、求和等操作时,部分计算结果需要在更高的精度下累积,以避免 FP 16 的精度限制导致数值误差。

纯 FP32 训练 baseline

在 LLM 训练里,如果完全不用混合精度,最直接的 baseline 就是:

  • 参数是 FP32
  • 激活值是 FP32
  • 反向传播中的梯度是 FP32
  • 优化器状态也是 FP32

如果把一个线性层写成矩阵乘 $Y = XW$ ,那么一个训练 step 里最核心的计算通常可以拆成 3 部分:

1. Fprop(Forward Propagation)

也就是前向传播。给定输入激活 X 和权重 W,计算输出激活 Y

1
Y = X @ W

在 Transformer / LLM 里,下面这些大头计算本质上都属于 Fprop:

  • Q = X @ Wq
  • K = X @ Wk
  • V = X @ Wv
  • AttnOut = P @ V
  • MLP_up = X @ W1
  • MLP_down = H @ W2

这一阶段会产生大量中间激活,供后续反向传播使用。

2. Dgrad(Data Gradient)

反向传播时,先从上一层传回来输出梯度 dY。要继续往前一层传,就要计算输入梯度 dX

1
dX = dY @ W.T

这部分叫 Dgrad,可以理解成“对输入激活的梯度”。它的作用是把误差信号继续往网络前面传。

3. Wgrad(Weight Gradient)

同样拿到 dY 之后,还要计算当前层权重的梯度 dW

1
dW = X.T @ dY

这部分叫 Wgrad,也就是“对权重的梯度”。优化器后面更新参数,依赖的就是它。

所以,从算子视角看,一个 LLM layer 的训练主干其实就是:

  1. Fprop: Y = X @ W
  2. Dgrad: dX = dY @ W.T
  3. Wgrad: dW = X.T @ dY
  4. Optimizer step: 用 dW 去更新参数和优化器状态

如果写成更贴近训练过程的伪代码,大概是这样:

1
2
3
4
5
6
7
for x, y in loader:
    optimizer.zero_grad()

    pred = model(x)        # Fprop: all major GEMMs in FP32
    loss = criterion(pred, y)
    loss.backward()          # Dgrad + Wgrad: also FP32
    optimizer.step()         # FP32 master params and FP32 optimizer states

这个 baseline 的优点是数值稳定、行为简单、最容易对齐理论公式。在模型刚开始验证正确性时,很多团队还是会先跑纯 FP32,确认:

  • loss 曲线正常
  • 梯度没有异常
  • Fprop / Dgrad / Wgrad 的实现都正确

但它也有一个非常现实的问题:太贵。当模型从几亿参数走向几十亿甚至上百亿参数时,FP32 的训练成本会迅速变得不可接受。

为什么 FP32 太慢/太贵

FP32 的“贵”,在 LLM 训练里最好不要只理解成“显存更大一点”,而要理解成:

  • Fprop 贵
  • Dgrad 贵
  • Wgrad 贵
  • 保存它们需要的激活、梯度、优化器状态也贵

也就是说,一个 step 里最贵的几段矩阵乘,全都在用 32 bit 做。

1. 显存占用高

FP32 每个数占 4 Byte。对 LLM 来说,显存消耗不只是“参数本身”,还包括:

  • 参数(weights)
  • Fprop 需要保留的激活(activations)
  • Dgrad / Wgrad 产生的梯度(gradients)
  • 优化器状态(例如 Adam 的 mv

其中最容易被忽略的一点是:Wgrad 依赖前向时保存下来的输入激活 X。这意味着:

  • Fprop 不能只是“算完就丢”
  • 为了反向做 Dgrad/Wgrad,很多中间结果必须缓存
  • sequence length、micro-batch、hidden size 一大,激活显存就会非常夸张

再加上 AdamW 这类优化器,单个参数往往至少对应:

  • 1 份参数
  • 1 份梯度
  • 2 份优化器状态

也就是说,仅参数相关状态就可能接近 4 份 FP32 张量,还没算激活、临时 buffer、通信 buffer。模型一大,显存马上爆掉。

2. Tensor Core 利用率低

NVIDIA Tensor Core:

  • FP16/BF16/FP8 吞吐极高
  • FP32 throughput 很低

例如 Hopper H 100:

精度 Tensor Core TFLOPS
FP32 ~60
BF16 ~1000
FP8 ~2000

差距巨大。

3. Memory Bandwidth 成瓶颈

FP32:每个元素 4 bytes。

大模型训练:

真正瓶颈往往不是 FLOPS,
而是:

  • HBM 带宽
  • NVLink 带宽
  • All-reduce 带宽

降低精度可以直接减轻 IO

一张图总结

FP16/BF16 混合精度训练

FP 16 混合精度训练是最早成熟的混合精度方案,其核心思想是:

能用 FP 16 的地方尽量用 FP 16 提速和省显存;真正对数值稳定性敏感的地方保留 FP 32。

1. 朴素 FP 16 训练的问题

如果简单地将所有计算都转换为 FP16,会遇到两个致命问题:

  1. 梯度下溢:大部分梯度值会小于 FP16 的最小正数值,变为 0,导致模型无法收敛
  2. 权重更新失效:权重更新量(学习率 × 梯度)通常很小,在 FP16 下无法精确表示

2. 完整 FP16 混合精度训练流程 (经典 AMP)

NVIDIA 在 2017 年提出的混合精度训练方案1 通过三个核心技术解决了上述问题:

  1. 模型前向多数算子用 FP 16
    • 线性层、卷积、attention 中的大矩阵乘,优先走 FP 16 Tensor Core
    • 激活值通常也以 FP 16 保存,减少显存
  2. 反向传播多数梯度计算仍在 FP 16 路径上进行
    • 梯度张量本身很多时候是半精度
  3. 维护 FP 32 master weight
    • 训练时真正用于更新的“主参数”保留一份 FP 32
    • 每次更新先在 FP 32 上做,再 cast 回 FP 16 供下一轮 forward 使用
  4. 部分算子强制保留 FP 32
    • 例如 softmax、layer norm、某些 reduction、loss 计算等
    • 因为这些地方更容易出现舍入误差、上溢或下溢
  5. 损失缩放 Loss Scaling 防止梯度下溢

FP16 完整流程如下2

  1. Maintain a primary copy of weights in FP32.
  2. Initialize S to a large value.
  3. For each iteration:
    1. Make an FP16 copy of the weights.
    2. Forward propagation (FP16 weights and activations).
    3. Multiply the resulting FP32 loss with the scaling factor S.
    4. Backward propagation (FP16 weights, activations, and their gradients).
    5. If there is an Inf or NaN in weight gradients: 6. Reduce S. 7. Skip the weight update and move to the next iteration.
    6. Multiply the weight gradient FP16 ->FP32 with 1/S.
    7. Complete the weight update FP32 (including gradient clipping, etc.).
    8. If there hasn’t been an Inf or NaN in the last N iterations, increase *S

经典 AMP(Automatic Mixed Precision)的收益非常直接:

  • 大量 GEMM/Conv 使用 Tensor Core,吞吐显著提升
  • 参数和激活改成半精度后,显存占用下降
  • 保留 FP32 的关键路径后,训练仍能收敛

但 FP16 并不是“免费午餐”。它虽然省显存、算得快,却有很明显的数值缺陷,最典型的就是动态范围太小,这会直接引出下面的 loss scaling

3. FP16 下的 Dynamic Loss Scaling

FP16 的最大问题不是只有“小数不够精确”,而是很小的梯度可能直接下溢为 0

深度网络反向传播时,很多梯度本来就很小。如果它们在 FP16 中落到可表示范围之外,就会变成 0,结果是:

  • 一部分参数得不到更新
  • 梯度统计失真
  • 训练可能不稳定,甚至完全不收敛

loss scaling 的思路非常简单:既然梯度太小,那就先把它们整体放大。

基本过程

  1. forward 之后得到原始 loss
  2. loss 乘上一个缩放因子 scale
  3. 对放大后的 loss 做 backward
  4. 得到的梯度会同步被放大
  5. 在 optimizer step 前,再把梯度除以同样的 scale

伪代码如下:

1
2
3
4
5
6
7
8
scaled_loss = loss * scale
scaled_loss.backward()

for p in model.parameters():
    if p.grad is not None:
        p.grad /= scale

optimizer.step()

为什么需要“动态” loss scaling

如果 scale 设得太小,还是会下溢;但如果设得太大,又可能上溢成 infnan。因此经典 AMP 通常不是用固定值,而是使用 dynamic loss scaling

  • 一段时间没有溢出,就把 scale 提高
  • 一旦发现梯度出现 inf/nan,就跳过这次更新并降低 scale
1
2
3
4
if overflow:
    scale /= 2
else:
    scale *= 2 occasionally

这套机制让 FP16 训练在工程上变得可用,但也带来了额外复杂度:

  • 需要做 overflow 检测
  • 需要跳过异常 step
  • 训练流程更复杂,调试也更麻烦 而 BF16 的崛起,恰恰就是因为它在很大程度上绕开了这个问题。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 伪代码:FP16混合精度训练
model = Model().to(torch.float32)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler(init_scale=2**16)  # 损失缩放器

for batch in dataloader:
    x, y = batch.to(torch.float32)
    
    # 1. 前向传播:自动转换为FP16
    with torch.cuda.amp.autocast(dtype=torch.float16):
        outputs = model(x)
        loss = loss_fn(outputs, y)
    
    # 2. 反向传播:先缩放损失,再计算梯度
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    
    # 3. 可选:梯度裁剪(在unscale之后)
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    # 4. 优化器更新:自动unscale梯度并更新FP32主权重
    scaler.step(optimizer)
    
    # 5. 更新损失缩放因子
    scaler.update()

autocast 做的事情不是“全都转成半精度”,而是3

  • 给适合低精度的算子分配 BF16/FP16
  • 给敏感算子保留 FP32
  • 自动处理算子输入输出的 dtype 协调

4. 当前主流:BF16 + FP32 Master Weight + FP32 Optimizer State

BF16(bfloat16)也是 16 bit,但它和 FP16 的位分配不一样:

  • FP16:指数位少、尾数位相对多
  • BF16:指数位和 FP32 一样多,尾数位更少

这意味着 BF16 的关键优势是:

  • 动态范围接近 FP32
  • 对非常大或非常小的数更不容易溢出/下溢
  • 在深度学习训练里,通常比 FP16 更稳定

可以把两者简单理解为:

  • FP16:精度稍细,但“量程短”
  • BF16:精度粗一点,但“量程长”

训练里更怕什么?通常更怕梯度、激活、logits 在中间过程直接炸掉或归零,所以 BF16 往往比 FP16 更实用。

于是工业界逐渐形成了非常一致的结论:

  • 如果硬件支持 BF16,优先用 BF16,而不是 FP16
  • BF16 训练通常不需要 loss scaling
  • 对大模型、长序列、复杂 attention 结构,BF16 往往更稳

当然,BF16 也不是完美的。它的尾数位比 FP32 少很多,所以:

  • 仍然不适合把所有状态都一股脑降到 BF16
  • 某些需要高精度累积和更新的地方,还是应该保留 FP32

这就形成了现代混合精度训练的主流范式:BF16 负责大部分算子计算,FP32 负责关键累积与更新。

1
2
3
4
5
6
7
for x, y in loader:
    optimizer.zero_grad()
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        pred = model(x)
        loss = criterion(pred, y)
    loss.backward()
    optimizer.step()

如果用的是 BF16,一般可以不启用 GradScaler;如果是 FP16,则常配合 GradScaler 处理 loss scaling。

把前面的内容合起来,当前业界最常见、也最稳妥的范式可以概括为:

计算侧

  • Forward 使用 BF16
  • Backward 的大部分张量计算也使用 BF16
  • GEMM/Attention 等核心算子在 Tensor Core 上完成
  • 关键 reduction 和 accumulate 使用 FP32

状态侧

  • 参数存在 BF16 计算副本
  • 更新时使用 FP32 master weight,或者等价的 FP32 update buffer
  • Adam/AdamW 的 mv 等优化器状态维持 FP32
  • 某些框架还会保留 FP32 gradient accumulation buffer

一个典型 step 会发生什么

  1. 读取 BF16 权重做 forward
  2. 生成 BF16 激活并完成 backward
  3. 对梯度做必要的 all-reduce / reduce-scatter
  4. 在 FP32 语义下执行优化器更新
  5. 将更新后的结果转换回 BF16,作为下一轮 forward 的参数

这个设计之所以成为主流,是因为它在三件事之间取得了平衡:

  • 性能:主算子走低精度 Tensor Core
  • 显存:大部分激活和计算副本是 16 bit
  • 稳定性:参数更新和优化器状态保留 FP32

需要补充的是,不同框架对 “FP 32 master weight” 的实现方式并不完全一样:

  • 有的显式维护一份独立 FP 32 参数副本
  • 有的在优化器内部保存 FP32 更新视图
  • 在 ZeRO/FSDP 中,这份 FP32 状态还可能是分片存储的

但无论实现细节如何,背后的原则基本一致:更新时不要只依赖低精度状态。

5. 一张图总结

现代框架(PyTorch / Megatron / DeepSpeed)的典型实现

FSDP:BF16 混合精度是怎么落地的

结合 PyTorch 新版 FSDP(torch.distributed.fsdp._fully_shard,也常被叫作 FSDP2)代码来看,FSDP 的 mixed precision 不是像 autocast 那样按 op 粒度决定 dtype,而是按 module 边界管理参数的 all-gather、forward/backward 计算、梯度 reduce-scatter

入口配置就是 MixedPrecisionPolicy

1
2
3
4
5
6
@dataclass(frozen=True)
class MixedPrecisionPolicy:
    param_dtype: Optional[torch.dtype] = None
    reduce_dtype: Optional[torch.dtype] = None
    output_dtype: Optional[torch.dtype] = None
    cast_forward_inputs: bool = True

如果我们想做典型的 BF16 训练,常见模型就是:

1
2
3
4
5
mp_config = dict(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.float32,
)
  • param_dtype=torch.bfloat16:unshard 后给 forward/backward 用的参数是 BF16
  • reduce_dtype=None:梯度规约默认跟计算梯度 dtype 走,也就是通常还是 BF16
  • reduce_dtype=torch.float32:如果更在意数值稳定,也可以强制梯度规约走 FP32
  • output_dtype:需要和别的 mixed precision policy 串接时,再把 forward 输出转成指定 dtype

FSDP 的一个关键点是:它并不会把底层 sharded parameter 永久改成 BF16。 FSDPParam.init_dtype_attrs() 里会把:

  • orig_dtype = self.sharded_param.dtype
  • param_dtype = mp_policy.param_dtype
  • reduce_dtype = mp_policy.reduce_dtype

分开保存。也就是说,假设模型初始化时参数本来是 FP32,那么:

  • 常驻内存里的 sharded parameter 仍然是 FP32
  • 只有在 forward/backward 前 all-gather 成 unsharded parameter 时,才按 param_dtype 变成 BF16 供计算使用

这就是 _fsdp_api.py 注释里那句 “optimizer step uses the sharded parameter in the original dtype” 的真实含义:FSDP 不需要再额外维护一整份完整的 FP32 master weight,因为它本来就保留着分片后的原始精度参数。

把一个训练 step 展开,代码路径大致是这样的:

  1. forward 前把输入转成 BF16

    • FSDPState._pre_forward() 会在 cast_forward_inputs=Trueparam_dtype 非空时,把 floating-point inputs cast 到 param_dtype
  2. all-gather 时按 param_dtype 准备参数

    • FSDPParam.all_gather_inputs 会把本 rank 的 sharded parameter 转成 self.param_dtype
    • foreach_all_gather() 再把这些 BF16 shard all-gather 成完整参数
    • FSDPParam.init_unsharded_param() 用 all-gather 结果恢复出 unsharded parameter,dtype 是 self.param_dtype or self.orig_dtype
  3. forward/backward 主计算使用 BF16 参数

    • 因为 module 上注册的参数已经被切换成 unsharded BF16 parameter,所以线性层、attention、MLP 这些主干算子会自然跑在 BF16 路径上
    • 这就是 _fsdp_api.py 文档里说的 “module-level mixed precision”:cast 主要发生在 module 边界,而不是每个算子内部反复 cast
  4. backward 结束后,梯度按 reduce_dtype 做 reduce-scatter

    • FSDPParamGroup.post_backward() 会收集 unsharded grad
    • foreach_reduce()reduce_dtype = reduce_dtype or grad_dtype
    • 如果 reduce_dtype is None,那就直接用当前梯度 dtype 规约;在 BF16 compute 场景下,通常就是 BF16 reduce-scatter
    • 如果显式设成 torch.float32,就会先把梯度 buffer 建成 FP32,再走 FP32 规约
  5. 规约后再转回原始参数 dtype,供 optimizer step 使用

    • foreach_reduce() 在规约结束后会执行 reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype)
    • 随后把结果写回 fsdp_param.sharded_param.grad
    • 这意味着如果原始分片参数是 FP32,那么优化器看到的 sharded grad 也是 FP32,参数更新就在 FP32 语义下完成

所以,FSDP 中 BF16 mixed precision 的本质可以总结成一句话:

常驻的分片参数保留原始精度(通常 FP32),临时 all-gather 出来的计算参数用 BF16,梯度通信可以用 BF16 或 FP32,最后再把梯度恢复到原始参数精度去做 optimizer step。

这和经典 AMP + full master weight 很像,但又不完全一样:

  • 经典 AMP:常见做法是保留一份完整 FP32 master weight,再派生出 FP16/BF16 计算副本
  • FSDP:保留的是分片后的原始精度参数,计算时临时 all-gather/cast,所以不用多存一份完整 FP32 副本

从这个实现也能看出,FSDP 在回答的其实是三个更具体的问题:

  1. 计算参数用什么 dtype?param_dtype 控制,典型是 BF16
  2. 通信时梯度用什么 dtype?reduce_dtype 控制,默认跟 compute dtype,也可单独升到 FP32
  3. 优化器更新用什么 dtype?orig_dtype,通常仍是 FP32

Tensor Core 上的真实计算路径

从硬件角度看,现代 GPU 上的混合精度训练并不是“整个训练都在 BF16”或“整个训练都在 FP16”,而是一个输入精度、乘法精度、累加精度、输出精度分层处理的过程。

最常见的路径是:

  • 输入张量是 BF16/FP16
  • Tensor Core 执行低精度矩阵乘
  • 累加(accumulate)在更高精度里完成,通常是 FP32
  • 输出再根据算子策略写回 BF16 或 FP32

现在我们把 Tensor Core 和你之前学的混合精度训练结合起来,看一个完整的线性层前向传播过程:

1
2
3
4
# 你以为的代码逻辑
X = X.to(torch.float16)
W = W.to(torch.float16)
Y = X @ W  # 全部在FP16下计算
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
# 实际的硬件执行逻辑
X_fp16 = X.to(torch.float16)
W_fp16 = W.to(torch.float16)

# 调用Tensor Core指令,内部执行:
acc_fp32 = 0.0
for k in range(hidden_dim):
    product_fp16 = X_fp16[:, k] * W_fp16[k, :]
    product_fp32 = product_fp16.to(torch.float32)
    acc_fp32 += product_fp32

Y_fp16 = acc_fp32.to(torch.float16)

对矩阵乘来说,这一点尤其关键。因为训练里真正占大头的是:

  • Linear
  • Attention QK^T
  • Attention Prob @ V
  • MLP 中的大 GEMM

这些操作通常都在 Tensor Core 上跑,而 Tensor Core 的高吞吐正是混合精度训练性能提升的来源。

到了 Hopper 一代,这种趋势更明显:

  • 对 BF16/FP16 的硬件支持更强
  • 对 FP8 也开始提供专门路径
  • fused attention、fused MLP、Transformer Engine 等会更积极地管理 dtype

因此,“混合精度训练”在真实 GPU 上的含义,并不是简单地把模型 .half(),而是:

把大吞吐矩阵乘放到低精度 Tensor Core,把数值敏感的累加和更新保留在更高精度。

更进一步:FP8 / MXFP8 的方向

当 BF16 已经成为主流后,行业还在继续往更低精度推进,代表方向就是 FP8

推动 FP8 的核心动力还是老问题:

  • 进一步提升 Tensor Core 吞吐
  • 进一步降低显存与带宽压力
  • 让更大的模型、batch、sequence 能塞进同样的硬件

不过 FP8 比 BF16/FP16 更激进,因此它通常不会像 BF16 那样“直接默认开启”。常见做法是:

  • 只让一部分张量使用 FP8
  • 给不同张量维护独立的 scaling factor
  • 依赖专门的 runtime 或 Transformer Engine 管理量化/反量化
  • 继续保留更高精度的累积、权重更新和优化器状态

所以 FP8 更像是在 BF16 混合精度之上的下一层精细化工程:

  • BF16 解决的是“主流训练低精度化”
  • FP8 解决的是“进一步逼近硬件吞吐上限”

对大多数团队来说,今天最值得掌握的仍然是这条主线:

  1. 先理解纯 FP32 baseline
  2. 再理解经典 FP16 AMP 与 loss scaling
  3. 最后掌握 BF16 compute + FP32 accumulate/update 这一现代主流范式

掌握这条主线之后,再去看 FP8、MXFP8、Transformer Engine、块级 scaling 等新技术,就会顺畅很多。