1# mypy: allow-untyped-defs 2import warnings 3import weakref 4from typing import Callable, Optional, Set 5 6import torch 7from torch.autograd.graph import register_multi_grad_hook 8from torch.nn.modules.module import ( 9 register_module_forward_hook, 10 register_module_forward_pre_hook, 11) 12from torch.utils._pytree import tree_flatten 13 14 15__all__ = ["ModTracker"] 16 17 18class ModTracker: 19 """ 20 ``ModTracker`` is a context manager that tracks the nn.Module hierarchy during execution 21 so that other system can query which Module is currently being executed (or its backward is being 22 executed). 23 24 You can access the ``parents`` attribute on this context manager to get the set of all the 25 Modules currently being executed via their fqn (fully qualified name, also used as the key within 26 the state_dict). 27 You can access the ``is_bw`` attribute to know if you are currently running in backward or not. 28 29 Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag 30 will remain ``True`` after the forward until another Module is executed. If you need it to be 31 more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance 32 is possible but not done yet, please submit an issue requesting this if you need it. 33 34 Example usage 35 36 .. code-block:: python 37 38 mod = torch.nn.Linear(2, 2) 39 40 with ModTracker() as tracker: 41 # Access anything during the forward pass 42 def my_linear(m1, m2, bias): 43 print(f"Current modules: {tracker.parents}") 44 return torch.mm(m1, m2.t()) + bias 45 torch.nn.functional.linear = my_linear 46 47 mod(torch.rand(2, 2)) 48 49 """ 50 51 parents: Set[str] 52 """ 53 A Set containing the fqn for each module currently running their forward 54 """ 55 56 def __init__(self): 57 self.parents = {"Global"} 58 self._active_module_cnt = {} 59 self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() 60 self._seen_modules: weakref.WeakSet = weakref.WeakSet() 61 self._has_callback = False 62 self._user_pre_fw_hook = None 63 self._user_post_fw_hook = None 64 self._user_pre_bw_hook = None 65 self._user_post_bw_hook = None 66 67 def _maybe_set_engine_callback(self): 68 # This assumes no concurrent calls to backward 69 if self._has_callback: 70 return 71 72 def callback(): 73 self.parents = {"Global"} 74 self._has_callback = False 75 76 torch.autograd.Variable._execution_engine.queue_callback(callback) 77 self._has_callback = True 78 79 @property 80 def is_bw(self): 81 """ 82 A boolean marking if this is currently running during the backward pass or not 83 """ 84 return torch._C._current_graph_task_id() != -1 85 86 def get_known_fqn(self, mod): 87 """ 88 Return the fqn for the given module if it is known to the ``ModTracker``, otherwise ``None``. 89 """ 90 return self._known_modules.get(mod, None) 91 92 def register_user_hooks( 93 self, 94 pre_fw_hook: Optional[Callable] = None, 95 post_fw_hook: Optional[Callable] = None, 96 pre_bw_hook: Optional[Callable] = None, 97 post_bw_hook: Optional[Callable] = None, 98 ): 99 """ 100 Registers user-specified hooks to be called before/after the forward/backward pass for each 101 module tracked by the ``ModTracker``. One or more can be ``None``. 102 Args: 103 pre_fw_hook (Callable, optional): A hook to be called before the forward pass for the 104 module. It should have the following signature: 105 pre_fw_hook (module, input) -> None 106 post_fw_hook (Callable, optional): A hook to be called after the forward pass for the 107 module. It should have the following signature: 108 post_fw_hook (module, input, output) -> None 109 pre_bw_hook (Callable, optional): A multi-grad hook to be called on all the outputs of 110 the module that require gradients. It should have the following signature: 111 pre_bw_hook (module, grad_output) -> None 112 post_bw_hook (Callable, optional): A multi-grad hook to be called on all the inputs of 113 the module that require gradients. It should have the following signature: 114 post_bw_hook (module, grad_input) -> None 115 Raises: 116 AssertionError: If a new hook is provided when one is already registered. 117 Note: 118 If the module is not alive during the backward pass, the pre_bw_hook and post_bw_hook will 119 will receive None as the module argument. 120 The module fqn will be present in the ``parents`` attribute when each of the hooks is called. 121 Hooks are intended to be used as markers only not to modify the inputs/outputs. 122 """ 123 124 def set_hook(hook, user_hook, hook_name): 125 if hook is not None and user_hook is not None: 126 raise AssertionError( 127 f"Only one {hook_name} can be registered at a time" 128 f" Clear the existing hook by calling ``clear_user_hooks`` before registering a new one" 129 ) 130 return hook 131 132 self._user_pre_fw_hook = set_hook( 133 pre_fw_hook, self._user_pre_fw_hook, "pre_fw_hook" 134 ) 135 self._user_post_fw_hook = set_hook( 136 post_fw_hook, self._user_post_fw_hook, "post_fw_hook" 137 ) 138 self._user_pre_bw_hook = set_hook( 139 pre_bw_hook, self._user_pre_bw_hook, "pre_bw_hook" 140 ) 141 self._user_post_bw_hook = set_hook( 142 post_bw_hook, self._user_post_bw_hook, "post_bw_hook" 143 ) 144 145 def clear_user_hooks(self): 146 """ 147 Clears the user specified hooks registered with ``register_user_hooks`` 148 """ 149 self._user_pre_fw_hook = None 150 self._user_post_fw_hook = None 151 self._user_pre_bw_hook = None 152 self._user_post_bw_hook = None 153 154 def _get_mod_name(self, mod): 155 if mod not in self._known_modules: 156 self._known_modules[mod] = type(mod).__name__ 157 mod_name = self._known_modules[mod] 158 if mod not in self._seen_modules: 159 for name, submod in mod.named_children(): 160 self._known_modules[submod] = f"{mod_name}.{name}" 161 self._get_mod_name(submod) 162 self._seen_modules.add(mod) 163 return mod_name 164 165 def _get_append_fn(self, w_mod, name, is_bw): 166 def fn(*args): 167 if is_bw: 168 self._maybe_set_engine_callback() 169 if name in self.parents and not self.is_bw: 170 171 def custom_formatwarning(msg, category, filename, lineno, line=None): 172 return f"{filename}:{lineno}: {category.__name__}: {msg} \n" 173 174 warnings.formatwarning = custom_formatwarning 175 warnings.warn( 176 "The module hierarchy tracking maybe be messed up." 177 " Please file a bug to PyTorch, if it is the case." 178 ) 179 if name not in self.parents: 180 self._active_module_cnt[name] = 1 181 self.parents.add(name) 182 else: 183 self._active_module_cnt[name] += 1 184 185 if self._user_pre_bw_hook is not None and is_bw: 186 self._user_pre_bw_hook(w_mod(), args) 187 188 return fn 189 190 def _get_pop_fn(self, w_mod, name, is_bw): 191 def fn(*args): 192 if self._user_post_bw_hook is not None and is_bw: 193 self._user_post_bw_hook(w_mod(), args) 194 if name in self.parents: 195 self._active_module_cnt[name] -= 1 196 if self._active_module_cnt[name] == 0: 197 self.parents.remove(name) 198 elif not self.is_bw: 199 # Due to some input/output not requiring gradients, we cannot enforce 200 # proper nesting in backward 201 raise RuntimeError( 202 "The Module hierarchy tracking is wrong. Report a bug to PyTorch" 203 ) 204 205 return fn 206 207 def _fw_pre_hook(self, mod, input): 208 name = self._get_mod_name(mod) 209 w_mod = weakref.ref(mod) 210 self._get_append_fn(w_mod, name, False)() 211 if self._user_pre_fw_hook is not None: 212 self._user_pre_fw_hook(mod, input) 213 args, _ = tree_flatten(input) 214 tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] 215 if not self.is_bw and tensors: 216 register_multi_grad_hook(tensors, self._get_pop_fn(w_mod, name, True)) 217 218 def _fw_post_hook(self, mod, input, output): 219 name = self._get_mod_name(mod) 220 w_mod = weakref.ref(mod) 221 if self._user_post_fw_hook is not None: 222 self._user_post_fw_hook(mod, input, output) 223 self._get_pop_fn(w_mod, name, False)() 224 args, _ = tree_flatten(output) 225 tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] 226 if not self.is_bw and tensors: 227 register_multi_grad_hook(tensors, self._get_append_fn(w_mod, name, True)) 228 229 def __enter__(self): 230 self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) 231 self._fw_post_handle = register_module_forward_hook( 232 self._fw_post_hook, always_call=True 233 ) 234 return self 235 236 def __exit__(self, *args): 237 self._fw_pre_handle.remove() 238 self._fw_post_handle.remove() 239