xref: /aosp_15_r20/external/pytorch/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import inspect
3from abc import ABC, abstractmethod
4from typing import Dict, Type
5
6from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
7from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
8    _hook_then_optimizer,
9    _OptimizerHookState,
10)
11from torch.distributed.fsdp import FullyShardedDataParallel
12from torch.distributed.optim import as_functional_optim
13from torch.nn.parallel import DistributedDataParallel
14from torch.optim import Optimizer
15
16
17# Contains the mappings between the regular and overlapped optimizer types.
18_registered_overlapped_optims: Dict[Type, Type] = {}
19
20
21def register_overlapped(optim_cls):
22    def decorator(target_overlapped_optim_cls):
23        if target_overlapped_optim_cls in _registered_overlapped_optims:
24            raise ValueError(
25                f"{target_overlapped_optim_cls} already registered with optim_cls "
26                f"{_registered_overlapped_optims[optim_cls]} {optim_cls}, trying to"
27                f"re-register it for {optim_cls} is not supported."
28            )
29        _registered_overlapped_optims[optim_cls] = target_overlapped_optim_cls
30        return target_overlapped_optim_cls
31
32    return decorator
33
34
35class OverlappedOptimizer(ABC):
36    def __init__(self, optim_cls: Type) -> None:
37        """
38        Initialize the OverlappedOptimizer.
39
40        Overlappedoptimizer is a base class that child classes can implement to
41        specify how different optimizers will register themselves with DDP.
42        """
43        self.optim_cls = optim_cls
44
45    @abstractmethod
46    def register_ddp(self, ddp: DistributedDataParallel) -> None:
47        """Registers the overlapped optimizer with DDP."""
48        raise NotImplementedError(
49            f"{self.__class__.__name__} does not support overlapped DDP."
50        )
51
52    @abstractmethod
53    def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None:
54        """Registers the overlapped optimizer with FSDP."""
55        raise NotImplementedError(
56            f"{self.__class__.__name__} does not support overlapped FSDP."
57        )
58
59
60@register_overlapped(Optimizer)
61class _OverlappedStandardOptimizer(OverlappedOptimizer):
62    """Overlaps a regular ``Optimizer``."""
63
64    def __init__(self, optim_cls: Type, params, *optim_args, **optim_kwargs) -> None:
65        super().__init__(optim_cls)
66        f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs)
67        self._opt_hook_state = _OptimizerHookState(f_optim, params)
68
69    def register_ddp(self, ddp_inst: DistributedDataParallel):
70        # NOTE: using a custom communication hook and fused optimizer is not
71        # yet supported.
72        ddp_inst.register_comm_hook(  # type: ignore[operator]
73            None,  # wrapped hook state
74            _hook_then_optimizer(allreduce_hook, self._opt_hook_state),
75        )
76
77    # TODO: register_fsdp once FSDP supports communication hook.
78    def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None:
79        """Register the overlapped optimizer with FSDP."""
80        raise NotImplementedError(
81            f"{self.__class__.__name__} does not support overlapped FSDP."
82        )
83
84
85def _as_overlapped_optim(optim_cls: Type, params, *args, **kwargs):
86    """Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``."""
87    for clz in inspect.getmro(optim_cls):
88        try:
89            return _registered_overlapped_optims[clz](
90                optim_cls, params, *args, **kwargs
91            )
92        except KeyError:
93            pass
94
95    # Fallback to standard overlapped optimizer, which will raise errors if user
96    # is attempting to use an unsupported optimizer.
97    return _OverlappedStandardOptimizer(optim_cls, params, *args, **kwargs)
98