xref: /aosp_15_r20/external/pytorch/torch/distributed/_shard/sharded_tensor/reshard.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3from typing import List, Tuple
4
5import torch
6import torch.distributed as dist
7import torch.distributed._shard.sharding_spec as shard_spec
8from torch._C._distributed_c10d import ProcessGroup
9from torch.distributed._shard.metadata import ShardMetadata
10from torch.distributed._shard.sharding_spec._internals import (
11    get_chunked_dim_size,
12    get_split_size,
13)
14from torch.distributed.nn.functional import all_to_all, all_to_all_single
15
16from .shard import Shard
17
18
19def get_idx_from_placements(placements, current_rank) -> int:
20    """
21    Return the position of the current rank in the given placements.
22
23    Args:
24        placements(List[Union[_remote_device, str]]):
25            Specifies the placement of each shard of the Tensor. The size of
26            the list represents the number of shards to be created. This could
27            be a list of
28            :class:`torch.distributed._remote_device`'s. This list
29            could also contain a string which represents remote
30            device as accepted by
31            :class:`torch.distributed._remote_device`
32        current_rank (int): number of current device.
33
34    Returns:
35        A int which contains the position of current device in the placement list.
36    """
37    for idx, placement in enumerate(placements):  # type: ignore[attr-defined]
38        if current_rank == placement.rank():  # type: ignore[union-attr]
39            return idx
40    raise RuntimeError("current_rank not in the placement.")
41
42
43def build_reshard_metadata(
44    st_size: torch.Size,
45    sharding_spec: shard_spec.ShardingSpec,
46    world_size: int,
47) -> Tuple[List[ShardMetadata], List[int]]:
48    """
49    Based the given sharding spec, we calculate the offset and local shard size.
50    We then build a ShardMetadata on top of the calculation result.
51
52    Args:
53        st_size (torch.Size): The size of the sharded tensor.
54        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
55            specification describing how the tensor is sharded.
56        world_size (int): number of ranks.
57
58    Returns:
59        A Tuple of the followings:
60            A List[`ShardMetadata`] which contains the metadata for the shard, including
61                offsets, lengths and device placement.
62            A List[int] which contains the ranks in the order of placement.
63    """
64    shard_dim = int(sharding_spec.dim)  # type: ignore[attr-defined]
65    shards_metadata = [None] * world_size
66    ranks = []
67    offsets = [0] * len(st_size)
68    split_size = get_split_size(st_size[shard_dim], world_size)
69    for idx, placement in enumerate(sharding_spec.placements):  # type: ignore[attr-defined]
70        ranks.append(placement.rank())
71        sharded_dim_size = get_chunked_dim_size(st_size[shard_dim], split_size, idx)
72        local_tensor_size = list(st_size)
73        local_tensor_size[shard_dim] = sharded_dim_size
74        shards_metadata[placement.rank()] = ShardMetadata(  # type: ignore[call-overload]
75            shard_offsets=copy.deepcopy(offsets),
76            shard_sizes=local_tensor_size,
77            placement=placement,
78        )
79        offsets[shard_dim] += sharded_dim_size
80    return shards_metadata, ranks  # type: ignore[return-value]
81
82
83def reshuffle_local_shard(
84    local_shard: torch.Tensor,
85    st_size: torch.Size,
86    sharding_spec: shard_spec.ShardingSpec,
87    resharding_spec: shard_spec.ShardingSpec,
88    pg: ProcessGroup,
89) -> Tuple[List[Shard], List[ShardMetadata]]:
90    """
91    Reshuffle the local shard directly when the reshard dim is same as the original
92    sharding dim. Logically we do this in two step:
93    1. To collect all shards based on original sharding spec.
94    2. Reshard the tensor based on the given resharding spec.
95
96    In reality, we consolidate the two steps into one by sending the local tensor to
97    the new shard directly based on the resharding spec.
98
99    Args:
100        local_shard (Tensor): Local tensor stored in the current rank.
101        st_size (torch.Size): The size of the sharded tensor.
102        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
103            specification describing how the tensor is sharded originally.
104        resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
105            specification describing how the tensor will be resharded.
106        pg (ProcessGroup): The process group to aggregate on.
107
108    Returns:
109        A Tuple of the followings:
110            A List[`Shard`] which contains the local tensor and its metadata.
111            A List[`ShardMetadata`] which contains the metadata for the shard, including
112                offsets, lengths and device placement.
113    """
114    current_rank = dist.get_rank(pg)
115    world_size = dist.get_world_size(pg)
116    # Build shards_metadata first.
117    shards_metadata, ranks = build_reshard_metadata(
118        st_size, resharding_spec, world_size
119    )
120    # Get input split size for all2all.
121    reshard_dim = int(resharding_spec.dim)  # type: ignore[attr-defined]
122    split_size = get_split_size(st_size[reshard_dim], world_size)
123    input_split_sizes = [0] * world_size
124    idx = get_idx_from_placements(sharding_spec.placements, current_rank)  # type: ignore[attr-defined]
125    new_rank = resharding_spec.placements[idx].rank()  # type: ignore[union-attr, attr-defined]
126    input_split_sizes[new_rank] = local_shard.size(reshard_dim)
127    # Get output split size for all2all.
128    output_split_sizes = [0] * world_size
129    new_idx = ranks.index(current_rank)
130    sharded_dim_size = get_chunked_dim_size(st_size[reshard_dim], split_size, new_idx)
131    output_split_sizes[new_rank] = sharded_dim_size
132    # Get gathered_input for all2all.
133    local_shard = local_shard.transpose(0, reshard_dim).contiguous()
134    gathered_input_size = list(local_shard.size())
135    gathered_input_size[0] = sharded_dim_size
136    gathered_input = torch.empty(
137        gathered_input_size, device=local_shard.device, dtype=local_shard.dtype
138    )
139    # all2all.
140    local_shard = all_to_all_single(
141        gathered_input,
142        local_shard,
143        input_split_sizes=input_split_sizes,
144        output_split_sizes=output_split_sizes,
145        group=pg,
146    )
147    local_tensor = local_shard.transpose(0, reshard_dim).contiguous()
148    local_shards = [Shard(local_tensor, shards_metadata[current_rank])]
149    return local_shards, shards_metadata
150
151
152def reshard_local_shard(
153    local_tensor: torch.Tensor,
154    st_size: torch.Size,
155    sharding_spec: shard_spec.ShardingSpec,
156    resharding_spec: shard_spec.ShardingSpec,
157    pg: ProcessGroup,
158) -> Tuple[List[Shard], List[ShardMetadata]]:
159    """
160    Reshard a sharded tensor given the ``resharding_spec``. When the reshard dim is
161    different from the original sharding dim, we need to do two steps logically:
162    1. To collect all shards based on original sharding spec.
163    2. Reshard the tensor based on the given resharding spec.
164
165    In reality, we consolidate the two steps into one by sending each rank the new
166    shard based on the resharding spec.
167
168    Args:
169        local_tensor (Tensor): Local tensor stored in the current rank.
170        st_size (torch.Size): The size of the sharded tensor.
171        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
172            specification describing how the tensor is sharded originally.
173        resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
174            specification describing how the tensor will be resharded.
175        pg (ProcessGroup): The process group to aggregate on.
176
177    Returns:
178        A Tuple of the followings:
179            A List[`Shard`] which contains the local tensor and its metadata.
180            A List[`ShardMetadata`] which contains the metadata for the shard, including
181                offsets, lengths and device placement.
182    """
183    current_rank = dist.get_rank(pg)
184    world_size = dist.get_world_size(pg)
185    current_sharding_dim = int(sharding_spec.dim)  # type: ignore[attr-defined]
186    reshard_dim = int(resharding_spec.dim)  # type: ignore[attr-defined]
187
188    # Build shards_metadata first.
189    shards_metadata, ranks = build_reshard_metadata(
190        st_size, resharding_spec, world_size
191    )
192
193    # Compute expected size
194    input_split_sizes = []
195    for metadata in shards_metadata:
196        input_split_sizes.append(metadata.shard_sizes[reshard_dim])
197    rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1))
198
199    if rearrange_input:
200        # Need to re-arrange reshard_dim of local_tensor before all2all.
201        indices: List[int] = []
202        for metadata in shards_metadata:
203            offset_start_idx = metadata.shard_offsets[reshard_dim]
204            split_size = metadata.shard_sizes[reshard_dim]
205            indices += range(offset_start_idx, offset_start_idx + split_size)
206        local_tensor = local_tensor.index_select(
207            reshard_dim, torch.tensor(indices, device=local_tensor.device)
208        )
209
210    # Because reshard_dim != original shard_dim. We need to compute the
211    # size of tensor from each rank.
212    output_tensor_list = [torch.tensor(1)] * world_size
213    split_size = get_split_size(st_size[current_sharding_dim], world_size)
214    rearrange_output_list = False
215    indices = []
216    for idx, placement in enumerate(sharding_spec.placements):  # type: ignore[attr-defined]
217        sharded_dim_size = get_chunked_dim_size(
218            st_size[current_sharding_dim], split_size, idx
219        )
220        output_tensor_size = list(st_size)
221        output_tensor_size[current_sharding_dim] = sharded_dim_size
222        output_tensor_size[reshard_dim] = input_split_sizes[current_rank]
223        output_tensor_list[
224            placement.rank()
225        ] = torch.empty(  # type: ignore[union-attr, index]
226            output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype
227        )
228        indices.append(placement.rank())  # type: ignore[union-attr, index, arg-type]
229        if idx != placement.rank():  # type: ignore[union-attr]
230            rearrange_output_list = True
231
232    # Perform autograd enabled all2all.
233    input_tensor_tuple = torch.split(local_tensor, input_split_sizes, dim=reshard_dim)
234    input_tensor_list = [tensor.contiguous() for tensor in input_tensor_tuple]
235    output_tensor_list = all_to_all(
236        output_tensor_list,
237        input_tensor_list,
238        group=pg,
239    )
240
241    if rearrange_output_list:
242        # Need to re-arrange original shard_dim of output_tensor_list.
243        output_tensor_list = [output_tensor_list[idx] for idx in indices]  # type: ignore[call-overload]
244    local_tensor = torch.cat(output_tensor_list, dim=current_sharding_dim)
245    local_shards = [Shard(local_tensor, shards_metadata[current_rank])]
246    return local_shards, shards_metadata
247