xref: /aosp_15_r20/external/pytorch/docs/source/torch.compiler_dynamo_overview.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1Dynamo Overview
2===============
3
4Before you read this section, read :ref:`torch.compiler_overview`.
5
6TorchDynamo (or simply Dynamo) is a Python-level Just-In-Time (JIT) compiler designed to make
7unmodified PyTorch programs faster. Dynamo hooks into the frame evaluation
8API in CPython (`PEP 523 <https://peps.python.org/pep-0523/>`__) to
9dynamically modify Python bytecode right before it is executed. It
10rewrites Python bytecode to extract sequences of PyTorch
11operations into an `FX Graph <https://pytorch.org/docs/stable/fx.html>`__
12which is then compiled with a customizable backend.
13It creates this FX Graph through bytecode analysis and is designed to
14mix Python execution with compiled backends to get the best of both
15worlds — usability and performance.
16
17Dynamo makes it easy to experiment with different compiler
18backends to make PyTorch code faster with a single line decorator
19``torch._dynamo.optimize()`` which is wrapped for convenience by ``torch.compile()``
20
21The following diagram demonstrates how PyTorch works with ``torch.compile``
22and without it:
23
24.. image:: _static/img/dynamo/TorchDynamo.png
25
26`TorchInductor` is one of the backends
27supported by `Dynamo Graph <https://pytorch.org/docs/stable/fx.html>`__
28into `Triton <https://github.com/openai/triton>`__ for GPUs or
29`C++/OpenMP <https://www.openmp.org/>`__ for CPUs. We have a
30`training performance dashboard <https://github.com/pytorch/torchdynamo/issues/681#issuecomment-1233828468>`__
31that provides performance comparison for different training backends. You can read
32more in the `TorchInductor post on PyTorch
33dev-discuss <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__.
34
35For an in-depth overview, read the sections below, watch the deep-dive video,
36and check out the dev-discuss topics.
37
38   * `Dynamo deep-dive video <https://www.youtube.com/watch?v=egZB5Uxki0I>`__
39   * `dev-discuss topics <https://dev-discuss.pytorch.org/search?q=TorchDynamo%20order%3Alatest>`__
40
41Dynamo Internals
42~~~~~~~~~~~~~~~~
43**Author**: `Jason Ansel <https://github.com/jansel>`_ and `Kaichao You <https://github.com/youkaichao>`_
44
45This section will go over some of the Dynamo internals and will
46demonstrate how Dynamo works under the hood.
47
48What is a guard?
49----------------
50
51Dynamo operates just-in-time and specializes graphs based on
52dynamic properties. Below is a basic example of how to use Dynamo.
53One can decorate a function or a method using ``torchdynamo.optimize`` to enable
54Dynamo optimization:
55
56.. code-block:: python
57
58   from typing import List
59   import torch
60   from torch import _dynamo as torchdynamo
61   def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
62       print("my_compiler() called with FX graph:")
63       gm.graph.print_tabular()
64       return gm.forward  # return a python callable
65
66   @torchdynamo.optimize(my_compiler)
67   def toy_example(a, b):
68       x = a / (torch.abs(a) + 1)
69       if b.sum() < 0:
70           b = b * -1
71       return x * b
72   for _ in range(100):
73       toy_example(torch.randn(10), torch.randn(10))
74
75For example, the first graph above has the following
76guards:
77
78::
79
80   GUARDS:
81   hasattr(L['a'], '_dynamo_dynamic_indices') == False
82   hasattr(L['b'], '_dynamo_dynamic_indices') == False
83   utils_device.CURRENT_DEVICE == None
84   ___skip_backend_check() or ___current_backend() == ___lookup_backend(140355900538256)
85   check_tensor(L['a'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1])
86   check_tensor(L['b'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1])
87
88If any of those guards fail, the graph will be recaptured and
89recompiled. The interesting guard there is ``check_tensor``, which
90checks the following ``torch.Tensor`` properties:
91
92- Python class of the tensor (tensor subclassing, etc)
93- dtype
94- device
95- requires_grad
96- dispatch_key (with thread-local includes/excludes applied)
97- ndim
98- sizes\*
99- strides\*
100
101The full specialization mode allows the backend compiler to assume an
102entirely static graph. Unfortunately, most backends require this.
103Operators which return dynamic shapes will trigger a graph break when
104not in dynamic shape mode.
105
106What is Dynamo doing?
107---------------------
108
109If you want to understand better what Dynamo is doing, you can run your code with:
110
111::
112
113   TORCH_LOGS="+dynamo,guards,bytecode"
114
115If you are not familiar with Python bytecode, you can add a decompiler hook
116to decompile the bytecode into human-readable source code. One available
117tool is `depyf <https://github.com/youkaichao/depyf>`__. If you don't have
118``depyf`` already installed, run ``pip install depyf``. Then, add the
119following code to install decompilation hooks before you run any code.
120
121.. code-block:: python
122
123   import depyf
124   depyf.install()
125
126This code triggers useful (but spammy) printouts.
127
128For example, the printouts for the first graph in the ``toy_example``
129are:
130
131::
132
133   __compiled_fn_0 <eval_with_key>.1
134   opcode         name     target                                                  args              kwargs
135   -------------  -------  ------------------------------------------------------  ----------------  --------
136   placeholder    a        a                                                       ()                {}
137   placeholder    b        b                                                       ()                {}
138   call_function  abs_1    <built-in method abs of type object at 0x7f9ca082f8a0>  (a,)              {}
139   call_function  add      <built-in function add>                                 (abs_1, 1)        {}
140   call_function  truediv  <built-in function truediv>                             (a, add)          {}
141   call_method    sum_1    sum                                                     (b,)              {}
142   call_function  lt       <built-in function lt>                                  (sum_1, 0)        {}
143   output         output   output                                                  ((truediv, lt),)  {}
144
145   ORIGINAL BYTECODE toy_example example.py line 12
146    14           0 LOAD_FAST                0 (a)
147                 2 LOAD_GLOBAL              0 (torch)
148                 4 LOAD_METHOD              1 (abs)
149                 6 LOAD_FAST                0 (a)
150                 8 CALL_METHOD              1
151                10 LOAD_CONST               1 (1)
152                12 BINARY_ADD
153                14 BINARY_TRUE_DIVIDE
154                16 STORE_FAST               2 (x)
155
156    15          18 LOAD_FAST                1 (b)
157                20 LOAD_METHOD              2 (sum)
158                22 CALL_METHOD              0
159                24 LOAD_CONST               2 (0)
160                26 COMPARE_OP               0 (<)
161                28 POP_JUMP_IF_FALSE       19 (to 38)
162
163    16          30 LOAD_FAST                1 (b)
164                32 LOAD_CONST               3 (-1)
165                34 BINARY_MULTIPLY
166                36 STORE_FAST               1 (b)
167
168    17     >>   38 LOAD_FAST                2 (x)
169                40 LOAD_FAST                1 (b)
170                42 BINARY_MULTIPLY
171                44 RETURN_VALUE
172
173
174   MODIFIED BYTECODE toy_example example.py line 12
175    12           0 LOAD_GLOBAL              3 (__compiled_fn_0)
176                 2 LOAD_FAST                0 (a)
177                 4 LOAD_FAST                1 (b)
178                 6 CALL_FUNCTION            2
179                 8 UNPACK_SEQUENCE          2
180                10 STORE_FAST               2 (x)
181                12 POP_JUMP_IF_FALSE       12 (to 24)
182                14 LOAD_GLOBAL              4 (__resume_at_30_1)
183                16 LOAD_FAST                1 (b)
184                18 LOAD_FAST                2 (x)
185                20 CALL_FUNCTION            2
186                22 RETURN_VALUE
187           >>   24 LOAD_GLOBAL              5 (__resume_at_38_2)
188                26 LOAD_FAST                1 (b)
189                28 LOAD_FAST                2 (x)
190                30 CALL_FUNCTION            2
191                32 RETURN_VALUE
192
193
194   possible source code:
195   def toy_example(a, b):
196       __temp_1 = __compiled_fn_0(a, b)
197       x = __temp_1[0]
198       if __temp_1[1]:
199           return __resume_at_30_1(b, x)
200       return __resume_at_38_2(b, x)
201
202   If you find the decompiled code is wrong,please submit an issue at https://github.com/youkaichao/depyf/issues.
203
204At the top you can see the FX graph.
205Next, you see the original bytecode of the function, followed by the
206modified bytecode generated by Dynamo, and the decompiled source
207code for reference. Finally, you see the guards which we covered above.
208
209In the modified bytecode, ``__compiled_fn_0`` is the return value of
210``my_compiler()`` (the compiled graph). ``__resume_at_30_1`` and
211``__resume_at_38_2`` are both generated continuation functions that pick
212up execution after a graph break (at bytecode offsets 30 and 38). Each
213of these functions take the form:
214
215::
216
217   __resume_at_<offset>:
218       ... restore stack state if needed ...
219       JUMP_ABSOLUTE <offset> into toy_example
220       ... original bytecode of toy_example ...
221
222By generating this ``resume_at`` function, we force the remainder of the
223function to be executed in a new Python frame which recursively
224triggers Dynamo to restart its capture once execution reaches that
225point for the first time.
226
227How to inspect artifacts generated by Dynamo?
228---------------------------------------------
229
230To inspect the artifacts generated by Dynamo, there is an API ``torch._dynamo.eval_frame._debug_get_cache_entry_list`` that retrieves compiled code and guards out of a function's ``__code__`` object. A compiled function can have several cache entries, and each cache entry consists a generated function to check guards, and a ``types.CodeType`` object to keep the code to be executed if the guarding conditions are satisfied.
231
232.. code-block:: python
233
234   from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn
235   cache_entries = _debug_get_cache_entry_list(innermost_fn(toy_example))
236   cache_entry = cache_entries[0]
237   guard, code = cache_entry.check_fn, cache_entry.code
238   # the guard takes the local variables of an input frame, and tells whether a re-compilation should be triggered.
239   import dis
240   dis.dis(guard)
241   dis.dis(code)
242
243If you know Python bytecode, you can understand the above output.
244
245For the guard function, there is no need to inspect the bytecode. We can directly access its guarding conditions:
246
247.. code-block:: python
248
249   for code_part in guard.code_parts:
250       print(code_part)
251
252The output is:
253
254::
255
256   ___guarded_code.valid
257   ___check_global_state()
258   hasattr(L['a'], '_dynamo_dynamic_indices') == False
259   hasattr(L['b'], '_dynamo_dynamic_indices') == False
260   utils_device.CURRENT_DEVICE == None
261   ___skip_backend_check() or ___current_backend() == ___lookup_backend(140215810860528)
262   ___check_tensors(L['a'], L['b'], tensor_check_names=tensor_check_names)
263
264Only when all the conditions are satisfied, the guard function returns true, and the compiled code is executed.
265
266For the compiled code, we cannot directly access its source but have to decompile it.
267
268.. code-block:: python
269
270   from depyf import decompile
271   print(decompile(code))
272
273The output is:
274
275::
276
277   def toy_example(a, b):
278       __temp_1 = __compiled_fn_0(a, b)
279       x = __temp_1[0]
280       if __temp_1[1]:
281           return __resume_at_30_1(b, x)
282       return __resume_at_38_2(b, x)
283
284Some names referenced in the code are:
285
286- Compiled functions, stored in the global namespace of the module containing the original function ``toy_example``. These include names like ``__compiled_fn_0`` / ``__resume_at_30_1`` / ``__resume_at_38_2``.
287
288- Closure variables used for checking guards. The names can be accessed from ``guard.__code__.co_freevars``, and the values are stored in ``guard.__closure__``. These include names like ``___guarded_code`` / ``___is_grad_enabled`` / ``___are_deterministic_algorithms_enabled`` / ``___is_torch_function_enabled`` / ``utils_device`` / ``___check_tensors`` / ``tensor_check_names``.
289
290- Argument ``L`` of the ``guard`` function. This is a dict mapping the name of arguments of ``toy_example`` to its values. This is only available when the function is called, where the frame evaluation API comes into play. In short, ``L`` is a ``dict`` with structure of ``{'a': value_a, 'b': value_b}``. Therefore, you can see the code uses ``L['a']`` to refer to the input variable ``a``.
291
292The graph break is shown in the code of compiled ``toy_example``, where we have to use Python interpreter to select the following graph to execute.
293
294Note that we pass a simple ``my_compiler`` function as the backend compiler, therefore the subgraph code ``__resume_at_38_2``, ``__resume_at_30_1``, and ``__compiled_fn_0`` remain Python code. This can also be inspected (please ignore the function name, and only use the function signature and function body code):
295
296.. code-block:: python
297
298   print("source code of __compiled_fn_0:")
299   print(innermost_fn(__compiled_fn_0).__self__.code)
300   print("=" * 60)
301   print("source code of __resume_at_30_1:")
302   print(decompile(__resume_at_30_1))
303   print("=" * 60)
304   print("source code of __resume_at_38_2:")
305   print(decompile(__resume_at_38_2))
306
307::
308
309   source code of __compiled_fn_0:
310
311   def forward(self, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
312       l_a_ = L_a_
313       l_b_ = L_b_
314       abs_1 = torch.abs(l_a_)
315       add = abs_1 + 1;  abs_1 = None
316       truediv = l_a_ / add;  l_a_ = add = None
317       sum_1 = l_b_.sum();  l_b_ = None
318       lt = sum_1 < 0;  sum_1 = None
319       return (truediv, lt)
320
321   # To see more debug info, please use ``graph_module.print_readable()``
322   ============================================================
323   source code of __resume_at_30_1:
324   def <resume in toy_example>(b, x):
325       b = b * -1
326       return x * b
327
328   ============================================================
329   source code of __resume_at_38_2:
330   def <resume in toy_example>(b, x):
331       return x * b
332
333However, if we use other backends like the built-in ``inductor``, the subgraph code will be compiled CUDA kernels for GPU or C++ code for CPU.
334
335To summarize, the compiled code is conceptually equivalent to the code below:
336
337.. code-block:: python
338
339   def compiled_example(a, b):
340       L = {'a': a, 'b': b}
341       for guard, code in get_cache_entries():
342           if guard(L):
343               return code(a, b)
344       recompile_and_add_another_cache_entry()
345
346The following diagram demonstrates how ``torch.compile`` transforms and optimizes user-written code: it first extracts computation graphs from the user-written function, and compiles these graphs into optimized functions, then assembles them into a new function, which is functionally equivalent to the user-written code but optimized to have a good computation speed.
347
348.. image:: _static/img/dynamo/flowchart.jpg
349
350To learn more about how all this is implemented internally, see :ref:`torch.compiler_dynamo_deepdive`.
351