xref: /aosp_15_r20/external/pytorch/torch/utils/module_tracker.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3import weakref
4from typing import Set
5
6import torch
7from torch.autograd.graph import register_multi_grad_hook
8from torch.nn.modules.module import (
9    register_module_forward_hook,
10    register_module_forward_pre_hook,
11)
12from torch.utils._pytree import tree_flatten
13
14
15logger = logging.getLogger(__name__)
16
17
18__all__ = ["ModuleTracker"]
19
20
21class ModuleTracker:
22    """
23    ``ModuleTracker`` is a context manager that tracks the nn.Module hierarchy during execution
24    so that other system can query which Module is currently being executed (or its backward is being
25    executed).
26
27    You can access the ``parents`` attribute on this context manager to get the set of all the
28    Modules currently being executed via their fqn (fully qualified name, also used as the key within
29    the state_dict).
30    You can access the ``is_bw`` attribute to know if you are currently running in backward or not.
31
32    Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag
33    will remain ``True`` after the forward until another Module is executed. If you need it to be
34    more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance
35    is possible but not done yet, please submit an issue requesting this if you need it.
36
37    Example usage
38
39    .. code-block:: python
40
41        mod = torch.nn.Linear(2, 2)
42
43        with ModuleTracker() as tracker:
44            # Access anything during the forward pass
45            def my_linear(m1, m2, bias):
46                print(f"Current modules: {tracker.parents}")
47                return torch.mm(m1, m2.t()) + bias
48            torch.nn.functional.linear = my_linear
49
50            mod(torch.rand(2, 2))
51
52    """
53
54    parents: Set[str]
55    """
56    A Set containing the fqn for each module currently running their forward
57    """
58
59    def __init__(self) -> None:
60        self.parents = {"Global"}
61        self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
62        self._seen_modules: weakref.WeakSet = weakref.WeakSet()
63        self._has_callback = False
64
65    def _maybe_set_engine_callback(self):
66        # This assumes no concurrent calls to backward
67        if self._has_callback:
68            return
69
70        def callback():
71            self.parents = {"Global"}
72            self._has_callback = False
73
74        torch.autograd.Variable._execution_engine.queue_callback(callback)
75        self._has_callback = True
76
77    @property
78    def is_bw(self):
79        """
80        A boolean marking if this is currently running during the backward pass or not
81        """
82        return torch._C._current_graph_task_id() != -1
83
84    def _get_mod_name(self, mod):
85        if mod not in self._known_modules:
86            self._known_modules[mod] = type(mod).__name__
87        mod_name = self._known_modules[mod]
88        if mod not in self._seen_modules:
89            for name, submod in mod.named_children():
90                self._known_modules[submod] = f"{mod_name}.{name}"
91                self._get_mod_name(submod)
92            self._seen_modules.add(mod)
93        return mod_name
94
95    def _get_append_fn(self, name, is_bw):
96        def fn(*args):
97            if is_bw:
98                self._maybe_set_engine_callback()
99            if name in self.parents:
100                logger.info(
101                    "The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s",
102                    name,
103                    "backward" if is_bw else "forward",
104                )
105            self.parents.add(name)
106
107        return fn
108
109    def _get_pop_fn(self, name, is_bw):
110        def fn(*args):
111            if name in self.parents:
112                self.parents.remove(name)
113            else:
114                logger.info(
115                    "The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s",
116                    name,
117                    "backward" if is_bw else "forward",
118                )
119
120        return fn
121
122    def _fw_pre_hook(self, mod, input):
123        name = self._get_mod_name(mod)
124        self._get_append_fn(name, False)()
125
126        args, _ = tree_flatten(input)
127        tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
128        if tensors:
129            register_multi_grad_hook(tensors, self._get_pop_fn(name, True))
130
131    def _fw_post_hook(self, mod, input, output):
132        name = self._get_mod_name(mod)
133        self._get_pop_fn(name, False)()
134
135        args, _ = tree_flatten(output)
136        tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
137        if tensors:
138            register_multi_grad_hook(tensors, self._get_append_fn(name, True))
139
140    def __enter__(self):
141        self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
142        self._fw_post_handle = register_module_forward_hook(self._fw_post_hook)
143        return self
144
145    def __exit__(self, *args):
146        self._fw_pre_handle.remove()
147        self._fw_post_handle.remove()
148