xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/JIT-AUTOCAST.md (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1
2# JIT scripting & Autocast
3
4<!-- @import "[TOC]" {cmd="toc" depthFrom=2 depthTo=6 orderedList=false} -->
5
6<!-- code_chunk_output -->
7
8- [Overview](#overview)
9- [Usage](#usage)
10- [Known limitations](#known-limitations)
11    - [Diagnostics](#diagnostics)
12    - [Autocast decorators](#autocast-decorators)
13    - [Autocast argument must be a compile-time constant](#autocast-argument-must-be-a-compile-time-constant)
14    - [Uncommon autocast usage patterns may not be supported](#uncommon-autocast-usage-patterns-may-not-be-supported)
15    - [Limited support for promote autocast policy](#limited-support-for-promote-autocast-policy)
16    - [Missing autocast policies](#missing-autocast-policies)
17    - [Mixing eager mode and scripting autocast](#mixing-eager-mode-and-scripting-autocast)
18    - [Mixing tracing and scripting autocast (script calling traced)](#mixing-tracing-and-scripting-autocast-script-calling-traced)
19    - [Mixing tracing and scripting autocast (traced calling script)](#mixing-tracing-and-scripting-autocast-traced-calling-script)
20    - [Disabling eager autocast with scripted autocast](#disabling-eager-autocast-with-scripted-autocast)
21- [References](#references)
22
23<!-- /code_chunk_output -->
24
25## Overview
26
27[Autocast][2] (aka Automatic Mixed Precision) is an optimization which helps
28taking advantage of the storage and performance benefits of narrow types
29(float16) while preserving the additional range and numerical precision of
30float32.
31
32The JIT support for autocast is subject to different constraints compared to the
33eager mode implementation (mostly related to the fact that TorchScript is
34statically typed) and this document attempts to list the known limitations.
35
36## Usage
37
38Explicit `with autocast()` scopes are supported inside scripted functions and
39modules (subject to the limitations described below):
40
41```python
42import torch
43from torch.cuda.amp import autocast
44
45@torch.jit.script
46def func(a, b):
47    with autocast():
48        return torch.mm(a, b)
49
50a_float32 = torch.rand((8, 8), dtype=torch.float32, device="cuda")
51b_float32 = torch.rand((8, 8), dtype=torch.float32, device="cuda")
52result = func(a_float32, b_float32)
53print(result.dtype) # expecting torch.float16
54```
55
56## Known limitations
57
58This section documents the current set of known limitations. Ideally this list
59will shrink as we advance with the design and implementation, although some of
60the limitations are related to fundamental TorchScript aspects that are not easy
61to change.
62
63> One important goal is to avoid surprises (ex. autocast annotations
64> silently ignored) and to report sensible diagnostics when something deviates
65> from eager mode behavior.
66>
67> Please [report](https://github.com/csarofeen/pytorch/issues/new/choose) any
68> issues not covered here.
69
70#### Diagnostics
71
72The current Autocast/JIT diagnostics should be improved:
73- Some errors are not specific enough or not actionable
74- Not all the errors point to the Python source location
75
76#### Autocast decorators
77
78Using `@autocast` is not currently supported in script mode (a diagnostic
79will be emitted)
80
81```python
82import torch
83from torch.cpu.amp import autocast
84
85@autocast(enabled=True)
86def helper(x):
87    ...
88
89@torch.jit.script
90def foo(x):
91    return helper(x) # not supported
92```
93
94Another example
95
96```python
97import torch
98from torch.cpu.amp import autocast
99
100@torch.jit.script
101@autocast() # not supported
102def foo(a, b, c, d):
103    ...
104```
105
106#### Autocast argument must be a compile-time constant
107
108```python
109import torch
110from torch.cpu.amp import autocast
111
112@torch.jit.script
113def fn(a, b, use_amp: bool):
114    # runtime values for autocast enable argument are not supported
115    with autocast(enabled=use_amp):
116        return torch.mm(a, b)
117
118```
119
120#### Uncommon autocast usage patterns may not be supported
121
122```python
123import torch
124from torch.cpu.amp import autocast
125
126@torch.jit.script
127def fn(a, b, c, d):
128    with autocast(enabled=True) as autocast_instance: # not supported
129        ...
130        with autocast_instance:
131            ...
132```
133
134#### Limited support for promote autocast policy
135
136For some operations, autocast needs to [promote to the widest argument type][3].
137When the concrete types are not available, the current implementation will
138conservatively inject a promotion even when it may not be needed.
139
140#### Missing autocast policies
141
142Also related to the lack of concrete dtype availability, a few specialized
143autocast policies are not yet supported with JIT scripting:
144- [CastPolicy::fp32_append_dtype][5]
145
146#### Mixing tracing and scripting autocast (script calling traced)
147
148Calling a traced function from a scripted one mostly works, except for the case
149where the traced part uses `autocast(False)`. After tracing, the `autocast` is
150stripped from the TorchScript IR so it's effectively ignored:
151
152> This is one known limitation where we don't have a way to emit a diagnostic!
153
154```python
155import torch
156from torch.cpu.amp import autocast
157
158def helper(a, b):
159    with autocast(enabled=False):
160        return torch.mm(a, b) * 2.0
161
162traced = torch.jit.trace(helper, (x, y))
163
164@torch.jit.script
165def fn(a, b):
166    with autocast(enabled=True):
167        return traced(a, b)
168```
169
170#### Mixing tracing and scripting autocast (traced calling script)
171
172Calling a scripted function from a trace is similar to calling the scripted
173function from eager mode:
174
175```python
176import torch
177from torch.cpu.amp import autocast
178
179@torch.jit.script
180def fn(a, b):
181    return torch.mm(a, b)
182
183def traced(a, b):
184    with autocast(enabled=True):
185        return fn(a, b)
186
187# running TorchScript with Autocast enabled is not supported
188torch.jit.trace(traced, (x, y))
189```
190
191#### Disabling eager autocast with scripted autocast
192
193If eager-mode autocast is enabled and we try to disable autocasting from
194within a scripted function, autocasting will still occur.
195
196```python
197import torch
198from torch.cuda.amp import autocast
199
200@torch.jit.script
201def fn(a, b):
202    with autocast(enabled=False):
203        return torch.mm(a, b)
204
205x = torch.rand((2, 2), device='cuda', dtype=torch.float)
206y = torch.rand((2, 2), device='cuda', dtype=torch.float)
207
208# this will print half-precision dtype
209with autocast(enabled=True):
210    print(fn(x, y).dtype)
211```
212
213## References
214
215- [torch.cuda.amp Package][1]
216- [Automatic Mixed Precision - Tutorial](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html)
217- [Automatic Mixed Precision - Examples](https://pytorch.org/docs/stable/notes/amp_examples.html)
218
219[1]: https://pytorch.org/docs/stable/amp.html
220[2]: https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/
221[3]: https://pytorch.org/docs/stable/amp.html#ops-that-promote-to-the-widest-input-type
222[4]: https://github.com/csarofeen/pytorch/blob/4d8575604ad9fa5fdfc21037490a041d8d43bcae/aten/src/ATen/autocast_mode.cpp#L94
223[5]: https://github.com/csarofeen/pytorch/blob/4d8575604ad9fa5fdfc21037490a041d8d43bcae/aten/src/ATen/autocast_mode.cpp#L99
224[6]: https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html#adding-autocast
225