xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/parallel/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2from fnmatch import fnmatch
3from typing import Dict, Union
4
5import torch
6import torch.distributed.tensor._random as random
7import torch.nn as nn
8from torch.distributed.tensor import DeviceMesh
9from torch.distributed.tensor._random import (
10    is_rng_supported_mesh,
11    TensorParallelRNGTracker,
12)
13from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
14from torch.distributed.tensor.parallel.style import ParallelStyle
15
16
17__all__ = [
18    "parallelize_module",
19]
20
21
22def parallelize_module(  # type: ignore[return]
23    module: nn.Module,
24    device_mesh: DeviceMesh,
25    parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
26) -> nn.Module:
27    """
28    Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
29
30    We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
31    :class:`ParallelStyle`, which indicates how user wants the module or sub_module
32    to be parallelized.
33
34    User can also specify different parallel style per module fully qualified name (FQN).
35
36    Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,
37    slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)
38
39    Args:
40        module (:class:`nn.Module`):
41            Module to be parallelized.
42        device_mesh (:class:`DeviceMesh`):
43            Object which describes the mesh topology
44            of devices for the DTensor.
45        parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]):
46            The plan used to parallelize the module. It can be either a
47            :class:`ParallelStyle` object which contains how
48            we prepare input/output for Tensor Parallelism or it can be a
49            dict of module FQN and its corresponding :class:`ParallelStyle` object.
50    Return:
51        A :class:`nn.Module` object parallelized.
52
53    Example::
54        >>> # xdoctest: +SKIP("distributed")
55        >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
56        >>> from torch.distributed.device_mesh import init_device_mesh
57        >>>
58        >>> # Define the module.
59        >>> m = Model(...)
60        >>> tp_mesh = init_device_mesh("cuda", (8,))
61        >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
62        >>>
63
64    .. note:: For complex module architecture like Attention, MLP layers, we recommend composing
65        different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass
66        as a parallelize_plan, to achieves the desired sharding computation.
67    """
68    torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
69
70    _validate_tp_mesh_dim(device_mesh)
71
72    # instantiate a TP RNG state tracker if it's not there
73    if is_rng_supported_mesh(device_mesh) and not isinstance(
74        random._rng_tracker, TensorParallelRNGTracker
75    ):
76        random._rng_tracker = TensorParallelRNGTracker(device_mesh.device_type)
77        # TODO: we should allow user to pass in the default seed from a config
78        random._rng_tracker._manual_seed(device_mesh, base_seed=1234)
79        # By default we execute random ops in non-tensor-parallel region. If users want
80        # to execute in tensor-parallel region, they can manually set this field to True
81        # after parallelizing the model.
82        random._rng_tracker.distribute_region_enabled = False
83
84    if isinstance(parallelize_plan, ParallelStyle):
85        return parallelize_plan._apply(module, device_mesh)
86    elif isinstance(parallelize_plan, dict):
87        for module_path, parallelize_style in parallelize_plan.items():
88            path_splits = module_path.split(".")
89            if len(path_splits) == 0:
90                raise ValueError(
91                    "Expect module path to be non-empty, but got empty string!"
92                )
93            while path_splits:
94                atom = path_splits.pop(0)
95                matched_children = filter(
96                    # `t[0]` is child name
97                    lambda t: fnmatch(t[0], atom),
98                    module.named_children(),
99                )
100                # apply the plan to all matched submodules
101                for _, submodule in matched_children:
102                    if path_splits:
103                        # we haven't reached the leaf, apply in dict style
104                        leaf_path = ".".join(
105                            path_splits
106                        )  # rest of the path after `atom`
107                        parallelize_module(
108                            submodule, device_mesh, {leaf_path: parallelize_style}
109                        )
110                    else:
111                        # otherwise, directly apply style to this submodule
112                        parallelize_module(submodule, device_mesh, parallelize_style)
113        return module
114    else:
115        raise TypeError(  # pyre-ignore[7]
116            "Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
117            f" parallelize_plan, {type(parallelize_plan)} found!"
118        )
119