1# mypy: allow-untyped-defs 2import copy 3from typing import List, Tuple 4 5import torch 6import torch.distributed as dist 7import torch.distributed._shard.sharding_spec as shard_spec 8from torch._C._distributed_c10d import ProcessGroup 9from torch.distributed._shard.metadata import ShardMetadata 10from torch.distributed._shard.sharding_spec._internals import ( 11 get_chunked_dim_size, 12 get_split_size, 13) 14from torch.distributed.nn.functional import all_to_all, all_to_all_single 15 16from .shard import Shard 17 18 19def get_idx_from_placements(placements, current_rank) -> int: 20 """ 21 Return the position of the current rank in the given placements. 22 23 Args: 24 placements(List[Union[_remote_device, str]]): 25 Specifies the placement of each shard of the Tensor. The size of 26 the list represents the number of shards to be created. This could 27 be a list of 28 :class:`torch.distributed._remote_device`'s. This list 29 could also contain a string which represents remote 30 device as accepted by 31 :class:`torch.distributed._remote_device` 32 current_rank (int): number of current device. 33 34 Returns: 35 A int which contains the position of current device in the placement list. 36 """ 37 for idx, placement in enumerate(placements): # type: ignore[attr-defined] 38 if current_rank == placement.rank(): # type: ignore[union-attr] 39 return idx 40 raise RuntimeError("current_rank not in the placement.") 41 42 43def build_reshard_metadata( 44 st_size: torch.Size, 45 sharding_spec: shard_spec.ShardingSpec, 46 world_size: int, 47) -> Tuple[List[ShardMetadata], List[int]]: 48 """ 49 Based the given sharding spec, we calculate the offset and local shard size. 50 We then build a ShardMetadata on top of the calculation result. 51 52 Args: 53 st_size (torch.Size): The size of the sharded tensor. 54 sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The 55 specification describing how the tensor is sharded. 56 world_size (int): number of ranks. 57 58 Returns: 59 A Tuple of the followings: 60 A List[`ShardMetadata`] which contains the metadata for the shard, including 61 offsets, lengths and device placement. 62 A List[int] which contains the ranks in the order of placement. 63 """ 64 shard_dim = int(sharding_spec.dim) # type: ignore[attr-defined] 65 shards_metadata = [None] * world_size 66 ranks = [] 67 offsets = [0] * len(st_size) 68 split_size = get_split_size(st_size[shard_dim], world_size) 69 for idx, placement in enumerate(sharding_spec.placements): # type: ignore[attr-defined] 70 ranks.append(placement.rank()) 71 sharded_dim_size = get_chunked_dim_size(st_size[shard_dim], split_size, idx) 72 local_tensor_size = list(st_size) 73 local_tensor_size[shard_dim] = sharded_dim_size 74 shards_metadata[placement.rank()] = ShardMetadata( # type: ignore[call-overload] 75 shard_offsets=copy.deepcopy(offsets), 76 shard_sizes=local_tensor_size, 77 placement=placement, 78 ) 79 offsets[shard_dim] += sharded_dim_size 80 return shards_metadata, ranks # type: ignore[return-value] 81 82 83def reshuffle_local_shard( 84 local_shard: torch.Tensor, 85 st_size: torch.Size, 86 sharding_spec: shard_spec.ShardingSpec, 87 resharding_spec: shard_spec.ShardingSpec, 88 pg: ProcessGroup, 89) -> Tuple[List[Shard], List[ShardMetadata]]: 90 """ 91 Reshuffle the local shard directly when the reshard dim is same as the original 92 sharding dim. Logically we do this in two step: 93 1. To collect all shards based on original sharding spec. 94 2. Reshard the tensor based on the given resharding spec. 95 96 In reality, we consolidate the two steps into one by sending the local tensor to 97 the new shard directly based on the resharding spec. 98 99 Args: 100 local_shard (Tensor): Local tensor stored in the current rank. 101 st_size (torch.Size): The size of the sharded tensor. 102 sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The 103 specification describing how the tensor is sharded originally. 104 resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The 105 specification describing how the tensor will be resharded. 106 pg (ProcessGroup): The process group to aggregate on. 107 108 Returns: 109 A Tuple of the followings: 110 A List[`Shard`] which contains the local tensor and its metadata. 111 A List[`ShardMetadata`] which contains the metadata for the shard, including 112 offsets, lengths and device placement. 113 """ 114 current_rank = dist.get_rank(pg) 115 world_size = dist.get_world_size(pg) 116 # Build shards_metadata first. 117 shards_metadata, ranks = build_reshard_metadata( 118 st_size, resharding_spec, world_size 119 ) 120 # Get input split size for all2all. 121 reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined] 122 split_size = get_split_size(st_size[reshard_dim], world_size) 123 input_split_sizes = [0] * world_size 124 idx = get_idx_from_placements(sharding_spec.placements, current_rank) # type: ignore[attr-defined] 125 new_rank = resharding_spec.placements[idx].rank() # type: ignore[union-attr, attr-defined] 126 input_split_sizes[new_rank] = local_shard.size(reshard_dim) 127 # Get output split size for all2all. 128 output_split_sizes = [0] * world_size 129 new_idx = ranks.index(current_rank) 130 sharded_dim_size = get_chunked_dim_size(st_size[reshard_dim], split_size, new_idx) 131 output_split_sizes[new_rank] = sharded_dim_size 132 # Get gathered_input for all2all. 133 local_shard = local_shard.transpose(0, reshard_dim).contiguous() 134 gathered_input_size = list(local_shard.size()) 135 gathered_input_size[0] = sharded_dim_size 136 gathered_input = torch.empty( 137 gathered_input_size, device=local_shard.device, dtype=local_shard.dtype 138 ) 139 # all2all. 140 local_shard = all_to_all_single( 141 gathered_input, 142 local_shard, 143 input_split_sizes=input_split_sizes, 144 output_split_sizes=output_split_sizes, 145 group=pg, 146 ) 147 local_tensor = local_shard.transpose(0, reshard_dim).contiguous() 148 local_shards = [Shard(local_tensor, shards_metadata[current_rank])] 149 return local_shards, shards_metadata 150 151 152def reshard_local_shard( 153 local_tensor: torch.Tensor, 154 st_size: torch.Size, 155 sharding_spec: shard_spec.ShardingSpec, 156 resharding_spec: shard_spec.ShardingSpec, 157 pg: ProcessGroup, 158) -> Tuple[List[Shard], List[ShardMetadata]]: 159 """ 160 Reshard a sharded tensor given the ``resharding_spec``. When the reshard dim is 161 different from the original sharding dim, we need to do two steps logically: 162 1. To collect all shards based on original sharding spec. 163 2. Reshard the tensor based on the given resharding spec. 164 165 In reality, we consolidate the two steps into one by sending each rank the new 166 shard based on the resharding spec. 167 168 Args: 169 local_tensor (Tensor): Local tensor stored in the current rank. 170 st_size (torch.Size): The size of the sharded tensor. 171 sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The 172 specification describing how the tensor is sharded originally. 173 resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The 174 specification describing how the tensor will be resharded. 175 pg (ProcessGroup): The process group to aggregate on. 176 177 Returns: 178 A Tuple of the followings: 179 A List[`Shard`] which contains the local tensor and its metadata. 180 A List[`ShardMetadata`] which contains the metadata for the shard, including 181 offsets, lengths and device placement. 182 """ 183 current_rank = dist.get_rank(pg) 184 world_size = dist.get_world_size(pg) 185 current_sharding_dim = int(sharding_spec.dim) # type: ignore[attr-defined] 186 reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined] 187 188 # Build shards_metadata first. 189 shards_metadata, ranks = build_reshard_metadata( 190 st_size, resharding_spec, world_size 191 ) 192 193 # Compute expected size 194 input_split_sizes = [] 195 for metadata in shards_metadata: 196 input_split_sizes.append(metadata.shard_sizes[reshard_dim]) 197 rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1)) 198 199 if rearrange_input: 200 # Need to re-arrange reshard_dim of local_tensor before all2all. 201 indices: List[int] = [] 202 for metadata in shards_metadata: 203 offset_start_idx = metadata.shard_offsets[reshard_dim] 204 split_size = metadata.shard_sizes[reshard_dim] 205 indices += range(offset_start_idx, offset_start_idx + split_size) 206 local_tensor = local_tensor.index_select( 207 reshard_dim, torch.tensor(indices, device=local_tensor.device) 208 ) 209 210 # Because reshard_dim != original shard_dim. We need to compute the 211 # size of tensor from each rank. 212 output_tensor_list = [torch.tensor(1)] * world_size 213 split_size = get_split_size(st_size[current_sharding_dim], world_size) 214 rearrange_output_list = False 215 indices = [] 216 for idx, placement in enumerate(sharding_spec.placements): # type: ignore[attr-defined] 217 sharded_dim_size = get_chunked_dim_size( 218 st_size[current_sharding_dim], split_size, idx 219 ) 220 output_tensor_size = list(st_size) 221 output_tensor_size[current_sharding_dim] = sharded_dim_size 222 output_tensor_size[reshard_dim] = input_split_sizes[current_rank] 223 output_tensor_list[ 224 placement.rank() 225 ] = torch.empty( # type: ignore[union-attr, index] 226 output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype 227 ) 228 indices.append(placement.rank()) # type: ignore[union-attr, index, arg-type] 229 if idx != placement.rank(): # type: ignore[union-attr] 230 rearrange_output_list = True 231 232 # Perform autograd enabled all2all. 233 input_tensor_tuple = torch.split(local_tensor, input_split_sizes, dim=reshard_dim) 234 input_tensor_list = [tensor.contiguous() for tensor in input_tensor_tuple] 235 output_tensor_list = all_to_all( 236 output_tensor_list, 237 input_tensor_list, 238 group=pg, 239 ) 240 241 if rearrange_output_list: 242 # Need to re-arrange original shard_dim of output_tensor_list. 243 output_tensor_list = [output_tensor_list[idx] for idx in indices] # type: ignore[call-overload] 244 local_tensor = torch.cat(output_tensor_list, dim=current_sharding_dim) 245 local_shards = [Shard(local_tensor, shards_metadata[current_rank])] 246 return local_shards, shards_metadata 247