1# mypy: allow-untyped-defs 2import logging 3import weakref 4from typing import Set 5 6import torch 7from torch.autograd.graph import register_multi_grad_hook 8from torch.nn.modules.module import ( 9 register_module_forward_hook, 10 register_module_forward_pre_hook, 11) 12from torch.utils._pytree import tree_flatten 13 14 15logger = logging.getLogger(__name__) 16 17 18__all__ = ["ModuleTracker"] 19 20 21class ModuleTracker: 22 """ 23 ``ModuleTracker`` is a context manager that tracks the nn.Module hierarchy during execution 24 so that other system can query which Module is currently being executed (or its backward is being 25 executed). 26 27 You can access the ``parents`` attribute on this context manager to get the set of all the 28 Modules currently being executed via their fqn (fully qualified name, also used as the key within 29 the state_dict). 30 You can access the ``is_bw`` attribute to know if you are currently running in backward or not. 31 32 Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag 33 will remain ``True`` after the forward until another Module is executed. If you need it to be 34 more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance 35 is possible but not done yet, please submit an issue requesting this if you need it. 36 37 Example usage 38 39 .. code-block:: python 40 41 mod = torch.nn.Linear(2, 2) 42 43 with ModuleTracker() as tracker: 44 # Access anything during the forward pass 45 def my_linear(m1, m2, bias): 46 print(f"Current modules: {tracker.parents}") 47 return torch.mm(m1, m2.t()) + bias 48 torch.nn.functional.linear = my_linear 49 50 mod(torch.rand(2, 2)) 51 52 """ 53 54 parents: Set[str] 55 """ 56 A Set containing the fqn for each module currently running their forward 57 """ 58 59 def __init__(self) -> None: 60 self.parents = {"Global"} 61 self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() 62 self._seen_modules: weakref.WeakSet = weakref.WeakSet() 63 self._has_callback = False 64 65 def _maybe_set_engine_callback(self): 66 # This assumes no concurrent calls to backward 67 if self._has_callback: 68 return 69 70 def callback(): 71 self.parents = {"Global"} 72 self._has_callback = False 73 74 torch.autograd.Variable._execution_engine.queue_callback(callback) 75 self._has_callback = True 76 77 @property 78 def is_bw(self): 79 """ 80 A boolean marking if this is currently running during the backward pass or not 81 """ 82 return torch._C._current_graph_task_id() != -1 83 84 def _get_mod_name(self, mod): 85 if mod not in self._known_modules: 86 self._known_modules[mod] = type(mod).__name__ 87 mod_name = self._known_modules[mod] 88 if mod not in self._seen_modules: 89 for name, submod in mod.named_children(): 90 self._known_modules[submod] = f"{mod_name}.{name}" 91 self._get_mod_name(submod) 92 self._seen_modules.add(mod) 93 return mod_name 94 95 def _get_append_fn(self, name, is_bw): 96 def fn(*args): 97 if is_bw: 98 self._maybe_set_engine_callback() 99 if name in self.parents: 100 logger.info( 101 "The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s", 102 name, 103 "backward" if is_bw else "forward", 104 ) 105 self.parents.add(name) 106 107 return fn 108 109 def _get_pop_fn(self, name, is_bw): 110 def fn(*args): 111 if name in self.parents: 112 self.parents.remove(name) 113 else: 114 logger.info( 115 "The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s", 116 name, 117 "backward" if is_bw else "forward", 118 ) 119 120 return fn 121 122 def _fw_pre_hook(self, mod, input): 123 name = self._get_mod_name(mod) 124 self._get_append_fn(name, False)() 125 126 args, _ = tree_flatten(input) 127 tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] 128 if tensors: 129 register_multi_grad_hook(tensors, self._get_pop_fn(name, True)) 130 131 def _fw_post_hook(self, mod, input, output): 132 name = self._get_mod_name(mod) 133 self._get_pop_fn(name, False)() 134 135 args, _ = tree_flatten(output) 136 tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] 137 if tensors: 138 register_multi_grad_hook(tensors, self._get_append_fn(name, True)) 139 140 def __enter__(self): 141 self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) 142 self._fw_post_handle = register_module_forward_hook(self._fw_post_hook) 143 return self 144 145 def __exit__(self, *args): 146 self._fw_pre_handle.remove() 147 self._fw_post_handle.remove() 148