xref: /aosp_15_r20/external/pytorch/docs/source/torch.compiler_nn_module.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1PyTorch 2.0 NNModule Support
2============================
3
4**Author**: `Will Constable <https://github.com/wconstab>`_
5
6`torch.compile` has special handling for torch.nn.Module objects, tracing them differently than it traces
7arbitrary python classes, with the intent of producing faster code by making assumptions about the structure.
8
9This doc describes some of the tradeoffs or edge cases that come up due to this specialization.
10
11NNModule Hooks Support
12----------------------
13Previously, `torch.compile` had no support for hooks on nn.Modules, and if hooks were registered
14they would simply be ignored in the compiled program. Indeed many users do not
15use nn.Module hooks at all, or only use them for debug workflows, but there are valid use cases
16for composing nn.Module hooks with `torch.compile`.
17
18Hooks that are orchestrated via nn.Module.__call__ implementation include `_forward_pre_hooks`,
19`forward_hooks`, `_backward_pre_hooks`, and `_backward_hooks`, and will be referred to as 'call hooks'.
20These hooks are partially supported by `torch.compile` with limitations described below.
21
22Another category of hooks includes `_state_dict_hooks` and its `pre` and `load_` variants, and are still
23unsupported by `torch.compile`.
24
25`nn.Module.__call__` Hooks Usage and limitations
26~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
27By default, `torch.compile` will trace the contents of `nn.Module.__call__` which means it will encounter
28and run forward/pre-forward hooks.  If you install hooks before calling `torch.compile` and then do not remove
29or alter the hooks later, your use case should be supported by default.
30
31Backward/Pre-backward hooks are generally also supported, with similar caveats: currently graph-breaks in dynamo
32occur when accessing backward_hooks dicts, which is probably avoiable with some work.  Graph-breaks also impact the
33timing of firing backward hooks, since graph-segments are run as autograd-functions which produce all their grads at
34the same time.  Assuming it were possible for dynamo to not graph-break on the presence of backward-hooks, we would
35still expect the backward hooks for a series of modules to all fire together after the whole compiled graph's backward
36ran.
37
38**hooks on 'allowed modules'**
39`torch.compile` treats common modules such as torch.conv, as well as modules that are difficult to trace, specially
40by allowing them to be called opaquely in the dynamo graph instead of traced into by dynamo.  For such modules, hooks
41currently trigger a graph-break so that the affected modules run outside of dynamo.  Depending on the model, this could
42introduce a significant performance regression, and additional work is required to improve this support.
43
44**skip_nnmodule_hook_guards**
45By default, `torch._dynamo.config.skip_nnmodule_hook_guards` is set to True, meaning no guards will be installed
46on each nn.Module hook dictionary, improving runtime by reducing guard execution time, at the cost of not noticing
47if any hook dict is changed after compilation.
48
49If you want to be able to remove or modify hooks after compilation and have `torch.compile` react appropriately
50(by recompiling), then you need to set `skip_nnmodule_hook_guards=False` and expect a runtime penalty for the added
51guards.
52
53TODO: confirm if backward/pre_backward hooks are working or not and document accordingly
54
55state_dict Hooks
56~~~~~~~~~~~~~~~~
57State dict hooks have not yet been supported in `torch.compile`.
58
59
60TODO: warn_once if graph-breaking on hooks.  warn_once to point to this doc if hooks are present.