xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/_redistribute.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3import logging
4from functools import lru_cache
5from typing import cast, List, NamedTuple, Tuple
6
7import torch
8import torch.distributed._functional_collectives as funcol
9import torch.distributed.tensor._api as dtensor
10from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
11from torch.distributed.tensor.device_mesh import DeviceMesh
12from torch.distributed.tensor.placement_types import (
13    Partial,
14    Placement,
15    Replicate,
16    Shard,
17)
18
19
20logger = logging.getLogger(__name__)
21
22
23class _TransformInfo(NamedTuple):
24    mesh_dim: int
25    src_dst_placements: Tuple[Placement, Placement]
26    # logical_shape on this mesh dimension
27    logical_shape: List[int]
28
29
30@lru_cache(maxsize=None)
31def _gen_transform_infos(
32    src_spec: DTensorSpec,
33    dst_spec: DTensorSpec,
34) -> List[_TransformInfo]:
35    """
36    Generate the transform infos from the source placements to the target placements.
37
38    To transform from source to target placement it might have multiple steps, i.e. it
39    might decompose Si -> Sj into Si -> R -> Sj.
40    This would detect if there're mis-aligned/nested shardings between src/dst placements.
41    E.g. Suppose the redistribution to perform is (Shard(0), Shard(0)) -> (Replicate(), Shard(0)),
42    in this case Shard(0) -> Shard(0) for mesh dimension 1 actually needs resharding, because in
43    the former is a nested-sharding of a tensor already already sharded dimension 0, whereras
44    the latter is the first sharding on tensor dimension 0.
45    """
46    transform_infos: List[_TransformInfo] = []
47
48    device_mesh = src_spec.device_mesh
49    my_coordinate = device_mesh.get_coordinate()
50    assert my_coordinate is not None
51
52    # logical shape records the logic tensor shape on the mesh dimension
53    # this is useful to ensure uneven sharding gets correct output shape
54    initial_logical_shape = list(src_spec.shape)
55    mesh_dims_to_logical_shape = [initial_logical_shape]
56
57    if device_mesh.ndim == 1:
58        # if device_mesh is 1D, redistribute is a simple direct transformation
59        transform_infos.append(
60            _TransformInfo(
61                mesh_dim=0,
62                src_dst_placements=(src_spec.placements[0], dst_spec.placements[0]),
63                logical_shape=initial_logical_shape,
64            )
65        )
66        return transform_infos
67
68    # Handle multi-dim device mesh placement redistribution
69    # First, we need to build the logical shape for each mesh dim
70    # for correct allgathering uneven shards on each mesh dim (with dynamic padding)
71    for i, (src, dst) in enumerate(zip(src_spec.placements, dst_spec.placements)):
72        current_logical_shape = mesh_dims_to_logical_shape[i]
73        if isinstance(src, Shard):
74            if i < device_mesh.ndim - 1:
75                # calculate and save the logical shape for this sharding
76                mesh_dim_size = device_mesh.size(mesh_dim=i)
77                local_shard_size, _ = src._local_shard_size_on_dim(
78                    current_logical_shape[src.dim],
79                    mesh_dim_size,
80                    my_coordinate[i],
81                )
82                new_logical_shape = list(current_logical_shape)
83                new_logical_shape[src.dim] = local_shard_size
84                mesh_dims_to_logical_shape.append(new_logical_shape)
85        else:
86            mesh_dims_to_logical_shape.append(current_logical_shape)
87
88    # Next, we need to derive the transform infos from src to dst placements,
89    # here we use a greedy search with step by step state transformations
90    current_placements = list(src_spec.placements)
91    target_placements = list(dst_spec.placements)
92
93    if src_spec.num_shards > 1:
94        # If src_spec have sharding, it could potentially have sharding that is misaligned with dst_spec
95        # a common case of this is nested sharding (i.e. (S(0), S(0)) -> (R, S(0))).
96        # In those cases, we first traverse from inner placement to outer placement
97        # to detect misaligned shardings and properly replicate nested sharding first.
98        for mesh_dim in reversed(range(len(current_placements))):
99            current = current_placements[mesh_dim]
100            target = target_placements[mesh_dim]
101            # If target is not Shard, we can directly redistribute since we are traversing from innner
102            # to outer placements here
103            if isinstance(target, Shard):
104                # If target is Shard, check for nested sharding on the tensor dim BEFORE the current mesh_dim
105                shard_dim = target.dim
106                current_mesh_sharding, target_mesh_sharding = [], []
107                for i, (s, p) in enumerate(zip(current_placements, target_placements)):
108                    if i >= mesh_dim:
109                        break
110                    if s.is_shard(shard_dim):
111                        current_mesh_sharding.append(i)
112                    if p.is_shard(shard_dim):
113                        target_mesh_sharding.append(i)
114
115                if current_mesh_sharding != target_mesh_sharding:
116                    # if current/target_placements have misaligned sharding on the tensor dim BEFORE the current
117                    # mesh_dim, we need to replicate the tensor on the mesh dim first to clear the nested sharding
118                    target = Replicate()
119
120            if current != target:
121                transform_infos.append(
122                    _TransformInfo(
123                        mesh_dim=mesh_dim,
124                        src_dst_placements=(current, target),
125                        logical_shape=mesh_dims_to_logical_shape[mesh_dim],
126                    )
127                )
128                current_placements[mesh_dim] = target
129
130    # We always traverse from outer placement to inner placement to collect the remaining
131    # needed transform infos (i.e. the replication from nested sharding might need to further
132    # perform resharding to Shard again)
133    for mesh_dim, (current, target) in enumerate(
134        zip(current_placements, target_placements)
135    ):
136        if current != target:
137            transform_infos.append(
138                _TransformInfo(
139                    mesh_dim=mesh_dim,
140                    src_dst_placements=(current, target),
141                    logical_shape=mesh_dims_to_logical_shape[mesh_dim],
142                )
143            )
144            current_placements[mesh_dim] = target
145
146    return transform_infos
147
148
149def redistribute_local_tensor(
150    local_tensor: torch.Tensor,
151    current_spec: DTensorSpec,
152    target_spec: DTensorSpec,
153    *,
154    async_op: bool = False,
155    is_backward: bool = False,
156) -> torch.Tensor:
157    """
158    This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to
159    the target DTensorSpec, which involves the necessary collective calls to transform
160    the local shard of the DTensor from its current spec to the target spec.
161    """
162
163    if current_spec.mesh != target_spec.mesh:
164        # TODO: alltoall/permute reshuffling to change device_mesh if they are not the same
165        raise NotImplementedError("Cross device mesh comm not supported yet!")
166
167    new_local_tensor = None
168    device_mesh = current_spec.mesh
169
170    my_coordinate = device_mesh.get_coordinate()
171
172    if my_coordinate is None:
173        # if rank is not part of mesh, we skip redistribute and simply return local_tensor,
174        # which should be an empty tensor
175        return local_tensor
176
177    transform_infos = _gen_transform_infos(current_spec, target_spec)
178
179    for transform_info in transform_infos:
180        i = transform_info.mesh_dim
181        current, target = transform_info.src_dst_placements
182        num_chunks = device_mesh.size(mesh_dim=i)
183
184        if current == target:
185            # short cut, just use the original local tensor
186            new_local_tensor = local_tensor
187            continue
188
189        logger.debug("redistribute from %s to %s on mesh dim %s", current, target, i)
190
191        if target.is_replicate():
192            # Case 1: target is Replicate
193            if current.is_partial():
194                partial_spec = cast(Partial, current)
195                new_local_tensor = partial_spec._reduce_value(
196                    local_tensor, device_mesh, i
197                )
198            elif current.is_shard():
199                current_placement = cast(Shard, current)
200                new_local_tensor = current_placement._to_replicate_tensor(
201                    local_tensor, device_mesh, i, transform_info.logical_shape
202                )
203            else:
204                raise RuntimeError(
205                    f"redistribute from {current} to {target} not supported yet"
206                )
207        elif target.is_shard():
208            # Case 2: target is Shard
209            target_placement = cast(Shard, target)
210            target_dim = target_placement.dim
211            if current.is_partial():
212                partial_spec = cast(Partial, current)
213                new_local_tensor = partial_spec._reduce_shard_value(
214                    local_tensor, device_mesh, i, target_placement
215                )
216            elif current.is_replicate():
217                # split the tensor and return the corresponding cloned local shard
218                new_local_tensor = target_placement._replicate_to_shard(
219                    local_tensor, device_mesh, i, my_coordinate[i]
220                )
221            else:
222                assert (
223                    current.is_shard()
224                ), f"Current placement should be shard but found {current}"
225                shard_spec = cast(Shard, current)
226                if shard_spec.dim != target_placement.dim:
227                    new_local_tensor = shard_spec._to_new_shard_dim(
228                        local_tensor,
229                        device_mesh,
230                        i,
231                        transform_info.logical_shape,
232                        target_placement.dim,
233                    )
234        elif target.is_partial():
235            if current.is_replicate():
236                partial_spec = cast(Partial, target)
237                # skip the replicate to partial transformation when we are in backward pass
238                # In this case we keep the grad as replicate, this is because we don't
239                # want to convert the replicated gradients back to partial, although
240                # that's logically conform with the same layout, converting the gradients
241                # back to partial is actually useless as you would have to do reduce later
242                # which would be more expensive than keeping it replicate! For this reason,
243                # we keep the replicate grad here.
244                new_local_tensor = (
245                    partial_spec._partition_value(local_tensor, device_mesh, i)
246                    if not is_backward
247                    else local_tensor
248                )
249            elif current.is_shard():
250                if not is_backward:
251                    raise RuntimeError(
252                        f"redistribute from {current} to {target} not supported yet"
253                    )
254                # for backward shard -> partial, we just need to convert the shard to replicate
255                current_placement = cast(Shard, current)
256                new_local_tensor = current_placement._to_replicate_tensor(
257                    local_tensor, device_mesh, i, transform_info.logical_shape
258                )
259            else:
260                # partial -> partial no op, should never hit
261                new_local_tensor = local_tensor
262
263        assert new_local_tensor is not None
264        local_tensor = new_local_tensor
265
266    assert new_local_tensor is not None, "redistribute failed!"
267
268    if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor):
269        new_local_tensor = new_local_tensor.wait()
270
271    return new_local_tensor
272
273
274class Redistribute(torch.autograd.Function):
275    @staticmethod
276    def forward(  # type: ignore[override]
277        # pyre-fixme[2]: Parameter must be annotated.
278        ctx,
279        input: "dtensor.DTensor",
280        device_mesh: DeviceMesh,
281        placements: Tuple[Placement, ...],
282        async_op: bool = False,
283    ):
284        current_spec = input._spec
285        ctx.current_spec = current_spec
286        ctx.async_op = async_op
287
288        if current_spec.placements != placements:
289            target_spec = DTensorSpec(
290                device_mesh, placements, tensor_meta=input._spec.tensor_meta
291            )
292
293            local_tensor = input._local_tensor
294            output = redistribute_local_tensor(
295                local_tensor, current_spec, target_spec, async_op=async_op
296            )
297        else:
298            # use the same local tensor if placements are the same.
299            output = input._local_tensor
300            target_spec = current_spec
301
302        return dtensor.DTensor(
303            output,
304            target_spec,
305            requires_grad=input.requires_grad,
306        )
307
308    @staticmethod
309    def backward(ctx, grad_output: "dtensor.DTensor"):  # type: ignore[override]
310        previous_spec = ctx.current_spec
311        current_spec = grad_output._spec
312        async_op = ctx.async_op
313
314        local_tensor = grad_output._local_tensor
315        output = redistribute_local_tensor(
316            local_tensor,
317            current_spec,
318            previous_spec,
319            async_op=async_op,
320            is_backward=True,
321        )
322        # normalize the target placement to replicate if it is partial
323        normalized_placements: List[Placement] = []
324        for previous_placement in previous_spec.placements:
325            if previous_placement.is_partial():
326                # keep target placement to replicate instead of partial in this case
327                normalized_placements.append(Replicate())
328            else:
329                normalized_placements.append(previous_placement)
330
331        spec = DTensorSpec(
332            previous_spec.device_mesh,
333            tuple(normalized_placements),
334            tensor_meta=TensorMeta(
335                shape=grad_output.shape,
336                stride=grad_output.stride(),
337                dtype=grad_output.dtype,
338            ),
339        )
340        output_dtensor = dtensor.DTensor(
341            output,
342            spec,
343            requires_grad=grad_output.requires_grad,
344        )
345
346        return (
347            output_dtensor,
348            None,
349            None,
350            None,
351        )
352