1from dataclasses import dataclass 2from typing import Any, cast, List, NamedTuple, Optional, Tuple 3 4import torch 5from torch.distributed.device_mesh import DeviceMesh 6from torch.distributed.tensor.placement_types import ( 7 Partial, 8 Placement, 9 Replicate, 10 Shard, 11) 12 13 14class TensorMeta(NamedTuple): 15 # simple named tuple to represent tensor metadata 16 # intentionally to stay simple only for sharding 17 # propagation purposes. 18 shape: torch.Size 19 stride: Tuple[int, ...] 20 dtype: torch.dtype 21 22 23# used internally to propagate the placements 24@dataclass 25class DTensorSpec: 26 mesh: DeviceMesh 27 placements: Tuple[Placement, ...] 28 29 # tensor meta will only be set during sharding propagation 30 tensor_meta: Optional[TensorMeta] = None 31 32 def __post_init__(self) -> None: 33 if not isinstance(self.placements, tuple): 34 self.placements = tuple(self.placements) 35 self._hash: Optional[int] = None 36 37 def __setattr__(self, attr: str, value: Any) -> None: 38 super().__setattr__(attr, value) 39 # Make sure to recompute the hash in case any of the hashed attributes 40 # change (though we do not expect `mesh` or `placements` to change) 41 if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"): 42 self._hash = None 43 44 def _hash_impl(self) -> int: 45 # hashing and equality check for DTensorSpec are used to cache the sharding 46 # propagation results. We only need to consider the mesh, placements, shape 47 # dtype and stride. 48 # Caveat: we need to keep this in mind and sync hash and eq if we add more 49 # fields to them. 50 if self.tensor_meta is not None: 51 return hash( 52 ( 53 self.mesh, 54 self.placements, 55 self.tensor_meta.shape, 56 self.tensor_meta.stride, 57 self.tensor_meta.dtype, 58 ) 59 ) 60 return hash((self.mesh, self.placements)) 61 62 def __hash__(self) -> int: 63 # We lazily cache the spec to avoid recomputing the hash upon each 64 # use, where we make sure to update the hash when the `tensor_meta` 65 # changes by overriding `__setattr__`. This must be lazy so that Dynamo 66 # does not try to hash non-singleton `SymInt`s for the stride. 67 if self._hash is None: 68 self._hash = self._hash_impl() 69 return self._hash 70 71 def __eq__(self, __o: object) -> bool: 72 if not ( 73 isinstance(__o, DTensorSpec) 74 and self.mesh == __o.mesh 75 and self.placements == __o.placements 76 ): 77 return False 78 if self.tensor_meta is None or __o.tensor_meta is None: 79 return self.tensor_meta == __o.tensor_meta 80 81 return ( 82 self.tensor_meta.shape == __o.tensor_meta.shape # type: ignore[union-attr] 83 and self.tensor_meta.stride == __o.tensor_meta.stride # type: ignore[union-attr] 84 and self.tensor_meta.dtype == __o.tensor_meta.dtype # type: ignore[union-attr] 85 ) 86 87 def __str__(self) -> str: 88 """ 89 human readable representation of the DTensorSpec 90 """ 91 if len(self.placements) == 1: 92 placement_str = str(self.placements[0]) 93 else: 94 placement_str = str(self.placements) 95 96 if self.tensor_meta is not None: 97 tensor_shape = str(tuple(self.tensor_meta.shape)) 98 else: 99 tensor_shape = "unknown shape" 100 101 return f"Spec({placement_str} on {tensor_shape})" 102 103 @property 104 def shape(self) -> torch.Size: 105 if self.tensor_meta is None: 106 raise ValueError("tensor_meta is not set") 107 return self.tensor_meta.shape 108 109 @property 110 def stride(self) -> Tuple[int, ...]: 111 if self.tensor_meta is None: 112 raise ValueError("tensor_meta is not set") 113 return self.tensor_meta.stride 114 115 @property 116 def ndim(self) -> int: 117 if self.tensor_meta is None: 118 raise ValueError("tensor_meta is not set") 119 return len(self.tensor_meta.shape) 120 121 @property 122 def num_shards(self) -> int: 123 num_shards = 1 124 for i, placement in enumerate(self.placements): 125 if placement.is_shard(): 126 num_shards *= self.mesh.size(i) 127 return num_shards 128 129 @property 130 def device_mesh(self) -> DeviceMesh: 131 # simple aliasing for the mesh field, make some 132 # checks that mixes DTensor/DTensorSpec easier 133 return self.mesh 134 135 @property 136 def dim_map(self) -> List[int]: 137 """ 138 dim_map is a property we derive from `placements` of 139 the distributed tensor. It simply return a list of ints 140 where dim_map[i] denotes the sharding mapping to the mesh 141 dimension, and len(dim_map) == dist_tensor.ndim 142 dim_map[i] = -1: means tensor dim i replicate on mesh 143 dim_map[i] = j: means tensor dim i shard on mesh dim j 144 145 For example, we have a dist tensor that have the shape of 146 [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements: 147 [Shard(1)], the dim_map of this placement would be: 148 [-1, 0, -1]. This representation is pretty helpful during 149 sharding propagation where we could know exactly each 150 tensor dimension is sharded or not. 151 152 Note that if placements contains `_Partial`, we have to 153 explicitly deal with it, so that when we create a DTensorSpec 154 with dim_map, we could properly record the pending sums. 155 """ 156 # dims mapping of dist tensor sharding 157 # return size of tensor ndim, -1 represent replicate 158 # and int >=0 represent shard on that device mesh dim 159 r = [-1] * self.ndim 160 for i, placement in enumerate(self.placements): 161 if placement.is_shard(): 162 shard_dim = cast(Shard, placement).dim 163 if r[shard_dim] > -1: 164 raise ValueError( 165 f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," 166 " DTensor operator implementation does not support things like hybrid" 167 " sharding strategies yet (i.e. [Shard(0), Shard(0)])" 168 ) 169 r[shard_dim] = i 170 return r 171 172 @property 173 def num_shards_map(self) -> List[int]: 174 """ 175 dim_map is a property we derive from `placements` of 176 the distributed tensor. Unlike `dim_map`, `num_shards_map` 177 denotes how many shards each tensor dim has. Like `dim_map`: 178 len(num_shards_map) == dist_tensor.ndim 179 num_shards_map[i] = 1: means tensor dim i is not sharded 180 num_shards_map[i] = j: means tensor dim i has j shards in total 181 182 For example, we have a dist tensor of shape [18, 20, 30], 183 a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements 184 ([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor 185 would be: [4, 2, 1]. 186 """ 187 r = [1] * self.ndim 188 for i, placement in enumerate(self.placements): 189 if placement.is_shard(): 190 shard_dim = cast(Shard, placement).dim 191 r[shard_dim] *= self.mesh.size(i) 192 193 return r 194 195 @property 196 def sums(self) -> List[int]: 197 """ 198 sums is a property we derive from `placements` of the 199 distributed tensor. It simply return a list of ints where 200 sums[i] denotes the pending sum (partial) on mesh dim i 201 """ 202 return [ 203 idx 204 for idx, placement in enumerate(self.placements) 205 if placement.is_partial() 206 ] 207 208 @classmethod 209 def from_dim_map( 210 cls, 211 mesh: DeviceMesh, 212 dim_map: List[int], 213 sums: List[int], 214 tensor_meta: Optional[TensorMeta] = None, 215 ) -> "DTensorSpec": 216 """ 217 Construct a DTensorSpec from dim_map list and pending sum. 218 219 Args: 220 mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec 221 dim_map (List[int]): a list of integer that represents sharding on each 222 tensor dimension, see `dim_map` property doc for details 223 sums (List[int]): a list of integer that represents the dist tensor have 224 pending sum on which device mesh dimension. 225 tensor meta (TensorMeta): DTensor metadata 226 227 Return: 228 a class:`DTensorSpec` object 229 """ 230 # by default replicate on device mesh dims 231 placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)] 232 233 # find all mesh dims that need pending reductions 234 for s in sums: 235 placements[s] = Partial() 236 237 for i, m in enumerate(dim_map): 238 if m >= 0: 239 placement = placements[m] 240 if placement.is_shard(): 241 placement = cast(Shard, placement) 242 raise RuntimeError( 243 f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}" 244 ) 245 elif placement.is_partial(): 246 raise RuntimeError( 247 f"DeviceMesh dimension {m} cannot be both shard and partial!" 248 ) 249 placements[m] = Shard(i) 250 251 return cls(mesh, tuple(placements), tensor_meta=tensor_meta) 252 253 def is_replicated(self) -> bool: 254 """ 255 return True if the current DTensorSpec replicates on all mesh dims (devices) 256 """ 257 return all(placement.is_replicate() for placement in self.placements) 258 259 def is_sharded(self) -> bool: 260 """ 261 return True if the current DTensorSpec is sharded on any mesh dims (devices) 262 """ 263 return any(placement.is_shard() for placement in self.placements) 264 265 def shallow_copy_with_tensor_meta( 266 self, tensor_meta: Optional[TensorMeta] 267 ) -> "DTensorSpec": 268 """ 269 Shallow copy the DTensorSpec with a new tensor_meta. 270 """ 271 assert tensor_meta is not None, "shallow copy with no tensor_meta!" 272 return DTensorSpec( 273 self.mesh, 274 self.placements, 275 tensor_meta=tensor_meta, 276 ) 277