1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3import logging 4from functools import lru_cache 5from typing import cast, List, NamedTuple, Tuple 6 7import torch 8import torch.distributed._functional_collectives as funcol 9import torch.distributed.tensor._api as dtensor 10from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta 11from torch.distributed.tensor.device_mesh import DeviceMesh 12from torch.distributed.tensor.placement_types import ( 13 Partial, 14 Placement, 15 Replicate, 16 Shard, 17) 18 19 20logger = logging.getLogger(__name__) 21 22 23class _TransformInfo(NamedTuple): 24 mesh_dim: int 25 src_dst_placements: Tuple[Placement, Placement] 26 # logical_shape on this mesh dimension 27 logical_shape: List[int] 28 29 30@lru_cache(maxsize=None) 31def _gen_transform_infos( 32 src_spec: DTensorSpec, 33 dst_spec: DTensorSpec, 34) -> List[_TransformInfo]: 35 """ 36 Generate the transform infos from the source placements to the target placements. 37 38 To transform from source to target placement it might have multiple steps, i.e. it 39 might decompose Si -> Sj into Si -> R -> Sj. 40 This would detect if there're mis-aligned/nested shardings between src/dst placements. 41 E.g. Suppose the redistribution to perform is (Shard(0), Shard(0)) -> (Replicate(), Shard(0)), 42 in this case Shard(0) -> Shard(0) for mesh dimension 1 actually needs resharding, because in 43 the former is a nested-sharding of a tensor already already sharded dimension 0, whereras 44 the latter is the first sharding on tensor dimension 0. 45 """ 46 transform_infos: List[_TransformInfo] = [] 47 48 device_mesh = src_spec.device_mesh 49 my_coordinate = device_mesh.get_coordinate() 50 assert my_coordinate is not None 51 52 # logical shape records the logic tensor shape on the mesh dimension 53 # this is useful to ensure uneven sharding gets correct output shape 54 initial_logical_shape = list(src_spec.shape) 55 mesh_dims_to_logical_shape = [initial_logical_shape] 56 57 if device_mesh.ndim == 1: 58 # if device_mesh is 1D, redistribute is a simple direct transformation 59 transform_infos.append( 60 _TransformInfo( 61 mesh_dim=0, 62 src_dst_placements=(src_spec.placements[0], dst_spec.placements[0]), 63 logical_shape=initial_logical_shape, 64 ) 65 ) 66 return transform_infos 67 68 # Handle multi-dim device mesh placement redistribution 69 # First, we need to build the logical shape for each mesh dim 70 # for correct allgathering uneven shards on each mesh dim (with dynamic padding) 71 for i, (src, dst) in enumerate(zip(src_spec.placements, dst_spec.placements)): 72 current_logical_shape = mesh_dims_to_logical_shape[i] 73 if isinstance(src, Shard): 74 if i < device_mesh.ndim - 1: 75 # calculate and save the logical shape for this sharding 76 mesh_dim_size = device_mesh.size(mesh_dim=i) 77 local_shard_size, _ = src._local_shard_size_on_dim( 78 current_logical_shape[src.dim], 79 mesh_dim_size, 80 my_coordinate[i], 81 ) 82 new_logical_shape = list(current_logical_shape) 83 new_logical_shape[src.dim] = local_shard_size 84 mesh_dims_to_logical_shape.append(new_logical_shape) 85 else: 86 mesh_dims_to_logical_shape.append(current_logical_shape) 87 88 # Next, we need to derive the transform infos from src to dst placements, 89 # here we use a greedy search with step by step state transformations 90 current_placements = list(src_spec.placements) 91 target_placements = list(dst_spec.placements) 92 93 if src_spec.num_shards > 1: 94 # If src_spec have sharding, it could potentially have sharding that is misaligned with dst_spec 95 # a common case of this is nested sharding (i.e. (S(0), S(0)) -> (R, S(0))). 96 # In those cases, we first traverse from inner placement to outer placement 97 # to detect misaligned shardings and properly replicate nested sharding first. 98 for mesh_dim in reversed(range(len(current_placements))): 99 current = current_placements[mesh_dim] 100 target = target_placements[mesh_dim] 101 # If target is not Shard, we can directly redistribute since we are traversing from innner 102 # to outer placements here 103 if isinstance(target, Shard): 104 # If target is Shard, check for nested sharding on the tensor dim BEFORE the current mesh_dim 105 shard_dim = target.dim 106 current_mesh_sharding, target_mesh_sharding = [], [] 107 for i, (s, p) in enumerate(zip(current_placements, target_placements)): 108 if i >= mesh_dim: 109 break 110 if s.is_shard(shard_dim): 111 current_mesh_sharding.append(i) 112 if p.is_shard(shard_dim): 113 target_mesh_sharding.append(i) 114 115 if current_mesh_sharding != target_mesh_sharding: 116 # if current/target_placements have misaligned sharding on the tensor dim BEFORE the current 117 # mesh_dim, we need to replicate the tensor on the mesh dim first to clear the nested sharding 118 target = Replicate() 119 120 if current != target: 121 transform_infos.append( 122 _TransformInfo( 123 mesh_dim=mesh_dim, 124 src_dst_placements=(current, target), 125 logical_shape=mesh_dims_to_logical_shape[mesh_dim], 126 ) 127 ) 128 current_placements[mesh_dim] = target 129 130 # We always traverse from outer placement to inner placement to collect the remaining 131 # needed transform infos (i.e. the replication from nested sharding might need to further 132 # perform resharding to Shard again) 133 for mesh_dim, (current, target) in enumerate( 134 zip(current_placements, target_placements) 135 ): 136 if current != target: 137 transform_infos.append( 138 _TransformInfo( 139 mesh_dim=mesh_dim, 140 src_dst_placements=(current, target), 141 logical_shape=mesh_dims_to_logical_shape[mesh_dim], 142 ) 143 ) 144 current_placements[mesh_dim] = target 145 146 return transform_infos 147 148 149def redistribute_local_tensor( 150 local_tensor: torch.Tensor, 151 current_spec: DTensorSpec, 152 target_spec: DTensorSpec, 153 *, 154 async_op: bool = False, 155 is_backward: bool = False, 156) -> torch.Tensor: 157 """ 158 This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to 159 the target DTensorSpec, which involves the necessary collective calls to transform 160 the local shard of the DTensor from its current spec to the target spec. 161 """ 162 163 if current_spec.mesh != target_spec.mesh: 164 # TODO: alltoall/permute reshuffling to change device_mesh if they are not the same 165 raise NotImplementedError("Cross device mesh comm not supported yet!") 166 167 new_local_tensor = None 168 device_mesh = current_spec.mesh 169 170 my_coordinate = device_mesh.get_coordinate() 171 172 if my_coordinate is None: 173 # if rank is not part of mesh, we skip redistribute and simply return local_tensor, 174 # which should be an empty tensor 175 return local_tensor 176 177 transform_infos = _gen_transform_infos(current_spec, target_spec) 178 179 for transform_info in transform_infos: 180 i = transform_info.mesh_dim 181 current, target = transform_info.src_dst_placements 182 num_chunks = device_mesh.size(mesh_dim=i) 183 184 if current == target: 185 # short cut, just use the original local tensor 186 new_local_tensor = local_tensor 187 continue 188 189 logger.debug("redistribute from %s to %s on mesh dim %s", current, target, i) 190 191 if target.is_replicate(): 192 # Case 1: target is Replicate 193 if current.is_partial(): 194 partial_spec = cast(Partial, current) 195 new_local_tensor = partial_spec._reduce_value( 196 local_tensor, device_mesh, i 197 ) 198 elif current.is_shard(): 199 current_placement = cast(Shard, current) 200 new_local_tensor = current_placement._to_replicate_tensor( 201 local_tensor, device_mesh, i, transform_info.logical_shape 202 ) 203 else: 204 raise RuntimeError( 205 f"redistribute from {current} to {target} not supported yet" 206 ) 207 elif target.is_shard(): 208 # Case 2: target is Shard 209 target_placement = cast(Shard, target) 210 target_dim = target_placement.dim 211 if current.is_partial(): 212 partial_spec = cast(Partial, current) 213 new_local_tensor = partial_spec._reduce_shard_value( 214 local_tensor, device_mesh, i, target_placement 215 ) 216 elif current.is_replicate(): 217 # split the tensor and return the corresponding cloned local shard 218 new_local_tensor = target_placement._replicate_to_shard( 219 local_tensor, device_mesh, i, my_coordinate[i] 220 ) 221 else: 222 assert ( 223 current.is_shard() 224 ), f"Current placement should be shard but found {current}" 225 shard_spec = cast(Shard, current) 226 if shard_spec.dim != target_placement.dim: 227 new_local_tensor = shard_spec._to_new_shard_dim( 228 local_tensor, 229 device_mesh, 230 i, 231 transform_info.logical_shape, 232 target_placement.dim, 233 ) 234 elif target.is_partial(): 235 if current.is_replicate(): 236 partial_spec = cast(Partial, target) 237 # skip the replicate to partial transformation when we are in backward pass 238 # In this case we keep the grad as replicate, this is because we don't 239 # want to convert the replicated gradients back to partial, although 240 # that's logically conform with the same layout, converting the gradients 241 # back to partial is actually useless as you would have to do reduce later 242 # which would be more expensive than keeping it replicate! For this reason, 243 # we keep the replicate grad here. 244 new_local_tensor = ( 245 partial_spec._partition_value(local_tensor, device_mesh, i) 246 if not is_backward 247 else local_tensor 248 ) 249 elif current.is_shard(): 250 if not is_backward: 251 raise RuntimeError( 252 f"redistribute from {current} to {target} not supported yet" 253 ) 254 # for backward shard -> partial, we just need to convert the shard to replicate 255 current_placement = cast(Shard, current) 256 new_local_tensor = current_placement._to_replicate_tensor( 257 local_tensor, device_mesh, i, transform_info.logical_shape 258 ) 259 else: 260 # partial -> partial no op, should never hit 261 new_local_tensor = local_tensor 262 263 assert new_local_tensor is not None 264 local_tensor = new_local_tensor 265 266 assert new_local_tensor is not None, "redistribute failed!" 267 268 if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor): 269 new_local_tensor = new_local_tensor.wait() 270 271 return new_local_tensor 272 273 274class Redistribute(torch.autograd.Function): 275 @staticmethod 276 def forward( # type: ignore[override] 277 # pyre-fixme[2]: Parameter must be annotated. 278 ctx, 279 input: "dtensor.DTensor", 280 device_mesh: DeviceMesh, 281 placements: Tuple[Placement, ...], 282 async_op: bool = False, 283 ): 284 current_spec = input._spec 285 ctx.current_spec = current_spec 286 ctx.async_op = async_op 287 288 if current_spec.placements != placements: 289 target_spec = DTensorSpec( 290 device_mesh, placements, tensor_meta=input._spec.tensor_meta 291 ) 292 293 local_tensor = input._local_tensor 294 output = redistribute_local_tensor( 295 local_tensor, current_spec, target_spec, async_op=async_op 296 ) 297 else: 298 # use the same local tensor if placements are the same. 299 output = input._local_tensor 300 target_spec = current_spec 301 302 return dtensor.DTensor( 303 output, 304 target_spec, 305 requires_grad=input.requires_grad, 306 ) 307 308 @staticmethod 309 def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] 310 previous_spec = ctx.current_spec 311 current_spec = grad_output._spec 312 async_op = ctx.async_op 313 314 local_tensor = grad_output._local_tensor 315 output = redistribute_local_tensor( 316 local_tensor, 317 current_spec, 318 previous_spec, 319 async_op=async_op, 320 is_backward=True, 321 ) 322 # normalize the target placement to replicate if it is partial 323 normalized_placements: List[Placement] = [] 324 for previous_placement in previous_spec.placements: 325 if previous_placement.is_partial(): 326 # keep target placement to replicate instead of partial in this case 327 normalized_placements.append(Replicate()) 328 else: 329 normalized_placements.append(previous_placement) 330 331 spec = DTensorSpec( 332 previous_spec.device_mesh, 333 tuple(normalized_placements), 334 tensor_meta=TensorMeta( 335 shape=grad_output.shape, 336 stride=grad_output.stride(), 337 dtype=grad_output.dtype, 338 ), 339 ) 340 output_dtensor = dtensor.DTensor( 341 output, 342 spec, 343 requires_grad=grad_output.requires_grad, 344 ) 345 346 return ( 347 output_dtensor, 348 None, 349 None, 350 None, 351 ) 352