1from typing import Any, Dict, Iterable, List, no_type_check, Type 2 3import torch 4 5 6__all__: List[str] = [] 7 8# WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter 9# without changing it's life-time. 10# NOTE: Alternative is to add the meta-data as an attribute to the tensor, 11# but that will serialize the meta-data if Tensor is serialized. 12param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary() 13param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary() 14 15 16@no_type_check 17def _apply_optimizer_in_backward( 18 optimizer_class: Type[torch.optim.Optimizer], 19 params: Iterable[torch.nn.Parameter], 20 optimizer_kwargs: Dict[str, Any], 21 register_hook: bool = True, 22) -> None: 23 """ 24 Upon ``backward()``, the optimizer specified for each parameter will fire after 25 the gradient has been accumulated into the parameter. 26 27 Note - gradients for these parameters will be set to None after ``backward()``. 28 This means that any other optimizer not specified via `_apply_optimizer_in_backward` 29 over this parameter will be a no-op. 30 31 Args: 32 optimizer_class: (Type[torch.optim.Optimizer]): Optimizer to apply to parameter 33 params: (Iterator[nn.Parameter]): parameters to apply optimizer state to 34 optimizer_kwargs: (Dict[str, Any]): kwargs to pass to optimizer constructor 35 register_hook: (bool): whether to register a hook that runs the optimizer 36 after gradient for this parameter is accumulated. This is the default 37 way that optimizer in backward is implemented, but specific use cases 38 (such as DDP) may wish to override this to implement custom behavior. 39 (Default = True) 40 41 Example:: 42 params_generator = model.parameters() 43 param_1 = next(params_generator) 44 remainder_params = list(params_generator) 45 46 apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": .02}) 47 apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": .04}) 48 49 model(...).sum().backward() # after backward, parameters will already 50 # have their registered optimizer(s) applied. 51 52 """ 53 torch._C._log_api_usage_once("torch.distributed.optim.apply_optimizer_in_backward") 54 55 @no_type_check 56 def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None: 57 # view_as creates a node in autograd graph that allows us access to the 58 # parameter's AccumulateGrad autograd function object. We register a 59 # hook on this object to fire the optimizer when the gradient for 60 # this parameter is ready (has been accumulated into .grad field) 61 62 # Don't create a new acc_grad if we already have one 63 # i.e. for shared parameters or attaching multiple optimizers to a param. 64 if param not in param_to_acc_grad_map: 65 param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[ 66 0 67 ][0] 68 69 optimizer = optimizer_class([param], **optimizer_kwargs) 70 71 if not hasattr(param, "_in_backward_optimizers"): 72 param._in_backward_optimizers = [] # type: ignore[attr-defined] 73 # TODO: Remove these attributes once we have a better way of accessing 74 # optimizer classes and kwargs for a parameter. 75 param._optimizer_classes = [] # type: ignore[attr-defined] 76 param._optimizer_kwargs = [] # type: ignore[attr-defined] 77 78 param._in_backward_optimizers.append(optimizer) # type: ignore[attr-defined] 79 param._optimizer_classes.append(optimizer_class) # type: ignore[attr-defined] 80 param._optimizer_kwargs.append(optimizer_kwargs) # type: ignore[attr-defined] 81 82 if not register_hook: 83 return 84 85 def optimizer_hook(*_unused) -> None: 86 for opt in param._in_backward_optimizers: # type: ignore[attr-defined] 87 opt.step() 88 89 param.grad = None 90 91 handle = param_to_acc_grad_map[param].register_hook(optimizer_hook) # type: ignore[attr-defined] 92 if param not in param_to_optim_hook_handle_map: 93 param_to_optim_hook_handle_map[param] = [] 94 param_to_optim_hook_handle_map[param].append(handle) 95 96 for param in params: 97 _apply_optimizer_in_backward_to_param(param) 98 99 100def _get_in_backward_optimizers(module: torch.nn.Module) -> List[torch.optim.Optimizer]: 101 """ 102 Return a list of in-backward optimizers applied to ``module``'s parameters. Note that these 103 optimizers are not intended to directly have their ``step`` or ``zero_grad`` methods called 104 by the user and are intended to be used for things like checkpointing. 105 106 Args: 107 module: (torch.nn.Module): model to retrieve in-backward optimizers for 108 109 Returns: 110 List[torch.optim.Optimizer]: the in-backward optimizers. 111 112 Example:: 113 _apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {'lr': 0.01}) 114 optims = _get_optimizers_in_backward(model) 115 """ 116 optims: List[torch.optim.Optimizer] = [] 117 for param in module.parameters(): 118 optims.extend(getattr(param, "_in_backward_optimizers", [])) 119 120 return optims 121