xref: /aosp_15_r20/external/pytorch/functorch/COMPILE_README.md (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# AOT Autograd - Introduction to an experimental compilation feature in Functorch
2
3The primary compilation API we provide is something called AOTAutograd. AOT
4Autograd is an experimental feature that allows ahead of time capture of forward
5and backward graphs, and allows easy integration with compilers. This creates an
6easy to hack Python-based development environment to speedup training of PyTorch
7models. AOT Autograd currently lives inside functorch.compile namespace.
8
9AOT Autograd is experimental and the APIs are likely to change. We are looking
10for feedback. If you are interested in using AOT Autograd and need help or have
11suggestions, please feel free to open an issue. We will be happy to help.
12
13For example, here are some examples of how to use it.
14```python
15from functorch.compile import aot_function, aot_module, draw_graph
16import torch.fx as fx
17import torch
18
19# This simply prints out the FX graph of the forwards and the backwards
20def print_graph(name):
21    def f(fx_g: fx.GraphModule, inps):
22        print(name)
23        print(fx_g.code)
24        return fx_g
25    return f
26
27def f(x):
28    return x.cos().cos()
29
30nf = aot_function(f, fw_compiler=print_graph("forward"), bw_compiler=print_graph("backward"))
31nf(torch.randn(3, requires_grad=True))
32
33# You can do whatever you want before and after, and you can still backprop through the function.
34inp = torch.randn(3, requires_grad=True)
35inp = inp.cos()
36out = nf(inp)
37out = out.sin().sum().backward()
38
39def f(x):
40    return x.cos().cos()
41
42# This draws out the forwards and the backwards graphs as svg files
43def graph_drawer(name):
44    def f(fx_g: fx.GraphModule, inps):
45        draw_graph(fx_g, name)
46        return fx_g
47    return f
48
49aot_function(f, fw_compiler=graph_drawer("forward"), bw_compiler=graph_drawer("backward"))(torch.randn(3, requires_grad=True))
50
51# We also have a convenience API for applying AOTAutograd to modules
52from torchvision.models import resnet18
53aot_module(resnet18(), print_graph("forward"), print_graph("backward"))(torch.randn(1,3,200,200))
54# output elided since it's very long
55
56# In practice, you might want to speed it up by sending it to Torchscript. You might also lower it to Torchscript before passing it to another compiler
57
58def f(x):
59    return x.cos().cos()
60
61def ts_compiler(fx_g: fx.GraphModule, inps):
62    f = torch.jit.script(fx_g)
63    print(f.graph)
64    f = torch.jit.freeze(f.eval()) # Note: This eval() works fine *even* though we're using this for training
65    return f
66
67aot_function(f, ts_compiler, ts_compiler)(torch.randn(3, requires_grad=True))
68```
69
70## Documentation
71* AOT Autograd [documentation](https://pytorch.org/functorch/nightly/)
72* Min-cut [recomputation](https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467) with AOT Autograd.
73
74## Tutorials
75You can use this [tutorial](https://pytorch.org/functorch/nightly/notebooks/aot_autograd_optimizations.html) to play with AOT Autograd.
76