xref: /aosp_15_r20/external/pytorch/torch/distributed/device_mesh.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3import logging
4import math
5import threading
6from functools import reduce
7from itertools import chain
8from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
9
10import torch
11from torch.distributed import is_available
12from torch.utils._typing_utils import not_none
13
14
15__all__ = ["init_device_mesh", "DeviceMesh"]
16
17
18if not is_available():
19    import sys
20
21    # We need to create the stubs when distributed is not available.
22    # Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```),
23    # since it would try to import ``torch.distributed.device_mesh`` or
24    # ``torch.distributed.init_device_mesh`` but cannot find them.
25
26    class _DeviceMeshStub:
27        pass
28
29    def _init_device_mesh_stub():
30        pass
31
32    sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub  # type: ignore[attr-defined]
33    sys.modules[
34        "torch.distributed.device_mesh"
35    ].init_device_mesh = _init_device_mesh_stub  # type: ignore[attr-defined]
36
37
38else:
39    from torch.distributed.distributed_c10d import (
40        _find_pg_by_ranks_and_tag,
41        _get_default_group,
42        _get_group_tag,
43        get_backend,
44        get_process_group_ranks,
45        get_rank,
46        get_world_size,
47        init_process_group,
48        is_initialized,
49        new_group,
50        ProcessGroup,
51    )
52
53    logger = logging.getLogger(__name__)
54
55    # only import numpy typing when type checking
56    if TYPE_CHECKING:
57        try:
58            from numpy.typing import ArrayLike
59        except ImportError:
60            logger.warning(
61                "DeviceMesh requires numpy >= 1.21 to be installed for type checking"
62            )
63
64    class _MeshEnv(threading.local):
65        def __init__(self) -> None:
66            self.mesh_stack: List[DeviceMesh] = []
67            self.child_to_root_mapping: Dict[DeviceMesh, DeviceMesh] = {}
68            self.mesh_dim_group_options: Dict[
69                int, Tuple[str, Optional[ProcessGroup.Options]]
70            ] = {}
71            self.root_to_flatten_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {}
72            # Record flatten mesh name to its mesh dim index in root mesh.
73            self.flatten_name_to_root_dims: Dict[
74                DeviceMesh, Dict[str, Tuple[int, ...]]
75            ] = {}
76
77        def get_current_mesh(self) -> "DeviceMesh":
78            if len(self.mesh_stack) == 0:
79                raise RuntimeError("No device mesh is currently active!")
80            return self.mesh_stack[-1]
81
82        def create_sub_mesh(
83            self,
84            device_mesh: "DeviceMesh",
85            submesh_dim_names: Tuple[str, ...],
86            submesh_dims: List[Tuple[int, ...]],
87        ) -> "DeviceMesh":
88            # Get the submesh dim size from the submesh_dims.
89            # For example, if we have a 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp", "cp", "tp") and we want
90            # to slice out mesh["dp_cp"], then submesh_dims = [(0, 1), (2,)] and submesh_dim_size = [2 * 2, 2] = [4, 2].
91            # If we want to slice out mesh["dp", "cp"], then submesh_dims = [(0,), (1,)] and submesh_dim_size = [2, 2].
92            slice_dim_size = [
93                reduce(
94                    lambda x, y: device_mesh.mesh.size(x) * device_mesh.mesh.size(y),
95                    mesh_dim,
96                )
97                if len(mesh_dim) > 1
98                else device_mesh.mesh.size(mesh_dim[0])
99                for mesh_dim in submesh_dims
100            ]
101
102            mesh_tensor = device_mesh.mesh
103            # slice_dim_idx could be differnt from submesh_dims, as we may need to flatten out some dims.
104            slice_dim_idx = []
105            slice_dim_group_info = []
106            # keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the
107            # flattened mesh tensor.
108            num_dims_flatten = 0
109            for mesh_dim_indices, mesh_dim_name in zip(submesh_dims, submesh_dim_names):
110                # Currently, this only allows slicing out a contiguous flattened dim.
111                # TODO: we need to handle reconstructing a non-contiguous flattened dim.
112                if len(mesh_dim_indices) > 1:
113                    # We need to move the start_dim and end_dim to the left if some dims are already flattened.
114                    mesh_tensor = mesh_tensor.flatten(
115                        start_dim=mesh_dim_indices[0] - num_dims_flatten,
116                        end_dim=mesh_dim_indices[-1] - num_dims_flatten,
117                    )
118                    # If some dims are already flattened, we need to adjust the slice_dim_idx accordingly.
119                    # For example, if the submesh_dims = [(0, 1), (2,), (3, 4)] with 0-1 flattened and 3-4 flattened,
120                    # then the final slice_dim_idx should be [0, 1, 2].
121                    slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
122                    num_dims_flatten += len(mesh_dim_indices) - 1
123                    slice_dim_group_info.append(
124                        self.root_to_flatten_mapping[device_mesh][
125                            mesh_dim_name
126                        ]._dim_group_infos[0]
127                    )
128                else:
129                    slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
130                    slice_dim_group_info.append(
131                        device_mesh._dim_group_infos[mesh_dim_indices[0]]
132                    )
133
134            # mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now.
135            mesh_dims_remained_idx = list(range(mesh_tensor.ndim))
136            for idx in slice_dim_idx:
137                mesh_dims_remained_idx.remove(idx)
138
139            # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *slice_dim_idx]
140            # This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with
141            # the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank.
142            pg_ranks_by_dim = mesh_tensor.permute(
143                *mesh_dims_remained_idx, *slice_dim_idx
144            ).reshape(-1, *slice_dim_size)
145
146            cur_rank = device_mesh.get_rank()
147            for mesh_nd in pg_ranks_by_dim:
148                submesh = DeviceMesh(
149                    device_mesh.device_type,
150                    mesh_nd,
151                    mesh_dim_names=submesh_dim_names,
152                    _init_backend=False,
153                )
154                if cur_rank in mesh_nd:
155                    res_submesh = submesh
156
157            res_submesh._dim_group_infos = slice_dim_group_info  # type: ignore[possibly-undefined]
158            self.child_to_root_mapping[res_submesh] = device_mesh
159
160            return res_submesh
161
162        def create_flatten_mesh(
163            self, device_mesh: "DeviceMesh", mesh_dim_name: Optional[str] = None
164        ) -> "DeviceMesh":
165            root_mesh = _mesh_resources.get_root_mesh(device_mesh)
166
167            flatten_dims_in_root = [
168                not_none(root_mesh.mesh_dim_names).index(flattened_mesh_dim_name)
169                for flattened_mesh_dim_name in not_none(device_mesh.mesh_dim_names)
170            ]
171
172            if not mesh_dim_name:
173                mesh_dim_name = "_".join(
174                    [
175                        not_none(root_mesh.mesh_dim_names)[dim]
176                        for dim in flatten_dims_in_root
177                    ]
178                )
179
180            # Check whether the mesh_dim_name for flattened mesh is valid.
181            self.flatten_name_to_root_dims.setdefault(root_mesh, {})
182            invalid_dim_names = chain(
183                *list(not_none(root_mesh.mesh_dim_names)),
184                *self.flatten_name_to_root_dims[root_mesh].keys(),
185            )
186            if mesh_dim_name in invalid_dim_names:
187                raise RuntimeError(
188                    f"{mesh_dim_name} already exists for submesh of the {root_mesh}. ",
189                    f"The mesh_dim_names of submesh and flattened mesh are {invalid_dim_names}. "
190                    f"Please specify another valid mesh_dim_name.",
191                )
192
193            # Quick return if the flatten mesh has been created before.
194            # TODO: If we decide to restrict flatten initialization once, we should remove
195            # this check and throw an error if the flatten mesh is already created before.
196            if (
197                root_mesh in self.root_to_flatten_mapping
198                and mesh_dim_name in self.root_to_flatten_mapping[root_mesh]
199            ):
200                return self.root_to_flatten_mapping[root_mesh][mesh_dim_name]
201
202            flattened_mesh_dim_size = math.prod(device_mesh.mesh.size())
203
204            remained_dims_in_root = list(range(root_mesh.mesh.ndim))
205            for flatten_dim_in_root in flatten_dims_in_root:
206                remained_dims_in_root.remove(flatten_dim_in_root)
207
208            pg_ranks_by_dim = root_mesh.mesh.permute(
209                *remained_dims_in_root, *flatten_dims_in_root
210            ).reshape(-1, flattened_mesh_dim_size)
211
212            cur_rank = root_mesh.get_rank()
213            for mesh_nd in pg_ranks_by_dim:
214                # need to init backend here since the flattened pg doesn't exist in root mesh.
215                flattened_mesh = DeviceMesh(
216                    root_mesh.device_type,
217                    mesh_nd,
218                    mesh_dim_names=(mesh_dim_name,),
219                )
220                if cur_rank in mesh_nd:
221                    res_flattened_mesh = flattened_mesh
222            self.child_to_root_mapping[res_flattened_mesh] = root_mesh  # type: ignore[possibly-undefined]
223            self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = res_flattened_mesh  # type: ignore[possibly-undefined]
224            self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(flatten_dims_in_root)  # type: ignore[possibly-undefined]
225
226            return res_flattened_mesh
227
228        def get_root_mesh(self, device_mesh: "DeviceMesh") -> "DeviceMesh":
229            # If a mesh could not be found in the child_to_root_mapping, it is a root mesh itself.
230            # A root mesh is not created through slicing.
231            # We considers the root mesh of a root mesh is itself.
232            root_mesh = self.child_to_root_mapping.get(device_mesh, None)
233            return device_mesh if not root_mesh else root_mesh
234
235        def get_root_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
236            """
237            Returns the index of the mesh dim in the root mesh.
238            The device_mesh passed in needs to be sliced out from the root mesh
239            or submesh of the root mesh.
240            """
241            root_mesh = self.get_root_mesh(device_mesh)
242            child_mesh_dim_names = device_mesh.mesh_dim_names
243            if root_mesh and child_mesh_dim_names:
244                assert (
245                    len(child_mesh_dim_names) == 1
246                ), "The submesh can only be a 1D mesh."
247                child_mesh_dim_name = child_mesh_dim_names[0]
248                return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name)
249            return None
250
251        @staticmethod
252        def num_devices_per_host(device_type: str) -> int:
253            return _get_device_handle(device_type).device_count()
254
255        @staticmethod
256        def num_hosts(device_type: str) -> int:
257            # ProcessGroup can't tell us this info so we have to infer it, assume
258            # homogeneous hardware for now
259            return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
260
261        def get_mesh_dim_by_name(
262            self, device_mesh: "DeviceMesh", mesh_dim_name: str
263        ) -> int:
264            if (
265                device_mesh.mesh_dim_names is None
266                or len(device_mesh.mesh_dim_names) == 0
267            ):
268                raise KeyError(
269                    "No `mesh_dim_names` found.",
270                )
271            if mesh_dim_name not in device_mesh.mesh_dim_names:
272                raise KeyError(
273                    f"Mesh dimension '{mesh_dim_name}' does not exist.",
274                    f"Available mesh dimensions are: mesh_dim_names={device_mesh.mesh_dim_names}",
275                )
276            return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name))
277
278        def _set_mesh_dim_group_options(
279            self,
280            dim: int,
281            backend: str,
282            pg_options: Optional[ProcessGroup.Options] = None,
283        ) -> None:
284            self.mesh_dim_group_options[dim] = (backend, pg_options)
285
286        def _get_slice_mesh_dims(
287            self, device_mesh, mesh_dim_names
288        ) -> List[Tuple[int, ...]]:
289            """
290            Validate whether the mesh_dim_names is valid for slicing the given device_mesh.
291            If valid, return dim indexes of the slice mesh in the device mesh.
292            """
293            if device_mesh != self.get_root_mesh(device_mesh):
294                raise RuntimeError("Cannot create a submesh from a submesh.")
295
296            # The slice mesh_dim_names should consist either the device_mesh's mesh_dim_names
297            # or its flattened mesh's mesh_dim_names.
298            self.flatten_name_to_root_dims.setdefault(device_mesh, {})
299            flatten_name_to_root_dims = self.flatten_name_to_root_dims[device_mesh]
300            valid_mesh_dim_names = [
301                *device_mesh.mesh_dim_names,
302                *flatten_name_to_root_dims,
303            ]
304
305            if not all(
306                mesh_dim_name in valid_mesh_dim_names
307                for mesh_dim_name in mesh_dim_names
308            ):
309                raise KeyError(
310                    f"Invalid mesh_dim_names {mesh_dim_names} specified. "
311                    f"Valid mesh_dim_names are {valid_mesh_dim_names}."
312                )
313
314            # Validate the order of the slice mesh dim indices.
315            # This needs to be in ascending order.
316            curr_idx = -1
317            slice_mesh_dims = []
318            for mesh_dim_name in mesh_dim_names:
319                if mesh_dim_name in flatten_name_to_root_dims:
320                    mesh_indices = flatten_name_to_root_dims[mesh_dim_name]
321                    # TODO: this doesn't allow non-contiguous slicing with flatten dim yet. next_idx
322                    # should be mesh_indices[0] once we support non-contiguous slicing with flatten dim.
323                    next_idx = mesh_indices[-1]
324                    slice_mesh_dims.append(mesh_indices)
325                else:
326                    next_idx = device_mesh.mesh_dim_names.index(mesh_dim_name)
327                    slice_mesh_dims.append((next_idx,))
328                if next_idx <= curr_idx:
329                    raise KeyError(
330                        f"Invalid mesh_dim_names {mesh_dim_names} specified. ",
331                        f"Found mesh dim indices to slice: {slice_mesh_dims}. ",
332                        "Mesh dim indices should be in ascending order.",
333                    )
334                curr_idx = next_idx
335
336            return slice_mesh_dims
337
338        def _get_all_submeshes(
339            self, device_mesh: "DeviceMesh", mesh_dim_name: str
340        ) -> List["DeviceMesh"]:
341            """
342            Return all the submeshes of a given mesh dimension of the device mesh.
343            """
344            mesh_dim = self.get_mesh_dim_by_name(device_mesh, mesh_dim_name)
345            pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
346                -1, device_mesh.mesh.size(mesh_dim)
347            )
348
349            cur_rank = device_mesh.get_rank()
350            res_submeshes = []
351            for mesh_1d in pg_ranks_by_dim:
352                submesh = DeviceMesh(
353                    device_mesh.device_type,
354                    mesh_1d,
355                    mesh_dim_names=(mesh_dim_name,),
356                    _init_backend=False,
357                )
358                submesh._dim_group_infos = (
359                    [device_mesh._dim_group_infos[mesh_dim]]
360                    if cur_rank in mesh_1d
361                    else []
362                )
363                res_submeshes.append(submesh)
364
365            return res_submeshes
366
367    _mesh_resources: _MeshEnv = _MeshEnv()
368
369    def _get_device_handle(device_type: str = "cuda"):
370        """
371        Get the module corresponding to the device_type which is cuda or cuda-like device.
372        For example, when the device_type is cuda, the module `torch.cuda` is returned.
373        Return None when there is no corresponding module for device_type, otherwise
374        return the corresponding module.
375        """
376        return getattr(torch, device_type, None)
377
378    class DeviceMesh:
379        """
380        DeviceMesh represents a mesh of devices, where layout of devices could be
381        represented as a n-d dimension array, and each value of the n-d dimensional
382        array is the global id of the default process group ranks.
383
384        DeviceMesh could be used to describe the layout of devices across the cluster,
385        and serves as a proxy for communication among the device lists within the cluster.
386
387        DeviceMesh can be used as a context manager.
388
389        .. note::
390            DeviceMesh follows SPMD programming model, which means the same PyTorch Python program
391            is running on all processes/ranks in the cluster. Therefore, users need to make sure the
392            `mesh` array (which describes the layout of devices) should be identical across all ranks.
393            Inconsistent `mesh` will lead to silent hang.
394
395        Args:
396            device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
397            mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout
398                of devices, where the IDs are global IDs of the default process group.
399
400        Returns:
401            DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
402
403        The following program runs on each process/rank in an SPMD manner. In this example, we have 2
404        hosts with 4 GPUs each.
405        A reduction over the first dimension of mesh will reduce across
406        columns (0, 4), .. and (3, 7), a reduction over the second dimension
407        of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).
408
409        Example::
410            >>> # xdoctest: +SKIP("no rank")
411            >>> from torch.distributed.device_mesh import DeviceMesh
412            >>>
413            >>> # Initialize device mesh as (2, 4) to represent the topology
414            >>> # of cross-host(dim 0), and within-host (dim 1).
415            >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
416        """
417
418        device_type: str
419        mesh: torch.Tensor
420        mesh_dim_names: Optional[Tuple[str, ...]]
421
422        def __init__(
423            self,
424            device_type: str,
425            mesh: Union[torch.Tensor, "ArrayLike"],
426            *,
427            mesh_dim_names: Optional[Tuple[str, ...]] = None,
428            _init_backend: bool = True,
429        ) -> None:
430            self.device_type = device_type
431            if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
432                raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
433            self.mesh = (
434                mesh.detach().to(dtype=torch.int)
435                if isinstance(mesh, torch.Tensor)
436                else torch.tensor(mesh, device="cpu", dtype=torch.int)
437            )
438            self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
439
440            # private field to pre-generate DeviceMesh's hash
441            self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
442            self._thread_id = None
443
444            # Skip process group initialization if xla device or init backend is False
445            # TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
446            if device_type != "xla":
447                # always try to create default (world) pg, even if it is not initialized
448                # already. The world pg is used for device mesh identity (rank) on each
449                # process (we need to know if the current global rank is in the mesh or not).
450                if _init_backend:
451                    self._get_or_create_default_group()
452                    self._init_process_groups()
453
454                if is_initialized() and get_backend() == "threaded":
455                    self._thread_id = threading.get_ident()
456
457                # calculate the coordinates of the current global rank on the mesh
458                rank_coords = (self.mesh == get_rank()).nonzero()
459                assert rank_coords.size(0) in (0, 1)
460                self._coordinate_on_dim: Optional[List[int]] = (
461                    rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
462                )
463
464        def _get_or_create_default_group(self):
465            default_initialized = is_initialized()
466            if not default_initialized:
467                init_process_group()
468
469            world_size = get_world_size()
470            if self.mesh.numel() > world_size:
471                raise RuntimeError(
472                    f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!"
473                )
474
475            device_handle = _get_device_handle(self.device_type)
476            # TODO: if user want to pass pg_options, offer a way to do it
477            if not default_initialized and device_handle:
478                # automatically set the current cuda/cuda-like device base on num of gpu devices available in each host
479                # NOTE: This device selection would only work for homogeneous hardware.
480                num_devices_per_host = device_handle.device_count()
481                if (
482                    world_size > num_devices_per_host
483                    and world_size % num_devices_per_host != 0
484                ):
485                    raise RuntimeError(
486                        f"DeviceMesh only support homogeneous hardware, but found "
487                        f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
488                    )
489                device_handle.set_device(get_rank() % num_devices_per_host)
490
491            return _get_default_group()
492
493        def _init_process_groups(self):
494            # tag/ranks/group_name associated with each mesh dimension, each
495            # mesh dimension should have one sub-group per rank
496            #
497            # TODO(yifu): remove tag and ranks once we fully migrate to native
498            # functional collectives. See details in:
499            # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
500            dim_group_infos: List[Tuple[str, List[int], str]] = []
501
502            if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
503                # Append the default pg to the first dim groups only if the default pg is compatible with `self.device_type`.
504                # Otherwise, create new pg.
505                default_group = _get_default_group()
506                ranks = list(range(get_world_size()))
507                dim_group = (
508                    new_group(backend="cpu:gloo,cuda:nccl", ranks=ranks)
509                    if torch.cuda.is_available()
510                    and get_backend(default_group) == "gloo"
511                    else default_group
512                )
513                dim_group_infos.append(
514                    (
515                        _get_group_tag(dim_group),
516                        ranks,
517                        dim_group.group_name,
518                    )
519                )
520            else:
521                # create sub pgs base on the mesh argument specified
522                for dim in range(self.mesh.ndim):
523                    # swap the current dim to the last dim
524                    # then reshape to flatten out other dims
525                    pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
526                        -1, self.mesh.size(dim)
527                    )
528                    # multi-dim mesh, create subgroups by looping over the pg_ranks
529                    # for each dim and append the groups
530                    for dim_mesh in pg_ranks_by_dim:
531                        subgroup_ranks = dim_mesh.tolist()
532
533                        # Respect dim group options specified via _MeshEnv.set_dim_group_options().
534                        # Inherit from the parent group if no options are specified for the group.
535                        if dim in _mesh_resources.mesh_dim_group_options:
536                            (
537                                backend,
538                                pg_options,
539                            ) = _mesh_resources.mesh_dim_group_options[dim]
540                        else:
541                            backend, pg_options = None, None
542
543                        # We temporarily revert the re-use subgroup, since it breaks two internal tests.
544                        # Temporarily reverting to resolve test timeout while root-causing.
545                        # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
546                        dim_group = new_group(
547                            ranks=subgroup_ranks,
548                            backend=backend,
549                            pg_options=pg_options,
550                        )
551
552                        # only add to dim_groups if the current rank in the subgroup
553                        if self.get_rank() in subgroup_ranks:
554                            if len(dim_group_infos) > dim:
555                                raise RuntimeError(
556                                    f"Each device mesh dimension should get only one process group, but got {self.get_rank()} "
557                                    f"in {subgroup_ranks}!"
558                                )
559                            dim_group_infos.append(
560                                (
561                                    _get_group_tag(not_none(dim_group)),
562                                    subgroup_ranks,
563                                    dim_group.group_name,
564                                )
565                            )
566            self._dim_group_infos = dim_group_infos
567
568        def __enter__(self) -> "DeviceMesh":
569            # set this mesh as the current mesh in mesh env
570            _mesh_resources.mesh_stack.append(self)
571            return self
572
573        # pyre-fixme[2]: Parameter must be annotated.
574        def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
575            # pop this mesh from mesh env
576            _mesh_resources.mesh_stack.pop()
577
578        def __repr__(self) -> str:
579            device_mesh_repr = (
580                f"DeviceMesh('{self.device_type}', {self.mesh.tolist()})"
581                if not self.mesh_dim_names
582                else f"DeviceMesh('{self.device_type}', {self.mesh.tolist()}, mesh_dim_names={self.mesh_dim_names})"
583            )
584            return device_mesh_repr
585
586        def __hash__(self):
587            # lazily compute hash
588            self._hash = getattr(self, "_hash", None)
589            if not self._hash:
590                self._hash = hash(
591                    (
592                        self._flatten_mesh_list,
593                        self.mesh.shape,
594                        self.device_type,
595                        self.mesh_dim_names,
596                        self._thread_id,
597                    )
598                )
599            return self._hash
600
601        def __eq__(self, other: object) -> bool:
602            if not isinstance(other, DeviceMesh):
603                return False
604            if id(self) == id(other):
605                return True
606            else:
607                return (
608                    self._flatten_mesh_list == other._flatten_mesh_list
609                    and self.mesh.shape == other.mesh.shape
610                    and self.device_type == other.device_type
611                    and self.mesh_dim_names == other.mesh_dim_names
612                    and self._thread_id == other._thread_id
613                )
614
615        def __getitem__(
616            self, mesh_dim_names: Union[str, Tuple[str, ...]]
617        ) -> "DeviceMesh":
618            """
619            Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh.
620            The submesh created consists of the dimensions and the communicators indicated by
621            ``mesh_dim_names``
622
623            Args:
624                mesh_dim_names (Union[str, Tuple[str]]): the name or the tuple of names of the
625                mesh dimension of the DeviceMesh to create the submesh for.
626            Returns:
627                A :class:`DeviceMesh` object
628
629            The following program runs on each process/rank in an SPMD manner in a world size of 8.
630            In the first example:
631                Calling mesh_2d["tp"] on rank 0, 1, 2, 3 returns a 1D submesh of DeviceMesh:([0, 1, 2, 3]).
632                Calling mesh_2d["tp"] on rank 4, 5, 6, 7 returns a 1D submesh of  DeviceMesh:([4, 5, 6, 7]).
633                Calling mesh_2d["dp"] on rank 0, 4 returns a 1D submesh of  DeviceMesh:([0, 4]).
634                Calling mesh_2d["dp"] on rank 1, 5 returns a 1D submesh of  DeviceMesh:([1, 5]).
635                Calling mesh_2d["dp"] on rank 2, 6 returns a 1D submesh of  DeviceMesh:([2, 6]).
636                Calling mesh_2d["dp"] on rank 3, 7 returns a 1D submesh of  DeviceMesh:([3, 7]).
637
638            In the second example:
639                Calling mesh_3d["dp", "cp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 1], [4, 5]]).
640                Calling mesh_3d["dp", "cp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 3], [6, 7]]).
641                Calling mesh_3d["cp", "dp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 4], [1, 5]]).
642                Calling mesh_3d["cp", "dp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 6], [3, 7]]).
643
644            Example::
645                >>> # xdoctest: +SKIP("no rank")
646                >>> from torch.distributed.device_mesh import DeviceMesh
647                >>>
648                >>> # Initialize a 2D device mesh as (2, 4) to represent the topology
649                >>> # of cross-host(dim 0), and within-host (dim 1).
650                >>> mesh_2d = init_device_mesh(device_type="cuda", (2,4), mesh_dim_names=("dp", "tp"))
651                >>> tp_mesh = mesh_2d["tp"]
652                >>> dp_mesh = mesh_2d["dp"]
653                >>>
654                >>> # Initialize a 3D mesh.
655                >>> mesh_3d = init_device_mesh(device_type="cuda", (2,2,2), mesh_dim_names=("dp", "pp", "cp"))
656                >>> # The order of the mesh_dim_names provided deteremines the order of dimensions in the submesh.
657                >>> dp_cp_mesh = mesh_3d["dp", "cp"]
658                >>> cp_dp_mesh = mesh_3d["cp", "dp"]
659            """
660            if not self.mesh_dim_names:
661                raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!")
662
663            mesh_dim_names = (
664                (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names
665            )
666
667            if mesh_dim_names == self.mesh_dim_names:
668                return self
669            else:
670                slice_mesh_dims = _mesh_resources._get_slice_mesh_dims(
671                    self, mesh_dim_names
672                )
673                submesh = _mesh_resources.create_sub_mesh(
674                    self, mesh_dim_names, slice_mesh_dims
675                )
676                return submesh
677
678        def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup:
679            """
680            Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the
681            DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh.
682
683            Args:
684                mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
685                of the mesh dimension. Default is None.
686
687            Returns:
688                A :class:`ProcessGroup` object.
689            """
690            if not hasattr(self, "_dim_group_infos"):
691                raise RuntimeError("DeviceMesh process groups not initialized!")
692
693            if self.mesh.ndim > 1 and mesh_dim is None:
694                raise RuntimeError(
695                    f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
696                    "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
697                    "If you want to get the list of all the ProcessGroups in the DeviceMesh,"
698                    "please use `get_all_groups()` instead.",
699                )
700
701            # Quick return if the current device_mesh is a 1D mesh.
702            if self.mesh.ndim == 1 and mesh_dim is None:
703                return not_none(
704                    _find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2])  # type: ignore[index]
705                )
706
707            root_mesh = _mesh_resources.get_root_mesh(self)
708            root_to_flatten_mapping = _mesh_resources.root_to_flatten_mapping.get(
709                root_mesh, None
710            )
711            if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys():
712                dim_group_infos = root_to_flatten_mapping[mesh_dim]._dim_group_infos[0][:2]  # type: ignore[index]
713                return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos))
714            else:
715                mesh_dim = (
716                    _mesh_resources.get_mesh_dim_by_name(self, mesh_dim)
717                    if isinstance(mesh_dim, str)
718                    else mesh_dim
719                )
720                return not_none(
721                    _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2])  # type: ignore[index]
722                )
723
724        def get_all_groups(self) -> List[ProcessGroup]:
725            """
726            Returns a list of ProcessGroups for all mesh dimensions.
727
728            Returns:
729                A list of :class:`ProcessGroup` object.
730            """
731            return [self.get_group(i) for i in range(self.mesh.ndim)]
732
733        @staticmethod
734        def from_group(
735            group: Union[ProcessGroup, List[ProcessGroup]],
736            device_type: str,
737            mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
738            *,
739            mesh_dim_names: Optional[Tuple[str, ...]] = None,
740        ) -> "DeviceMesh":
741            """
742            Constructs a :class:`DeviceMesh` with ``device_type`` from an
743            existing :class:`ProcessGroup`.
744
745            The constructed device mesh has number of dimensions equal to the
746            number of groups passed. If more than one group is passed, then the
747            ``mesh`` argument is required.
748            """
749            if isinstance(group, ProcessGroup):
750                group_ranks = get_process_group_ranks(group)
751                if (
752                    isinstance(mesh, torch.Tensor) and mesh.tolist() != group_ranks
753                ) or (mesh is not None and mesh != group_ranks):
754                    raise ValueError(
755                        f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}"
756                    )
757                mesh = torch.tensor(group_ranks, device="cpu", dtype=torch.int)
758                device_mesh = DeviceMesh(
759                    device_type,
760                    mesh,
761                    mesh_dim_names=mesh_dim_names,
762                    _init_backend=False,
763                )
764                device_mesh._dim_group_infos = [
765                    (_get_group_tag(group), group_ranks, group.group_name)
766                ]
767                return device_mesh
768            groups = list(group)
769            if len(groups) == 0:
770                raise ValueError("Expects at least one ProcessGroup to be passed")
771            if mesh is None:
772                raise ValueError("Must pass mesh if passing multiple ProcessGroups")
773            mesh = (
774                mesh.detach().to(dtype=torch.int, device="cpu")
775                if isinstance(mesh, torch.Tensor)
776                else torch.tensor(mesh, device="cpu", dtype=torch.int)
777            )
778            if mesh.ndim != len(groups):
779                raise ValueError(
780                    "Expects mesh with ndim equal to number of ProcessGroups but got "
781                    f"mesh {mesh.tolist()} and {len(groups)} ProcessGroups"
782                )
783            device_mesh = DeviceMesh(
784                device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False
785            )
786            device_mesh._dim_group_infos = [
787                (
788                    _get_group_tag(group),
789                    get_process_group_ranks(group),
790                    group.group_name,
791                )
792                for group in groups
793            ]
794            return device_mesh
795
796        def size(self, mesh_dim: Optional[int] = None) -> int:
797            return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim)
798
799        @property
800        def ndim(self) -> int:
801            return self.mesh.ndim
802
803        @property
804        def shape(self) -> Tuple[int, ...]:
805            return tuple(self.mesh.shape)
806
807        def get_rank(self) -> int:
808            """
809            Returns the current global rank.
810            """
811            return get_rank()
812
813        def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
814            """
815            Returns the local rank of the given mesh_dim of the DeviceMesh.
816
817            Args:
818                mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
819                of the mesh dimension. Default is None.
820
821            Returns:
822                An integer denotes the local rank.
823
824            The following program runs on each process/rank in an SPMD manner. In this example, we have 2
825            hosts with 4 GPUs each.
826            Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0.
827            Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1.
828            Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0.
829            Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1.
830            Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2.
831            Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3.
832
833            Example::
834                >>> # xdoctest: +SKIP("no rank")
835                >>> from torch.distributed.device_mesh import DeviceMesh
836                >>>
837                >>> # Initialize device mesh as (2, 4) to represent the topology
838                >>> # of cross-host(dim 0), and within-host (dim 1).
839                >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
840            """
841            if self.ndim > 1 and mesh_dim is None:
842                raise RuntimeError(
843                    f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
844                    "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
845                )
846            elif mesh_dim is None:
847                mesh_dim = 0
848
849            mesh_dim_group = not_none(self.get_group(mesh_dim))
850            assert isinstance(
851                mesh_dim_group, ProcessGroup
852            ), "We expect ProcessGroup before calling `get_rank`!"
853            return not_none(get_rank(mesh_dim_group))
854
855        def get_coordinate(self) -> Optional[List[int]]:
856            """
857            Return the relative indices of this rank relative to all
858            dimensions of the mesh. If this rank is not part of the mesh, return None.
859            """
860            return self._coordinate_on_dim if self._coordinate_on_dim else None
861
862        def _flatten(self, mesh_dim_name: Optional[str] = None) -> "DeviceMesh":
863            """
864            Returns a 1D DeviceMesh by flattening the current DeviceMesh.
865
866            If no mesh_dim_name is provided, the default is a string concatentaing the mesh_dim_names of the
867            given submesh with each mesh_dim_name separated by "_". For example, if we have a 3D mesh
868            DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")), calling
869            mesh_3d["dp", "cp"]._flatten() will create a 1D submesh DeviceMesh([0, 1, 2, 3], mesh_dim_names=("dp_cp",))
870            on rank 0, 1, 2, 3 and a 1D submesh DeviceMesh([4, 5, 6, 7], mesh_dim_names=("dp_cp",)) on rank 4, 5, 6, 7.
871
872            After the flattened dimension is created, to access the flattened dimesnion in mesh_3d, one can use the
873            existing slicing method to obtain the flattened mesh through calling mesh_3d["dp_cp"].
874            """
875            if not self.mesh_dim_names:
876                raise RuntimeError(
877                    "Cannot flatten a DeviceMesh without mesh_dim_names!"
878                )
879
880            return _mesh_resources.create_flatten_mesh(self, mesh_dim_name)
881
882    def init_device_mesh(
883        device_type: str,
884        mesh_shape: Tuple[int, ...],
885        *,
886        mesh_dim_names: Optional[Tuple[str, ...]] = None,
887    ) -> DeviceMesh:
888        """
889        Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
890
891        This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`.
892        If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`.
893
894        .. note::
895            `init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program
896            runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array
897            describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging.
898
899        .. note::
900            If no process group is found, init_device_mesh will initialize distributed process group/groups
901            required for distributed communications behind the scene.
902
903        Args:
904            device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
905                Passing in a device type with a GPU index, such as "cuda:0", is not allowed.
906            mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array
907                describing the layout of devices.
908            mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension
909                of the multi-dimensional array describing the layout of devices. Its length must match the length
910                of `mesh_shape`. Each string in `mesh_dim_names` must be unique.
911
912        Returns:
913            DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
914
915        Example::
916            >>> # xdoctest: +SKIP("no rank")
917            >>> from torch.distributed.device_mesh import init_device_mesh
918            >>>
919            >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
920            >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
921
922        """
923        if mesh_dim_names is not None:
924            if len(set(mesh_dim_names)) != len(mesh_dim_names):
925                raise RuntimeError(
926                    "Each mesh_dim_name must be unique.",
927                    f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",
928                )
929
930            if len(mesh_shape) != len(mesh_dim_names):
931                raise RuntimeError(
932                    "mesh_shape and mesh_dim_names should have same length!",
933                    f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
934                )
935
936        # assume valid device types are all letters
937        if device_type and not device_type.isalpha():
938            raise RuntimeError(
939                f"Device type with GPU index is not supported but got {device_type}. ",
940                "If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.",
941            )
942
943        # Always initialize the mesh's tensor on CPU, regardless of what the
944        # external device type has been set to be (e.g. meta)
945        with torch.device("cpu"):
946            mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape)
947        device_mesh = DeviceMesh(
948            device_type=device_type,
949            mesh=mesh,
950            mesh_dim_names=mesh_dim_names,
951        )
952
953        return device_mesh
954