1 2# JIT scripting & Autocast 3 4<!-- @import "[TOC]" {cmd="toc" depthFrom=2 depthTo=6 orderedList=false} --> 5 6<!-- code_chunk_output --> 7 8- [Overview](#overview) 9- [Usage](#usage) 10- [Known limitations](#known-limitations) 11 - [Diagnostics](#diagnostics) 12 - [Autocast decorators](#autocast-decorators) 13 - [Autocast argument must be a compile-time constant](#autocast-argument-must-be-a-compile-time-constant) 14 - [Uncommon autocast usage patterns may not be supported](#uncommon-autocast-usage-patterns-may-not-be-supported) 15 - [Limited support for promote autocast policy](#limited-support-for-promote-autocast-policy) 16 - [Missing autocast policies](#missing-autocast-policies) 17 - [Mixing eager mode and scripting autocast](#mixing-eager-mode-and-scripting-autocast) 18 - [Mixing tracing and scripting autocast (script calling traced)](#mixing-tracing-and-scripting-autocast-script-calling-traced) 19 - [Mixing tracing and scripting autocast (traced calling script)](#mixing-tracing-and-scripting-autocast-traced-calling-script) 20 - [Disabling eager autocast with scripted autocast](#disabling-eager-autocast-with-scripted-autocast) 21- [References](#references) 22 23<!-- /code_chunk_output --> 24 25## Overview 26 27[Autocast][2] (aka Automatic Mixed Precision) is an optimization which helps 28taking advantage of the storage and performance benefits of narrow types 29(float16) while preserving the additional range and numerical precision of 30float32. 31 32The JIT support for autocast is subject to different constraints compared to the 33eager mode implementation (mostly related to the fact that TorchScript is 34statically typed) and this document attempts to list the known limitations. 35 36## Usage 37 38Explicit `with autocast()` scopes are supported inside scripted functions and 39modules (subject to the limitations described below): 40 41```python 42import torch 43from torch.cuda.amp import autocast 44 45@torch.jit.script 46def func(a, b): 47 with autocast(): 48 return torch.mm(a, b) 49 50a_float32 = torch.rand((8, 8), dtype=torch.float32, device="cuda") 51b_float32 = torch.rand((8, 8), dtype=torch.float32, device="cuda") 52result = func(a_float32, b_float32) 53print(result.dtype) # expecting torch.float16 54``` 55 56## Known limitations 57 58This section documents the current set of known limitations. Ideally this list 59will shrink as we advance with the design and implementation, although some of 60the limitations are related to fundamental TorchScript aspects that are not easy 61to change. 62 63> One important goal is to avoid surprises (ex. autocast annotations 64> silently ignored) and to report sensible diagnostics when something deviates 65> from eager mode behavior. 66> 67> Please [report](https://github.com/csarofeen/pytorch/issues/new/choose) any 68> issues not covered here. 69 70#### Diagnostics 71 72The current Autocast/JIT diagnostics should be improved: 73- Some errors are not specific enough or not actionable 74- Not all the errors point to the Python source location 75 76#### Autocast decorators 77 78Using `@autocast` is not currently supported in script mode (a diagnostic 79will be emitted) 80 81```python 82import torch 83from torch.cpu.amp import autocast 84 85@autocast(enabled=True) 86def helper(x): 87 ... 88 89@torch.jit.script 90def foo(x): 91 return helper(x) # not supported 92``` 93 94Another example 95 96```python 97import torch 98from torch.cpu.amp import autocast 99 100@torch.jit.script 101@autocast() # not supported 102def foo(a, b, c, d): 103 ... 104``` 105 106#### Autocast argument must be a compile-time constant 107 108```python 109import torch 110from torch.cpu.amp import autocast 111 112@torch.jit.script 113def fn(a, b, use_amp: bool): 114 # runtime values for autocast enable argument are not supported 115 with autocast(enabled=use_amp): 116 return torch.mm(a, b) 117 118``` 119 120#### Uncommon autocast usage patterns may not be supported 121 122```python 123import torch 124from torch.cpu.amp import autocast 125 126@torch.jit.script 127def fn(a, b, c, d): 128 with autocast(enabled=True) as autocast_instance: # not supported 129 ... 130 with autocast_instance: 131 ... 132``` 133 134#### Limited support for promote autocast policy 135 136For some operations, autocast needs to [promote to the widest argument type][3]. 137When the concrete types are not available, the current implementation will 138conservatively inject a promotion even when it may not be needed. 139 140#### Missing autocast policies 141 142Also related to the lack of concrete dtype availability, a few specialized 143autocast policies are not yet supported with JIT scripting: 144- [CastPolicy::fp32_append_dtype][5] 145 146#### Mixing tracing and scripting autocast (script calling traced) 147 148Calling a traced function from a scripted one mostly works, except for the case 149where the traced part uses `autocast(False)`. After tracing, the `autocast` is 150stripped from the TorchScript IR so it's effectively ignored: 151 152> This is one known limitation where we don't have a way to emit a diagnostic! 153 154```python 155import torch 156from torch.cpu.amp import autocast 157 158def helper(a, b): 159 with autocast(enabled=False): 160 return torch.mm(a, b) * 2.0 161 162traced = torch.jit.trace(helper, (x, y)) 163 164@torch.jit.script 165def fn(a, b): 166 with autocast(enabled=True): 167 return traced(a, b) 168``` 169 170#### Mixing tracing and scripting autocast (traced calling script) 171 172Calling a scripted function from a trace is similar to calling the scripted 173function from eager mode: 174 175```python 176import torch 177from torch.cpu.amp import autocast 178 179@torch.jit.script 180def fn(a, b): 181 return torch.mm(a, b) 182 183def traced(a, b): 184 with autocast(enabled=True): 185 return fn(a, b) 186 187# running TorchScript with Autocast enabled is not supported 188torch.jit.trace(traced, (x, y)) 189``` 190 191#### Disabling eager autocast with scripted autocast 192 193If eager-mode autocast is enabled and we try to disable autocasting from 194within a scripted function, autocasting will still occur. 195 196```python 197import torch 198from torch.cuda.amp import autocast 199 200@torch.jit.script 201def fn(a, b): 202 with autocast(enabled=False): 203 return torch.mm(a, b) 204 205x = torch.rand((2, 2), device='cuda', dtype=torch.float) 206y = torch.rand((2, 2), device='cuda', dtype=torch.float) 207 208# this will print half-precision dtype 209with autocast(enabled=True): 210 print(fn(x, y).dtype) 211``` 212 213## References 214 215- [torch.cuda.amp Package][1] 216- [Automatic Mixed Precision - Tutorial](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html) 217- [Automatic Mixed Precision - Examples](https://pytorch.org/docs/stable/notes/amp_examples.html) 218 219[1]: https://pytorch.org/docs/stable/amp.html 220[2]: https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/ 221[3]: https://pytorch.org/docs/stable/amp.html#ops-that-promote-to-the-widest-input-type 222[4]: https://github.com/csarofeen/pytorch/blob/4d8575604ad9fa5fdfc21037490a041d8d43bcae/aten/src/ATen/autocast_mode.cpp#L94 223[5]: https://github.com/csarofeen/pytorch/blob/4d8575604ad9fa5fdfc21037490a041d8d43bcae/aten/src/ATen/autocast_mode.cpp#L99 224[6]: https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html#adding-autocast 225