xref: /aosp_15_r20/external/pytorch/docs/source/export.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1.. _torch.export:
2
3torch.export
4=====================
5
6.. warning::
7    This feature is a prototype under active development and there WILL BE
8    BREAKING CHANGES in the future.
9
10
11Overview
12--------
13
14:func:`torch.export.export` takes an arbitrary Python callable (a
15:class:`torch.nn.Module`, a function or a method) and produces a traced graph
16representing only the Tensor computation of the function in an Ahead-of-Time
17(AOT) fashion, which can subsequently be executed with different outputs or
18serialized.
19
20::
21
22    import torch
23    from torch.export import export
24
25    class Mod(torch.nn.Module):
26        def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
27            a = torch.sin(x)
28            b = torch.cos(y)
29            return a + b
30
31    example_args = (torch.randn(10, 10), torch.randn(10, 10))
32
33    exported_program: torch.export.ExportedProgram = export(
34        Mod(), args=example_args
35    )
36    print(exported_program)
37
38.. code-block::
39
40    ExportedProgram:
41        class GraphModule(torch.nn.Module):
42            def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]):
43                # code: a = torch.sin(x)
44                sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1);
45
46                # code: b = torch.cos(y)
47                cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1);
48
49                # code: return a + b
50                add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos);
51                return (add,)
52
53        Graph signature: ExportGraphSignature(
54            parameters=[],
55            buffers=[],
56            user_inputs=['arg0_1', 'arg1_1'],
57            user_outputs=['add'],
58            inputs_to_parameters={},
59            inputs_to_buffers={},
60            buffers_to_mutate={},
61            backward_signature=None,
62            assertion_dep_token=None,
63        )
64        Range constraints: {}
65
66``torch.export`` produces a clean intermediate representation (IR) with the
67following invariants. More specifications about the IR can be found
68:ref:`here <export.ir_spec>`.
69
70* **Soundness**: It is guaranteed to be a sound representation of the original
71  program, and maintains the same calling conventions of the original program.
72
73* **Normalized**: There are no Python semantics within the graph. Submodules
74  from the original programs are inlined to form one fully flattened
75  computational graph.
76
77* **Graph properties**: The graph is purely functional, meaning it does not
78  contain operations with side effects such as mutations or aliasing. It does
79  not mutate any intermediate values, parameters, or buffers.
80
81* **Metadata**: The graph contains metadata captured during tracing, such as a
82  stacktrace from user's code.
83
84Under the hood, ``torch.export`` leverages the following latest technologies:
85
86* **TorchDynamo (torch._dynamo)** is an internal API that uses a CPython feature
87  called the Frame Evaluation API to safely trace PyTorch graphs. This
88  provides a massively improved graph capturing experience, with much fewer
89  rewrites needed in order to fully trace the PyTorch code.
90
91* **AOT Autograd** provides a functionalized PyTorch graph and ensures the graph
92  is decomposed/lowered to the ATen operator set.
93
94* **Torch FX (torch.fx)** is the underlying representation of the graph,
95  allowing flexible Python-based transformations.
96
97
98Existing frameworks
99^^^^^^^^^^^^^^^^^^^
100
101:func:`torch.compile` also utilizes the same PT2 stack as ``torch.export``, but
102is slightly different:
103
104* **JIT vs. AOT**: :func:`torch.compile` is a JIT compiler whereas
105  which is not intended to be used to produce compiled artifacts outside of
106  deployment.
107
108* **Partial vs. Full Graph Capture**: When :func:`torch.compile` runs into an
109  untraceable part of a model, it will "graph break" and fall back to running
110  the program in the eager Python runtime. In comparison, ``torch.export`` aims
111  to get a full graph representation of a PyTorch model, so it will error out
112  when something untraceable is reached. Since ``torch.export`` produces a full
113  graph disjoint from any Python features or runtime, this graph can then be
114  saved, loaded, and run in different environments and languages.
115
116* **Usability tradeoff**: Since :func:`torch.compile` is able to fallback to the
117  Python runtime whenever it reaches something untraceable, it is a lot more
118  flexible. ``torch.export`` will instead require users to provide more
119  information or rewrite their code to make it traceable.
120
121Compared to :func:`torch.fx.symbolic_trace`, ``torch.export`` traces using
122TorchDynamo which operates at the Python bytecode level, giving it the ability
123to trace arbitrary Python constructs not limited by what Python operator
124overloading supports. Additionally, ``torch.export`` keeps fine-grained track of
125tensor metadata, so that conditionals on things like tensor shapes do not
126fail tracing. In general, ``torch.export`` is expected to work on more user
127programs, and produce lower-level graphs (at the ``torch.ops.aten`` operator
128level). Note that users can still use :func:`torch.fx.symbolic_trace` as a
129preprocessing step before ``torch.export``.
130
131Compared to :func:`torch.jit.script`, ``torch.export`` does not capture Python
132control flow or data structures, but it supports more Python language features
133than TorchScript (as it is easier to have comprehensive coverage over Python
134bytecodes). The resulting graphs are simpler and only have straight line control
135flow (except for explicit control flow operators).
136
137Compared to :func:`torch.jit.trace`, ``torch.export`` is sound: it is able to
138trace code that performs integer computation on sizes and records all of the
139side-conditions necessary to show that a particular trace is valid for other
140inputs.
141
142
143Exporting a PyTorch Model
144-------------------------
145
146An Example
147^^^^^^^^^^
148
149The main entrypoint is through :func:`torch.export.export`, which takes a
150callable (:class:`torch.nn.Module`, function, or method) and sample inputs, and
151captures the computation graph into an :class:`torch.export.ExportedProgram`. An
152example:
153
154::
155
156    import torch
157    from torch.export import export
158
159    # Simple module for demonstration
160    class M(torch.nn.Module):
161        def __init__(self) -> None:
162            super().__init__()
163            self.conv = torch.nn.Conv2d(
164                in_channels=3, out_channels=16, kernel_size=3, padding=1
165            )
166            self.relu = torch.nn.ReLU()
167            self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
168
169        def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
170            a = self.conv(x)
171            a.add_(constant)
172            return self.maxpool(self.relu(a))
173
174    example_args = (torch.randn(1, 3, 256, 256),)
175    example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}
176
177    exported_program: torch.export.ExportedProgram = export(
178        M(), args=example_args, kwargs=example_kwargs
179    )
180    print(exported_program)
181
182.. code-block::
183
184    ExportedProgram:
185        class GraphModule(torch.nn.Module):
186            def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]):
187
188                # code: a = self.conv(x)
189                convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(
190                    arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
191                );
192
193                # code: a.add_(constant)
194                add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1);
195
196                # code: return self.maxpool(self.relu(a))
197                relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add);
198                max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(
199                    relu, [3, 3], [3, 3]
200                );
201                getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0];
202                return (getitem,)
203
204        Graph signature: ExportGraphSignature(
205            parameters=['L__self___conv.weight', 'L__self___conv.bias'],
206            buffers=[],
207            user_inputs=['arg2_1', 'arg3_1'],
208            user_outputs=['getitem'],
209            inputs_to_parameters={
210                'arg0_1': 'L__self___conv.weight',
211                'arg1_1': 'L__self___conv.bias',
212            },
213            inputs_to_buffers={},
214            buffers_to_mutate={},
215            backward_signature=None,
216            assertion_dep_token=None,
217        )
218        Range constraints: {}
219
220Inspecting the ``ExportedProgram``, we can note the following:
221
222* The :class:`torch.fx.Graph` contains the computation graph of the original
223  program, along with records of the original code for easy debugging.
224
225* The graph contains only ``torch.ops.aten`` operators found `here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml>`__
226  and custom operators, and is fully functional, without any inplace operators
227  such as ``torch.add_``.
228
229* The parameters (weight and bias to conv) are lifted as inputs to the graph,
230  resulting in no ``get_attr`` nodes in the graph, which previously existed in
231  the result of :func:`torch.fx.symbolic_trace`.
232
233* The :class:`torch.export.ExportGraphSignature` models the input and output
234  signature, along with specifying which inputs are parameters.
235
236* The resulting shape and dtype of tensors produced by each node in the graph is
237  noted. For example, the ``convolution`` node will result in a tensor of dtype
238  ``torch.float32`` and shape (1, 16, 256, 256).
239
240
241.. _Non-Strict Export:
242
243Non-Strict Export
244^^^^^^^^^^^^^^^^^
245
246In PyTorch 2.3, we introduced a new mode of tracing called **non-strict mode**.
247It's still going through hardening, so if you run into any issues, please file
248them to Github with the "oncall: export" tag.
249
250In *non-strict mode*, we trace through the program using the Python interpreter.
251Your code will execute exactly as it would in eager mode; the only difference is
252that all Tensor objects will be replaced by ProxyTensors, which will record all
253their operations into a graph.
254
255In *strict* mode, which is currently the default, we first trace through the
256program using TorchDynamo, a bytecode analysis engine. TorchDynamo does not
257actually execute your Python code. Instead, it symbolically analyzes it and
258builds a graph based on the results. This analysis allows torch.export to
259provide stronger guarantees about safety, but not all Python code is supported.
260
261An example of a case where one might want to use non-strict mode is if you run
262into a unsupported TorchDynamo feature that might not be easily solved, and you
263know the python code is not exactly needed for computation. For example:
264
265::
266
267    import contextlib
268    import torch
269
270    class ContextManager():
271        def __init__(self):
272            self.count = 0
273        def __enter__(self):
274            self.count += 1
275        def __exit__(self, exc_type, exc_value, traceback):
276            self.count -= 1
277
278    class M(torch.nn.Module):
279        def forward(self, x):
280            with ContextManager():
281                return x.sin() + x.cos()
282
283    export(M(), (torch.ones(3, 3),), strict=False)  # Non-strict traces successfully
284    export(M(), (torch.ones(3, 3),))  # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager
285
286In this example, the first call using non-strict mode (through the
287``strict=False`` flag) traces successfully whereas the second call using strict
288mode (default) results with a failure, where TorchDynamo is unable to support
289context managers. One option is to rewrite the code (see :ref:`Limitations of torch.export <Limitations of
290torch.export>`), but seeing as the context manager does not affect the tensor
291computations in the model, we can go with the non-strict mode's result.
292
293
294Expressing Dynamism
295^^^^^^^^^^^^^^^^^^^
296
297By default ``torch.export`` will trace the program assuming all input shapes are
298**static**, and specializing the exported program to those dimensions. However,
299some dimensions, such as a batch dimension, can be dynamic and vary from run to
300run. Such dimensions must be specified by using the
301:func:`torch.export.Dim` API to create them and by passing them into
302:func:`torch.export.export` through the ``dynamic_shapes`` argument. An example:
303
304::
305
306    import torch
307    from torch.export import Dim, export
308
309    class M(torch.nn.Module):
310        def __init__(self):
311            super().__init__()
312
313            self.branch1 = torch.nn.Sequential(
314                torch.nn.Linear(64, 32), torch.nn.ReLU()
315            )
316            self.branch2 = torch.nn.Sequential(
317                torch.nn.Linear(128, 64), torch.nn.ReLU()
318            )
319            self.buffer = torch.ones(32)
320
321        def forward(self, x1, x2):
322            out1 = self.branch1(x1)
323            out2 = self.branch2(x2)
324            return (out1 + self.buffer, out2)
325
326    example_args = (torch.randn(32, 64), torch.randn(32, 128))
327
328    # Create a dynamic batch size
329    batch = Dim("batch")
330    # Specify that the first dimension of each input is that batch size
331    dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
332
333    exported_program: torch.export.ExportedProgram = export(
334        M(), args=example_args, dynamic_shapes=dynamic_shapes
335    )
336    print(exported_program)
337
338.. code-block::
339
340    ExportedProgram:
341        class GraphModule(torch.nn.Module):
342            def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]):
343
344                # code: out1 = self.branch1(x1)
345                permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]);
346                addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute);
347                relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm);
348
349                # code: out2 = self.branch2(x2)
350                permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]);
351                addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1);
352                relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1);  addmm_1 = None
353
354                # code: return (out1 + self.buffer, out2)
355                add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1);
356                return (add, relu_1)
357
358        Graph signature: ExportGraphSignature(
359            parameters=[
360                'branch1.0.weight',
361                'branch1.0.bias',
362                'branch2.0.weight',
363                'branch2.0.bias',
364            ],
365            buffers=['L__self___buffer'],
366            user_inputs=['arg5_1', 'arg6_1'],
367            user_outputs=['add', 'relu_1'],
368            inputs_to_parameters={
369                'arg0_1': 'branch1.0.weight',
370                'arg1_1': 'branch1.0.bias',
371                'arg2_1': 'branch2.0.weight',
372                'arg3_1': 'branch2.0.bias',
373            },
374            inputs_to_buffers={'arg4_1': 'L__self___buffer'},
375            buffers_to_mutate={},
376            backward_signature=None,
377            assertion_dep_token=None,
378        )
379        Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}
380
381Some additional things to note:
382
383* Through the :func:`torch.export.Dim` API and the ``dynamic_shapes`` argument, we specified the first
384  dimension of each input to be dynamic. Looking at the inputs ``arg5_1`` and
385  ``arg6_1``, they have a symbolic shape of (s0, 64) and (s0, 128), instead of
386  the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs.
387  ``s0`` is a symbol representing that this dimension can be a range
388  of values.
389
390* ``exported_program.range_constraints`` describes the ranges of each symbol
391  appearing in the graph. In this case, we see that ``s0`` has the range
392  [2, inf]. For technical reasons that are difficult to explain here, they are
393  assumed to be not 0 or 1. This is not a bug, and does not necessarily mean
394  that the exported program will not work for dimensions 0 or 1. See
395  `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`_
396  for an in-depth discussion of this topic.
397
398
399We can also specify more expressive relationships between input shapes, such as
400where a pair of shapes might differ by one, a shape might be double of
401another, or a shape is even. An example:
402
403::
404
405    class M(torch.nn.Module):
406        def forward(self, x, y):
407            return x + y[1:]
408
409    x, y = torch.randn(5), torch.randn(6)
410    dimx = torch.export.Dim("dimx", min=3, max=6)
411    dimy = dimx + 1
412
413    exported_program = torch.export.export(
414        M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}),
415    )
416    print(exported_program)
417
418.. code-block::
419
420    ExportedProgram:
421    class GraphModule(torch.nn.Module):
422        def forward(self, arg0_1: "f32[s0]", arg1_1: "f32[s0 + 1]"):
423            # code: return x + y[1:]
424            slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(arg1_1, 0, 1, 9223372036854775807);  arg1_1 = None
425            add: "f32[s0]" = torch.ops.aten.add.Tensor(arg0_1, slice_1);  arg0_1 = slice_1 = None
426            return (add,)
427
428    Graph signature: ExportGraphSignature(
429        input_specs=[
430            InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None),
431            InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)
432        ],
433        output_specs=[
434            OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]
435    )
436    Range constraints: {s0: ValueRanges(lower=3, upper=6, is_bool=False), s0 + 1: ValueRanges(lower=4, upper=7, is_bool=False)}
437
438Some things to note:
439
440* By specifying ``{0: dimx}`` for the first input, we see that the resulting
441  shape of the first input is now dynamic, being ``[s0]``. And now by specifying
442  ``{0: dimy}`` for the second input, we see that the resulting shape of the
443  second input is also dynamic. However, because we expressed ``dimy = dimx + 1``,
444  instead of ``arg1_1``'s shape containing a new symbol, we see that it is
445  now being represented with the same symbol used in ``arg0_1``, ``s0``. We can
446  see that relationship of ``dimy = dimx + 1`` is being shown through ``s0 + 1``.
447
448* Looking at the range constraints, we see that ``s0`` has the range [3, 6],
449  which is specified initially, and we can see that ``s0 + 1`` has the solved
450  range of [4, 7].
451
452
453Serialization
454^^^^^^^^^^^^^
455
456To save the ``ExportedProgram``, users can use the :func:`torch.export.save` and
457:func:`torch.export.load` APIs. A convention is to save the ``ExportedProgram``
458using a ``.pt2`` file extension.
459
460An example:
461
462::
463
464    import torch
465    import io
466
467    class MyModule(torch.nn.Module):
468        def forward(self, x):
469            return x + 10
470
471    exported_program = torch.export.export(MyModule(), torch.randn(5))
472
473    torch.export.save(exported_program, 'exported_program.pt2')
474    saved_exported_program = torch.export.load('exported_program.pt2')
475
476
477Specializations
478^^^^^^^^^^^^^^^
479
480A key concept in understanding the behavior of ``torch.export`` is the
481difference between *static* and *dynamic* values.
482
483A *dynamic* value is one that can change from run to run. These behave like
484normal arguments to a Python function—you can pass different values for an
485argument and expect your function to do the right thing. Tensor *data* is
486treated as dynamic.
487
488
489A *static* value is a value that is fixed at export time and cannot change
490between executions of the exported program. When the value is encountered during
491tracing, the exporter will treat it as a constant and hard-code it into the
492graph.
493
494When an operation is performed (e.g. ``x + y``) and all inputs are static, then
495the output of the operation will be directly hard-coded into the graph, and the
496operation won’t show up (i.e. it will get constant-folded).
497
498When a value has been hard-coded into the graph, we say that the graph has been
499*specialized* to that value.
500
501The following values are static:
502
503Input Tensor Shapes
504~~~~~~~~~~~~~~~~~~~
505
506By default, ``torch.export`` will trace the program specializing on the input
507tensors' shapes, unless a dimension is specified as dynamic via the
508``dynamic_shapes`` argument to ``torch.export``. This means that if there exists
509shape-dependent control flow, ``torch.export`` will specialize on the branch
510that is being taken with the given sample inputs. For example:
511
512::
513
514    import torch
515    from torch.export import export
516
517    class Mod(torch.nn.Module):
518        def forward(self, x):
519            if x.shape[0] > 5:
520                return x + 1
521            else:
522                return x - 1
523
524    example_inputs = (torch.rand(10, 2),)
525    exported_program = export(Mod(), example_inputs)
526    print(exported_program)
527
528.. code-block::
529
530    ExportedProgram:
531        class GraphModule(torch.nn.Module):
532            def forward(self, arg0_1: f32[10, 2]):
533                add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
534                return (add,)
535
536The conditional of (``x.shape[0] > 5``) does not appear in the
537``ExportedProgram`` because the example inputs have the static
538shape of (10, 2). Since ``torch.export`` specializes on the inputs' static
539shapes, the else branch (``x - 1``) will never be reached. To preserve the dynamic
540branching behavior based on the shape of a tensor in the traced graph,
541:func:`torch.export.Dim` will need to be used to specify the dimension
542of the input tensor (``x.shape[0]``) to be dynamic, and the source code will
543need to be :ref:`rewritten <Data/Shape-Dependent Control Flow>`.
544
545Note that tensors that are part of the module state (e.g. parameters and
546buffers) always have static shapes.
547
548Python Primitives
549~~~~~~~~~~~~~~~~~
550
551``torch.export`` also specializes on Python primtivies,
552such as ``int``, ``float``, ``bool``, and ``str``. However they do have dynamic
553variants such as ``SymInt``, ``SymFloat``, and ``SymBool``.
554
555For example:
556
557::
558
559    import torch
560    from torch.export import export
561
562    class Mod(torch.nn.Module):
563        def forward(self, x: torch.Tensor, const: int, times: int):
564            for i in range(times):
565                x = x + const
566            return x
567
568    example_inputs = (torch.rand(2, 2), 1, 3)
569    exported_program = export(Mod(), example_inputs)
570    print(exported_program)
571
572.. code-block::
573
574    ExportedProgram:
575        class GraphModule(torch.nn.Module):
576            def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1):
577                add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
578                add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1);
579                add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1);
580                return (add_2,)
581
582Because integers are specialized, the ``torch.ops.aten.add.Tensor`` operations
583are all computed with the hard-coded constant ``1``, rather than ``arg1_1``. If
584a user passes a different value for ``arg1_1`` at runtime, like 2, than the one used
585during export time, 1, this will result in an error.
586Additionally, the ``times`` iterator used in the ``for`` loop is also "inlined"
587in the graph through the 3 repeated ``torch.ops.aten.add.Tensor`` calls, and the
588input ``arg2_1`` is never used.
589
590Python Containers
591~~~~~~~~~~~~~~~~~
592
593Python containers (``List``, ``Dict``, ``NamedTuple``, etc.) are considered to
594have static structure.
595
596
597.. _Limitations of torch.export:
598
599Limitations of torch.export
600---------------------------
601
602Graph Breaks
603^^^^^^^^^^^^
604
605As ``torch.export`` is a one-shot process for capturing a computation graph from
606a PyTorch program, it might ultimately run into untraceable parts of programs as
607it is nearly impossible to support tracing all PyTorch and Python features. In
608the case of ``torch.compile``, an unsupported operation will cause a "graph
609break" and the unsupported operation will be run with default Python evaluation.
610In contrast, ``torch.export`` will require users to provide additional
611information or rewrite parts of their code to make it traceable. As the
612tracing is based on TorchDynamo, which evaluates at the Python
613bytecode level, there will be significantly fewer rewrites required compared to
614previous tracing frameworks.
615
616When a graph break is encountered, :ref:`ExportDB <torch.export_db>` is a great
617resource for learning about the kinds of programs that are supported and
618unsupported, along with ways to rewrite programs to make them traceable.
619
620An option to get past dealing with this graph breaks is by using
621:ref:`non-strict export <Non-Strict Export>`
622
623.. _Data/Shape-Dependent Control Flow:
624
625Data/Shape-Dependent Control Flow
626^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
627
628Graph breaks can also be encountered on data-dependent control flow (``if
629x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot
630possibly deal with without generating code for a combinatorially exploding
631number of paths. In such cases, users will need to rewrite their code using
632special control flow operators. Currently, we support :ref:`torch.cond <cond>`
633to express if-else like control flow (more coming soon!).
634
635Missing Fake/Meta/Abstract Kernels for Operators
636^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
637
638When tracing, a FakeTensor kernel (aka meta kernel, abstract impl) is
639required for all operators. This is used to reason about the input/output shapes
640for this operator.
641
642Please see :func:`torch.library.register_fake` for more details.
643
644In the unfortunate case where your model uses an ATen operator that is does not
645have a FakeTensor kernel implementation yet, please file an issue.
646
647
648Read More
649---------
650
651.. toctree::
652   :caption: Additional Links for Export Users
653   :maxdepth: 1
654
655   export.ir_spec
656   torch.compiler_transformations
657   torch.compiler_ir
658   generated/exportdb/index
659   cond
660
661.. toctree::
662   :caption: Deep Dive for PyTorch Developers
663   :maxdepth: 1
664
665   torch.compiler_dynamo_overview
666   torch.compiler_dynamo_deepdive
667   torch.compiler_dynamic_shapes
668   torch.compiler_fake_tensor
669
670
671API Reference
672-------------
673
674.. automodule:: torch.export
675.. autofunction:: export
676.. autofunction:: save
677.. autofunction:: load
678.. autofunction:: register_dataclass
679.. autofunction:: torch.export.dynamic_shapes.Dim
680.. autofunction:: dims
681.. autoclass:: torch.export.dynamic_shapes.ShapesCollection
682
683    .. automethod:: dynamic_shapes
684
685.. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes
686.. autoclass:: Constraint
687.. autoclass:: ExportedProgram
688
689    .. automethod:: module
690    .. automethod:: buffers
691    .. automethod:: named_buffers
692    .. automethod:: parameters
693    .. automethod:: named_parameters
694    .. automethod:: run_decompositions
695
696.. autoclass:: ExportBackwardSignature
697.. autoclass:: ExportGraphSignature
698.. autoclass:: ModuleCallSignature
699.. autoclass:: ModuleCallEntry
700
701
702.. automodule:: torch.export.exported_program
703.. automodule:: torch.export.graph_signature
704.. autoclass:: InputKind
705.. autoclass:: InputSpec
706.. autoclass:: OutputKind
707.. autoclass:: OutputSpec
708.. autoclass:: ExportGraphSignature
709
710    .. automethod:: replace_all_uses
711    .. automethod:: get_replace_hook
712
713.. autoclass:: torch.export.graph_signature.CustomObjArgument
714
715.. py:module:: torch.export.dynamic_shapes
716
717.. automodule:: torch.export.unflatten
718    :members:
719
720.. automodule:: torch.export.custom_obj
721
722.. automodule:: torch.export.experimental
723.. automodule:: torch.export.passes
724.. autofunction:: torch.export.passes.move_to_device_pass
725