xref: /aosp_15_r20/external/pytorch/functorch/notebooks/whirlwind_tour.ipynb (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1{
2 "cells": [
3  {
4   "cell_type": "markdown",
5   "id": "903e2f76",
6   "metadata": {},
7   "source": [
8    "# Whirlwind Tour\n",
9    "\n",
10    "<a href=\"https://colab.research.google.com/github/pytorch/pytorch/blob/master/functorch/notebooks/whirlwind_tour.ipynb\">\n",
11    "  <img style=\"width: auto\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
12    "</a>\n",
13    "\n",
14    "## What is functorch?\n",
15    "\n",
16    "functorch is a library for [JAX](https://github.com/google/jax)-like composable function transforms in PyTorch.\n",
17    "- A \"function transform\" is a higher-order function that accepts a numerical function and returns a new function that computes a different quantity.\n",
18    "- functorch has auto-differentiation transforms (`grad(f)` returns a function that computes the gradient of `f`), a vectorization/batching transform (`vmap(f)` returns a function that computes `f` over batches of inputs), and others.\n",
19    "- These function transforms can compose with each other arbitrarily. For example, composing `vmap(grad(f))` computes a quantity called per-sample-gradients that stock PyTorch cannot efficiently compute today.\n",
20    "\n",
21    "Furthermore, we also provide an experimental compilation transform in the `functorch.compile` namespace. Our compilation transform, named AOT (ahead-of-time) Autograd, returns to you an [FX graph](https://pytorch.org/docs/stable/fx.html) (that optionally contains a backward pass), of which compilation via various backends is one path you can take.\n",
22    "\n",
23    "\n",
24    "## Why composable function transforms?\n",
25    "There are a number of use cases that are tricky to do in PyTorch today:\n",
26    "- computing per-sample-gradients (or other per-sample quantities)\n",
27    "- running ensembles of models on a single machine\n",
28    "- efficiently batching together tasks in the inner-loop of MAML\n",
29    "- efficiently computing Jacobians and Hessians\n",
30    "- efficiently computing batched Jacobians and Hessians\n",
31    "\n",
32    "Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above without designing a separate subsystem for each.\n",
33    "\n",
34    "## What are the transforms?\n",
35    "\n",
36    "### grad (gradient computation)\n",
37    "\n",
38    "`grad(func)` is our gradient computation transform. It returns a new function that computes the gradients of `func`. It assumes `func` returns a single-element Tensor and by default it computes the gradients of the output of `func` w.r.t. to the first input."
39   ]
40  },
41  {
42   "cell_type": "code",
43   "execution_count": null,
44   "id": "f920b923",
45   "metadata": {},
46   "outputs": [],
47   "source": [
48    "import torch\n",
49    "from functorch import grad\n",
50    "x = torch.randn([])\n",
51    "cos_x = grad(lambda x: torch.sin(x))(x)\n",
52    "assert torch.allclose(cos_x, x.cos())\n",
53    "\n",
54    "# Second-order gradients\n",
55    "neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)\n",
56    "assert torch.allclose(neg_sin_x, -x.sin())"
57   ]
58  },
59  {
60   "cell_type": "markdown",
61   "id": "ef3b2d85",
62   "metadata": {},
63   "source": [
64    "### vmap (auto-vectorization)\n",
65    "\n",
66    "Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.\n",
67    "\n",
68    "`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor operations in `func`. `vmap(func)` returns a new function that maps `func` over some dimension (default: 0) of each Tensor in inputs.\n",
69    "\n",
70    "vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with `vmap(func)`, leading to a simpler modeling experience:"
71   ]
72  },
73  {
74   "cell_type": "code",
75   "execution_count": null,
76   "id": "6ebac649",
77   "metadata": {},
78   "outputs": [],
79   "source": [
80    "import torch\n",
81    "from functorch import vmap\n",
82    "batch_size, feature_size = 3, 5\n",
83    "weights = torch.randn(feature_size, requires_grad=True)\n",
84    "\n",
85    "def model(feature_vec):\n",
86    "    # Very simple linear model with activation\n",
87    "    assert feature_vec.dim() == 1\n",
88    "    return feature_vec.dot(weights).relu()\n",
89    "\n",
90    "examples = torch.randn(batch_size, feature_size)\n",
91    "result = vmap(model)(examples)"
92   ]
93  },
94  {
95   "cell_type": "markdown",
96   "id": "5161e6d2",
97   "metadata": {},
98   "source": [
99    "When composed with `grad`, `vmap` can be used to compute per-sample-gradients:"
100   ]
101  },
102  {
103   "cell_type": "code",
104   "execution_count": null,
105   "id": "ffb2fcb1",
106   "metadata": {},
107   "outputs": [],
108   "source": [
109    "from functorch import vmap\n",
110    "batch_size, feature_size = 3, 5\n",
111    "\n",
112    "def model(weights,feature_vec):\n",
113    "    # Very simple linear model with activation\n",
114    "    assert feature_vec.dim() == 1\n",
115    "    return feature_vec.dot(weights).relu()\n",
116    "\n",
117    "def compute_loss(weights, example, target):\n",
118    "    y = model(weights, example)\n",
119    "    return ((y - target) ** 2).mean()  # MSELoss\n",
120    "\n",
121    "weights = torch.randn(feature_size, requires_grad=True)\n",
122    "examples = torch.randn(batch_size, feature_size)\n",
123    "targets = torch.randn(batch_size)\n",
124    "inputs = (weights,examples, targets)\n",
125    "grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)"
126   ]
127  },
128  {
129   "cell_type": "markdown",
130   "id": "11d711af",
131   "metadata": {},
132   "source": [
133    "### vjp (vector-Jacobian product)\n",
134    "\n",
135    "The `vjp` transform applies `func` to `inputs` and returns a new function that computes the vector-Jacobian product (vjp) given some `cotangents` Tensors."
136   ]
137  },
138  {
139   "cell_type": "code",
140   "execution_count": null,
141   "id": "ad48f9d4",
142   "metadata": {},
143   "outputs": [],
144   "source": [
145    "from functorch import vjp\n",
146    "\n",
147    "inputs = torch.randn(3)\n",
148    "func = torch.sin\n",
149    "cotangents = (torch.randn(3),)\n",
150    "\n",
151    "outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)"
152   ]
153  },
154  {
155   "cell_type": "markdown",
156   "id": "e0221270",
157   "metadata": {},
158   "source": [
159    "### jvp (Jacobian-vector product)\n",
160    "\n",
161    "The `jvp` transforms computes Jacobian-vector-products and is also known as \"forward-mode AD\". It is not a higher-order function unlike most other transforms, but it returns the outputs of `func(inputs)` as well as the jvps."
162   ]
163  },
164  {
165   "cell_type": "code",
166   "execution_count": null,
167   "id": "f3772f43",
168   "metadata": {},
169   "outputs": [],
170   "source": [
171    "from functorch import jvp\n",
172    "x = torch.randn(5)\n",
173    "y = torch.randn(5)\n",
174    "f = lambda x, y: (x * y)\n",
175    "_, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))\n",
176    "assert torch.allclose(output, x + y)"
177   ]
178  },
179  {
180   "cell_type": "markdown",
181   "id": "7b00953b",
182   "metadata": {},
183   "source": [
184    "### jacrev, jacfwd, and hessian\n",
185    "\n",
186    "The `jacrev` transform returns a new function that takes in `x` and returns the Jacobian of the function\n",
187    "with respect to `x` using reverse-mode AD."
188   ]
189  },
190  {
191   "cell_type": "code",
192   "execution_count": null,
193   "id": "20f53be2",
194   "metadata": {},
195   "outputs": [],
196   "source": [
197    "from functorch import jacrev\n",
198    "x = torch.randn(5)\n",
199    "jacobian = jacrev(torch.sin)(x)\n",
200    "expected = torch.diag(torch.cos(x))\n",
201    "assert torch.allclose(jacobian, expected)"
202   ]
203  },
204  {
205   "cell_type": "markdown",
206   "id": "b9007c88",
207   "metadata": {},
208   "source": [
209    "Use `jacrev` to compute the jacobian. This can be composed with `vmap` to produce batched jacobians:"
210   ]
211  },
212  {
213   "cell_type": "code",
214   "execution_count": null,
215   "id": "97d6c382",
216   "metadata": {},
217   "outputs": [],
218   "source": [
219    "x = torch.randn(64, 5)\n",
220    "jacobian = vmap(jacrev(torch.sin))(x)\n",
221    "assert jacobian.shape == (64, 5, 5)"
222   ]
223  },
224  {
225   "cell_type": "markdown",
226   "id": "cda642ec",
227   "metadata": {},
228   "source": [
229    "`jacfwd` is a drop-in replacement for `jacrev` that computes Jacobians using forward-mode AD:"
230   ]
231  },
232  {
233   "cell_type": "code",
234   "execution_count": null,
235   "id": "a8c1dedb",
236   "metadata": {},
237   "outputs": [],
238   "source": [
239    "from functorch import jacfwd\n",
240    "x = torch.randn(5)\n",
241    "jacobian = jacfwd(torch.sin)(x)\n",
242    "expected = torch.diag(torch.cos(x))\n",
243    "assert torch.allclose(jacobian, expected)"
244   ]
245  },
246  {
247   "cell_type": "markdown",
248   "id": "39f85b50",
249   "metadata": {},
250   "source": [
251    "Composing `jacrev` with itself or `jacfwd` can produce hessians:"
252   ]
253  },
254  {
255   "cell_type": "code",
256   "execution_count": null,
257   "id": "1e511139",
258   "metadata": {},
259   "outputs": [],
260   "source": [
261    "def f(x):\n",
262    "  return x.sin().sum()\n",
263    "\n",
264    "x = torch.randn(5)\n",
265    "hessian0 = jacrev(jacrev(f))(x)\n",
266    "hessian1 = jacfwd(jacrev(f))(x)"
267   ]
268  },
269  {
270   "cell_type": "markdown",
271   "id": "18efdc65",
272   "metadata": {},
273   "source": [
274    "The `hessian` is a convenience function that combines `jacfwd` and `jacrev`:"
275   ]
276  },
277  {
278   "cell_type": "code",
279   "execution_count": null,
280   "id": "fd1765df",
281   "metadata": {},
282   "outputs": [],
283   "source": [
284    "from functorch import hessian\n",
285    "\n",
286    "def f(x):\n",
287    "  return x.sin().sum()\n",
288    "\n",
289    "x = torch.randn(5)\n",
290    "hess = hessian(f)(x)"
291   ]
292  },
293  {
294   "cell_type": "markdown",
295   "id": "b597d7ad",
296   "metadata": {},
297   "source": [
298    "## Conclusion\n",
299    "\n",
300    "Check out our other tutorials (in the left bar) for more detailed explanations of how to apply functorch transforms for various use cases. `functorch` is very much a work in progress and we'd love to hear how you're using it -- we encourage you to start a conversation at our [issues tracker](https://github.com/pytorch/functorch) to discuss your use case."
301   ]
302  }
303 ],
304 "metadata": {
305  "kernelspec": {
306   "display_name": "Python 3 (ipykernel)",
307   "language": "python",
308   "name": "python3"
309  },
310  "language_info": {
311   "codemirror_mode": {
312    "name": "ipython",
313    "version": 3
314   },
315   "file_extension": ".py",
316   "mimetype": "text/x-python",
317   "name": "python",
318   "nbconvert_exporter": "python",
319   "pygments_lexer": "ipython3",
320   "version": "3.7.4"
321  }
322 },
323 "nbformat": 4,
324 "nbformat_minor": 5
325}
326