PyTorch 原生并行:从 Tensor 到 DTensor
从 Tensor 到 DTensor
Tensor
|
|
DeviceMesh
DeviceMesh 提供表达一组 device 布局的抽象,可以用一个多维数组表达,同时也提供 Mesh 内 device 通信的支持。
init_device_mesh 来初始化一个 DeviceMesh:
|
|
Placement
Shard
shard or split the tensor on the specific tensor dimension across devices
Replicate
Replicate the tensor across devices, each rank gets the exact same tensor
Partial
A type of tensor that has the same shape across devices, but only has partial values on each device. It could be further reduces (i.e. sum/min/max) to get the DistributedTensor. This is often useful as intermediate representation.
作为对比, PyTorch
| PT-D DistributedTensor | OneFlow’s SBP | GSPMD’s tensor sharding |
|---|---|---|
| Shard | Split | Tiled |
| Replicate | Broadcast | Replicated |
| Partial | Partial | Partially tiled = tiled + replicated |
DTensorSpec
DTensorSpec 完全表达了一个 DTensor 的元信息,分别由以下三个部分组成:
DeviceMesh对象来表达 DTensor 的 mesh 信息,Tuple[Placement]来表达 placements 方法TensorMetadata对象来表达 global tensor 的 meta 信息
|
|
实际举例:
|
|
DTensor
Torch 的 DTensor 在 torch.Tensor 类型上进行了简单的封装:
|
|
主要包括 _local_tensor 和 _spec:
_local_tensor是实际存储的torch.Tensor变量 (per rank)。_spec中存储了 DTensor 的全部元信息,对应于DTensorSpec字段- 包括 DeviceMesh、切分策略(Placements)以及传统 tensor 的属性信息,例如 shape、dtype 等
__torch_dispatch__
PyTorch 算子下发流程:
Linked Mentions
-
No backlinks found.