1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3import logging 4import math 5import threading 6from functools import reduce 7from itertools import chain 8from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union 9 10import torch 11from torch.distributed import is_available 12from torch.utils._typing_utils import not_none 13 14 15__all__ = ["init_device_mesh", "DeviceMesh"] 16 17 18if not is_available(): 19 import sys 20 21 # We need to create the stubs when distributed is not available. 22 # Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```), 23 # since it would try to import ``torch.distributed.device_mesh`` or 24 # ``torch.distributed.init_device_mesh`` but cannot find them. 25 26 class _DeviceMeshStub: 27 pass 28 29 def _init_device_mesh_stub(): 30 pass 31 32 sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined] 33 sys.modules[ 34 "torch.distributed.device_mesh" 35 ].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined] 36 37 38else: 39 from torch.distributed.distributed_c10d import ( 40 _find_pg_by_ranks_and_tag, 41 _get_default_group, 42 _get_group_tag, 43 get_backend, 44 get_process_group_ranks, 45 get_rank, 46 get_world_size, 47 init_process_group, 48 is_initialized, 49 new_group, 50 ProcessGroup, 51 ) 52 53 logger = logging.getLogger(__name__) 54 55 # only import numpy typing when type checking 56 if TYPE_CHECKING: 57 try: 58 from numpy.typing import ArrayLike 59 except ImportError: 60 logger.warning( 61 "DeviceMesh requires numpy >= 1.21 to be installed for type checking" 62 ) 63 64 class _MeshEnv(threading.local): 65 def __init__(self) -> None: 66 self.mesh_stack: List[DeviceMesh] = [] 67 self.child_to_root_mapping: Dict[DeviceMesh, DeviceMesh] = {} 68 self.mesh_dim_group_options: Dict[ 69 int, Tuple[str, Optional[ProcessGroup.Options]] 70 ] = {} 71 self.root_to_flatten_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {} 72 # Record flatten mesh name to its mesh dim index in root mesh. 73 self.flatten_name_to_root_dims: Dict[ 74 DeviceMesh, Dict[str, Tuple[int, ...]] 75 ] = {} 76 77 def get_current_mesh(self) -> "DeviceMesh": 78 if len(self.mesh_stack) == 0: 79 raise RuntimeError("No device mesh is currently active!") 80 return self.mesh_stack[-1] 81 82 def create_sub_mesh( 83 self, 84 device_mesh: "DeviceMesh", 85 submesh_dim_names: Tuple[str, ...], 86 submesh_dims: List[Tuple[int, ...]], 87 ) -> "DeviceMesh": 88 # Get the submesh dim size from the submesh_dims. 89 # For example, if we have a 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp", "cp", "tp") and we want 90 # to slice out mesh["dp_cp"], then submesh_dims = [(0, 1), (2,)] and submesh_dim_size = [2 * 2, 2] = [4, 2]. 91 # If we want to slice out mesh["dp", "cp"], then submesh_dims = [(0,), (1,)] and submesh_dim_size = [2, 2]. 92 slice_dim_size = [ 93 reduce( 94 lambda x, y: device_mesh.mesh.size(x) * device_mesh.mesh.size(y), 95 mesh_dim, 96 ) 97 if len(mesh_dim) > 1 98 else device_mesh.mesh.size(mesh_dim[0]) 99 for mesh_dim in submesh_dims 100 ] 101 102 mesh_tensor = device_mesh.mesh 103 # slice_dim_idx could be differnt from submesh_dims, as we may need to flatten out some dims. 104 slice_dim_idx = [] 105 slice_dim_group_info = [] 106 # keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the 107 # flattened mesh tensor. 108 num_dims_flatten = 0 109 for mesh_dim_indices, mesh_dim_name in zip(submesh_dims, submesh_dim_names): 110 # Currently, this only allows slicing out a contiguous flattened dim. 111 # TODO: we need to handle reconstructing a non-contiguous flattened dim. 112 if len(mesh_dim_indices) > 1: 113 # We need to move the start_dim and end_dim to the left if some dims are already flattened. 114 mesh_tensor = mesh_tensor.flatten( 115 start_dim=mesh_dim_indices[0] - num_dims_flatten, 116 end_dim=mesh_dim_indices[-1] - num_dims_flatten, 117 ) 118 # If some dims are already flattened, we need to adjust the slice_dim_idx accordingly. 119 # For example, if the submesh_dims = [(0, 1), (2,), (3, 4)] with 0-1 flattened and 3-4 flattened, 120 # then the final slice_dim_idx should be [0, 1, 2]. 121 slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten) 122 num_dims_flatten += len(mesh_dim_indices) - 1 123 slice_dim_group_info.append( 124 self.root_to_flatten_mapping[device_mesh][ 125 mesh_dim_name 126 ]._dim_group_infos[0] 127 ) 128 else: 129 slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten) 130 slice_dim_group_info.append( 131 device_mesh._dim_group_infos[mesh_dim_indices[0]] 132 ) 133 134 # mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now. 135 mesh_dims_remained_idx = list(range(mesh_tensor.ndim)) 136 for idx in slice_dim_idx: 137 mesh_dims_remained_idx.remove(idx) 138 139 # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *slice_dim_idx] 140 # This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with 141 # the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank. 142 pg_ranks_by_dim = mesh_tensor.permute( 143 *mesh_dims_remained_idx, *slice_dim_idx 144 ).reshape(-1, *slice_dim_size) 145 146 cur_rank = device_mesh.get_rank() 147 for mesh_nd in pg_ranks_by_dim: 148 submesh = DeviceMesh( 149 device_mesh.device_type, 150 mesh_nd, 151 mesh_dim_names=submesh_dim_names, 152 _init_backend=False, 153 ) 154 if cur_rank in mesh_nd: 155 res_submesh = submesh 156 157 res_submesh._dim_group_infos = slice_dim_group_info # type: ignore[possibly-undefined] 158 self.child_to_root_mapping[res_submesh] = device_mesh 159 160 return res_submesh 161 162 def create_flatten_mesh( 163 self, device_mesh: "DeviceMesh", mesh_dim_name: Optional[str] = None 164 ) -> "DeviceMesh": 165 root_mesh = _mesh_resources.get_root_mesh(device_mesh) 166 167 flatten_dims_in_root = [ 168 not_none(root_mesh.mesh_dim_names).index(flattened_mesh_dim_name) 169 for flattened_mesh_dim_name in not_none(device_mesh.mesh_dim_names) 170 ] 171 172 if not mesh_dim_name: 173 mesh_dim_name = "_".join( 174 [ 175 not_none(root_mesh.mesh_dim_names)[dim] 176 for dim in flatten_dims_in_root 177 ] 178 ) 179 180 # Check whether the mesh_dim_name for flattened mesh is valid. 181 self.flatten_name_to_root_dims.setdefault(root_mesh, {}) 182 invalid_dim_names = chain( 183 *list(not_none(root_mesh.mesh_dim_names)), 184 *self.flatten_name_to_root_dims[root_mesh].keys(), 185 ) 186 if mesh_dim_name in invalid_dim_names: 187 raise RuntimeError( 188 f"{mesh_dim_name} already exists for submesh of the {root_mesh}. ", 189 f"The mesh_dim_names of submesh and flattened mesh are {invalid_dim_names}. " 190 f"Please specify another valid mesh_dim_name.", 191 ) 192 193 # Quick return if the flatten mesh has been created before. 194 # TODO: If we decide to restrict flatten initialization once, we should remove 195 # this check and throw an error if the flatten mesh is already created before. 196 if ( 197 root_mesh in self.root_to_flatten_mapping 198 and mesh_dim_name in self.root_to_flatten_mapping[root_mesh] 199 ): 200 return self.root_to_flatten_mapping[root_mesh][mesh_dim_name] 201 202 flattened_mesh_dim_size = math.prod(device_mesh.mesh.size()) 203 204 remained_dims_in_root = list(range(root_mesh.mesh.ndim)) 205 for flatten_dim_in_root in flatten_dims_in_root: 206 remained_dims_in_root.remove(flatten_dim_in_root) 207 208 pg_ranks_by_dim = root_mesh.mesh.permute( 209 *remained_dims_in_root, *flatten_dims_in_root 210 ).reshape(-1, flattened_mesh_dim_size) 211 212 cur_rank = root_mesh.get_rank() 213 for mesh_nd in pg_ranks_by_dim: 214 # need to init backend here since the flattened pg doesn't exist in root mesh. 215 flattened_mesh = DeviceMesh( 216 root_mesh.device_type, 217 mesh_nd, 218 mesh_dim_names=(mesh_dim_name,), 219 ) 220 if cur_rank in mesh_nd: 221 res_flattened_mesh = flattened_mesh 222 self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined] 223 self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = res_flattened_mesh # type: ignore[possibly-undefined] 224 self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(flatten_dims_in_root) # type: ignore[possibly-undefined] 225 226 return res_flattened_mesh 227 228 def get_root_mesh(self, device_mesh: "DeviceMesh") -> "DeviceMesh": 229 # If a mesh could not be found in the child_to_root_mapping, it is a root mesh itself. 230 # A root mesh is not created through slicing. 231 # We considers the root mesh of a root mesh is itself. 232 root_mesh = self.child_to_root_mapping.get(device_mesh, None) 233 return device_mesh if not root_mesh else root_mesh 234 235 def get_root_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]: 236 """ 237 Returns the index of the mesh dim in the root mesh. 238 The device_mesh passed in needs to be sliced out from the root mesh 239 or submesh of the root mesh. 240 """ 241 root_mesh = self.get_root_mesh(device_mesh) 242 child_mesh_dim_names = device_mesh.mesh_dim_names 243 if root_mesh and child_mesh_dim_names: 244 assert ( 245 len(child_mesh_dim_names) == 1 246 ), "The submesh can only be a 1D mesh." 247 child_mesh_dim_name = child_mesh_dim_names[0] 248 return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name) 249 return None 250 251 @staticmethod 252 def num_devices_per_host(device_type: str) -> int: 253 return _get_device_handle(device_type).device_count() 254 255 @staticmethod 256 def num_hosts(device_type: str) -> int: 257 # ProcessGroup can't tell us this info so we have to infer it, assume 258 # homogeneous hardware for now 259 return get_world_size() // _MeshEnv.num_devices_per_host(device_type) 260 261 def get_mesh_dim_by_name( 262 self, device_mesh: "DeviceMesh", mesh_dim_name: str 263 ) -> int: 264 if ( 265 device_mesh.mesh_dim_names is None 266 or len(device_mesh.mesh_dim_names) == 0 267 ): 268 raise KeyError( 269 "No `mesh_dim_names` found.", 270 ) 271 if mesh_dim_name not in device_mesh.mesh_dim_names: 272 raise KeyError( 273 f"Mesh dimension '{mesh_dim_name}' does not exist.", 274 f"Available mesh dimensions are: mesh_dim_names={device_mesh.mesh_dim_names}", 275 ) 276 return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name)) 277 278 def _set_mesh_dim_group_options( 279 self, 280 dim: int, 281 backend: str, 282 pg_options: Optional[ProcessGroup.Options] = None, 283 ) -> None: 284 self.mesh_dim_group_options[dim] = (backend, pg_options) 285 286 def _get_slice_mesh_dims( 287 self, device_mesh, mesh_dim_names 288 ) -> List[Tuple[int, ...]]: 289 """ 290 Validate whether the mesh_dim_names is valid for slicing the given device_mesh. 291 If valid, return dim indexes of the slice mesh in the device mesh. 292 """ 293 if device_mesh != self.get_root_mesh(device_mesh): 294 raise RuntimeError("Cannot create a submesh from a submesh.") 295 296 # The slice mesh_dim_names should consist either the device_mesh's mesh_dim_names 297 # or its flattened mesh's mesh_dim_names. 298 self.flatten_name_to_root_dims.setdefault(device_mesh, {}) 299 flatten_name_to_root_dims = self.flatten_name_to_root_dims[device_mesh] 300 valid_mesh_dim_names = [ 301 *device_mesh.mesh_dim_names, 302 *flatten_name_to_root_dims, 303 ] 304 305 if not all( 306 mesh_dim_name in valid_mesh_dim_names 307 for mesh_dim_name in mesh_dim_names 308 ): 309 raise KeyError( 310 f"Invalid mesh_dim_names {mesh_dim_names} specified. " 311 f"Valid mesh_dim_names are {valid_mesh_dim_names}." 312 ) 313 314 # Validate the order of the slice mesh dim indices. 315 # This needs to be in ascending order. 316 curr_idx = -1 317 slice_mesh_dims = [] 318 for mesh_dim_name in mesh_dim_names: 319 if mesh_dim_name in flatten_name_to_root_dims: 320 mesh_indices = flatten_name_to_root_dims[mesh_dim_name] 321 # TODO: this doesn't allow non-contiguous slicing with flatten dim yet. next_idx 322 # should be mesh_indices[0] once we support non-contiguous slicing with flatten dim. 323 next_idx = mesh_indices[-1] 324 slice_mesh_dims.append(mesh_indices) 325 else: 326 next_idx = device_mesh.mesh_dim_names.index(mesh_dim_name) 327 slice_mesh_dims.append((next_idx,)) 328 if next_idx <= curr_idx: 329 raise KeyError( 330 f"Invalid mesh_dim_names {mesh_dim_names} specified. ", 331 f"Found mesh dim indices to slice: {slice_mesh_dims}. ", 332 "Mesh dim indices should be in ascending order.", 333 ) 334 curr_idx = next_idx 335 336 return slice_mesh_dims 337 338 def _get_all_submeshes( 339 self, device_mesh: "DeviceMesh", mesh_dim_name: str 340 ) -> List["DeviceMesh"]: 341 """ 342 Return all the submeshes of a given mesh dimension of the device mesh. 343 """ 344 mesh_dim = self.get_mesh_dim_by_name(device_mesh, mesh_dim_name) 345 pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape( 346 -1, device_mesh.mesh.size(mesh_dim) 347 ) 348 349 cur_rank = device_mesh.get_rank() 350 res_submeshes = [] 351 for mesh_1d in pg_ranks_by_dim: 352 submesh = DeviceMesh( 353 device_mesh.device_type, 354 mesh_1d, 355 mesh_dim_names=(mesh_dim_name,), 356 _init_backend=False, 357 ) 358 submesh._dim_group_infos = ( 359 [device_mesh._dim_group_infos[mesh_dim]] 360 if cur_rank in mesh_1d 361 else [] 362 ) 363 res_submeshes.append(submesh) 364 365 return res_submeshes 366 367 _mesh_resources: _MeshEnv = _MeshEnv() 368 369 def _get_device_handle(device_type: str = "cuda"): 370 """ 371 Get the module corresponding to the device_type which is cuda or cuda-like device. 372 For example, when the device_type is cuda, the module `torch.cuda` is returned. 373 Return None when there is no corresponding module for device_type, otherwise 374 return the corresponding module. 375 """ 376 return getattr(torch, device_type, None) 377 378 class DeviceMesh: 379 """ 380 DeviceMesh represents a mesh of devices, where layout of devices could be 381 represented as a n-d dimension array, and each value of the n-d dimensional 382 array is the global id of the default process group ranks. 383 384 DeviceMesh could be used to describe the layout of devices across the cluster, 385 and serves as a proxy for communication among the device lists within the cluster. 386 387 DeviceMesh can be used as a context manager. 388 389 .. note:: 390 DeviceMesh follows SPMD programming model, which means the same PyTorch Python program 391 is running on all processes/ranks in the cluster. Therefore, users need to make sure the 392 `mesh` array (which describes the layout of devices) should be identical across all ranks. 393 Inconsistent `mesh` will lead to silent hang. 394 395 Args: 396 device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". 397 mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout 398 of devices, where the IDs are global IDs of the default process group. 399 400 Returns: 401 DeviceMesh: A :class:`DeviceMesh` object representing the device layout. 402 403 The following program runs on each process/rank in an SPMD manner. In this example, we have 2 404 hosts with 4 GPUs each. 405 A reduction over the first dimension of mesh will reduce across 406 columns (0, 4), .. and (3, 7), a reduction over the second dimension 407 of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7). 408 409 Example:: 410 >>> # xdoctest: +SKIP("no rank") 411 >>> from torch.distributed.device_mesh import DeviceMesh 412 >>> 413 >>> # Initialize device mesh as (2, 4) to represent the topology 414 >>> # of cross-host(dim 0), and within-host (dim 1). 415 >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) 416 """ 417 418 device_type: str 419 mesh: torch.Tensor 420 mesh_dim_names: Optional[Tuple[str, ...]] 421 422 def __init__( 423 self, 424 device_type: str, 425 mesh: Union[torch.Tensor, "ArrayLike"], 426 *, 427 mesh_dim_names: Optional[Tuple[str, ...]] = None, 428 _init_backend: bool = True, 429 ) -> None: 430 self.device_type = device_type 431 if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu": 432 raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}") 433 self.mesh = ( 434 mesh.detach().to(dtype=torch.int) 435 if isinstance(mesh, torch.Tensor) 436 else torch.tensor(mesh, device="cpu", dtype=torch.int) 437 ) 438 self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None 439 440 # private field to pre-generate DeviceMesh's hash 441 self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) 442 self._thread_id = None 443 444 # Skip process group initialization if xla device or init backend is False 445 # TODO(yeounoh) implement DeviceMesh backend and register XLA backend. 446 if device_type != "xla": 447 # always try to create default (world) pg, even if it is not initialized 448 # already. The world pg is used for device mesh identity (rank) on each 449 # process (we need to know if the current global rank is in the mesh or not). 450 if _init_backend: 451 self._get_or_create_default_group() 452 self._init_process_groups() 453 454 if is_initialized() and get_backend() == "threaded": 455 self._thread_id = threading.get_ident() 456 457 # calculate the coordinates of the current global rank on the mesh 458 rank_coords = (self.mesh == get_rank()).nonzero() 459 assert rank_coords.size(0) in (0, 1) 460 self._coordinate_on_dim: Optional[List[int]] = ( 461 rank_coords[0].tolist() if rank_coords.size(0) > 0 else None 462 ) 463 464 def _get_or_create_default_group(self): 465 default_initialized = is_initialized() 466 if not default_initialized: 467 init_process_group() 468 469 world_size = get_world_size() 470 if self.mesh.numel() > world_size: 471 raise RuntimeError( 472 f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!" 473 ) 474 475 device_handle = _get_device_handle(self.device_type) 476 # TODO: if user want to pass pg_options, offer a way to do it 477 if not default_initialized and device_handle: 478 # automatically set the current cuda/cuda-like device base on num of gpu devices available in each host 479 # NOTE: This device selection would only work for homogeneous hardware. 480 num_devices_per_host = device_handle.device_count() 481 if ( 482 world_size > num_devices_per_host 483 and world_size % num_devices_per_host != 0 484 ): 485 raise RuntimeError( 486 f"DeviceMesh only support homogeneous hardware, but found " 487 f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" 488 ) 489 device_handle.set_device(get_rank() % num_devices_per_host) 490 491 return _get_default_group() 492 493 def _init_process_groups(self): 494 # tag/ranks/group_name associated with each mesh dimension, each 495 # mesh dimension should have one sub-group per rank 496 # 497 # TODO(yifu): remove tag and ranks once we fully migrate to native 498 # functional collectives. See details in: 499 # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208 500 dim_group_infos: List[Tuple[str, List[int], str]] = [] 501 502 if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size(): 503 # Append the default pg to the first dim groups only if the default pg is compatible with `self.device_type`. 504 # Otherwise, create new pg. 505 default_group = _get_default_group() 506 ranks = list(range(get_world_size())) 507 dim_group = ( 508 new_group(backend="cpu:gloo,cuda:nccl", ranks=ranks) 509 if torch.cuda.is_available() 510 and get_backend(default_group) == "gloo" 511 else default_group 512 ) 513 dim_group_infos.append( 514 ( 515 _get_group_tag(dim_group), 516 ranks, 517 dim_group.group_name, 518 ) 519 ) 520 else: 521 # create sub pgs base on the mesh argument specified 522 for dim in range(self.mesh.ndim): 523 # swap the current dim to the last dim 524 # then reshape to flatten out other dims 525 pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape( 526 -1, self.mesh.size(dim) 527 ) 528 # multi-dim mesh, create subgroups by looping over the pg_ranks 529 # for each dim and append the groups 530 for dim_mesh in pg_ranks_by_dim: 531 subgroup_ranks = dim_mesh.tolist() 532 533 # Respect dim group options specified via _MeshEnv.set_dim_group_options(). 534 # Inherit from the parent group if no options are specified for the group. 535 if dim in _mesh_resources.mesh_dim_group_options: 536 ( 537 backend, 538 pg_options, 539 ) = _mesh_resources.mesh_dim_group_options[dim] 540 else: 541 backend, pg_options = None, None 542 543 # We temporarily revert the re-use subgroup, since it breaks two internal tests. 544 # Temporarily reverting to resolve test timeout while root-causing. 545 # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists. 546 dim_group = new_group( 547 ranks=subgroup_ranks, 548 backend=backend, 549 pg_options=pg_options, 550 ) 551 552 # only add to dim_groups if the current rank in the subgroup 553 if self.get_rank() in subgroup_ranks: 554 if len(dim_group_infos) > dim: 555 raise RuntimeError( 556 f"Each device mesh dimension should get only one process group, but got {self.get_rank()} " 557 f"in {subgroup_ranks}!" 558 ) 559 dim_group_infos.append( 560 ( 561 _get_group_tag(not_none(dim_group)), 562 subgroup_ranks, 563 dim_group.group_name, 564 ) 565 ) 566 self._dim_group_infos = dim_group_infos 567 568 def __enter__(self) -> "DeviceMesh": 569 # set this mesh as the current mesh in mesh env 570 _mesh_resources.mesh_stack.append(self) 571 return self 572 573 # pyre-fixme[2]: Parameter must be annotated. 574 def __exit__(self, exc_type, exc_value, exc_traceback) -> None: 575 # pop this mesh from mesh env 576 _mesh_resources.mesh_stack.pop() 577 578 def __repr__(self) -> str: 579 device_mesh_repr = ( 580 f"DeviceMesh('{self.device_type}', {self.mesh.tolist()})" 581 if not self.mesh_dim_names 582 else f"DeviceMesh('{self.device_type}', {self.mesh.tolist()}, mesh_dim_names={self.mesh_dim_names})" 583 ) 584 return device_mesh_repr 585 586 def __hash__(self): 587 # lazily compute hash 588 self._hash = getattr(self, "_hash", None) 589 if not self._hash: 590 self._hash = hash( 591 ( 592 self._flatten_mesh_list, 593 self.mesh.shape, 594 self.device_type, 595 self.mesh_dim_names, 596 self._thread_id, 597 ) 598 ) 599 return self._hash 600 601 def __eq__(self, other: object) -> bool: 602 if not isinstance(other, DeviceMesh): 603 return False 604 if id(self) == id(other): 605 return True 606 else: 607 return ( 608 self._flatten_mesh_list == other._flatten_mesh_list 609 and self.mesh.shape == other.mesh.shape 610 and self.device_type == other.device_type 611 and self.mesh_dim_names == other.mesh_dim_names 612 and self._thread_id == other._thread_id 613 ) 614 615 def __getitem__( 616 self, mesh_dim_names: Union[str, Tuple[str, ...]] 617 ) -> "DeviceMesh": 618 """ 619 Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh. 620 The submesh created consists of the dimensions and the communicators indicated by 621 ``mesh_dim_names`` 622 623 Args: 624 mesh_dim_names (Union[str, Tuple[str]]): the name or the tuple of names of the 625 mesh dimension of the DeviceMesh to create the submesh for. 626 Returns: 627 A :class:`DeviceMesh` object 628 629 The following program runs on each process/rank in an SPMD manner in a world size of 8. 630 In the first example: 631 Calling mesh_2d["tp"] on rank 0, 1, 2, 3 returns a 1D submesh of DeviceMesh:([0, 1, 2, 3]). 632 Calling mesh_2d["tp"] on rank 4, 5, 6, 7 returns a 1D submesh of DeviceMesh:([4, 5, 6, 7]). 633 Calling mesh_2d["dp"] on rank 0, 4 returns a 1D submesh of DeviceMesh:([0, 4]). 634 Calling mesh_2d["dp"] on rank 1, 5 returns a 1D submesh of DeviceMesh:([1, 5]). 635 Calling mesh_2d["dp"] on rank 2, 6 returns a 1D submesh of DeviceMesh:([2, 6]). 636 Calling mesh_2d["dp"] on rank 3, 7 returns a 1D submesh of DeviceMesh:([3, 7]). 637 638 In the second example: 639 Calling mesh_3d["dp", "cp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 1], [4, 5]]). 640 Calling mesh_3d["dp", "cp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 3], [6, 7]]). 641 Calling mesh_3d["cp", "dp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 4], [1, 5]]). 642 Calling mesh_3d["cp", "dp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 6], [3, 7]]). 643 644 Example:: 645 >>> # xdoctest: +SKIP("no rank") 646 >>> from torch.distributed.device_mesh import DeviceMesh 647 >>> 648 >>> # Initialize a 2D device mesh as (2, 4) to represent the topology 649 >>> # of cross-host(dim 0), and within-host (dim 1). 650 >>> mesh_2d = init_device_mesh(device_type="cuda", (2,4), mesh_dim_names=("dp", "tp")) 651 >>> tp_mesh = mesh_2d["tp"] 652 >>> dp_mesh = mesh_2d["dp"] 653 >>> 654 >>> # Initialize a 3D mesh. 655 >>> mesh_3d = init_device_mesh(device_type="cuda", (2,2,2), mesh_dim_names=("dp", "pp", "cp")) 656 >>> # The order of the mesh_dim_names provided deteremines the order of dimensions in the submesh. 657 >>> dp_cp_mesh = mesh_3d["dp", "cp"] 658 >>> cp_dp_mesh = mesh_3d["cp", "dp"] 659 """ 660 if not self.mesh_dim_names: 661 raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!") 662 663 mesh_dim_names = ( 664 (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names 665 ) 666 667 if mesh_dim_names == self.mesh_dim_names: 668 return self 669 else: 670 slice_mesh_dims = _mesh_resources._get_slice_mesh_dims( 671 self, mesh_dim_names 672 ) 673 submesh = _mesh_resources.create_sub_mesh( 674 self, mesh_dim_names, slice_mesh_dims 675 ) 676 return submesh 677 678 def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup: 679 """ 680 Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the 681 DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh. 682 683 Args: 684 mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index 685 of the mesh dimension. Default is None. 686 687 Returns: 688 A :class:`ProcessGroup` object. 689 """ 690 if not hasattr(self, "_dim_group_infos"): 691 raise RuntimeError("DeviceMesh process groups not initialized!") 692 693 if self.mesh.ndim > 1 and mesh_dim is None: 694 raise RuntimeError( 695 f"Found the DeviceMesh have {self.mesh.ndim} dimensions", 696 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", 697 "If you want to get the list of all the ProcessGroups in the DeviceMesh," 698 "please use `get_all_groups()` instead.", 699 ) 700 701 # Quick return if the current device_mesh is a 1D mesh. 702 if self.mesh.ndim == 1 and mesh_dim is None: 703 return not_none( 704 _find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2]) # type: ignore[index] 705 ) 706 707 root_mesh = _mesh_resources.get_root_mesh(self) 708 root_to_flatten_mapping = _mesh_resources.root_to_flatten_mapping.get( 709 root_mesh, None 710 ) 711 if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys(): 712 dim_group_infos = root_to_flatten_mapping[mesh_dim]._dim_group_infos[0][:2] # type: ignore[index] 713 return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos)) 714 else: 715 mesh_dim = ( 716 _mesh_resources.get_mesh_dim_by_name(self, mesh_dim) 717 if isinstance(mesh_dim, str) 718 else mesh_dim 719 ) 720 return not_none( 721 _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) # type: ignore[index] 722 ) 723 724 def get_all_groups(self) -> List[ProcessGroup]: 725 """ 726 Returns a list of ProcessGroups for all mesh dimensions. 727 728 Returns: 729 A list of :class:`ProcessGroup` object. 730 """ 731 return [self.get_group(i) for i in range(self.mesh.ndim)] 732 733 @staticmethod 734 def from_group( 735 group: Union[ProcessGroup, List[ProcessGroup]], 736 device_type: str, 737 mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None, 738 *, 739 mesh_dim_names: Optional[Tuple[str, ...]] = None, 740 ) -> "DeviceMesh": 741 """ 742 Constructs a :class:`DeviceMesh` with ``device_type`` from an 743 existing :class:`ProcessGroup`. 744 745 The constructed device mesh has number of dimensions equal to the 746 number of groups passed. If more than one group is passed, then the 747 ``mesh`` argument is required. 748 """ 749 if isinstance(group, ProcessGroup): 750 group_ranks = get_process_group_ranks(group) 751 if ( 752 isinstance(mesh, torch.Tensor) and mesh.tolist() != group_ranks 753 ) or (mesh is not None and mesh != group_ranks): 754 raise ValueError( 755 f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}" 756 ) 757 mesh = torch.tensor(group_ranks, device="cpu", dtype=torch.int) 758 device_mesh = DeviceMesh( 759 device_type, 760 mesh, 761 mesh_dim_names=mesh_dim_names, 762 _init_backend=False, 763 ) 764 device_mesh._dim_group_infos = [ 765 (_get_group_tag(group), group_ranks, group.group_name) 766 ] 767 return device_mesh 768 groups = list(group) 769 if len(groups) == 0: 770 raise ValueError("Expects at least one ProcessGroup to be passed") 771 if mesh is None: 772 raise ValueError("Must pass mesh if passing multiple ProcessGroups") 773 mesh = ( 774 mesh.detach().to(dtype=torch.int, device="cpu") 775 if isinstance(mesh, torch.Tensor) 776 else torch.tensor(mesh, device="cpu", dtype=torch.int) 777 ) 778 if mesh.ndim != len(groups): 779 raise ValueError( 780 "Expects mesh with ndim equal to number of ProcessGroups but got " 781 f"mesh {mesh.tolist()} and {len(groups)} ProcessGroups" 782 ) 783 device_mesh = DeviceMesh( 784 device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False 785 ) 786 device_mesh._dim_group_infos = [ 787 ( 788 _get_group_tag(group), 789 get_process_group_ranks(group), 790 group.group_name, 791 ) 792 for group in groups 793 ] 794 return device_mesh 795 796 def size(self, mesh_dim: Optional[int] = None) -> int: 797 return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim) 798 799 @property 800 def ndim(self) -> int: 801 return self.mesh.ndim 802 803 @property 804 def shape(self) -> Tuple[int, ...]: 805 return tuple(self.mesh.shape) 806 807 def get_rank(self) -> int: 808 """ 809 Returns the current global rank. 810 """ 811 return get_rank() 812 813 def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: 814 """ 815 Returns the local rank of the given mesh_dim of the DeviceMesh. 816 817 Args: 818 mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index 819 of the mesh dimension. Default is None. 820 821 Returns: 822 An integer denotes the local rank. 823 824 The following program runs on each process/rank in an SPMD manner. In this example, we have 2 825 hosts with 4 GPUs each. 826 Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0. 827 Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1. 828 Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0. 829 Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1. 830 Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2. 831 Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3. 832 833 Example:: 834 >>> # xdoctest: +SKIP("no rank") 835 >>> from torch.distributed.device_mesh import DeviceMesh 836 >>> 837 >>> # Initialize device mesh as (2, 4) to represent the topology 838 >>> # of cross-host(dim 0), and within-host (dim 1). 839 >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) 840 """ 841 if self.ndim > 1 and mesh_dim is None: 842 raise RuntimeError( 843 f"Found the DeviceMesh have {self.mesh.ndim} dimensions", 844 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", 845 ) 846 elif mesh_dim is None: 847 mesh_dim = 0 848 849 mesh_dim_group = not_none(self.get_group(mesh_dim)) 850 assert isinstance( 851 mesh_dim_group, ProcessGroup 852 ), "We expect ProcessGroup before calling `get_rank`!" 853 return not_none(get_rank(mesh_dim_group)) 854 855 def get_coordinate(self) -> Optional[List[int]]: 856 """ 857 Return the relative indices of this rank relative to all 858 dimensions of the mesh. If this rank is not part of the mesh, return None. 859 """ 860 return self._coordinate_on_dim if self._coordinate_on_dim else None 861 862 def _flatten(self, mesh_dim_name: Optional[str] = None) -> "DeviceMesh": 863 """ 864 Returns a 1D DeviceMesh by flattening the current DeviceMesh. 865 866 If no mesh_dim_name is provided, the default is a string concatentaing the mesh_dim_names of the 867 given submesh with each mesh_dim_name separated by "_". For example, if we have a 3D mesh 868 DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")), calling 869 mesh_3d["dp", "cp"]._flatten() will create a 1D submesh DeviceMesh([0, 1, 2, 3], mesh_dim_names=("dp_cp",)) 870 on rank 0, 1, 2, 3 and a 1D submesh DeviceMesh([4, 5, 6, 7], mesh_dim_names=("dp_cp",)) on rank 4, 5, 6, 7. 871 872 After the flattened dimension is created, to access the flattened dimesnion in mesh_3d, one can use the 873 existing slicing method to obtain the flattened mesh through calling mesh_3d["dp_cp"]. 874 """ 875 if not self.mesh_dim_names: 876 raise RuntimeError( 877 "Cannot flatten a DeviceMesh without mesh_dim_names!" 878 ) 879 880 return _mesh_resources.create_flatten_mesh(self, mesh_dim_name) 881 882 def init_device_mesh( 883 device_type: str, 884 mesh_shape: Tuple[int, ...], 885 *, 886 mesh_dim_names: Optional[Tuple[str, ...]] = None, 887 ) -> DeviceMesh: 888 """ 889 Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. 890 891 This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`. 892 If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`. 893 894 .. note:: 895 `init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program 896 runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array 897 describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging. 898 899 .. note:: 900 If no process group is found, init_device_mesh will initialize distributed process group/groups 901 required for distributed communications behind the scene. 902 903 Args: 904 device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". 905 Passing in a device type with a GPU index, such as "cuda:0", is not allowed. 906 mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array 907 describing the layout of devices. 908 mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension 909 of the multi-dimensional array describing the layout of devices. Its length must match the length 910 of `mesh_shape`. Each string in `mesh_dim_names` must be unique. 911 912 Returns: 913 DeviceMesh: A :class:`DeviceMesh` object representing the device layout. 914 915 Example:: 916 >>> # xdoctest: +SKIP("no rank") 917 >>> from torch.distributed.device_mesh import init_device_mesh 918 >>> 919 >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,)) 920 >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) 921 922 """ 923 if mesh_dim_names is not None: 924 if len(set(mesh_dim_names)) != len(mesh_dim_names): 925 raise RuntimeError( 926 "Each mesh_dim_name must be unique.", 927 f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}", 928 ) 929 930 if len(mesh_shape) != len(mesh_dim_names): 931 raise RuntimeError( 932 "mesh_shape and mesh_dim_names should have same length!", 933 f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.", 934 ) 935 936 # assume valid device types are all letters 937 if device_type and not device_type.isalpha(): 938 raise RuntimeError( 939 f"Device type with GPU index is not supported but got {device_type}. ", 940 "If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.", 941 ) 942 943 # Always initialize the mesh's tensor on CPU, regardless of what the 944 # external device type has been set to be (e.g. meta) 945 with torch.device("cpu"): 946 mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape) 947 device_mesh = DeviceMesh( 948 device_type=device_type, 949 mesh=mesh, 950 mesh_dim_names=mesh_dim_names, 951 ) 952 953 return device_mesh 954