xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/tutorial.md (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Lazy Tensor Tutorial
2
3## Introduction
4
5Lazy Tensor is a brand-new tracing system in PyTorch. It includes a safety guarantee not provided by other tracing systems (jit.trace) in that it retraces and recompiles if properties about the input change or uses a cached computation otherwise. It's easier to use than jit.trace and **much** easier to use than jit.script! Lazy Tensor traces both forward and backward passes and removes many Python features present in jit scripted and traced graphs
6that are difficult for hardware vendors to support.
7
8Let's kick off our introduction to Lazy Tensor with an example that illustrates the safety guarantee, as it's one of the biggest usability issues of jit.trace. Suppose we'd like to jit trace the following function.
9
10```python
11import torch
12
13def add_two_maybe(t: torch.Tensor, maybe: torch.Tensor):
14    if maybe:
15        return t + 2
16    return t
17```
18
19You may have noticed that `add_two_maybe` contains an if statement that depends on `maybe` input.
20Let's jit trace the function with the following inputs.
21
22```python
23t = torch.ones(1)
24maybe_false = torch.BoolTensor([0])
25good_inputs = (t, maybe_false)
26jit = torch.jit.trace(add_two_maybe, good_inputs)
27# let's check that the results match with eager
28assert jit(*good_inputs) == add_two_maybe(*good_inputs)
29```
30
31So far, so good! We successfully traced `add_two_maybe` into `jit` and running it gives us the same result as the original function.
32
33Our troubles start if we change the second input and re-run the traced function.
34
35```python
36maybe_true = torch.BoolTensor([1])
37assert jit(t, maybe_true) == add_two_maybe(t, maybe_true)
38```
39
40```shell
41Traceback (most recent call last):
42  File "/home/villedepommes/github/pytorch4/test/test_tutorial.py", line 27, in <module>
43    assert jit(t, maybe_true) == add_two_maybe(t, maybe_true)
44AssertionError
45```
46
47Uh oh?! What really happened here? Let's print out the graph for `jit`:
48
49
50```python
51
52print(torch.jit.last_executed_optimized_graph())
53
54# graph(%t : Tensor,
55#       %maybe : Tensor):
56#   %2 : Tensor = prim::profile[profiled_type=Float(1, strides=[1], requires_grad=0, device=cpu), seen_none=0](%t)
57#    = prim::profile()
58#   return (%2)
59```
60
61We could see that the if statement disappeared and jit trace only traced the `else` path. In fact, jit trace can trace **only** aten operations. It's completely oblivious to any control flow operations such as `if`, `for` or an exception.
62If this sounds unsafe to you, that's because it is!
63
64Let's now learn how we can solve this issue with Lazy Tensors.
65
66The first step is to move the inputs to the Lazy device. The Lazy device isn't any real hardware device. Your code still runs either on CPU or on GPU if you set `LTC_TS_CUDA="1"`.
67
68The lazy device is however very special: it makes PyTorch "remember" every aten operation (into a graph) the user calls rather than eagerly executing it. It's lazy that way ;) get it?
69
70So, the lazy device is an API that users should use to trace their models with Lazy Tensor. It's also a PyTorch device which is a very convenient way for implementing tracing based on PyTorch dispatcher.
71
72First of all, we need a little bit of setup. The Lazy Tensor needs a backend to actually run traced graphs. We implemented a TorchScript-based backend to give our users end-to-end experience running their models with Lazy Tensor. It also serves as an example for hardware vendors looking to integrate with Lazy Tensor.
73
74
75```python
76import torch._lazy
77import torch._lazy.ts_backend
78torch._lazy.ts_backend.init()
79```
80
81Now, we can run our example,
82
83```python
84dev = "lazy"
85t_lazy = torch.ones(1).to(dev)
86maybe_false_lazy = torch.BoolTensor([0]).to(dev)
87lazy_result = add_two_maybe(t_lazy, maybe_false_lazy)
88```
89
90This is pretty cool! Eventually, however, we would still like to execute our computation and access the result, wouldn't we?
91
92There are a few ways to do it. Typically, PyTorch transparently triggers the execution when the user tries to access the result e.g., print a tensor out, move the tensor to a non-lazy device, etc.
93
94Let's give it a try:
95
96```python
97lazy_result = add_two_maybe(t_lazy, maybe_false_lazy)
98print(lazy_result)
99assert lazy_result.cpu() == add_two_maybe(t, maybe_false)
100```
101
102This works as expected! Let's try the case jit trace couldn't handle.
103
104```python
105maybe_true_lazy = torch.BoolTensor([1]).to(dev)
106lazy_result = add_two_maybe(t_lazy, maybe_true_lazy)
107assert lazy_result.cpu() == add_two_maybe(t, maybe_true)
108```
109
110Woo-hoo! This works too!
111Unfortunately, this flexibility comes with a few downsides. Remember that backends need to translate aten ops into some much lower-level operations that an accelerator understands. The translation process may be time-consuming. Although, usually, it's well worth it!
112
113However, if a non-trivial model is wildly dynamic and contains loops that always run different number of times or if statements one after another that explode into different traces every time you run the model, the backend will spend non-trivial amount of time compiling each trace even though the latter is used only for a few times.
114
115Alright, at this point, you should have learned the main ideas behind Lazy Tensor, most common usage patterns and APIs.
116Also, you are hopefully as inspired and motivated about Lazy Tensor as I am.
117
118Let's see now how we can run a full training loop with an optimizer and backward pass! We will learn a few more important concepts and APIs.
119
120## MNIST MLP
121
122We will adapt the following example running MNIST_MLP from [pytorch/examples](https://github.com/pytorch/examples/blob/main/mnist/main.py)
123
124Note, you can access the full version of the script [here](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/test_mnist.py)
125
126First, we need to install one single dependency, `torchvision`
127
128```
129pip install torchvision
130```
131
132`torchvision` comes with MNIST dataset w/ images of handwritten digits, which we will be using for training.
133
134Here's our model definition:
135
136```python
137class Net(nn.Module):
138    def __init__(self) -> None:
139        super().__init__()
140        self.conv1 = nn.Conv2d(1, 32, 3, 1)
141        self.conv2 = nn.Conv2d(32, 64, 3, 1)
142        self.dropout1 = nn.Dropout(0.25)
143        self.dropout2 = nn.Dropout(0.5)
144        self.fc1 = nn.Linear(9216, 128)
145        self.fc2 = nn.Linear(128, 10)
146
147    def forward(self, x):
148        x = self.conv1(x)
149        x = F.relu(x)
150        x = self.conv2(x)
151        x = F.relu(x)
152        x = F.max_pool2d(x, 2)
153        x = self.dropout1(x)
154        x = torch.flatten(x, 1)
155        x = self.fc1(x)
156        x = F.relu(x)
157        x = self.dropout2(x)
158        x = self.fc2(x)
159        output = F.log_softmax(x, dim=1)
160        return output
161```
162
163We are using a multi-level perceptron model with two convolutions, two linear layers and activations sandwiched in between.
164
165Let's set up a loader that would feed the `MNIST` dataset in `train` to our model.
166We are going to run the training loop for 14 epochs which is what the original MNIST example uses.
167**Note, we had to move the model to the Lazy device, `Net().to(device)`. This is very similar to what we would have done had we been training this model on a GPU.**
168
169The rest of the code is pretty standard boilerplate.
170
171```python
172import torch
173import torch.nn as nn
174import torch.nn.functional as F
175import torch.optim as optim
176import os
177from torchvision import datasets, transforms
178from torch.optim.lr_scheduler import StepLR
179import torch._lazy
180import torch._lazy.ts_backend
181import torch._lazy.metrics
182torch._lazy.ts_backend.init()
183
184if __name__  == '__main__':
185    bsz = 64
186    device = 'lazy'
187    epochs = 14
188    log_interval = 10
189    lr = 1
190    gamma = 0.7
191    train_kwargs = {'batch_size': bsz}
192    # if we want to use CUDA
193    if "LTC_TS_CUDA" in os.environ:
194        cuda_kwargs = {'num_workers': 1,
195                       'pin_memory': True,
196                       'shuffle': True,
197                       'batch_size': bsz}
198        train_kwargs.update(cuda_kwargs)
199
200    transform=transforms.Compose([
201        transforms.ToTensor(),
202        transforms.Normalize((0.1307,), (0.3081,))
203        ])
204    dataset1 = datasets.MNIST('./data', train=True, download=True,
205                        transform=transform)
206    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
207    model = Net().to(device)
208    optimizer = optim.Adadelta(model.parameters(), lr=lr)
209    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
210    for epoch in range(1, epochs + 1):
211        train(log_interval, model, device, train_loader, optimizer, epoch)
212        scheduler.step()
213```
214
215The training loop in `train` also has one addition. Namely, `torch._lazy.mark_step()` which deserves some elaboration on our part. `mark_step()` instructs Lazy Tensor to break up the current trace and start executing it asynchronously. The current trace encompasses both forward and backward passes and provides the backends with the whole model graph w/o any pythonisms.
216If we don't stop the trace after `optimizer_step` it will include two or more iterations which is way more stuff for the backends to chew through without a whole lot of benefit.
217
218Another important point is that after `mark_step()` we actually continue tracing the next iteration! And... start executing the previous one at the same time! Really, nothing stops us from tracing the next iteration ...and then the one after next until we hit `if batch_idx % log_interval == 0:` where
219we actually need to wait for execution to catch up, so we can print out `loss`. Remember to avoid accessing intermediate results too often if you would like to extract the maximum benefit out of Lazy Tensor.
220
221Since every iteration looks exactly like the one before it, the TS backend will be re-using the same TS compilation.
222
223Alright, let's run it now!
224
225```python
226def train(log_interval, model, device, train_loader, optimizer, epoch):
227    model.train()
228    for batch_idx, (data, target) in enumerate(train_loader):
229        data, target = data.to(device), target.to(device)
230        optimizer.zero_grad(set_to_none=True)
231        output = model(data)
232        loss = F.nll_loss(output, target)
233        loss.backward()
234        optimizer.step()
235        torch._lazy.mark_step()
236
237        if batch_idx % log_interval == 0:
238            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
239                epoch, batch_idx * len(data), len(train_loader.dataset),
240                100. * batch_idx / len(train_loader), loss.item()))
241```
242
243
244After the script downloads the dataset, the model will be trained on the Lazy device as
245evidenced by the decreasing loss.
246
247```shell
248Train Epoch: 1 [0/60000 (0%)]   Loss: 2.343924
249Train Epoch: 1 [640/60000 (1%)] Loss: 1.760821
250Train Epoch: 1 [1280/60000 (2%)]        Loss: 0.802798
251Train Epoch: 1 [1920/60000 (3%)]        Loss: 0.856164
252Train Epoch: 1 [2560/60000 (4%)]        Loss: 0.568396
253Train Epoch: 1 [3200/60000 (5%)]        Loss: 0.399044
254Train Epoch: 1 [3840/60000 (6%)]        Loss: 0.457996
255Train Epoch: 1 [4480/60000 (7%)]        Loss: 0.285104
256Train Epoch: 1 [5120/60000 (9%)]        Loss: 0.193083
257Train Epoch: 1 [5760/60000 (10%)]       Loss: 0.486165
258Train Epoch: 1 [6400/60000 (11%)]       Loss: 0.163996
259Train Epoch: 1 [7040/60000 (12%)]       Loss: 0.200323
260
261```
262
263Let's briefly mention a few more APIs before we wrap this up. Unfortunately, LT is still very early in its development which means it doesn't implement every single PyTorch op out of there.
264In fact, we implement about a hundred most common ops. What happens if a model contains an op that LT does **not** implement. Lazy Tensor transparently (from a user) breaks up the current trace, waits until all inputs to the op are computed, computes the op on some different device, and finally moves the results onto the lazy device again and starts a new trace.
265This big-little wrinkle means that *sometimes* LT can **not** give the backend a whole model graph which may have a negative impact on performance. You could get the list of the ops that LT could handle for your model by adding the following to your model:
266
267```python
268torch._lazy.metrics.reset()
269train(...)
270print(torch._lazy.metrics.counter_names())
271```
272
273If you are seeing any ops with the prefix: `aten::`
274
275*Sometimes* you could replace such ops with similar that LT does support. More often than not, we will have to just live with it until LT matures.
276
277Another handy API is `torch._lazy.wait_device_ops()`. Remember, we said that `mark_step()` breaks up the current trace and kicks off a computation asynchronously? If downstream there are no blocking operations such as `print`, `item()`, `to`, LT will happily continue tracing.
278If you would like to time how much exactly time computation and tracing took for some model without including device transfers or printing, you could stick `torch._lazy.wait_device_ops()` and `time.perf_counter()` right after it. Don't forget another `time.perf_counter()` before the trace start!
279
280This concludes our brief introduction to LT. Hopefully, you'll remember the main takeaways:
281
282* Backends prefer bigger graphs that preferably include both forward and backward as there's ample opportunity for performance optimizations
283* It's really tricky to produce such graphs without overburdening a user too much. Think, torch.jit.script, torch.jit.trace! Also, think ifs, fors, "Lions, and Tigers, and Bears, Oh My" We digressed.
284
285
286Please give LT a try and tell us what you think on GitHub! We are **eager, not lazy** (haha!) to hear from you!
287