1*da0073e9SAndroid Build Coastguard Worker# functorch 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker[**Why functorch?**](#why-composable-function-transforms) 4*da0073e9SAndroid Build Coastguard Worker| [**Install guide**](#install) 5*da0073e9SAndroid Build Coastguard Worker| [**Transformations**](#what-are-the-transforms) 6*da0073e9SAndroid Build Coastguard Worker| [**Documentation**](#documentation) 7*da0073e9SAndroid Build Coastguard Worker| [**Future Plans**](#future-plans) 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker**This library is currently under heavy development - if you have suggestions 10*da0073e9SAndroid Build Coastguard Workeron the API or use-cases you'd like to be covered, please open an github issue 11*da0073e9SAndroid Build Coastguard Workeror reach out. We'd love to hear about how you're using the library.** 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker`functorch` is [JAX-like](https://github.com/google/jax) composable function 14*da0073e9SAndroid Build Coastguard Workertransforms for PyTorch. 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard WorkerIt aims to provide composable `vmap` and `grad` transforms that work with 17*da0073e9SAndroid Build Coastguard WorkerPyTorch modules and PyTorch autograd with good eager-mode performance. 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard WorkerIn addition, there is experimental functionality to trace through these 20*da0073e9SAndroid Build Coastguard Workertransformations using FX in order to capture the results of these transforms 21*da0073e9SAndroid Build Coastguard Workerahead of time. This would allow us to compile the results of vmap or grad 22*da0073e9SAndroid Build Coastguard Workerto improve performance. 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker## Why composable function transforms? 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard WorkerThere are a number of use cases that are tricky to do in 27*da0073e9SAndroid Build Coastguard WorkerPyTorch today: 28*da0073e9SAndroid Build Coastguard Worker- computing per-sample-gradients (or other per-sample quantities) 29*da0073e9SAndroid Build Coastguard Worker- running ensembles of models on a single machine 30*da0073e9SAndroid Build Coastguard Worker- efficiently batching together tasks in the inner-loop of MAML 31*da0073e9SAndroid Build Coastguard Worker- efficiently computing Jacobians and Hessians 32*da0073e9SAndroid Build Coastguard Worker- efficiently computing batched Jacobians and Hessians 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard WorkerComposing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above 35*da0073e9SAndroid Build Coastguard Workerwithout designing a separate subsystem for each. This idea of composable function 36*da0073e9SAndroid Build Coastguard Workertransforms comes from the [JAX framework](https://github.com/google/jax). 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker## Install 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard WorkerThere are two ways to install functorch: 41*da0073e9SAndroid Build Coastguard Worker1. functorch from source 42*da0073e9SAndroid Build Coastguard Worker2. functorch beta (compatible with recent PyTorch releases) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard WorkerWe recommend trying out the functorch beta first. 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker### Installing functorch from source 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker<details><summary>Click to expand</summary> 49*da0073e9SAndroid Build Coastguard Worker<p> 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker#### Using Colab 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard WorkerFollow the instructions [in this Colab notebook](https://colab.research.google.com/drive/1CrLkqIrydBYP_svnF89UUO-aQEqNPE8x?usp=sharing) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker#### Locally 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard WorkerAs of 9/21/2022, `functorch` comes installed alongside a nightly PyTorch binary. 58*da0073e9SAndroid Build Coastguard WorkerPlease install a Preview (nightly) PyTorch binary; see https://pytorch.org/ 59*da0073e9SAndroid Build Coastguard Workerfor instructions. 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard WorkerOnce you've done that, run a quick sanity check in Python: 62*da0073e9SAndroid Build Coastguard Worker```py 63*da0073e9SAndroid Build Coastguard Workerimport torch 64*da0073e9SAndroid Build Coastguard Workerfrom functorch import vmap 65*da0073e9SAndroid Build Coastguard Workerx = torch.randn(3) 66*da0073e9SAndroid Build Coastguard Workery = vmap(torch.sin)(x) 67*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(y, x.sin()) 68*da0073e9SAndroid Build Coastguard Worker``` 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker#### functorch development setup 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard WorkerAs of 9/21/2022, `functorch` comes installed alongside PyTorch and is in the 73*da0073e9SAndroid Build Coastguard WorkerPyTorch source tree. Please install 74*da0073e9SAndroid Build Coastguard Worker[PyTorch from source](https://github.com/pytorch/pytorch#from-source), then, 75*da0073e9SAndroid Build Coastguard Workeryou will be able to `import functorch`. 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard WorkerTry to run some tests to make sure all is OK: 78*da0073e9SAndroid Build Coastguard Worker```bash 79*da0073e9SAndroid Build Coastguard Workerpytest test/test_vmap.py -v 80*da0073e9SAndroid Build Coastguard Workerpytest test/test_eager_transforms.py -v 81*da0073e9SAndroid Build Coastguard Worker``` 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard WorkerAOTAutograd has some additional optional requirements. You can install them via: 84*da0073e9SAndroid Build Coastguard Worker```bash 85*da0073e9SAndroid Build Coastguard Workerpip install networkx 86*da0073e9SAndroid Build Coastguard Worker``` 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard WorkerTo run functorch tests, please install our test dependencies (`expecttest`, `pyyaml`). 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker</p> 92*da0073e9SAndroid Build Coastguard Worker</details> 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker### Installing functorch beta (compatible with recent PyTorch releases) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker<details><summary>Click to expand</summary> 97*da0073e9SAndroid Build Coastguard Worker<p> 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker#### Using Colab 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard WorkerFollow the instructions [here](https://colab.research.google.com/drive/1GNfb01W_xf8JRu78ZKoNnLqiwcrJrbYG#scrollTo=HJ1srOGeNCGA) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker#### pip 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard WorkerPrerequisite: [Install PyTorch](https://pytorch.org/get-started/locally/) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker```bash 109*da0073e9SAndroid Build Coastguard Workerpip install functorch 110*da0073e9SAndroid Build Coastguard Worker``` 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard WorkerFinally, run a quick sanity check in python: 113*da0073e9SAndroid Build Coastguard Worker```py 114*da0073e9SAndroid Build Coastguard Workerimport torch 115*da0073e9SAndroid Build Coastguard Workerfrom functorch import vmap 116*da0073e9SAndroid Build Coastguard Workerx = torch.randn(3) 117*da0073e9SAndroid Build Coastguard Workery = vmap(torch.sin)(x) 118*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(y, x.sin()) 119*da0073e9SAndroid Build Coastguard Worker``` 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker</p> 122*da0073e9SAndroid Build Coastguard Worker</details> 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker## What are the transforms? 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard WorkerRight now, we support the following transforms: 127*da0073e9SAndroid Build Coastguard Worker- `grad`, `vjp`, `jvp`, 128*da0073e9SAndroid Build Coastguard Worker- `jacrev`, `jacfwd`, `hessian` 129*da0073e9SAndroid Build Coastguard Worker- `vmap` 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard WorkerFurthermore, we have some utilities for working with PyTorch modules. 132*da0073e9SAndroid Build Coastguard Worker- `make_functional(model)` 133*da0073e9SAndroid Build Coastguard Worker- `make_functional_with_buffers(model)` 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker### vmap 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard WorkerNote: `vmap` imposes restrictions on the code that it can be used on. 138*da0073e9SAndroid Build Coastguard WorkerFor more details, please read its docstring. 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor 141*da0073e9SAndroid Build Coastguard Workeroperations in `func`. `vmap(func)` returns a new function that maps `func` over 142*da0073e9SAndroid Build Coastguard Workersome dimension (default: 0) of each Tensor in `inputs`. 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker`vmap` is useful for hiding batch dimensions: one can write a function `func` 145*da0073e9SAndroid Build Coastguard Workerthat runs on examples and then lift it to a function that can take batches of 146*da0073e9SAndroid Build Coastguard Workerexamples with `vmap(func)`, leading to a simpler modeling experience: 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker```py 149*da0073e9SAndroid Build Coastguard Workerfrom functorch import vmap 150*da0073e9SAndroid Build Coastguard Workerbatch_size, feature_size = 3, 5 151*da0073e9SAndroid Build Coastguard Workerweights = torch.randn(feature_size, requires_grad=True) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Workerdef model(feature_vec): 154*da0073e9SAndroid Build Coastguard Worker # Very simple linear model with activation 155*da0073e9SAndroid Build Coastguard Worker assert feature_vec.dim() == 1 156*da0073e9SAndroid Build Coastguard Worker return feature_vec.dot(weights).relu() 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Workerexamples = torch.randn(batch_size, feature_size) 159*da0073e9SAndroid Build Coastguard Workerresult = vmap(model)(examples) 160*da0073e9SAndroid Build Coastguard Worker``` 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker### grad 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It compute 165*da0073e9SAndroid Build Coastguard Workerthe gradients of the output of func w.r.t. to `inputs[0]`. 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker```py 168*da0073e9SAndroid Build Coastguard Workerfrom functorch import grad 169*da0073e9SAndroid Build Coastguard Workerx = torch.randn([]) 170*da0073e9SAndroid Build Coastguard Workercos_x = grad(lambda x: torch.sin(x))(x) 171*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(cos_x, x.cos()) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker# Second-order gradients 174*da0073e9SAndroid Build Coastguard Workerneg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) 175*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(neg_sin_x, -x.sin()) 176*da0073e9SAndroid Build Coastguard Worker``` 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard WorkerWhen composed with `vmap`, `grad` can be used to compute per-sample-gradients: 179*da0073e9SAndroid Build Coastguard Worker```py 180*da0073e9SAndroid Build Coastguard Workerfrom functorch import vmap 181*da0073e9SAndroid Build Coastguard Workerbatch_size, feature_size = 3, 5 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Workerdef model(weights,feature_vec): 184*da0073e9SAndroid Build Coastguard Worker # Very simple linear model with activation 185*da0073e9SAndroid Build Coastguard Worker assert feature_vec.dim() == 1 186*da0073e9SAndroid Build Coastguard Worker return feature_vec.dot(weights).relu() 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Workerdef compute_loss(weights, example, target): 189*da0073e9SAndroid Build Coastguard Worker y = model(weights, example) 190*da0073e9SAndroid Build Coastguard Worker return ((y - target) ** 2).mean() # MSELoss 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Workerweights = torch.randn(feature_size, requires_grad=True) 193*da0073e9SAndroid Build Coastguard Workerexamples = torch.randn(batch_size, feature_size) 194*da0073e9SAndroid Build Coastguard Workertargets = torch.randn(batch_size) 195*da0073e9SAndroid Build Coastguard Workerinputs = (weights,examples, targets) 196*da0073e9SAndroid Build Coastguard Workergrad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) 197*da0073e9SAndroid Build Coastguard Worker``` 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker### vjp 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard WorkerThe `vjp` transform applies `func` to `inputs` and returns a new function that 202*da0073e9SAndroid Build Coastguard Workercomputes vjps given some `cotangents` Tensors. 203*da0073e9SAndroid Build Coastguard Worker```py 204*da0073e9SAndroid Build Coastguard Workerfrom functorch import vjp 205*da0073e9SAndroid Build Coastguard Workeroutputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents) 206*da0073e9SAndroid Build Coastguard Worker``` 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker### jvp 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard WorkerThe `jvp` transforms computes Jacobian-vector-products and is also known as 211*da0073e9SAndroid Build Coastguard Worker"forward-mode AD". It is not a higher-order function unlike most other transforms, 212*da0073e9SAndroid Build Coastguard Workerbut it returns the outputs of `func(inputs)` as well as the `jvp`s. 213*da0073e9SAndroid Build Coastguard Worker```py 214*da0073e9SAndroid Build Coastguard Workerfrom functorch import jvp 215*da0073e9SAndroid Build Coastguard Workerx = torch.randn(5) 216*da0073e9SAndroid Build Coastguard Workery = torch.randn(5) 217*da0073e9SAndroid Build Coastguard Workerf = lambda x, y: (x * y) 218*da0073e9SAndroid Build Coastguard Worker_, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) 219*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(output, x + y) 220*da0073e9SAndroid Build Coastguard Worker``` 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker### jacrev, jacfwd, and hessian 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard WorkerThe `jacrev` transform returns a new function that takes in `x` and returns the 225*da0073e9SAndroid Build Coastguard WorkerJacobian of `torch.sin` with respect to `x` using reverse-mode AD. 226*da0073e9SAndroid Build Coastguard Worker```py 227*da0073e9SAndroid Build Coastguard Workerfrom functorch import jacrev 228*da0073e9SAndroid Build Coastguard Workerx = torch.randn(5) 229*da0073e9SAndroid Build Coastguard Workerjacobian = jacrev(torch.sin)(x) 230*da0073e9SAndroid Build Coastguard Workerexpected = torch.diag(torch.cos(x)) 231*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(jacobian, expected) 232*da0073e9SAndroid Build Coastguard Worker``` 233*da0073e9SAndroid Build Coastguard WorkerUse `jacrev` to compute the jacobian. This can be composed with vmap to produce 234*da0073e9SAndroid Build Coastguard Workerbatched jacobians: 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker```py 237*da0073e9SAndroid Build Coastguard Workerx = torch.randn(64, 5) 238*da0073e9SAndroid Build Coastguard Workerjacobian = vmap(jacrev(torch.sin))(x) 239*da0073e9SAndroid Build Coastguard Workerassert jacobian.shape == (64, 5, 5) 240*da0073e9SAndroid Build Coastguard Worker``` 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker`jacfwd` is a drop-in replacement for `jacrev` that computes Jacobians using 243*da0073e9SAndroid Build Coastguard Workerforward-mode AD: 244*da0073e9SAndroid Build Coastguard Worker```py 245*da0073e9SAndroid Build Coastguard Workerfrom functorch import jacfwd 246*da0073e9SAndroid Build Coastguard Workerx = torch.randn(5) 247*da0073e9SAndroid Build Coastguard Workerjacobian = jacfwd(torch.sin)(x) 248*da0073e9SAndroid Build Coastguard Workerexpected = torch.diag(torch.cos(x)) 249*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(jacobian, expected) 250*da0073e9SAndroid Build Coastguard Worker``` 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard WorkerComposing `jacrev` with itself or `jacfwd` can produce hessians: 253*da0073e9SAndroid Build Coastguard Worker```py 254*da0073e9SAndroid Build Coastguard Workerdef f(x): 255*da0073e9SAndroid Build Coastguard Worker return x.sin().sum() 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Workerx = torch.randn(5) 258*da0073e9SAndroid Build Coastguard Workerhessian0 = jacrev(jacrev(f))(x) 259*da0073e9SAndroid Build Coastguard Workerhessian1 = jacfwd(jacrev(f))(x) 260*da0073e9SAndroid Build Coastguard Worker``` 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard WorkerThe `hessian` is a convenience function that combines `jacfwd` and `jacrev`: 263*da0073e9SAndroid Build Coastguard Worker```py 264*da0073e9SAndroid Build Coastguard Workerfrom functorch import hessian 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Workerdef f(x): 267*da0073e9SAndroid Build Coastguard Worker return x.sin().sum() 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Workerx = torch.randn(5) 270*da0073e9SAndroid Build Coastguard Workerhess = hessian(f)(x) 271*da0073e9SAndroid Build Coastguard Worker``` 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker### Tracing through the transformations 274*da0073e9SAndroid Build Coastguard WorkerWe can also trace through these transformations in order to capture the results as new code using `make_fx`. There is also experimental integration with the NNC compiler (only works on CPU for now!). 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker```py 277*da0073e9SAndroid Build Coastguard Workerfrom functorch import make_fx, grad 278*da0073e9SAndroid Build Coastguard Workerdef f(x): 279*da0073e9SAndroid Build Coastguard Worker return torch.sin(x).sum() 280*da0073e9SAndroid Build Coastguard Workerx = torch.randn(100) 281*da0073e9SAndroid Build Coastguard Workergrad_f = make_fx(grad(f))(x) 282*da0073e9SAndroid Build Coastguard Workerprint(grad_f.code) 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_1): 285*da0073e9SAndroid Build Coastguard Worker sin = torch.ops.aten.sin(x_1) 286*da0073e9SAndroid Build Coastguard Worker sum_1 = torch.ops.aten.sum(sin, None); sin = None 287*da0073e9SAndroid Build Coastguard Worker cos = torch.ops.aten.cos(x_1); x_1 = None 288*da0073e9SAndroid Build Coastguard Worker _tensor_constant0 = self._tensor_constant0 289*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul(_tensor_constant0, cos); _tensor_constant0 = cos = None 290*da0073e9SAndroid Build Coastguard Worker return mul 291*da0073e9SAndroid Build Coastguard Worker``` 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker### Working with NN modules: make_functional and friends 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard WorkerSometimes you may want to perform a transform with respect to the parameters 296*da0073e9SAndroid Build Coastguard Workerand/or buffers of an nn.Module. This can happen for example in: 297*da0073e9SAndroid Build Coastguard Worker- model ensembling, where all of your weights and buffers have an additional 298*da0073e9SAndroid Build Coastguard Workerdimension 299*da0073e9SAndroid Build Coastguard Worker- per-sample-gradient computation where you want to compute per-sample-grads 300*da0073e9SAndroid Build Coastguard Workerof the loss with respect to the model parameters 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard WorkerOur solution to this right now is an API that, given an nn.Module, creates a 303*da0073e9SAndroid Build Coastguard Workerstateless version of it that can be called like a function. 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker- `make_functional(model)` returns a functional version of `model` and the 306*da0073e9SAndroid Build Coastguard Worker`model.parameters()` 307*da0073e9SAndroid Build Coastguard Worker- `make_functional_with_buffers(model)` returns a functional version of 308*da0073e9SAndroid Build Coastguard Worker`model` and the `model.parameters()` and `model.buffers()`. 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard WorkerHere's an example where we compute per-sample-gradients using an nn.Linear 311*da0073e9SAndroid Build Coastguard Workerlayer: 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker```py 314*da0073e9SAndroid Build Coastguard Workerimport torch 315*da0073e9SAndroid Build Coastguard Workerfrom functorch import make_functional, vmap, grad 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Workermodel = torch.nn.Linear(3, 3) 318*da0073e9SAndroid Build Coastguard Workerdata = torch.randn(64, 3) 319*da0073e9SAndroid Build Coastguard Workertargets = torch.randn(64, 3) 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Workerfunc_model, params = make_functional(model) 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Workerdef compute_loss(params, data, targets): 324*da0073e9SAndroid Build Coastguard Worker preds = func_model(params, data) 325*da0073e9SAndroid Build Coastguard Worker return torch.mean((preds - targets) ** 2) 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Workerper_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets) 328*da0073e9SAndroid Build Coastguard Worker``` 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard WorkerIf you're making an ensemble of models, you may find 331*da0073e9SAndroid Build Coastguard Worker`combine_state_for_ensemble` useful. 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker## Documentation 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard WorkerFor more documentation, see [our docs website](https://pytorch.org/functorch). 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker## Debugging 338*da0073e9SAndroid Build Coastguard Worker`torch._C._functorch.dump_tensor`: Dumps dispatch keys on stack 339*da0073e9SAndroid Build Coastguard Worker`torch._C._functorch._set_vmap_fallback_warning_enabled(False)` if the vmap warning spam bothers you. 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker## Future Plans 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard WorkerIn the end state, we'd like to upstream this into PyTorch once we iron out the 344*da0073e9SAndroid Build Coastguard Workerdesign details. To figure out the details, we need your help -- please send us 345*da0073e9SAndroid Build Coastguard Workeryour use cases by starting a conversation in the issue tracker or trying our 346*da0073e9SAndroid Build Coastguard Workerproject out. 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker## License 349*da0073e9SAndroid Build Coastguard WorkerFunctorch has a BSD-style license, as found in the [LICENSE](LICENSE) file. 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker## Citing functorch 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard WorkerIf you use functorch in your publication, please cite it by using the following BibTeX entry. 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker```bibtex 356*da0073e9SAndroid Build Coastguard Worker@Misc{functorch2021, 357*da0073e9SAndroid Build Coastguard Worker author = {Horace He, Richard Zou}, 358*da0073e9SAndroid Build Coastguard Worker title = {functorch: JAX-like composable function transforms for PyTorch}, 359*da0073e9SAndroid Build Coastguard Worker howpublished = {\url{https://github.com/pytorch/functorch}}, 360*da0073e9SAndroid Build Coastguard Worker year = {2021} 361*da0073e9SAndroid Build Coastguard Worker} 362*da0073e9SAndroid Build Coastguard Worker``` 363