1# functorch 2 3[**Why functorch?**](#why-composable-function-transforms) 4| [**Install guide**](#install) 5| [**Transformations**](#what-are-the-transforms) 6| [**Documentation**](#documentation) 7| [**Future Plans**](#future-plans) 8 9**This library is currently under heavy development - if you have suggestions 10on the API or use-cases you'd like to be covered, please open an github issue 11or reach out. We'd love to hear about how you're using the library.** 12 13`functorch` is [JAX-like](https://github.com/google/jax) composable function 14transforms for PyTorch. 15 16It aims to provide composable `vmap` and `grad` transforms that work with 17PyTorch modules and PyTorch autograd with good eager-mode performance. 18 19In addition, there is experimental functionality to trace through these 20transformations using FX in order to capture the results of these transforms 21ahead of time. This would allow us to compile the results of vmap or grad 22to improve performance. 23 24## Why composable function transforms? 25 26There are a number of use cases that are tricky to do in 27PyTorch today: 28- computing per-sample-gradients (or other per-sample quantities) 29- running ensembles of models on a single machine 30- efficiently batching together tasks in the inner-loop of MAML 31- efficiently computing Jacobians and Hessians 32- efficiently computing batched Jacobians and Hessians 33 34Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above 35without designing a separate subsystem for each. This idea of composable function 36transforms comes from the [JAX framework](https://github.com/google/jax). 37 38## Install 39 40There are two ways to install functorch: 411. functorch from source 422. functorch beta (compatible with recent PyTorch releases) 43 44We recommend trying out the functorch beta first. 45 46### Installing functorch from source 47 48<details><summary>Click to expand</summary> 49<p> 50 51#### Using Colab 52 53Follow the instructions [in this Colab notebook](https://colab.research.google.com/drive/1CrLkqIrydBYP_svnF89UUO-aQEqNPE8x?usp=sharing) 54 55#### Locally 56 57As of 9/21/2022, `functorch` comes installed alongside a nightly PyTorch binary. 58Please install a Preview (nightly) PyTorch binary; see https://pytorch.org/ 59for instructions. 60 61Once you've done that, run a quick sanity check in Python: 62```py 63import torch 64from functorch import vmap 65x = torch.randn(3) 66y = vmap(torch.sin)(x) 67assert torch.allclose(y, x.sin()) 68``` 69 70#### functorch development setup 71 72As of 9/21/2022, `functorch` comes installed alongside PyTorch and is in the 73PyTorch source tree. Please install 74[PyTorch from source](https://github.com/pytorch/pytorch#from-source), then, 75you will be able to `import functorch`. 76 77Try to run some tests to make sure all is OK: 78```bash 79pytest test/test_vmap.py -v 80pytest test/test_eager_transforms.py -v 81``` 82 83AOTAutograd has some additional optional requirements. You can install them via: 84```bash 85pip install networkx 86``` 87 88To run functorch tests, please install our test dependencies (`expecttest`, `pyyaml`). 89 90 91</p> 92</details> 93 94### Installing functorch beta (compatible with recent PyTorch releases) 95 96<details><summary>Click to expand</summary> 97<p> 98 99#### Using Colab 100 101Follow the instructions [here](https://colab.research.google.com/drive/1GNfb01W_xf8JRu78ZKoNnLqiwcrJrbYG#scrollTo=HJ1srOGeNCGA) 102 103#### pip 104 105Prerequisite: [Install PyTorch](https://pytorch.org/get-started/locally/) 106 107 108```bash 109pip install functorch 110``` 111 112Finally, run a quick sanity check in python: 113```py 114import torch 115from functorch import vmap 116x = torch.randn(3) 117y = vmap(torch.sin)(x) 118assert torch.allclose(y, x.sin()) 119``` 120 121</p> 122</details> 123 124## What are the transforms? 125 126Right now, we support the following transforms: 127- `grad`, `vjp`, `jvp`, 128- `jacrev`, `jacfwd`, `hessian` 129- `vmap` 130 131Furthermore, we have some utilities for working with PyTorch modules. 132- `make_functional(model)` 133- `make_functional_with_buffers(model)` 134 135### vmap 136 137Note: `vmap` imposes restrictions on the code that it can be used on. 138For more details, please read its docstring. 139 140`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor 141operations in `func`. `vmap(func)` returns a new function that maps `func` over 142some dimension (default: 0) of each Tensor in `inputs`. 143 144`vmap` is useful for hiding batch dimensions: one can write a function `func` 145that runs on examples and then lift it to a function that can take batches of 146examples with `vmap(func)`, leading to a simpler modeling experience: 147 148```py 149from functorch import vmap 150batch_size, feature_size = 3, 5 151weights = torch.randn(feature_size, requires_grad=True) 152 153def model(feature_vec): 154 # Very simple linear model with activation 155 assert feature_vec.dim() == 1 156 return feature_vec.dot(weights).relu() 157 158examples = torch.randn(batch_size, feature_size) 159result = vmap(model)(examples) 160``` 161 162### grad 163 164`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It compute 165the gradients of the output of func w.r.t. to `inputs[0]`. 166 167```py 168from functorch import grad 169x = torch.randn([]) 170cos_x = grad(lambda x: torch.sin(x))(x) 171assert torch.allclose(cos_x, x.cos()) 172 173# Second-order gradients 174neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) 175assert torch.allclose(neg_sin_x, -x.sin()) 176``` 177 178When composed with `vmap`, `grad` can be used to compute per-sample-gradients: 179```py 180from functorch import vmap 181batch_size, feature_size = 3, 5 182 183def model(weights,feature_vec): 184 # Very simple linear model with activation 185 assert feature_vec.dim() == 1 186 return feature_vec.dot(weights).relu() 187 188def compute_loss(weights, example, target): 189 y = model(weights, example) 190 return ((y - target) ** 2).mean() # MSELoss 191 192weights = torch.randn(feature_size, requires_grad=True) 193examples = torch.randn(batch_size, feature_size) 194targets = torch.randn(batch_size) 195inputs = (weights,examples, targets) 196grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) 197``` 198 199### vjp 200 201The `vjp` transform applies `func` to `inputs` and returns a new function that 202computes vjps given some `cotangents` Tensors. 203```py 204from functorch import vjp 205outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents) 206``` 207 208### jvp 209 210The `jvp` transforms computes Jacobian-vector-products and is also known as 211"forward-mode AD". It is not a higher-order function unlike most other transforms, 212but it returns the outputs of `func(inputs)` as well as the `jvp`s. 213```py 214from functorch import jvp 215x = torch.randn(5) 216y = torch.randn(5) 217f = lambda x, y: (x * y) 218_, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) 219assert torch.allclose(output, x + y) 220``` 221 222### jacrev, jacfwd, and hessian 223 224The `jacrev` transform returns a new function that takes in `x` and returns the 225Jacobian of `torch.sin` with respect to `x` using reverse-mode AD. 226```py 227from functorch import jacrev 228x = torch.randn(5) 229jacobian = jacrev(torch.sin)(x) 230expected = torch.diag(torch.cos(x)) 231assert torch.allclose(jacobian, expected) 232``` 233Use `jacrev` to compute the jacobian. This can be composed with vmap to produce 234batched jacobians: 235 236```py 237x = torch.randn(64, 5) 238jacobian = vmap(jacrev(torch.sin))(x) 239assert jacobian.shape == (64, 5, 5) 240``` 241 242`jacfwd` is a drop-in replacement for `jacrev` that computes Jacobians using 243forward-mode AD: 244```py 245from functorch import jacfwd 246x = torch.randn(5) 247jacobian = jacfwd(torch.sin)(x) 248expected = torch.diag(torch.cos(x)) 249assert torch.allclose(jacobian, expected) 250``` 251 252Composing `jacrev` with itself or `jacfwd` can produce hessians: 253```py 254def f(x): 255 return x.sin().sum() 256 257x = torch.randn(5) 258hessian0 = jacrev(jacrev(f))(x) 259hessian1 = jacfwd(jacrev(f))(x) 260``` 261 262The `hessian` is a convenience function that combines `jacfwd` and `jacrev`: 263```py 264from functorch import hessian 265 266def f(x): 267 return x.sin().sum() 268 269x = torch.randn(5) 270hess = hessian(f)(x) 271``` 272 273### Tracing through the transformations 274We can also trace through these transformations in order to capture the results as new code using `make_fx`. There is also experimental integration with the NNC compiler (only works on CPU for now!). 275 276```py 277from functorch import make_fx, grad 278def f(x): 279 return torch.sin(x).sum() 280x = torch.randn(100) 281grad_f = make_fx(grad(f))(x) 282print(grad_f.code) 283 284def forward(self, x_1): 285 sin = torch.ops.aten.sin(x_1) 286 sum_1 = torch.ops.aten.sum(sin, None); sin = None 287 cos = torch.ops.aten.cos(x_1); x_1 = None 288 _tensor_constant0 = self._tensor_constant0 289 mul = torch.ops.aten.mul(_tensor_constant0, cos); _tensor_constant0 = cos = None 290 return mul 291``` 292 293### Working with NN modules: make_functional and friends 294 295Sometimes you may want to perform a transform with respect to the parameters 296and/or buffers of an nn.Module. This can happen for example in: 297- model ensembling, where all of your weights and buffers have an additional 298dimension 299- per-sample-gradient computation where you want to compute per-sample-grads 300of the loss with respect to the model parameters 301 302Our solution to this right now is an API that, given an nn.Module, creates a 303stateless version of it that can be called like a function. 304 305- `make_functional(model)` returns a functional version of `model` and the 306`model.parameters()` 307- `make_functional_with_buffers(model)` returns a functional version of 308`model` and the `model.parameters()` and `model.buffers()`. 309 310Here's an example where we compute per-sample-gradients using an nn.Linear 311layer: 312 313```py 314import torch 315from functorch import make_functional, vmap, grad 316 317model = torch.nn.Linear(3, 3) 318data = torch.randn(64, 3) 319targets = torch.randn(64, 3) 320 321func_model, params = make_functional(model) 322 323def compute_loss(params, data, targets): 324 preds = func_model(params, data) 325 return torch.mean((preds - targets) ** 2) 326 327per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets) 328``` 329 330If you're making an ensemble of models, you may find 331`combine_state_for_ensemble` useful. 332 333## Documentation 334 335For more documentation, see [our docs website](https://pytorch.org/functorch). 336 337## Debugging 338`torch._C._functorch.dump_tensor`: Dumps dispatch keys on stack 339`torch._C._functorch._set_vmap_fallback_warning_enabled(False)` if the vmap warning spam bothers you. 340 341## Future Plans 342 343In the end state, we'd like to upstream this into PyTorch once we iron out the 344design details. To figure out the details, we need your help -- please send us 345your use cases by starting a conversation in the issue tracker or trying our 346project out. 347 348## License 349Functorch has a BSD-style license, as found in the [LICENSE](LICENSE) file. 350 351## Citing functorch 352 353If you use functorch in your publication, please cite it by using the following BibTeX entry. 354 355```bibtex 356@Misc{functorch2021, 357 author = {Horace He, Richard Zou}, 358 title = {functorch: JAX-like composable function transforms for PyTorch}, 359 howpublished = {\url{https://github.com/pytorch/functorch}}, 360 year = {2021} 361} 362``` 363