1# AOT Autograd - Introduction to an experimental compilation feature in Functorch 2 3The primary compilation API we provide is something called AOTAutograd. AOT 4Autograd is an experimental feature that allows ahead of time capture of forward 5and backward graphs, and allows easy integration with compilers. This creates an 6easy to hack Python-based development environment to speedup training of PyTorch 7models. AOT Autograd currently lives inside functorch.compile namespace. 8 9AOT Autograd is experimental and the APIs are likely to change. We are looking 10for feedback. If you are interested in using AOT Autograd and need help or have 11suggestions, please feel free to open an issue. We will be happy to help. 12 13For example, here are some examples of how to use it. 14```python 15from functorch.compile import aot_function, aot_module, draw_graph 16import torch.fx as fx 17import torch 18 19# This simply prints out the FX graph of the forwards and the backwards 20def print_graph(name): 21 def f(fx_g: fx.GraphModule, inps): 22 print(name) 23 print(fx_g.code) 24 return fx_g 25 return f 26 27def f(x): 28 return x.cos().cos() 29 30nf = aot_function(f, fw_compiler=print_graph("forward"), bw_compiler=print_graph("backward")) 31nf(torch.randn(3, requires_grad=True)) 32 33# You can do whatever you want before and after, and you can still backprop through the function. 34inp = torch.randn(3, requires_grad=True) 35inp = inp.cos() 36out = nf(inp) 37out = out.sin().sum().backward() 38 39def f(x): 40 return x.cos().cos() 41 42# This draws out the forwards and the backwards graphs as svg files 43def graph_drawer(name): 44 def f(fx_g: fx.GraphModule, inps): 45 draw_graph(fx_g, name) 46 return fx_g 47 return f 48 49aot_function(f, fw_compiler=graph_drawer("forward"), bw_compiler=graph_drawer("backward"))(torch.randn(3, requires_grad=True)) 50 51# We also have a convenience API for applying AOTAutograd to modules 52from torchvision.models import resnet18 53aot_module(resnet18(), print_graph("forward"), print_graph("backward"))(torch.randn(1,3,200,200)) 54# output elided since it's very long 55 56# In practice, you might want to speed it up by sending it to Torchscript. You might also lower it to Torchscript before passing it to another compiler 57 58def f(x): 59 return x.cos().cos() 60 61def ts_compiler(fx_g: fx.GraphModule, inps): 62 f = torch.jit.script(fx_g) 63 print(f.graph) 64 f = torch.jit.freeze(f.eval()) # Note: This eval() works fine *even* though we're using this for training 65 return f 66 67aot_function(f, ts_compiler, ts_compiler)(torch.randn(3, requires_grad=True)) 68``` 69 70## Documentation 71* AOT Autograd [documentation](https://pytorch.org/functorch/nightly/) 72* Min-cut [recomputation](https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467) with AOT Autograd. 73 74## Tutorials 75You can use this [tutorial](https://pytorch.org/functorch/nightly/notebooks/aot_autograd_optimizations.html) to play with AOT Autograd. 76