1# mypy: allow-untyped-defs 2import torch 3from collections import OrderedDict 4import weakref 5import warnings 6from typing import Any, Tuple 7 8__all__ = ["RemovableHandle", "unserializable_hook", "warn_if_has_hooks", "BackwardHook"] 9 10class RemovableHandle: 11 r""" 12 A handle which provides the capability to remove a hook. 13 14 Args: 15 hooks_dict (dict): A dictionary of hooks, indexed by hook ``id``. 16 extra_dict (Union[dict, List[dict]]): An additional dictionary or list of 17 dictionaries whose keys will be deleted when the same keys are 18 removed from ``hooks_dict``. 19 """ 20 21 id: int 22 next_id: int = 0 23 24 def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None: 25 self.hooks_dict_ref = weakref.ref(hooks_dict) 26 self.id = RemovableHandle.next_id 27 RemovableHandle.next_id += 1 28 29 self.extra_dict_ref: Tuple = () 30 if isinstance(extra_dict, dict): 31 self.extra_dict_ref = (weakref.ref(extra_dict),) 32 elif isinstance(extra_dict, list): 33 self.extra_dict_ref = tuple(weakref.ref(d) for d in extra_dict) 34 35 def remove(self) -> None: 36 hooks_dict = self.hooks_dict_ref() 37 if hooks_dict is not None and self.id in hooks_dict: 38 del hooks_dict[self.id] 39 40 for ref in self.extra_dict_ref: 41 extra_dict = ref() 42 if extra_dict is not None and self.id in extra_dict: 43 del extra_dict[self.id] 44 45 def __getstate__(self): 46 if self.extra_dict_ref is None: 47 return (self.hooks_dict_ref(), self.id) 48 else: 49 return (self.hooks_dict_ref(), self.id, tuple(ref() for ref in self.extra_dict_ref)) 50 51 def __setstate__(self, state) -> None: 52 if state[0] is None: 53 # create a dead reference 54 self.hooks_dict_ref = weakref.ref(OrderedDict()) 55 else: 56 self.hooks_dict_ref = weakref.ref(state[0]) 57 self.id = state[1] 58 RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1) 59 60 if len(state) < 3 or state[2] is None: 61 self.extra_dict_ref = () 62 else: 63 self.extra_dict_ref = tuple(weakref.ref(d) for d in state[2]) 64 65 def __enter__(self) -> "RemovableHandle": 66 return self 67 68 def __exit__(self, type: Any, value: Any, tb: Any) -> None: 69 self.remove() 70 71 72def unserializable_hook(f): 73 """ 74 Mark a function as an unserializable hook with this decorator. 75 76 This suppresses warnings that would otherwise arise if you attempt 77 to serialize a tensor that has a hook. 78 """ 79 f.__torch_unserializable__ = True 80 return f 81 82 83def warn_if_has_hooks(tensor): 84 if tensor._backward_hooks: 85 for k in tensor._backward_hooks: 86 hook = tensor._backward_hooks[k] 87 if not hasattr(hook, "__torch_unserializable__"): 88 warnings.warn(f"backward hook {repr(hook)} on tensor will not be " 89 "serialized. If this is expected, you can " 90 "decorate the function with @torch.utils.hooks.unserializable_hook " 91 "to suppress this warning") 92 93class BackwardHook: 94 """ 95 A wrapper class to implement nn.Module backward hooks. 96 97 It handles: 98 - Ignoring non-Tensor inputs and replacing them by None before calling the user hook 99 - Generating the proper Node to capture a set of Tensor's gradients 100 - Linking the gradients captures for the outputs with the gradients captured for the input 101 - Calling the user hook once both output and input gradients are available 102 """ 103 104 def __init__(self, module, user_hooks, user_pre_hooks): 105 self.user_hooks = user_hooks 106 self.user_pre_hooks = user_pre_hooks 107 self.module = module 108 109 self.grad_outputs = None 110 self.n_outputs = -1 111 self.output_tensors_index = None 112 self.n_inputs = -1 113 self.input_tensors_index = None 114 115 def _pack_with_none(self, indices, values, size): 116 res = [None] * size 117 for idx, val in zip(indices, values): 118 res[idx] = val 119 120 return tuple(res) 121 122 def _unpack_none(self, indices, values): 123 res = [] 124 for idx in indices: 125 res.append(values[idx]) 126 127 return tuple(res) 128 129 def _set_user_hook(self, grad_fn): 130 def hook(grad_input, _): 131 if self.grad_outputs is None: 132 # This happens because the gradient in your nn.Module flows to 133 # the Module's input without " passing through the Module's 134 # output, e.g. when you're doing double backward. 135 return 136 res = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs) 137 138 for hook in self.user_hooks: 139 out = hook(self.module, res, self.grad_outputs) 140 141 if out is None: 142 continue 143 144 if len(out) != len(res): 145 raise RuntimeError("Backward hook returned an invalid number of grad_input, " 146 f"got {len(out)}, but expected {len(res)}") 147 148 res = out 149 150 self.grad_outputs = None 151 152 return self._unpack_none(self.input_tensors_index, res) 153 154 grad_fn.register_hook(hook) 155 156 def _apply_on_tensors(self, fn, args): 157 # Can be used to apply the given function to the tensors contained in the 158 # args. Will return updated args and the tensors indices 159 tensors_idx = [] 160 tensors = [] 161 162 requires_grad = False 163 for i, arg in enumerate(args): 164 if isinstance(arg, torch.Tensor): 165 tensors_idx.append(i) 166 tensors.append(arg) 167 requires_grad |= arg.requires_grad 168 169 if not (requires_grad and torch.is_grad_enabled()): 170 return args, None 171 172 new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors) 173 if len(new_tensors) == 0: 174 raise RuntimeError("Cannot set Module backward hook for a Module with no input Tensors.") 175 176 grad_fns = [t.grad_fn for t in new_tensors if t.grad_fn is not None and t.grad_fn.name() == "BackwardHookFunctionBackward"] 177 if len(grad_fns) == 0: 178 raise RuntimeError("Error while setting up backward hooks. Please open " 179 "an issue with a code sample to reproduce this.") 180 181 fn(grad_fns[0]) 182 183 arg_list = list(args) 184 for idx, val in zip(tensors_idx, new_tensors): 185 arg_list[idx] = val 186 187 if type(args) is tuple: 188 out = tuple(arg_list) 189 else: 190 out = type(args)(*arg_list) 191 return out, tensors_idx 192 193 def setup_input_hook(self, args): 194 def fn(grad_fn): 195 self._set_user_hook(grad_fn) 196 197 res, input_idx = self._apply_on_tensors(fn, args) 198 self.n_inputs = len(args) 199 self.input_tensors_index = input_idx 200 return res 201 202 def setup_output_hook(self, args): 203 def fn(grad_fn): 204 def hook(_, grad_output): 205 self.grad_outputs = self._pack_with_none(self.output_tensors_index, 206 grad_output, 207 self.n_outputs) 208 209 if self.user_pre_hooks: 210 expected_len = len(self.grad_outputs) 211 for user_pre_hook in self.user_pre_hooks: 212 hook_grad_outputs = user_pre_hook(self.module, self.grad_outputs) 213 if hook_grad_outputs is None: 214 continue 215 216 actual_len = len(hook_grad_outputs) 217 if actual_len != expected_len: 218 raise RuntimeError("Backward pre hook returned an invalid number of grad_output, " 219 f"got {actual_len}, but expected {expected_len}") 220 self.grad_outputs = hook_grad_outputs 221 222 # We need to be able to clear self.grad_outputs but also return it 223 local_grad_outputs = self.grad_outputs 224 225 # Special case if no input required gradients, this hook should call the user 226 # hook directly 227 if self.input_tensors_index is None: 228 grad_inputs = self._pack_with_none([], [], self.n_inputs) 229 for user_hook in self.user_hooks: 230 res = user_hook(self.module, grad_inputs, self.grad_outputs) 231 if res is not None and not (isinstance(res, tuple) and all(el is None for el in res)): 232 raise RuntimeError("Backward hook for Modules where no input requires " 233 "gradient should always return None or None for all gradients.") 234 self.grad_outputs = None 235 236 if local_grad_outputs is not None: 237 assert self.output_tensors_index is not None # mypy 238 return tuple(local_grad_outputs[i] for i in self.output_tensors_index) 239 240 grad_fn.register_hook(hook) 241 242 is_tuple = True 243 if not isinstance(args, tuple): 244 args = (args,) 245 is_tuple = False 246 247 res, output_idx = self._apply_on_tensors(fn, args) 248 self.n_outputs = len(args) 249 self.output_tensors_index = output_idx 250 251 if not is_tuple: 252 res = res[0] 253 return res 254