1PyTorch 2.0 NNModule Support 2============================ 3 4**Author**: `Will Constable <https://github.com/wconstab>`_ 5 6`torch.compile` has special handling for torch.nn.Module objects, tracing them differently than it traces 7arbitrary python classes, with the intent of producing faster code by making assumptions about the structure. 8 9This doc describes some of the tradeoffs or edge cases that come up due to this specialization. 10 11NNModule Hooks Support 12---------------------- 13Previously, `torch.compile` had no support for hooks on nn.Modules, and if hooks were registered 14they would simply be ignored in the compiled program. Indeed many users do not 15use nn.Module hooks at all, or only use them for debug workflows, but there are valid use cases 16for composing nn.Module hooks with `torch.compile`. 17 18Hooks that are orchestrated via nn.Module.__call__ implementation include `_forward_pre_hooks`, 19`forward_hooks`, `_backward_pre_hooks`, and `_backward_hooks`, and will be referred to as 'call hooks'. 20These hooks are partially supported by `torch.compile` with limitations described below. 21 22Another category of hooks includes `_state_dict_hooks` and its `pre` and `load_` variants, and are still 23unsupported by `torch.compile`. 24 25`nn.Module.__call__` Hooks Usage and limitations 26~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 27By default, `torch.compile` will trace the contents of `nn.Module.__call__` which means it will encounter 28and run forward/pre-forward hooks. If you install hooks before calling `torch.compile` and then do not remove 29or alter the hooks later, your use case should be supported by default. 30 31Backward/Pre-backward hooks are generally also supported, with similar caveats: currently graph-breaks in dynamo 32occur when accessing backward_hooks dicts, which is probably avoiable with some work. Graph-breaks also impact the 33timing of firing backward hooks, since graph-segments are run as autograd-functions which produce all their grads at 34the same time. Assuming it were possible for dynamo to not graph-break on the presence of backward-hooks, we would 35still expect the backward hooks for a series of modules to all fire together after the whole compiled graph's backward 36ran. 37 38**hooks on 'allowed modules'** 39`torch.compile` treats common modules such as torch.conv, as well as modules that are difficult to trace, specially 40by allowing them to be called opaquely in the dynamo graph instead of traced into by dynamo. For such modules, hooks 41currently trigger a graph-break so that the affected modules run outside of dynamo. Depending on the model, this could 42introduce a significant performance regression, and additional work is required to improve this support. 43 44**skip_nnmodule_hook_guards** 45By default, `torch._dynamo.config.skip_nnmodule_hook_guards` is set to True, meaning no guards will be installed 46on each nn.Module hook dictionary, improving runtime by reducing guard execution time, at the cost of not noticing 47if any hook dict is changed after compilation. 48 49If you want to be able to remove or modify hooks after compilation and have `torch.compile` react appropriately 50(by recompiling), then you need to set `skip_nnmodule_hook_guards=False` and expect a runtime penalty for the added 51guards. 52 53TODO: confirm if backward/pre_backward hooks are working or not and document accordingly 54 55state_dict Hooks 56~~~~~~~~~~~~~~~~ 57State dict hooks have not yet been supported in `torch.compile`. 58 59 60TODO: warn_once if graph-breaking on hooks. warn_once to point to this doc if hooks are present.