xref: /aosp_15_r20/external/pytorch/torch/distributed/_shard/sharding_plan/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import abc
2from dataclasses import dataclass
3from typing import Dict, List, Optional, Union
4
5import torch.nn as nn
6from torch.distributed._shard.sharder import Sharder
7from torch.distributed._shard.sharding_spec import ShardingSpec
8
9
10@dataclass
11class ShardingPlan:
12    """
13    Representation of a sharding plan, describes how to shard a module
14    across hosts. `plan` is used to shard module parameters according to the spec provided,
15    `output_plan` and `return_local_tensor` are optional, they are used to specify the output
16    layout of a module with a spec, and when to convert back to data parallel fashion.
17
18    Args:
19        plan (Dict[str, Union[:class:`torch.distributed._shard.sharding_spec.ShardingSpec`,
20              :class:`torch.distributed._shard.sharder.Sharder`]):
21            a dict describes how to shard a module, there're currently two ways to shard a module:
22                1. directly shard a module parameter by a `ShardingSpec`, keyed by the name of
23                   a parameter to a `ShardingSpec`.
24                2. shard a submodule by applying a `Sharder` on it, keyed by the name of a module
25                   to a `Sharder` object.
26        output_plan (Dict[str, :class:`torch.distributed._shard.sharding_spec.ShardingSpec`), optional):
27            a dict specifies the layout of a module's output which produces a ShardedTensor,
28            keyed by the name of module to ShardingSpec("" in key means the root module).
29            Default: `None`
30        return_local_tensor (List[str], optional): a list of string, each element enables
31            a module's sharded output to be returned as a Tensor from its local shards to
32            ensure further processing in a data parallel fashion. ("" in list means the
33            root module).
34            Default: None
35    Example:
36      Suppose we want to shard a module with two linear layers and then run it with DDP, we also
37      want to convert the output of the second linear layer back to DDP, we can do it as follows:
38
39        >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
40        >>> class MyModule(nn.Module):
41        >>>     def __init__(self) -> None:
42        >>>        super().__init__()
43        >>>        self.fc1 = nn.Linear()
44        >>>        self.gelu = nn.GELU()
45        >>>        self.fc2 = nn.Linear()
46        >>>        self.relu = nn.Linear()
47        >>>
48        >>>     def forward(self, input):
49        >>>         return self.relu(self.fc2(self.gelu(self.fc1(input))))
50
51
52        >>> # xdoctest: +SKIP("Undefined spec1, spec2)
53        >>> sharding_plan = ShardingPlan(
54        >>>    plan={
55        >>>        "fc1.weight": spec1,
56        >>>        "fc2.weight": spec2
57        >>>    },
58        >>>    output_plan={
59        >>>        "fc2": output_spec
60        >>>    },
61        >>>    return_local_tensor=["fc2"]
62        >>> )
63    """
64
65    plan: Dict[str, Union[ShardingSpec, Sharder]]
66    output_plan: Optional[Dict[str, ShardingSpec]] = None
67    return_local_tensor: Optional[List[str]] = None
68
69
70class ShardingPlanner(abc.ABC):
71    """
72    Default ShardingPlanner interface, can be extended and
73    implement advanced sharding strategies.
74    """
75
76    @abc.abstractmethod
77    def build_plan(self, module: nn.Module) -> ShardingPlan:
78        """
79        Given a nn.Module, define how to shard the module across
80        ranks, return a ShardingPlan
81        Args:
82            module (:class:`torch.nn.Module`):
83                The module to apply sharding to.
84        Returns:
85            A :class:`torch.distributed._shard.sharding_plan.ShardingPlan` object that
86            represents how to shard the module.
87        """
88