xref: /aosp_15_r20/external/pytorch/torch/distributed/optim/apply_optimizer_in_backward.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Any, Dict, Iterable, List, no_type_check, Type
2
3import torch
4
5
6__all__: List[str] = []
7
8# WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter
9# without changing it's life-time.
10# NOTE: Alternative is to add the meta-data as an attribute to the tensor,
11#       but that will serialize the meta-data if Tensor is serialized.
12param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary()
13param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary()
14
15
16@no_type_check
17def _apply_optimizer_in_backward(
18    optimizer_class: Type[torch.optim.Optimizer],
19    params: Iterable[torch.nn.Parameter],
20    optimizer_kwargs: Dict[str, Any],
21    register_hook: bool = True,
22) -> None:
23    """
24    Upon ``backward()``, the optimizer specified for each parameter will fire after
25    the gradient has been accumulated into the parameter.
26
27    Note - gradients for these parameters will be set to None after ``backward()``.
28    This means that any other optimizer not specified via `_apply_optimizer_in_backward`
29    over this parameter will be a no-op.
30
31    Args:
32        optimizer_class: (Type[torch.optim.Optimizer]): Optimizer to apply to parameter
33        params: (Iterator[nn.Parameter]): parameters to apply optimizer state to
34        optimizer_kwargs: (Dict[str, Any]): kwargs to pass to optimizer constructor
35        register_hook: (bool): whether to register a hook that runs the optimizer
36            after gradient for this parameter is accumulated. This is the default
37            way that optimizer in backward is implemented, but specific use cases
38            (such as DDP) may wish to override this to implement custom behavior.
39            (Default = True)
40
41    Example::
42        params_generator = model.parameters()
43        param_1 = next(params_generator)
44        remainder_params = list(params_generator)
45
46        apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": .02})
47        apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": .04})
48
49        model(...).sum().backward() # after backward, parameters will already
50        # have their registered optimizer(s) applied.
51
52    """
53    torch._C._log_api_usage_once("torch.distributed.optim.apply_optimizer_in_backward")
54
55    @no_type_check
56    def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None:
57        # view_as creates a node in autograd graph that allows us access to the
58        # parameter's AccumulateGrad autograd function object. We register a
59        # hook on this object to fire the optimizer when the gradient for
60        # this parameter is ready (has been accumulated into .grad field)
61
62        # Don't create a new acc_grad if we already have one
63        # i.e. for shared parameters or attaching multiple optimizers to a param.
64        if param not in param_to_acc_grad_map:
65            param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[
66                0
67            ][0]
68
69        optimizer = optimizer_class([param], **optimizer_kwargs)
70
71        if not hasattr(param, "_in_backward_optimizers"):
72            param._in_backward_optimizers = []  # type: ignore[attr-defined]
73            # TODO: Remove these attributes once we have a better way of accessing
74            # optimizer classes and kwargs for a parameter.
75            param._optimizer_classes = []  # type: ignore[attr-defined]
76            param._optimizer_kwargs = []  # type: ignore[attr-defined]
77
78        param._in_backward_optimizers.append(optimizer)  # type: ignore[attr-defined]
79        param._optimizer_classes.append(optimizer_class)  # type: ignore[attr-defined]
80        param._optimizer_kwargs.append(optimizer_kwargs)  # type: ignore[attr-defined]
81
82        if not register_hook:
83            return
84
85        def optimizer_hook(*_unused) -> None:
86            for opt in param._in_backward_optimizers:  # type: ignore[attr-defined]
87                opt.step()
88
89            param.grad = None
90
91        handle = param_to_acc_grad_map[param].register_hook(optimizer_hook)  # type: ignore[attr-defined]
92        if param not in param_to_optim_hook_handle_map:
93            param_to_optim_hook_handle_map[param] = []
94        param_to_optim_hook_handle_map[param].append(handle)
95
96    for param in params:
97        _apply_optimizer_in_backward_to_param(param)
98
99
100def _get_in_backward_optimizers(module: torch.nn.Module) -> List[torch.optim.Optimizer]:
101    """
102    Return a list of in-backward optimizers applied to ``module``'s parameters. Note that these
103    optimizers are not intended to directly have their ``step`` or ``zero_grad`` methods called
104    by the user and are intended to be used for things like checkpointing.
105
106    Args:
107        module: (torch.nn.Module): model to retrieve in-backward optimizers for
108
109    Returns:
110        List[torch.optim.Optimizer]: the in-backward optimizers.
111
112    Example::
113        _apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {'lr': 0.01})
114        optims = _get_optimizers_in_backward(model)
115    """
116    optims: List[torch.optim.Optimizer] = []
117    for param in module.parameters():
118        optims.extend(getattr(param, "_in_backward_optimizers", []))
119
120    return optims
121