xref: /aosp_15_r20/external/pytorch/functorch/README.md (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# functorch
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker[**Why functorch?**](#why-composable-function-transforms)
4*da0073e9SAndroid Build Coastguard Worker| [**Install guide**](#install)
5*da0073e9SAndroid Build Coastguard Worker| [**Transformations**](#what-are-the-transforms)
6*da0073e9SAndroid Build Coastguard Worker| [**Documentation**](#documentation)
7*da0073e9SAndroid Build Coastguard Worker| [**Future Plans**](#future-plans)
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker**This library is currently under heavy development - if you have suggestions
10*da0073e9SAndroid Build Coastguard Workeron the API or use-cases you'd like to be covered, please open an github issue
11*da0073e9SAndroid Build Coastguard Workeror reach out. We'd love to hear about how you're using the library.**
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker`functorch` is [JAX-like](https://github.com/google/jax) composable function
14*da0073e9SAndroid Build Coastguard Workertransforms for PyTorch.
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard WorkerIt aims to provide composable `vmap` and `grad` transforms that work with
17*da0073e9SAndroid Build Coastguard WorkerPyTorch modules and PyTorch autograd with good eager-mode performance.
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard WorkerIn addition, there is experimental functionality to trace through these
20*da0073e9SAndroid Build Coastguard Workertransformations using FX in order to capture the results of these transforms
21*da0073e9SAndroid Build Coastguard Workerahead of time. This would allow us to compile the results of vmap or grad
22*da0073e9SAndroid Build Coastguard Workerto improve performance.
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker## Why composable function transforms?
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard WorkerThere are a number of use cases that are tricky to do in
27*da0073e9SAndroid Build Coastguard WorkerPyTorch today:
28*da0073e9SAndroid Build Coastguard Worker- computing per-sample-gradients (or other per-sample quantities)
29*da0073e9SAndroid Build Coastguard Worker- running ensembles of models on a single machine
30*da0073e9SAndroid Build Coastguard Worker- efficiently batching together tasks in the inner-loop of MAML
31*da0073e9SAndroid Build Coastguard Worker- efficiently computing Jacobians and Hessians
32*da0073e9SAndroid Build Coastguard Worker- efficiently computing batched Jacobians and Hessians
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard WorkerComposing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above
35*da0073e9SAndroid Build Coastguard Workerwithout designing a separate subsystem for each. This idea of composable function
36*da0073e9SAndroid Build Coastguard Workertransforms comes from the [JAX framework](https://github.com/google/jax).
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker## Install
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard WorkerThere are two ways to install functorch:
41*da0073e9SAndroid Build Coastguard Worker1. functorch from source
42*da0073e9SAndroid Build Coastguard Worker2. functorch beta (compatible with recent PyTorch releases)
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard WorkerWe recommend trying out the functorch beta first.
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker### Installing functorch from source
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker<details><summary>Click to expand</summary>
49*da0073e9SAndroid Build Coastguard Worker<p>
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker#### Using Colab
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard WorkerFollow the instructions [in this Colab notebook](https://colab.research.google.com/drive/1CrLkqIrydBYP_svnF89UUO-aQEqNPE8x?usp=sharing)
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker#### Locally
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard WorkerAs of 9/21/2022, `functorch` comes installed alongside a nightly PyTorch binary.
58*da0073e9SAndroid Build Coastguard WorkerPlease install a Preview (nightly) PyTorch binary; see  https://pytorch.org/
59*da0073e9SAndroid Build Coastguard Workerfor instructions.
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard WorkerOnce you've done that, run a quick sanity check in Python:
62*da0073e9SAndroid Build Coastguard Worker```py
63*da0073e9SAndroid Build Coastguard Workerimport torch
64*da0073e9SAndroid Build Coastguard Workerfrom functorch import vmap
65*da0073e9SAndroid Build Coastguard Workerx = torch.randn(3)
66*da0073e9SAndroid Build Coastguard Workery = vmap(torch.sin)(x)
67*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(y, x.sin())
68*da0073e9SAndroid Build Coastguard Worker```
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker#### functorch development setup
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard WorkerAs of 9/21/2022, `functorch` comes installed alongside PyTorch and is in the
73*da0073e9SAndroid Build Coastguard WorkerPyTorch source tree. Please install
74*da0073e9SAndroid Build Coastguard Worker[PyTorch from source](https://github.com/pytorch/pytorch#from-source), then,
75*da0073e9SAndroid Build Coastguard Workeryou will be able to `import functorch`.
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard WorkerTry to run some tests to make sure all is OK:
78*da0073e9SAndroid Build Coastguard Worker```bash
79*da0073e9SAndroid Build Coastguard Workerpytest test/test_vmap.py -v
80*da0073e9SAndroid Build Coastguard Workerpytest test/test_eager_transforms.py -v
81*da0073e9SAndroid Build Coastguard Worker```
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard WorkerAOTAutograd has some additional optional requirements. You can install them via:
84*da0073e9SAndroid Build Coastguard Worker```bash
85*da0073e9SAndroid Build Coastguard Workerpip install networkx
86*da0073e9SAndroid Build Coastguard Worker```
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard WorkerTo run functorch tests, please install our test dependencies (`expecttest`, `pyyaml`).
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker</p>
92*da0073e9SAndroid Build Coastguard Worker</details>
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker### Installing functorch beta (compatible with recent PyTorch releases)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker<details><summary>Click to expand</summary>
97*da0073e9SAndroid Build Coastguard Worker<p>
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker#### Using Colab
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard WorkerFollow the instructions [here](https://colab.research.google.com/drive/1GNfb01W_xf8JRu78ZKoNnLqiwcrJrbYG#scrollTo=HJ1srOGeNCGA)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker#### pip
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard WorkerPrerequisite: [Install PyTorch](https://pytorch.org/get-started/locally/)
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker```bash
109*da0073e9SAndroid Build Coastguard Workerpip install functorch
110*da0073e9SAndroid Build Coastguard Worker```
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard WorkerFinally, run a quick sanity check in python:
113*da0073e9SAndroid Build Coastguard Worker```py
114*da0073e9SAndroid Build Coastguard Workerimport torch
115*da0073e9SAndroid Build Coastguard Workerfrom functorch import vmap
116*da0073e9SAndroid Build Coastguard Workerx = torch.randn(3)
117*da0073e9SAndroid Build Coastguard Workery = vmap(torch.sin)(x)
118*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(y, x.sin())
119*da0073e9SAndroid Build Coastguard Worker```
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker</p>
122*da0073e9SAndroid Build Coastguard Worker</details>
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker## What are the transforms?
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard WorkerRight now, we support the following transforms:
127*da0073e9SAndroid Build Coastguard Worker- `grad`, `vjp`, `jvp`,
128*da0073e9SAndroid Build Coastguard Worker- `jacrev`, `jacfwd`, `hessian`
129*da0073e9SAndroid Build Coastguard Worker- `vmap`
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard WorkerFurthermore, we have some utilities for working with PyTorch modules.
132*da0073e9SAndroid Build Coastguard Worker- `make_functional(model)`
133*da0073e9SAndroid Build Coastguard Worker- `make_functional_with_buffers(model)`
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker### vmap
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard WorkerNote: `vmap` imposes restrictions on the code that it can be used on.
138*da0073e9SAndroid Build Coastguard WorkerFor more details, please read its docstring.
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor
141*da0073e9SAndroid Build Coastguard Workeroperations in `func`. `vmap(func)` returns a new function that maps `func` over
142*da0073e9SAndroid Build Coastguard Workersome dimension (default: 0) of each Tensor in `inputs`.
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker`vmap` is useful for hiding batch dimensions: one can write a function `func`
145*da0073e9SAndroid Build Coastguard Workerthat runs on examples and then lift it to a function that can take batches of
146*da0073e9SAndroid Build Coastguard Workerexamples with `vmap(func)`, leading to a simpler modeling experience:
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker```py
149*da0073e9SAndroid Build Coastguard Workerfrom functorch import vmap
150*da0073e9SAndroid Build Coastguard Workerbatch_size, feature_size = 3, 5
151*da0073e9SAndroid Build Coastguard Workerweights = torch.randn(feature_size, requires_grad=True)
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Workerdef model(feature_vec):
154*da0073e9SAndroid Build Coastguard Worker    # Very simple linear model with activation
155*da0073e9SAndroid Build Coastguard Worker    assert feature_vec.dim() == 1
156*da0073e9SAndroid Build Coastguard Worker    return feature_vec.dot(weights).relu()
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Workerexamples = torch.randn(batch_size, feature_size)
159*da0073e9SAndroid Build Coastguard Workerresult = vmap(model)(examples)
160*da0073e9SAndroid Build Coastguard Worker```
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker### grad
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It compute
165*da0073e9SAndroid Build Coastguard Workerthe gradients of the output of func w.r.t. to `inputs[0]`.
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker```py
168*da0073e9SAndroid Build Coastguard Workerfrom functorch import grad
169*da0073e9SAndroid Build Coastguard Workerx = torch.randn([])
170*da0073e9SAndroid Build Coastguard Workercos_x = grad(lambda x: torch.sin(x))(x)
171*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(cos_x, x.cos())
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker# Second-order gradients
174*da0073e9SAndroid Build Coastguard Workerneg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
175*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(neg_sin_x, -x.sin())
176*da0073e9SAndroid Build Coastguard Worker```
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard WorkerWhen composed with `vmap`, `grad` can be used to compute per-sample-gradients:
179*da0073e9SAndroid Build Coastguard Worker```py
180*da0073e9SAndroid Build Coastguard Workerfrom functorch import vmap
181*da0073e9SAndroid Build Coastguard Workerbatch_size, feature_size = 3, 5
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Workerdef model(weights,feature_vec):
184*da0073e9SAndroid Build Coastguard Worker    # Very simple linear model with activation
185*da0073e9SAndroid Build Coastguard Worker    assert feature_vec.dim() == 1
186*da0073e9SAndroid Build Coastguard Worker    return feature_vec.dot(weights).relu()
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Workerdef compute_loss(weights, example, target):
189*da0073e9SAndroid Build Coastguard Worker    y = model(weights, example)
190*da0073e9SAndroid Build Coastguard Worker    return ((y - target) ** 2).mean()  # MSELoss
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Workerweights = torch.randn(feature_size, requires_grad=True)
193*da0073e9SAndroid Build Coastguard Workerexamples = torch.randn(batch_size, feature_size)
194*da0073e9SAndroid Build Coastguard Workertargets = torch.randn(batch_size)
195*da0073e9SAndroid Build Coastguard Workerinputs = (weights,examples, targets)
196*da0073e9SAndroid Build Coastguard Workergrad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
197*da0073e9SAndroid Build Coastguard Worker```
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker### vjp
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard WorkerThe `vjp` transform applies `func` to `inputs` and returns a new function that
202*da0073e9SAndroid Build Coastguard Workercomputes vjps given some `cotangents` Tensors.
203*da0073e9SAndroid Build Coastguard Worker```py
204*da0073e9SAndroid Build Coastguard Workerfrom functorch import vjp
205*da0073e9SAndroid Build Coastguard Workeroutputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)
206*da0073e9SAndroid Build Coastguard Worker```
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker### jvp
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard WorkerThe `jvp` transforms computes Jacobian-vector-products and is also known as
211*da0073e9SAndroid Build Coastguard Worker"forward-mode AD". It is not a higher-order function unlike most other transforms,
212*da0073e9SAndroid Build Coastguard Workerbut it returns the outputs of `func(inputs)` as well as the `jvp`s.
213*da0073e9SAndroid Build Coastguard Worker```py
214*da0073e9SAndroid Build Coastguard Workerfrom functorch import jvp
215*da0073e9SAndroid Build Coastguard Workerx = torch.randn(5)
216*da0073e9SAndroid Build Coastguard Workery = torch.randn(5)
217*da0073e9SAndroid Build Coastguard Workerf = lambda x, y: (x * y)
218*da0073e9SAndroid Build Coastguard Worker_, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
219*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(output, x + y)
220*da0073e9SAndroid Build Coastguard Worker```
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker### jacrev, jacfwd, and hessian
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard WorkerThe `jacrev` transform returns a new function that takes in `x` and returns the
225*da0073e9SAndroid Build Coastguard WorkerJacobian of `torch.sin` with respect to `x` using reverse-mode AD.
226*da0073e9SAndroid Build Coastguard Worker```py
227*da0073e9SAndroid Build Coastguard Workerfrom functorch import jacrev
228*da0073e9SAndroid Build Coastguard Workerx = torch.randn(5)
229*da0073e9SAndroid Build Coastguard Workerjacobian = jacrev(torch.sin)(x)
230*da0073e9SAndroid Build Coastguard Workerexpected = torch.diag(torch.cos(x))
231*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(jacobian, expected)
232*da0073e9SAndroid Build Coastguard Worker```
233*da0073e9SAndroid Build Coastguard WorkerUse `jacrev` to compute the jacobian. This can be composed with vmap to produce
234*da0073e9SAndroid Build Coastguard Workerbatched jacobians:
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker```py
237*da0073e9SAndroid Build Coastguard Workerx = torch.randn(64, 5)
238*da0073e9SAndroid Build Coastguard Workerjacobian = vmap(jacrev(torch.sin))(x)
239*da0073e9SAndroid Build Coastguard Workerassert jacobian.shape == (64, 5, 5)
240*da0073e9SAndroid Build Coastguard Worker```
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker`jacfwd` is a drop-in replacement for `jacrev` that computes Jacobians using
243*da0073e9SAndroid Build Coastguard Workerforward-mode AD:
244*da0073e9SAndroid Build Coastguard Worker```py
245*da0073e9SAndroid Build Coastguard Workerfrom functorch import jacfwd
246*da0073e9SAndroid Build Coastguard Workerx = torch.randn(5)
247*da0073e9SAndroid Build Coastguard Workerjacobian = jacfwd(torch.sin)(x)
248*da0073e9SAndroid Build Coastguard Workerexpected = torch.diag(torch.cos(x))
249*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(jacobian, expected)
250*da0073e9SAndroid Build Coastguard Worker```
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard WorkerComposing `jacrev` with itself or `jacfwd` can produce hessians:
253*da0073e9SAndroid Build Coastguard Worker```py
254*da0073e9SAndroid Build Coastguard Workerdef f(x):
255*da0073e9SAndroid Build Coastguard Worker  return x.sin().sum()
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Workerx = torch.randn(5)
258*da0073e9SAndroid Build Coastguard Workerhessian0 = jacrev(jacrev(f))(x)
259*da0073e9SAndroid Build Coastguard Workerhessian1 = jacfwd(jacrev(f))(x)
260*da0073e9SAndroid Build Coastguard Worker```
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard WorkerThe `hessian` is a convenience function that combines `jacfwd` and `jacrev`:
263*da0073e9SAndroid Build Coastguard Worker```py
264*da0073e9SAndroid Build Coastguard Workerfrom functorch import hessian
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Workerdef f(x):
267*da0073e9SAndroid Build Coastguard Worker  return x.sin().sum()
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Workerx = torch.randn(5)
270*da0073e9SAndroid Build Coastguard Workerhess = hessian(f)(x)
271*da0073e9SAndroid Build Coastguard Worker```
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker### Tracing through the transformations
274*da0073e9SAndroid Build Coastguard WorkerWe can also trace through these transformations in order to capture the results as new code using `make_fx`. There is also experimental integration with the NNC compiler (only works on CPU for now!).
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker```py
277*da0073e9SAndroid Build Coastguard Workerfrom functorch import make_fx, grad
278*da0073e9SAndroid Build Coastguard Workerdef f(x):
279*da0073e9SAndroid Build Coastguard Worker    return torch.sin(x).sum()
280*da0073e9SAndroid Build Coastguard Workerx = torch.randn(100)
281*da0073e9SAndroid Build Coastguard Workergrad_f = make_fx(grad(f))(x)
282*da0073e9SAndroid Build Coastguard Workerprint(grad_f.code)
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_1):
285*da0073e9SAndroid Build Coastguard Worker    sin = torch.ops.aten.sin(x_1)
286*da0073e9SAndroid Build Coastguard Worker    sum_1 = torch.ops.aten.sum(sin, None);  sin = None
287*da0073e9SAndroid Build Coastguard Worker    cos = torch.ops.aten.cos(x_1);  x_1 = None
288*da0073e9SAndroid Build Coastguard Worker    _tensor_constant0 = self._tensor_constant0
289*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul(_tensor_constant0, cos);  _tensor_constant0 = cos = None
290*da0073e9SAndroid Build Coastguard Worker    return mul
291*da0073e9SAndroid Build Coastguard Worker```
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker### Working with NN modules: make_functional and friends
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard WorkerSometimes you may want to perform a transform with respect to the parameters
296*da0073e9SAndroid Build Coastguard Workerand/or buffers of an nn.Module. This can happen for example in:
297*da0073e9SAndroid Build Coastguard Worker- model ensembling, where all of your weights and buffers have an additional
298*da0073e9SAndroid Build Coastguard Workerdimension
299*da0073e9SAndroid Build Coastguard Worker- per-sample-gradient computation where you want to compute per-sample-grads
300*da0073e9SAndroid Build Coastguard Workerof the loss with respect to the model parameters
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard WorkerOur solution to this right now is an API that, given an nn.Module, creates a
303*da0073e9SAndroid Build Coastguard Workerstateless version of it that can be called like a function.
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker- `make_functional(model)` returns a functional version of `model` and the
306*da0073e9SAndroid Build Coastguard Worker`model.parameters()`
307*da0073e9SAndroid Build Coastguard Worker- `make_functional_with_buffers(model)` returns a functional version of
308*da0073e9SAndroid Build Coastguard Worker`model` and the `model.parameters()` and `model.buffers()`.
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard WorkerHere's an example where we compute per-sample-gradients using an nn.Linear
311*da0073e9SAndroid Build Coastguard Workerlayer:
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker```py
314*da0073e9SAndroid Build Coastguard Workerimport torch
315*da0073e9SAndroid Build Coastguard Workerfrom functorch import make_functional, vmap, grad
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Workermodel = torch.nn.Linear(3, 3)
318*da0073e9SAndroid Build Coastguard Workerdata = torch.randn(64, 3)
319*da0073e9SAndroid Build Coastguard Workertargets = torch.randn(64, 3)
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Workerfunc_model, params = make_functional(model)
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Workerdef compute_loss(params, data, targets):
324*da0073e9SAndroid Build Coastguard Worker    preds = func_model(params, data)
325*da0073e9SAndroid Build Coastguard Worker    return torch.mean((preds - targets) ** 2)
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Workerper_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets)
328*da0073e9SAndroid Build Coastguard Worker```
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard WorkerIf you're making an ensemble of models, you may find
331*da0073e9SAndroid Build Coastguard Worker`combine_state_for_ensemble` useful.
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker## Documentation
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard WorkerFor more documentation, see [our docs website](https://pytorch.org/functorch).
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker## Debugging
338*da0073e9SAndroid Build Coastguard Worker`torch._C._functorch.dump_tensor`: Dumps dispatch keys on stack
339*da0073e9SAndroid Build Coastguard Worker`torch._C._functorch._set_vmap_fallback_warning_enabled(False)` if the vmap warning spam bothers you.
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker## Future Plans
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard WorkerIn the end state, we'd like to upstream this into PyTorch once we iron out the
344*da0073e9SAndroid Build Coastguard Workerdesign details. To figure out the details, we need your help -- please send us
345*da0073e9SAndroid Build Coastguard Workeryour use cases by starting a conversation in the issue tracker or trying our
346*da0073e9SAndroid Build Coastguard Workerproject out.
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker## License
349*da0073e9SAndroid Build Coastguard WorkerFunctorch has a BSD-style license, as found in the [LICENSE](LICENSE) file.
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker## Citing functorch
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard WorkerIf you use functorch in your publication, please cite it by using the following BibTeX entry.
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker```bibtex
356*da0073e9SAndroid Build Coastguard Worker@Misc{functorch2021,
357*da0073e9SAndroid Build Coastguard Worker  author =       {Horace He, Richard Zou},
358*da0073e9SAndroid Build Coastguard Worker  title =        {functorch: JAX-like composable function transforms for PyTorch},
359*da0073e9SAndroid Build Coastguard Worker  howpublished = {\url{https://github.com/pytorch/functorch}},
360*da0073e9SAndroid Build Coastguard Worker  year =         {2021}
361*da0073e9SAndroid Build Coastguard Worker}
362*da0073e9SAndroid Build Coastguard Worker```
363