1# Lazy Tensor Tutorial 2 3## Introduction 4 5Lazy Tensor is a brand-new tracing system in PyTorch. It includes a safety guarantee not provided by other tracing systems (jit.trace) in that it retraces and recompiles if properties about the input change or uses a cached computation otherwise. It's easier to use than jit.trace and **much** easier to use than jit.script! Lazy Tensor traces both forward and backward passes and removes many Python features present in jit scripted and traced graphs 6that are difficult for hardware vendors to support. 7 8Let's kick off our introduction to Lazy Tensor with an example that illustrates the safety guarantee, as it's one of the biggest usability issues of jit.trace. Suppose we'd like to jit trace the following function. 9 10```python 11import torch 12 13def add_two_maybe(t: torch.Tensor, maybe: torch.Tensor): 14 if maybe: 15 return t + 2 16 return t 17``` 18 19You may have noticed that `add_two_maybe` contains an if statement that depends on `maybe` input. 20Let's jit trace the function with the following inputs. 21 22```python 23t = torch.ones(1) 24maybe_false = torch.BoolTensor([0]) 25good_inputs = (t, maybe_false) 26jit = torch.jit.trace(add_two_maybe, good_inputs) 27# let's check that the results match with eager 28assert jit(*good_inputs) == add_two_maybe(*good_inputs) 29``` 30 31So far, so good! We successfully traced `add_two_maybe` into `jit` and running it gives us the same result as the original function. 32 33Our troubles start if we change the second input and re-run the traced function. 34 35```python 36maybe_true = torch.BoolTensor([1]) 37assert jit(t, maybe_true) == add_two_maybe(t, maybe_true) 38``` 39 40```shell 41Traceback (most recent call last): 42 File "/home/villedepommes/github/pytorch4/test/test_tutorial.py", line 27, in <module> 43 assert jit(t, maybe_true) == add_two_maybe(t, maybe_true) 44AssertionError 45``` 46 47Uh oh?! What really happened here? Let's print out the graph for `jit`: 48 49 50```python 51 52print(torch.jit.last_executed_optimized_graph()) 53 54# graph(%t : Tensor, 55# %maybe : Tensor): 56# %2 : Tensor = prim::profile[profiled_type=Float(1, strides=[1], requires_grad=0, device=cpu), seen_none=0](%t) 57# = prim::profile() 58# return (%2) 59``` 60 61We could see that the if statement disappeared and jit trace only traced the `else` path. In fact, jit trace can trace **only** aten operations. It's completely oblivious to any control flow operations such as `if`, `for` or an exception. 62If this sounds unsafe to you, that's because it is! 63 64Let's now learn how we can solve this issue with Lazy Tensors. 65 66The first step is to move the inputs to the Lazy device. The Lazy device isn't any real hardware device. Your code still runs either on CPU or on GPU if you set `LTC_TS_CUDA="1"`. 67 68The lazy device is however very special: it makes PyTorch "remember" every aten operation (into a graph) the user calls rather than eagerly executing it. It's lazy that way ;) get it? 69 70So, the lazy device is an API that users should use to trace their models with Lazy Tensor. It's also a PyTorch device which is a very convenient way for implementing tracing based on PyTorch dispatcher. 71 72First of all, we need a little bit of setup. The Lazy Tensor needs a backend to actually run traced graphs. We implemented a TorchScript-based backend to give our users end-to-end experience running their models with Lazy Tensor. It also serves as an example for hardware vendors looking to integrate with Lazy Tensor. 73 74 75```python 76import torch._lazy 77import torch._lazy.ts_backend 78torch._lazy.ts_backend.init() 79``` 80 81Now, we can run our example, 82 83```python 84dev = "lazy" 85t_lazy = torch.ones(1).to(dev) 86maybe_false_lazy = torch.BoolTensor([0]).to(dev) 87lazy_result = add_two_maybe(t_lazy, maybe_false_lazy) 88``` 89 90This is pretty cool! Eventually, however, we would still like to execute our computation and access the result, wouldn't we? 91 92There are a few ways to do it. Typically, PyTorch transparently triggers the execution when the user tries to access the result e.g., print a tensor out, move the tensor to a non-lazy device, etc. 93 94Let's give it a try: 95 96```python 97lazy_result = add_two_maybe(t_lazy, maybe_false_lazy) 98print(lazy_result) 99assert lazy_result.cpu() == add_two_maybe(t, maybe_false) 100``` 101 102This works as expected! Let's try the case jit trace couldn't handle. 103 104```python 105maybe_true_lazy = torch.BoolTensor([1]).to(dev) 106lazy_result = add_two_maybe(t_lazy, maybe_true_lazy) 107assert lazy_result.cpu() == add_two_maybe(t, maybe_true) 108``` 109 110Woo-hoo! This works too! 111Unfortunately, this flexibility comes with a few downsides. Remember that backends need to translate aten ops into some much lower-level operations that an accelerator understands. The translation process may be time-consuming. Although, usually, it's well worth it! 112 113However, if a non-trivial model is wildly dynamic and contains loops that always run different number of times or if statements one after another that explode into different traces every time you run the model, the backend will spend non-trivial amount of time compiling each trace even though the latter is used only for a few times. 114 115Alright, at this point, you should have learned the main ideas behind Lazy Tensor, most common usage patterns and APIs. 116Also, you are hopefully as inspired and motivated about Lazy Tensor as I am. 117 118Let's see now how we can run a full training loop with an optimizer and backward pass! We will learn a few more important concepts and APIs. 119 120## MNIST MLP 121 122We will adapt the following example running MNIST_MLP from [pytorch/examples](https://github.com/pytorch/examples/blob/main/mnist/main.py) 123 124Note, you can access the full version of the script [here](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/test_mnist.py) 125 126First, we need to install one single dependency, `torchvision` 127 128``` 129pip install torchvision 130``` 131 132`torchvision` comes with MNIST dataset w/ images of handwritten digits, which we will be using for training. 133 134Here's our model definition: 135 136```python 137class Net(nn.Module): 138 def __init__(self) -> None: 139 super().__init__() 140 self.conv1 = nn.Conv2d(1, 32, 3, 1) 141 self.conv2 = nn.Conv2d(32, 64, 3, 1) 142 self.dropout1 = nn.Dropout(0.25) 143 self.dropout2 = nn.Dropout(0.5) 144 self.fc1 = nn.Linear(9216, 128) 145 self.fc2 = nn.Linear(128, 10) 146 147 def forward(self, x): 148 x = self.conv1(x) 149 x = F.relu(x) 150 x = self.conv2(x) 151 x = F.relu(x) 152 x = F.max_pool2d(x, 2) 153 x = self.dropout1(x) 154 x = torch.flatten(x, 1) 155 x = self.fc1(x) 156 x = F.relu(x) 157 x = self.dropout2(x) 158 x = self.fc2(x) 159 output = F.log_softmax(x, dim=1) 160 return output 161``` 162 163We are using a multi-level perceptron model with two convolutions, two linear layers and activations sandwiched in between. 164 165Let's set up a loader that would feed the `MNIST` dataset in `train` to our model. 166We are going to run the training loop for 14 epochs which is what the original MNIST example uses. 167**Note, we had to move the model to the Lazy device, `Net().to(device)`. This is very similar to what we would have done had we been training this model on a GPU.** 168 169The rest of the code is pretty standard boilerplate. 170 171```python 172import torch 173import torch.nn as nn 174import torch.nn.functional as F 175import torch.optim as optim 176import os 177from torchvision import datasets, transforms 178from torch.optim.lr_scheduler import StepLR 179import torch._lazy 180import torch._lazy.ts_backend 181import torch._lazy.metrics 182torch._lazy.ts_backend.init() 183 184if __name__ == '__main__': 185 bsz = 64 186 device = 'lazy' 187 epochs = 14 188 log_interval = 10 189 lr = 1 190 gamma = 0.7 191 train_kwargs = {'batch_size': bsz} 192 # if we want to use CUDA 193 if "LTC_TS_CUDA" in os.environ: 194 cuda_kwargs = {'num_workers': 1, 195 'pin_memory': True, 196 'shuffle': True, 197 'batch_size': bsz} 198 train_kwargs.update(cuda_kwargs) 199 200 transform=transforms.Compose([ 201 transforms.ToTensor(), 202 transforms.Normalize((0.1307,), (0.3081,)) 203 ]) 204 dataset1 = datasets.MNIST('./data', train=True, download=True, 205 transform=transform) 206 train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 207 model = Net().to(device) 208 optimizer = optim.Adadelta(model.parameters(), lr=lr) 209 scheduler = StepLR(optimizer, step_size=1, gamma=gamma) 210 for epoch in range(1, epochs + 1): 211 train(log_interval, model, device, train_loader, optimizer, epoch) 212 scheduler.step() 213``` 214 215The training loop in `train` also has one addition. Namely, `torch._lazy.mark_step()` which deserves some elaboration on our part. `mark_step()` instructs Lazy Tensor to break up the current trace and start executing it asynchronously. The current trace encompasses both forward and backward passes and provides the backends with the whole model graph w/o any pythonisms. 216If we don't stop the trace after `optimizer_step` it will include two or more iterations which is way more stuff for the backends to chew through without a whole lot of benefit. 217 218Another important point is that after `mark_step()` we actually continue tracing the next iteration! And... start executing the previous one at the same time! Really, nothing stops us from tracing the next iteration ...and then the one after next until we hit `if batch_idx % log_interval == 0:` where 219we actually need to wait for execution to catch up, so we can print out `loss`. Remember to avoid accessing intermediate results too often if you would like to extract the maximum benefit out of Lazy Tensor. 220 221Since every iteration looks exactly like the one before it, the TS backend will be re-using the same TS compilation. 222 223Alright, let's run it now! 224 225```python 226def train(log_interval, model, device, train_loader, optimizer, epoch): 227 model.train() 228 for batch_idx, (data, target) in enumerate(train_loader): 229 data, target = data.to(device), target.to(device) 230 optimizer.zero_grad(set_to_none=True) 231 output = model(data) 232 loss = F.nll_loss(output, target) 233 loss.backward() 234 optimizer.step() 235 torch._lazy.mark_step() 236 237 if batch_idx % log_interval == 0: 238 print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 239 epoch, batch_idx * len(data), len(train_loader.dataset), 240 100. * batch_idx / len(train_loader), loss.item())) 241``` 242 243 244After the script downloads the dataset, the model will be trained on the Lazy device as 245evidenced by the decreasing loss. 246 247```shell 248Train Epoch: 1 [0/60000 (0%)] Loss: 2.343924 249Train Epoch: 1 [640/60000 (1%)] Loss: 1.760821 250Train Epoch: 1 [1280/60000 (2%)] Loss: 0.802798 251Train Epoch: 1 [1920/60000 (3%)] Loss: 0.856164 252Train Epoch: 1 [2560/60000 (4%)] Loss: 0.568396 253Train Epoch: 1 [3200/60000 (5%)] Loss: 0.399044 254Train Epoch: 1 [3840/60000 (6%)] Loss: 0.457996 255Train Epoch: 1 [4480/60000 (7%)] Loss: 0.285104 256Train Epoch: 1 [5120/60000 (9%)] Loss: 0.193083 257Train Epoch: 1 [5760/60000 (10%)] Loss: 0.486165 258Train Epoch: 1 [6400/60000 (11%)] Loss: 0.163996 259Train Epoch: 1 [7040/60000 (12%)] Loss: 0.200323 260 261``` 262 263Let's briefly mention a few more APIs before we wrap this up. Unfortunately, LT is still very early in its development which means it doesn't implement every single PyTorch op out of there. 264In fact, we implement about a hundred most common ops. What happens if a model contains an op that LT does **not** implement. Lazy Tensor transparently (from a user) breaks up the current trace, waits until all inputs to the op are computed, computes the op on some different device, and finally moves the results onto the lazy device again and starts a new trace. 265This big-little wrinkle means that *sometimes* LT can **not** give the backend a whole model graph which may have a negative impact on performance. You could get the list of the ops that LT could handle for your model by adding the following to your model: 266 267```python 268torch._lazy.metrics.reset() 269train(...) 270print(torch._lazy.metrics.counter_names()) 271``` 272 273If you are seeing any ops with the prefix: `aten::` 274 275*Sometimes* you could replace such ops with similar that LT does support. More often than not, we will have to just live with it until LT matures. 276 277Another handy API is `torch._lazy.wait_device_ops()`. Remember, we said that `mark_step()` breaks up the current trace and kicks off a computation asynchronously? If downstream there are no blocking operations such as `print`, `item()`, `to`, LT will happily continue tracing. 278If you would like to time how much exactly time computation and tracing took for some model without including device transfers or printing, you could stick `torch._lazy.wait_device_ops()` and `time.perf_counter()` right after it. Don't forget another `time.perf_counter()` before the trace start! 279 280This concludes our brief introduction to LT. Hopefully, you'll remember the main takeaways: 281 282* Backends prefer bigger graphs that preferably include both forward and backward as there's ample opportunity for performance optimizations 283* It's really tricky to produce such graphs without overburdening a user too much. Think, torch.jit.script, torch.jit.trace! Also, think ifs, fors, "Lions, and Tigers, and Bears, Oh My" We digressed. 284 285 286Please give LT a try and tell us what you think on GitHub! We are **eager, not lazy** (haha!) to hear from you! 287