1# mypy: allow-untyped-defs 2import warnings 3from typing import Tuple, Union 4 5from torch.distributed.device_mesh import _mesh_resources 6from torch.distributed.tensor import DeviceMesh 7from torch.distributed.tensor.placement_types import Placement 8 9 10try: 11 from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling 12except Exception: 13 14 def is_torchdynamo_compiling(): # type: ignore[misc] 15 return False 16 17 18LayoutsType = Union[Placement, Tuple[Placement, ...]] 19 20 21def _deprecate_warnings(func_name: str, extra_msg: str) -> None: 22 """ 23 Inject common validation logics for `_prepare_input` funcs via this decorator. 24 25 Include verifying that input needs to be either a :class:`Tensor` or :class:`DTensor` 26 and only 1D :class:`DeviceMesh` is passed in. 27 """ 28 # TODO: Will follow up with dynamo POC to make warnings.warn working with dynamo. 29 if not is_torchdynamo_compiling(): 30 warnings.warn( 31 f"{func_name} is deprecated and will be removed soon. {extra_msg}", 32 FutureWarning, 33 stacklevel=3, 34 ) 35 36 37def _validate_tp_mesh_dim( 38 device_mesh: DeviceMesh, 39) -> None: 40 """ 41 Check whether TP mesh dimension is valid or not. 42 43 Args: 44 device_mesh (:class:`DeviceMesh`): 45 The `device_mesh` where we perform 46 Tensor Parallelism on. 47 48 Return: 49 `True` if the mesh dimension 50 is valid, `False` otherwise. 51 """ 52 if device_mesh.ndim > 1: 53 raise ValueError( 54 f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!" 55 'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]' 56 ) 57 58 root_mesh = _mesh_resources.get_root_mesh(device_mesh) 59 # if a root mesh is not the same as device_mesh, 60 # meaning the device_mesh is sliced out from the root mesh. 61 if root_mesh and root_mesh != device_mesh: 62 tp_mesh_dim_in_root = _mesh_resources.get_root_mesh_dim(device_mesh) 63 if tp_mesh_dim_in_root != root_mesh.ndim - 1: 64 raise RuntimeError( 65 f"Found TP device_mesh on the {tp_mesh_dim_in_root} dimension of its parent mesh.", 66 "Currently we only support intranode TP and TP needs to be the innermost dimension on its parent mesh.", 67 ) 68