xref: /aosp_15_r20/external/pytorch/torch/distributed/_shard/sharder.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import abc
2
3import torch.nn as nn
4
5
6class Sharder(abc.ABC):
7    """
8    This is an interface which allows user to create more advanced
9    sharding strategies that are not easily be composed by the
10    `ShardingSpec`.
11
12    :class:`torch.distributed._shard.sharding_plan.ShardingPlan` could
13    take an object of the `Sharder` and call `shard` to shard the module,
14    then replace the original module with sharded module returned.
15    """
16
17    @abc.abstractmethod
18    def shard(self, module: nn.Module) -> nn.Module:
19        """
20        Shard a module base on the implementation of this method, and
21        return the sharded version of the module.
22
23        Args:
24            module (:class:`torch.nn.Module`):
25                The module to apply sharding to.
26        Returns:
27            A :class:`torch.nn.Module` object that represents a module
28            that's already been sharded.
29        """
30