xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/_collective_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3import math
4from dataclasses import dataclass
5from functools import lru_cache
6from typing import List, Optional
7
8import torch
9import torch.distributed._functional_collectives as funcol
10import torch.distributed.tensor._dtensor_spec as dtensor_spec
11from torch._C._distributed_c10d import _resolve_process_group
12from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
13from torch.distributed.distributed_c10d import (
14    _get_group_size_by_name,
15    broadcast,
16    get_global_rank,
17    get_group_rank,
18    get_rank,
19    GroupMember,
20    ProcessGroup,
21    scatter,
22    Work,
23)
24
25
26logger = logging.getLogger(__name__)
27
28
29if not torch._running_with_deploy():
30
31    @torch.library.register_fake("_dtensor::shard_dim_alltoall")
32    def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name):
33        group_size = _get_group_size_by_name(group_name)
34        stacked_list = [torch.empty_like(input) for _ in range(group_size)]
35        group = _resolve_process_group(group_name)
36        group_rank = get_group_rank(group, get_rank())
37
38        return torch.cat(stacked_list, dim=gather_dim).chunk(group_size, dim=shard_dim)[
39            group_rank
40        ]
41
42else:
43    import warnings
44
45    warnings.warn(
46        "PyTorch Distributed functional collectives do not work with torch::deploy."
47    )
48
49
50def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim):
51    if mesh.device_type == "cpu":
52        # Gloo does not support alltoall, so falling back to allgather + chunk
53
54        # TODO: This logs way too much
55        logger.warning(
56            "CPU process group does not support alltoall yet, falling back with allgather + chunk!"
57        )
58        out = funcol.all_gather_tensor(input, gather_dim, (mesh, mesh_dim))
59        if isinstance(out, funcol.AsyncCollectiveTensor):
60            # stick to the same behavior for the alltoall case, remove this once we enable alltoall async
61            out = out.wait()
62        out = torch.chunk(out, mesh.size(mesh_dim), dim=shard_dim)[
63            mesh.get_local_rank(mesh_dim)
64        ]
65        return out.contiguous() if not out.is_contiguous() else out
66
67    group_name = funcol._resolve_group_name((mesh, mesh_dim))
68    # TODO: enable async op for shard_dim_alltoall
69    return torch.ops._dtensor.shard_dim_alltoall(
70        input, gather_dim, shard_dim, group_name
71    )
72
73
74def mesh_scatter(
75    output: torch.Tensor,
76    scatter_list: List[torch.Tensor],
77    mesh: DeviceMesh,
78    mesh_dim: int = 0,
79    async_op: bool = False,
80) -> Optional[Work]:
81    """
82    scatter a list of tensors to a device mesh dimension. We by default
83    use the first rank of the mesh dimension as the source of truth, i.e
84    for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will
85    scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank
86    2 to rank 2/3.
87
88    Args:
89        output (torch.Tensor): the tensor to receive the scattered list.
90        scatter_list (List[torch.Tensor]): the tensor list to be scattered.
91        mesh_dim (int, optional): indicate which mesh dimension we want
92            to scatter on, we by default choose the first rank on the
93            mesh dimension as source of truth.
94
95    Returns:
96        A :class:`Work` object
97    """
98    # TODO: Ideally we should use the meta tensor way
99    # (to register a meta kernel for the collective op)
100    # so that it would avoid the communication. Need to
101    # remove the check below once that is done.
102    if output.is_meta:
103        return None
104    dim_group = mesh.get_group(mesh_dim)
105    assert isinstance(dim_group, ProcessGroup)
106    # src need to be global rank
107    src_for_dim = 0
108
109    if dim_group is not GroupMember.WORLD:
110        src_for_dim = get_global_rank(dim_group, 0)
111
112    if src_for_dim == get_rank():
113        fut = scatter(
114            output,
115            scatter_list=scatter_list,
116            src=src_for_dim,
117            group=dim_group,
118            async_op=async_op,
119        )
120    else:
121        fut = scatter(
122            output,
123            scatter_list=None,
124            src=src_for_dim,
125            group=dim_group,
126            async_op=async_op,
127        )
128
129    return fut
130
131
132def mesh_broadcast(
133    tensor: torch.Tensor,
134    mesh: DeviceMesh,
135    mesh_dim: int = 0,
136    async_op: bool = False,
137) -> Optional[Work]:
138    """
139    broadcast the tensor to a device mesh dimension. We by default
140    use the first rank of the mesh dimension as the source of truth, i.e
141    for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will
142    broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2
143    to rank 2/3.
144
145    Args:
146        tensor (torch.Tensor): tensor to broadcast.
147        mesh_dim (int, optional): indicate which mesh dimension we want
148            to scatter on, we by default choose the first rank on the
149            mesh dimension as source of truth.
150
151    Returns:
152        A :class:`Work` object
153    """
154    # TODO: Ideally we should use the meta tensor way
155    # (to register a meta kernel for the collective op)
156    # so that it would avoid the communication. Need to
157    # remove the check below once that is done.
158    if tensor.is_meta:
159        return None
160    dim_group = mesh.get_group(mesh_dim)
161    assert isinstance(dim_group, ProcessGroup)
162    # src need to be global rank
163    src_for_dim = 0
164    if dim_group is not GroupMember.WORLD:
165        src_for_dim = get_global_rank(dim_group, 0)
166
167    return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op)
168
169
170def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
171    if pad_size == 0:
172        return tensor
173    pad = [0, 0] * (tensor.ndim - pad_dim)
174    pad[-1] = pad_size
175    return torch.nn.functional.pad(tensor, pad)
176
177
178def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
179    if pad_size == 0:
180        return tensor
181    return tensor.narrow(
182        pad_dim,
183        start=0,
184        length=tensor.size(pad_dim) - pad_size,
185    )
186
187
188def fill_empty_tensor_to_shards(
189    shards: List[torch.Tensor], shard_dim: int, num_empty_tensors: int
190) -> List[torch.Tensor]:
191    if num_empty_tensors == 0:
192        return shards
193    tensor_size = list(shards[0].size())
194    tensor_size = [
195        size if idx != shard_dim else 0 for idx, size in enumerate(tensor_size)
196    ]
197    tensor = shards[0].new_zeros(tensor_size)
198    for _ in range(num_empty_tensors):
199        shards.append(tensor)
200    return shards
201
202
203def check_tensor_meta(
204    local_tensor, check_shape_stride=False
205) -> Optional["dtensor_spec.TensorMeta"]:
206    local_metadata = {
207        "dtype": local_tensor.dtype,
208        "requires_grad": local_tensor.requires_grad,
209    }
210
211    if check_shape_stride:
212        local_metadata.update(
213            {"shape": local_tensor.shape, "stride": local_tensor.stride()}
214        )
215
216    gathered_metadata = [None for _ in range(torch.distributed.get_world_size())]
217    torch.distributed.all_gather_object(gathered_metadata, local_metadata)
218
219    # Check if metadata is consistent across ranks
220    if not all(meta == local_metadata for meta in gathered_metadata):
221        raise ValueError(
222            "Inconsistent tensor metadata (including shape and stride) across ranks."
223        )
224    return None
225
226
227def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int:
228    assert spec.tensor_meta is not None, "spec should have tensor meta defined!"
229    return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape)
230
231
232@dataclass
233class MeshTopoInfo:
234    """
235    Mesh information for collective cost estimation
236    """
237
238    mesh: DeviceMesh
239    mesh_dim_devices: List[int]
240    mesh_dim_bandwidth: List[float]
241    mesh_dim_latency: List[float]
242
243    @staticmethod
244    @lru_cache(None)
245    def build_from_mesh(mesh: DeviceMesh) -> "MeshTopoInfo":
246        # Generate mesh topology info for intra-host/inter-host communication pattern
247        # Note that we made bunch of assumptions for simplicity:
248        # 1. we assume the mesh is homogeneous, and it's gpu/nccl model
249        # 2. we assume gpu arch is Ampere or Hopper
250        # 3. we assume collectives are all ring base algo for now
251        num_devices_per_host = _mesh_resources.num_devices_per_host(mesh.device_type)
252        # the base bw number (intra-node), GB/s
253        base_bw = 87.7
254        mesh_dim_bandwidth = [base_bw] * mesh.ndim
255        # the latency in terms of us (intra-node, nv-link)
256        mesh_dim_latency = [0.6] * mesh.ndim
257        mesh_dim_devices = [1] * mesh.ndim
258
259        total_num_devices = 1
260        for mesh_dim in reversed(range(mesh.ndim)):
261            num_devices = mesh.size(mesh_dim)
262            mesh_dim_devices[mesh_dim] = num_devices
263            total_num_devices *= num_devices
264            if total_num_devices > num_devices_per_host:
265                # magic number for inter-host communication bandwidth/latency factor
266                # This number assumes latest GPU arch, i.e. Ampere or Hopper
267                # TODO: see if we need to tweak this or offer a way for user
268                # to specify the bandwidths/latency
269                mesh_dim_bandwidth[mesh_dim] *= 0.22
270                # set to ethernet latency for inter-host
271                mesh_dim_latency[mesh_dim] = 2.7
272
273        return MeshTopoInfo(
274            mesh, mesh_dim_devices, mesh_dim_bandwidth, mesh_dim_latency
275        )
276
277
278def allgather_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
279    num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
280    mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
281    num_hops = num_devices_on_mesh_dim - 1
282    # base latency + comm latency
283    latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]  # us
284    bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth  # s
285    return latency + bw * 1e6  # rescale to us
286
287
288def allreduce_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
289    num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
290    mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
291    # allreduce have almost 2x comm bytes compare to allgather/reduce_scatter
292    num_hops = 2 * num_devices_on_mesh_dim - 1
293
294    latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]
295    bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth
296    return latency + bw * 1e6
297
298
299def reduce_scatter_cost(
300    bytes_gb: float,
301    mesh_topo: MeshTopoInfo,
302    mesh_dim: int,
303) -> float:
304    num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
305    mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
306    num_hops = num_devices_on_mesh_dim - 1
307    # base latency + comm latency
308    latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]
309    bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth
310    return latency + bw * 1e6
311
312
313def redistribute_cost(
314    current_spec: "dtensor_spec.DTensorSpec",
315    target_spec: "dtensor_spec.DTensorSpec",
316) -> float:
317    """
318    This function returns the cost of redistribute from current to target DTensorSpec.
319
320    NOTE:
321    1. Only consider communication cost here, since computation costs for redistribute
322       are quite trival (i.e. we only need to narrow or simple division)
323    2. Only consider redistribute cost on same mesh, cross mesh communication cost is
324       not quite needed for operator strategy estimation/selection.
325    """
326    if current_spec.mesh != target_spec.mesh:
327        # make infinite cost if meshes are not same
328        # TODO: see if we want to support this once there's cross mesh communication
329        return float("inf")
330
331    if current_spec.is_replicated():
332        # short-cut:
333        # comm cost is 0 if current spec is already full replication
334        return 0.0
335
336    mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh)
337    cost = 0.0
338    comm_bytes_gb = (
339        spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
340    )
341    # Transformation that considered for redistribute cost:
342    # 1. allgather 2. alltoall
343    # 3. allreduce 4. reduce_scatter
344    for i, (current, target) in enumerate(
345        zip(current_spec.placements, target_spec.placements)
346    ):
347        if current == target:
348            continue
349
350        num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i]
351        if current.is_shard() and target.is_replicate():
352            # allgather gives larger comm bytes
353            comm_bytes_gb *= num_devices_on_mesh_dim
354            # add up allgather comm cost
355            cost += allgather_cost(comm_bytes_gb, mesh_topo, i)
356        elif current.is_shard() and target.is_shard():
357            # should be alltoall comm, since we haven't implement it yet, add penalty
358            # to favor allgather instead
359            cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + 1.0
360        elif current.is_partial() and target.is_replicate():
361            # add up allreduce comm cost
362            cost += allreduce_cost(comm_bytes_gb, mesh_topo, i)
363        elif current.is_partial() and target.is_shard():
364            # add up reduce_scatter comm cost
365            cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i)
366            # after reduce_scatter the comm bytes for further collectives halved.
367            comm_bytes_gb /= num_devices_on_mesh_dim
368        elif current.is_shard() and target.is_partial():
369            # ban shard -> partial as it does not make sense to perform
370            # this redistribute
371            return float("inf")
372
373    return cost
374