xref: /aosp_15_r20/external/pytorch/torch/utils/hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3from collections import OrderedDict
4import weakref
5import warnings
6from typing import Any, Tuple
7
8__all__ = ["RemovableHandle", "unserializable_hook", "warn_if_has_hooks", "BackwardHook"]
9
10class RemovableHandle:
11    r"""
12    A handle which provides the capability to remove a hook.
13
14    Args:
15        hooks_dict (dict): A dictionary of hooks, indexed by hook ``id``.
16        extra_dict (Union[dict, List[dict]]): An additional dictionary or list of
17            dictionaries whose keys will be deleted when the same keys are
18            removed from ``hooks_dict``.
19    """
20
21    id: int
22    next_id: int = 0
23
24    def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None:
25        self.hooks_dict_ref = weakref.ref(hooks_dict)
26        self.id = RemovableHandle.next_id
27        RemovableHandle.next_id += 1
28
29        self.extra_dict_ref: Tuple = ()
30        if isinstance(extra_dict, dict):
31            self.extra_dict_ref = (weakref.ref(extra_dict),)
32        elif isinstance(extra_dict, list):
33            self.extra_dict_ref = tuple(weakref.ref(d) for d in extra_dict)
34
35    def remove(self) -> None:
36        hooks_dict = self.hooks_dict_ref()
37        if hooks_dict is not None and self.id in hooks_dict:
38            del hooks_dict[self.id]
39
40        for ref in self.extra_dict_ref:
41            extra_dict = ref()
42            if extra_dict is not None and self.id in extra_dict:
43                del extra_dict[self.id]
44
45    def __getstate__(self):
46        if self.extra_dict_ref is None:
47            return (self.hooks_dict_ref(), self.id)
48        else:
49            return (self.hooks_dict_ref(), self.id, tuple(ref() for ref in self.extra_dict_ref))
50
51    def __setstate__(self, state) -> None:
52        if state[0] is None:
53            # create a dead reference
54            self.hooks_dict_ref = weakref.ref(OrderedDict())
55        else:
56            self.hooks_dict_ref = weakref.ref(state[0])
57        self.id = state[1]
58        RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
59
60        if len(state) < 3 or state[2] is None:
61            self.extra_dict_ref = ()
62        else:
63            self.extra_dict_ref = tuple(weakref.ref(d) for d in state[2])
64
65    def __enter__(self) -> "RemovableHandle":
66        return self
67
68    def __exit__(self, type: Any, value: Any, tb: Any) -> None:
69        self.remove()
70
71
72def unserializable_hook(f):
73    """
74    Mark a function as an unserializable hook with this decorator.
75
76    This suppresses warnings that would otherwise arise if you attempt
77    to serialize a tensor that has a hook.
78    """
79    f.__torch_unserializable__ = True
80    return f
81
82
83def warn_if_has_hooks(tensor):
84    if tensor._backward_hooks:
85        for k in tensor._backward_hooks:
86            hook = tensor._backward_hooks[k]
87            if not hasattr(hook, "__torch_unserializable__"):
88                warnings.warn(f"backward hook {repr(hook)} on tensor will not be "
89                              "serialized.  If this is expected, you can "
90                              "decorate the function with @torch.utils.hooks.unserializable_hook "
91                              "to suppress this warning")
92
93class BackwardHook:
94    """
95    A wrapper class to implement nn.Module backward hooks.
96
97    It handles:
98      - Ignoring non-Tensor inputs and replacing them by None before calling the user hook
99      - Generating the proper Node to capture a set of Tensor's gradients
100      - Linking the gradients captures for the outputs with the gradients captured for the input
101      - Calling the user hook once both output and input gradients are available
102    """
103
104    def __init__(self, module, user_hooks, user_pre_hooks):
105        self.user_hooks = user_hooks
106        self.user_pre_hooks = user_pre_hooks
107        self.module = module
108
109        self.grad_outputs = None
110        self.n_outputs = -1
111        self.output_tensors_index = None
112        self.n_inputs = -1
113        self.input_tensors_index = None
114
115    def _pack_with_none(self, indices, values, size):
116        res = [None] * size
117        for idx, val in zip(indices, values):
118            res[idx] = val
119
120        return tuple(res)
121
122    def _unpack_none(self, indices, values):
123        res = []
124        for idx in indices:
125            res.append(values[idx])
126
127        return tuple(res)
128
129    def _set_user_hook(self, grad_fn):
130        def hook(grad_input, _):
131            if self.grad_outputs is None:
132                # This happens because the gradient in your nn.Module flows to
133                # the Module's input without " passing through the Module's
134                # output, e.g. when you're doing double backward.
135                return
136            res = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs)
137
138            for hook in self.user_hooks:
139                out = hook(self.module, res, self.grad_outputs)
140
141                if out is None:
142                    continue
143
144                if len(out) != len(res):
145                    raise RuntimeError("Backward hook returned an invalid number of grad_input, "
146                                       f"got {len(out)}, but expected {len(res)}")
147
148                res = out
149
150            self.grad_outputs = None
151
152            return self._unpack_none(self.input_tensors_index, res)
153
154        grad_fn.register_hook(hook)
155
156    def _apply_on_tensors(self, fn, args):
157        # Can be used to apply the given function to the tensors contained in the
158        # args. Will return updated args and the tensors indices
159        tensors_idx = []
160        tensors = []
161
162        requires_grad = False
163        for i, arg in enumerate(args):
164            if isinstance(arg, torch.Tensor):
165                tensors_idx.append(i)
166                tensors.append(arg)
167                requires_grad |= arg.requires_grad
168
169        if not (requires_grad and torch.is_grad_enabled()):
170            return args, None
171
172        new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors)
173        if len(new_tensors) == 0:
174            raise RuntimeError("Cannot set Module backward hook for a Module with no input Tensors.")
175
176        grad_fns = [t.grad_fn for t in new_tensors if t.grad_fn is not None and t.grad_fn.name() == "BackwardHookFunctionBackward"]
177        if len(grad_fns) == 0:
178            raise RuntimeError("Error while setting up backward hooks. Please open "
179                               "an issue with a code sample to reproduce this.")
180
181        fn(grad_fns[0])
182
183        arg_list = list(args)
184        for idx, val in zip(tensors_idx, new_tensors):
185            arg_list[idx] = val
186
187        if type(args) is tuple:
188            out = tuple(arg_list)
189        else:
190            out = type(args)(*arg_list)
191        return out, tensors_idx
192
193    def setup_input_hook(self, args):
194        def fn(grad_fn):
195            self._set_user_hook(grad_fn)
196
197        res, input_idx = self._apply_on_tensors(fn, args)
198        self.n_inputs = len(args)
199        self.input_tensors_index = input_idx
200        return res
201
202    def setup_output_hook(self, args):
203        def fn(grad_fn):
204            def hook(_, grad_output):
205                self.grad_outputs = self._pack_with_none(self.output_tensors_index,
206                                                         grad_output,
207                                                         self.n_outputs)
208
209                if self.user_pre_hooks:
210                    expected_len = len(self.grad_outputs)
211                    for user_pre_hook in self.user_pre_hooks:
212                        hook_grad_outputs = user_pre_hook(self.module, self.grad_outputs)
213                        if hook_grad_outputs is None:
214                            continue
215
216                        actual_len = len(hook_grad_outputs)
217                        if actual_len != expected_len:
218                            raise RuntimeError("Backward pre hook returned an invalid number of grad_output, "
219                                               f"got {actual_len}, but expected {expected_len}")
220                        self.grad_outputs = hook_grad_outputs
221
222                # We need to be able to clear self.grad_outputs but also return it
223                local_grad_outputs = self.grad_outputs
224
225                # Special case if no input required gradients, this hook should call the user
226                # hook directly
227                if self.input_tensors_index is None:
228                    grad_inputs = self._pack_with_none([], [], self.n_inputs)
229                    for user_hook in self.user_hooks:
230                        res = user_hook(self.module, grad_inputs, self.grad_outputs)
231                        if res is not None and not (isinstance(res, tuple) and all(el is None for el in res)):
232                            raise RuntimeError("Backward hook for Modules where no input requires "
233                                               "gradient should always return None or None for all gradients.")
234                    self.grad_outputs = None
235
236                if local_grad_outputs is not None:
237                    assert self.output_tensors_index is not None  # mypy
238                    return tuple(local_grad_outputs[i] for i in self.output_tensors_index)
239
240            grad_fn.register_hook(hook)
241
242        is_tuple = True
243        if not isinstance(args, tuple):
244            args = (args,)
245            is_tuple = False
246
247        res, output_idx = self._apply_on_tensors(fn, args)
248        self.n_outputs = len(args)
249        self.output_tensors_index = output_idx
250
251        if not is_tuple:
252            res = res[0]
253        return res
254