从 Tensor 到 DTensor

Tensor

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
class TensorMetadata(NamedTuple):  
    # TensorMetadata is a structure containing pertinent information  
    # about a tensor within a PyTorch program.
    
    # General Tensor metadata
    shape: torch.Size  
    dtype: torch.dtype  
    requires_grad: bool  
    stride: Tuple[int, ...]  
    memory_format: Optional[torch.memory_format]  
  
    # Quantization metadata  
    is_quantized: bool  
    qparams: Dict[str, Any]

DeviceMesh

DeviceMesh 提供表达一组 device 布局的抽象,可以用一个多维数组表达,同时也提供 Mesh 内 device 通信的支持。

可以通过 init_device_mesh 来初始化一个 DeviceMesh:

1
2
3
4
5
6
from torch.distributed.device_mesh import init_device_mesh
mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("dp", "tp"))

# Users can access the underlying process group thru `get_group` API.
dp_group = mesh_2d.get_group(mesh_dim="dp")
tp_group = mesh_2d.get_group(mesh_dim="tp")

Placement

Shard

shard or split the tensor on the specific tensor dimension across devices

Replicate

Replicate the tensor across devices, each rank gets the exact same tensor

Partial

A type of tensor that has the same shape across devices, but only has partial values on each device. It could be further reduces (i.e. sum/min/max) to get the DistributedTensor. This is often useful as intermediate representation.

作为对比, PyTorch

PT-D DistributedTensor OneFlow’s SBP GSPMD’s tensor sharding
Shard Split Tiled
Replicate Broadcast Replicated
Partial Partial Partially tiled = tiled + replicated

DTensorSpec

DTensorSpec 完全表达了一个 DTensor 的元信息,分别由以下三个部分组成:

  • DeviceMesh 对象来表达 DTensor 的 mesh 信息,
  • Tuple[Placement] 来表达 placements 方法
  • TensorMetadata 对象来表达 global tensor 的 meta 信息
1
2
3
4
5
6
class DTensorSpec:  
    mesh: DeviceMesh  
    placements: Tuple[Placement, ...]  
  
    # tensor meta will only be set during sharding propagation  
    tensor_meta: Optional[TensorMeta] = None

实际举例:

1
DTensorSpec(mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)], tensor_meta=TensorMetadata(shape=torch.Size([6, 3]), dtype=torch.int64, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}))

DTensor

Torch 的 DTensor 在 torch.Tensor 类型上进行了简单的封装:

1
2
3
4
5
6
7
class DTensor(torch.Tensor):  
    _local_tensor: torch.Tensor  
    _spec: DTensorSpec  
    __slots__ = ["_local_tensor", "_spec"]  
  
    # _op_dispatcher instance as a class attribute to handle runtime dispatching logic  
    _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher()

主要包括 _local_tensor_spec:​

  • _local_tensor 是实际存储的 torch.Tensor 变量 (per rank)。​
  • _spec 中存储了 DTensor 的全部元信息,对应于 DTensorSpec 字段
    • 包括 DeviceMesh、切分策略(Placements)以及传统 tensor 的属性信息,例如 shape、dtype 等

__torch_dispatch__

PyTorch 算子下发流程: