1# mypy: allow-untyped-defs 2import inspect 3from abc import ABC, abstractmethod 4from typing import Dict, Type 5 6from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook 7from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import ( 8 _hook_then_optimizer, 9 _OptimizerHookState, 10) 11from torch.distributed.fsdp import FullyShardedDataParallel 12from torch.distributed.optim import as_functional_optim 13from torch.nn.parallel import DistributedDataParallel 14from torch.optim import Optimizer 15 16 17# Contains the mappings between the regular and overlapped optimizer types. 18_registered_overlapped_optims: Dict[Type, Type] = {} 19 20 21def register_overlapped(optim_cls): 22 def decorator(target_overlapped_optim_cls): 23 if target_overlapped_optim_cls in _registered_overlapped_optims: 24 raise ValueError( 25 f"{target_overlapped_optim_cls} already registered with optim_cls " 26 f"{_registered_overlapped_optims[optim_cls]} {optim_cls}, trying to" 27 f"re-register it for {optim_cls} is not supported." 28 ) 29 _registered_overlapped_optims[optim_cls] = target_overlapped_optim_cls 30 return target_overlapped_optim_cls 31 32 return decorator 33 34 35class OverlappedOptimizer(ABC): 36 def __init__(self, optim_cls: Type) -> None: 37 """ 38 Initialize the OverlappedOptimizer. 39 40 Overlappedoptimizer is a base class that child classes can implement to 41 specify how different optimizers will register themselves with DDP. 42 """ 43 self.optim_cls = optim_cls 44 45 @abstractmethod 46 def register_ddp(self, ddp: DistributedDataParallel) -> None: 47 """Registers the overlapped optimizer with DDP.""" 48 raise NotImplementedError( 49 f"{self.__class__.__name__} does not support overlapped DDP." 50 ) 51 52 @abstractmethod 53 def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None: 54 """Registers the overlapped optimizer with FSDP.""" 55 raise NotImplementedError( 56 f"{self.__class__.__name__} does not support overlapped FSDP." 57 ) 58 59 60@register_overlapped(Optimizer) 61class _OverlappedStandardOptimizer(OverlappedOptimizer): 62 """Overlaps a regular ``Optimizer``.""" 63 64 def __init__(self, optim_cls: Type, params, *optim_args, **optim_kwargs) -> None: 65 super().__init__(optim_cls) 66 f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs) 67 self._opt_hook_state = _OptimizerHookState(f_optim, params) 68 69 def register_ddp(self, ddp_inst: DistributedDataParallel): 70 # NOTE: using a custom communication hook and fused optimizer is not 71 # yet supported. 72 ddp_inst.register_comm_hook( # type: ignore[operator] 73 None, # wrapped hook state 74 _hook_then_optimizer(allreduce_hook, self._opt_hook_state), 75 ) 76 77 # TODO: register_fsdp once FSDP supports communication hook. 78 def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None: 79 """Register the overlapped optimizer with FSDP.""" 80 raise NotImplementedError( 81 f"{self.__class__.__name__} does not support overlapped FSDP." 82 ) 83 84 85def _as_overlapped_optim(optim_cls: Type, params, *args, **kwargs): 86 """Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``.""" 87 for clz in inspect.getmro(optim_cls): 88 try: 89 return _registered_overlapped_optims[clz]( 90 optim_cls, params, *args, **kwargs 91 ) 92 except KeyError: 93 pass 94 95 # Fallback to standard overlapped optimizer, which will raise errors if user 96 # is attempting to use an unsupported optimizer. 97 return _OverlappedStandardOptimizer(optim_cls, params, *args, **kwargs) 98