xref: /aosp_15_r20/external/pytorch/torch/distributed/_tools/mod_tracker.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import warnings
3import weakref
4from typing import Callable, Optional, Set
5
6import torch
7from torch.autograd.graph import register_multi_grad_hook
8from torch.nn.modules.module import (
9    register_module_forward_hook,
10    register_module_forward_pre_hook,
11)
12from torch.utils._pytree import tree_flatten
13
14
15__all__ = ["ModTracker"]
16
17
18class ModTracker:
19    """
20    ``ModTracker`` is a context manager that tracks the nn.Module hierarchy during execution
21    so that other system can query which Module is currently being executed (or its backward is being
22    executed).
23
24    You can access the ``parents`` attribute on this context manager to get the set of all the
25    Modules currently being executed via their fqn (fully qualified name, also used as the key within
26    the state_dict).
27    You can access the ``is_bw`` attribute to know if you are currently running in backward or not.
28
29    Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag
30    will remain ``True`` after the forward until another Module is executed. If you need it to be
31    more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance
32    is possible but not done yet, please submit an issue requesting this if you need it.
33
34    Example usage
35
36    .. code-block:: python
37
38        mod = torch.nn.Linear(2, 2)
39
40        with ModTracker() as tracker:
41            # Access anything during the forward pass
42            def my_linear(m1, m2, bias):
43                print(f"Current modules: {tracker.parents}")
44                return torch.mm(m1, m2.t()) + bias
45            torch.nn.functional.linear = my_linear
46
47            mod(torch.rand(2, 2))
48
49    """
50
51    parents: Set[str]
52    """
53    A Set containing the fqn for each module currently running their forward
54    """
55
56    def __init__(self):
57        self.parents = {"Global"}
58        self._active_module_cnt = {}
59        self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
60        self._seen_modules: weakref.WeakSet = weakref.WeakSet()
61        self._has_callback = False
62        self._user_pre_fw_hook = None
63        self._user_post_fw_hook = None
64        self._user_pre_bw_hook = None
65        self._user_post_bw_hook = None
66
67    def _maybe_set_engine_callback(self):
68        # This assumes no concurrent calls to backward
69        if self._has_callback:
70            return
71
72        def callback():
73            self.parents = {"Global"}
74            self._has_callback = False
75
76        torch.autograd.Variable._execution_engine.queue_callback(callback)
77        self._has_callback = True
78
79    @property
80    def is_bw(self):
81        """
82        A boolean marking if this is currently running during the backward pass or not
83        """
84        return torch._C._current_graph_task_id() != -1
85
86    def get_known_fqn(self, mod):
87        """
88        Return the fqn for the given module if it is known to the ``ModTracker``, otherwise ``None``.
89        """
90        return self._known_modules.get(mod, None)
91
92    def register_user_hooks(
93        self,
94        pre_fw_hook: Optional[Callable] = None,
95        post_fw_hook: Optional[Callable] = None,
96        pre_bw_hook: Optional[Callable] = None,
97        post_bw_hook: Optional[Callable] = None,
98    ):
99        """
100        Registers user-specified hooks to be called before/after the forward/backward pass for each
101        module tracked by the ``ModTracker``. One or more can be ``None``.
102        Args:
103            pre_fw_hook (Callable, optional): A hook to be called before the forward pass for the
104                module. It should have the following signature:
105                pre_fw_hook (module, input) -> None
106            post_fw_hook (Callable, optional): A hook to be called after the forward pass for the
107                module. It should have the following signature:
108                post_fw_hook (module, input, output) -> None
109            pre_bw_hook (Callable, optional): A multi-grad hook to be called on all the outputs of
110                the module that require gradients. It should have the following signature:
111                pre_bw_hook (module, grad_output) -> None
112            post_bw_hook (Callable, optional): A multi-grad hook to be called on all the inputs of
113                the module that require gradients. It should have the following signature:
114                post_bw_hook (module, grad_input) -> None
115        Raises:
116            AssertionError: If a new hook is provided when one is already registered.
117        Note:
118            If the module is not alive during the backward pass, the pre_bw_hook and post_bw_hook will
119            will receive None as the module argument.
120            The module fqn will be present in the ``parents`` attribute when each of the hooks is called.
121            Hooks are intended to be used as markers only not to modify the inputs/outputs.
122        """
123
124        def set_hook(hook, user_hook, hook_name):
125            if hook is not None and user_hook is not None:
126                raise AssertionError(
127                    f"Only one {hook_name} can be registered at a time"
128                    f" Clear the existing hook by calling ``clear_user_hooks`` before registering a new one"
129                )
130            return hook
131
132        self._user_pre_fw_hook = set_hook(
133            pre_fw_hook, self._user_pre_fw_hook, "pre_fw_hook"
134        )
135        self._user_post_fw_hook = set_hook(
136            post_fw_hook, self._user_post_fw_hook, "post_fw_hook"
137        )
138        self._user_pre_bw_hook = set_hook(
139            pre_bw_hook, self._user_pre_bw_hook, "pre_bw_hook"
140        )
141        self._user_post_bw_hook = set_hook(
142            post_bw_hook, self._user_post_bw_hook, "post_bw_hook"
143        )
144
145    def clear_user_hooks(self):
146        """
147        Clears the user specified hooks registered with ``register_user_hooks``
148        """
149        self._user_pre_fw_hook = None
150        self._user_post_fw_hook = None
151        self._user_pre_bw_hook = None
152        self._user_post_bw_hook = None
153
154    def _get_mod_name(self, mod):
155        if mod not in self._known_modules:
156            self._known_modules[mod] = type(mod).__name__
157        mod_name = self._known_modules[mod]
158        if mod not in self._seen_modules:
159            for name, submod in mod.named_children():
160                self._known_modules[submod] = f"{mod_name}.{name}"
161                self._get_mod_name(submod)
162            self._seen_modules.add(mod)
163        return mod_name
164
165    def _get_append_fn(self, w_mod, name, is_bw):
166        def fn(*args):
167            if is_bw:
168                self._maybe_set_engine_callback()
169            if name in self.parents and not self.is_bw:
170
171                def custom_formatwarning(msg, category, filename, lineno, line=None):
172                    return f"{filename}:{lineno}: {category.__name__}: {msg} \n"
173
174                warnings.formatwarning = custom_formatwarning
175                warnings.warn(
176                    "The module hierarchy tracking maybe be messed up."
177                    " Please file a bug to PyTorch, if it is the case."
178                )
179            if name not in self.parents:
180                self._active_module_cnt[name] = 1
181                self.parents.add(name)
182            else:
183                self._active_module_cnt[name] += 1
184
185            if self._user_pre_bw_hook is not None and is_bw:
186                self._user_pre_bw_hook(w_mod(), args)
187
188        return fn
189
190    def _get_pop_fn(self, w_mod, name, is_bw):
191        def fn(*args):
192            if self._user_post_bw_hook is not None and is_bw:
193                self._user_post_bw_hook(w_mod(), args)
194            if name in self.parents:
195                self._active_module_cnt[name] -= 1
196                if self._active_module_cnt[name] == 0:
197                    self.parents.remove(name)
198            elif not self.is_bw:
199                # Due to some input/output not requiring gradients, we cannot enforce
200                # proper nesting in backward
201                raise RuntimeError(
202                    "The Module hierarchy tracking is wrong. Report a bug to PyTorch"
203                )
204
205        return fn
206
207    def _fw_pre_hook(self, mod, input):
208        name = self._get_mod_name(mod)
209        w_mod = weakref.ref(mod)
210        self._get_append_fn(w_mod, name, False)()
211        if self._user_pre_fw_hook is not None:
212            self._user_pre_fw_hook(mod, input)
213        args, _ = tree_flatten(input)
214        tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
215        if not self.is_bw and tensors:
216            register_multi_grad_hook(tensors, self._get_pop_fn(w_mod, name, True))
217
218    def _fw_post_hook(self, mod, input, output):
219        name = self._get_mod_name(mod)
220        w_mod = weakref.ref(mod)
221        if self._user_post_fw_hook is not None:
222            self._user_post_fw_hook(mod, input, output)
223        self._get_pop_fn(w_mod, name, False)()
224        args, _ = tree_flatten(output)
225        tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
226        if not self.is_bw and tensors:
227            register_multi_grad_hook(tensors, self._get_append_fn(w_mod, name, True))
228
229    def __enter__(self):
230        self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
231        self._fw_post_handle = register_module_forward_hook(
232            self._fw_post_hook, always_call=True
233        )
234        return self
235
236    def __exit__(self, *args):
237        self._fw_pre_handle.remove()
238        self._fw_post_handle.remove()
239