xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/parallel/_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import warnings
3from typing import Tuple, Union
4
5from torch.distributed.device_mesh import _mesh_resources
6from torch.distributed.tensor import DeviceMesh
7from torch.distributed.tensor.placement_types import Placement
8
9
10try:
11    from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling
12except Exception:
13
14    def is_torchdynamo_compiling():  # type: ignore[misc]
15        return False
16
17
18LayoutsType = Union[Placement, Tuple[Placement, ...]]
19
20
21def _deprecate_warnings(func_name: str, extra_msg: str) -> None:
22    """
23    Inject common validation logics for `_prepare_input` funcs via this decorator.
24
25    Include verifying that input needs to be either a :class:`Tensor` or :class:`DTensor`
26    and only 1D :class:`DeviceMesh` is passed in.
27    """
28    # TODO: Will follow up with dynamo POC to make warnings.warn working with dynamo.
29    if not is_torchdynamo_compiling():
30        warnings.warn(
31            f"{func_name} is deprecated and will be removed soon. {extra_msg}",
32            FutureWarning,
33            stacklevel=3,
34        )
35
36
37def _validate_tp_mesh_dim(
38    device_mesh: DeviceMesh,
39) -> None:
40    """
41    Check whether TP mesh dimension is valid or not.
42
43    Args:
44        device_mesh (:class:`DeviceMesh`):
45            The `device_mesh` where we perform
46            Tensor Parallelism on.
47
48    Return:
49        `True` if the mesh dimension
50        is valid, `False` otherwise.
51    """
52    if device_mesh.ndim > 1:
53        raise ValueError(
54            f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!"
55            'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]'
56        )
57
58    root_mesh = _mesh_resources.get_root_mesh(device_mesh)
59    # if a root mesh is not the same as device_mesh,
60    # meaning the device_mesh is sliced out from the root mesh.
61    if root_mesh and root_mesh != device_mesh:
62        tp_mesh_dim_in_root = _mesh_resources.get_root_mesh_dim(device_mesh)
63        if tp_mesh_dim_in_root != root_mesh.ndim - 1:
64            raise RuntimeError(
65                f"Found TP device_mesh on the {tp_mesh_dim_in_root} dimension of its parent mesh.",
66                "Currently we only support intranode TP and TP needs to be the innermost dimension on its parent mesh.",
67            )
68