1.. _torch.export: 2 3torch.export 4===================== 5 6.. warning:: 7 This feature is a prototype under active development and there WILL BE 8 BREAKING CHANGES in the future. 9 10 11Overview 12-------- 13 14:func:`torch.export.export` takes an arbitrary Python callable (a 15:class:`torch.nn.Module`, a function or a method) and produces a traced graph 16representing only the Tensor computation of the function in an Ahead-of-Time 17(AOT) fashion, which can subsequently be executed with different outputs or 18serialized. 19 20:: 21 22 import torch 23 from torch.export import export 24 25 class Mod(torch.nn.Module): 26 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 27 a = torch.sin(x) 28 b = torch.cos(y) 29 return a + b 30 31 example_args = (torch.randn(10, 10), torch.randn(10, 10)) 32 33 exported_program: torch.export.ExportedProgram = export( 34 Mod(), args=example_args 35 ) 36 print(exported_program) 37 38.. code-block:: 39 40 ExportedProgram: 41 class GraphModule(torch.nn.Module): 42 def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]): 43 # code: a = torch.sin(x) 44 sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1); 45 46 # code: b = torch.cos(y) 47 cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1); 48 49 # code: return a + b 50 add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos); 51 return (add,) 52 53 Graph signature: ExportGraphSignature( 54 parameters=[], 55 buffers=[], 56 user_inputs=['arg0_1', 'arg1_1'], 57 user_outputs=['add'], 58 inputs_to_parameters={}, 59 inputs_to_buffers={}, 60 buffers_to_mutate={}, 61 backward_signature=None, 62 assertion_dep_token=None, 63 ) 64 Range constraints: {} 65 66``torch.export`` produces a clean intermediate representation (IR) with the 67following invariants. More specifications about the IR can be found 68:ref:`here <export.ir_spec>`. 69 70* **Soundness**: It is guaranteed to be a sound representation of the original 71 program, and maintains the same calling conventions of the original program. 72 73* **Normalized**: There are no Python semantics within the graph. Submodules 74 from the original programs are inlined to form one fully flattened 75 computational graph. 76 77* **Graph properties**: The graph is purely functional, meaning it does not 78 contain operations with side effects such as mutations or aliasing. It does 79 not mutate any intermediate values, parameters, or buffers. 80 81* **Metadata**: The graph contains metadata captured during tracing, such as a 82 stacktrace from user's code. 83 84Under the hood, ``torch.export`` leverages the following latest technologies: 85 86* **TorchDynamo (torch._dynamo)** is an internal API that uses a CPython feature 87 called the Frame Evaluation API to safely trace PyTorch graphs. This 88 provides a massively improved graph capturing experience, with much fewer 89 rewrites needed in order to fully trace the PyTorch code. 90 91* **AOT Autograd** provides a functionalized PyTorch graph and ensures the graph 92 is decomposed/lowered to the ATen operator set. 93 94* **Torch FX (torch.fx)** is the underlying representation of the graph, 95 allowing flexible Python-based transformations. 96 97 98Existing frameworks 99^^^^^^^^^^^^^^^^^^^ 100 101:func:`torch.compile` also utilizes the same PT2 stack as ``torch.export``, but 102is slightly different: 103 104* **JIT vs. AOT**: :func:`torch.compile` is a JIT compiler whereas 105 which is not intended to be used to produce compiled artifacts outside of 106 deployment. 107 108* **Partial vs. Full Graph Capture**: When :func:`torch.compile` runs into an 109 untraceable part of a model, it will "graph break" and fall back to running 110 the program in the eager Python runtime. In comparison, ``torch.export`` aims 111 to get a full graph representation of a PyTorch model, so it will error out 112 when something untraceable is reached. Since ``torch.export`` produces a full 113 graph disjoint from any Python features or runtime, this graph can then be 114 saved, loaded, and run in different environments and languages. 115 116* **Usability tradeoff**: Since :func:`torch.compile` is able to fallback to the 117 Python runtime whenever it reaches something untraceable, it is a lot more 118 flexible. ``torch.export`` will instead require users to provide more 119 information or rewrite their code to make it traceable. 120 121Compared to :func:`torch.fx.symbolic_trace`, ``torch.export`` traces using 122TorchDynamo which operates at the Python bytecode level, giving it the ability 123to trace arbitrary Python constructs not limited by what Python operator 124overloading supports. Additionally, ``torch.export`` keeps fine-grained track of 125tensor metadata, so that conditionals on things like tensor shapes do not 126fail tracing. In general, ``torch.export`` is expected to work on more user 127programs, and produce lower-level graphs (at the ``torch.ops.aten`` operator 128level). Note that users can still use :func:`torch.fx.symbolic_trace` as a 129preprocessing step before ``torch.export``. 130 131Compared to :func:`torch.jit.script`, ``torch.export`` does not capture Python 132control flow or data structures, but it supports more Python language features 133than TorchScript (as it is easier to have comprehensive coverage over Python 134bytecodes). The resulting graphs are simpler and only have straight line control 135flow (except for explicit control flow operators). 136 137Compared to :func:`torch.jit.trace`, ``torch.export`` is sound: it is able to 138trace code that performs integer computation on sizes and records all of the 139side-conditions necessary to show that a particular trace is valid for other 140inputs. 141 142 143Exporting a PyTorch Model 144------------------------- 145 146An Example 147^^^^^^^^^^ 148 149The main entrypoint is through :func:`torch.export.export`, which takes a 150callable (:class:`torch.nn.Module`, function, or method) and sample inputs, and 151captures the computation graph into an :class:`torch.export.ExportedProgram`. An 152example: 153 154:: 155 156 import torch 157 from torch.export import export 158 159 # Simple module for demonstration 160 class M(torch.nn.Module): 161 def __init__(self) -> None: 162 super().__init__() 163 self.conv = torch.nn.Conv2d( 164 in_channels=3, out_channels=16, kernel_size=3, padding=1 165 ) 166 self.relu = torch.nn.ReLU() 167 self.maxpool = torch.nn.MaxPool2d(kernel_size=3) 168 169 def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor: 170 a = self.conv(x) 171 a.add_(constant) 172 return self.maxpool(self.relu(a)) 173 174 example_args = (torch.randn(1, 3, 256, 256),) 175 example_kwargs = {"constant": torch.ones(1, 16, 256, 256)} 176 177 exported_program: torch.export.ExportedProgram = export( 178 M(), args=example_args, kwargs=example_kwargs 179 ) 180 print(exported_program) 181 182.. code-block:: 183 184 ExportedProgram: 185 class GraphModule(torch.nn.Module): 186 def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]): 187 188 # code: a = self.conv(x) 189 convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default( 190 arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1 191 ); 192 193 # code: a.add_(constant) 194 add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1); 195 196 # code: return self.maxpool(self.relu(a)) 197 relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add); 198 max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default( 199 relu, [3, 3], [3, 3] 200 ); 201 getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0]; 202 return (getitem,) 203 204 Graph signature: ExportGraphSignature( 205 parameters=['L__self___conv.weight', 'L__self___conv.bias'], 206 buffers=[], 207 user_inputs=['arg2_1', 'arg3_1'], 208 user_outputs=['getitem'], 209 inputs_to_parameters={ 210 'arg0_1': 'L__self___conv.weight', 211 'arg1_1': 'L__self___conv.bias', 212 }, 213 inputs_to_buffers={}, 214 buffers_to_mutate={}, 215 backward_signature=None, 216 assertion_dep_token=None, 217 ) 218 Range constraints: {} 219 220Inspecting the ``ExportedProgram``, we can note the following: 221 222* The :class:`torch.fx.Graph` contains the computation graph of the original 223 program, along with records of the original code for easy debugging. 224 225* The graph contains only ``torch.ops.aten`` operators found `here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml>`__ 226 and custom operators, and is fully functional, without any inplace operators 227 such as ``torch.add_``. 228 229* The parameters (weight and bias to conv) are lifted as inputs to the graph, 230 resulting in no ``get_attr`` nodes in the graph, which previously existed in 231 the result of :func:`torch.fx.symbolic_trace`. 232 233* The :class:`torch.export.ExportGraphSignature` models the input and output 234 signature, along with specifying which inputs are parameters. 235 236* The resulting shape and dtype of tensors produced by each node in the graph is 237 noted. For example, the ``convolution`` node will result in a tensor of dtype 238 ``torch.float32`` and shape (1, 16, 256, 256). 239 240 241.. _Non-Strict Export: 242 243Non-Strict Export 244^^^^^^^^^^^^^^^^^ 245 246In PyTorch 2.3, we introduced a new mode of tracing called **non-strict mode**. 247It's still going through hardening, so if you run into any issues, please file 248them to Github with the "oncall: export" tag. 249 250In *non-strict mode*, we trace through the program using the Python interpreter. 251Your code will execute exactly as it would in eager mode; the only difference is 252that all Tensor objects will be replaced by ProxyTensors, which will record all 253their operations into a graph. 254 255In *strict* mode, which is currently the default, we first trace through the 256program using TorchDynamo, a bytecode analysis engine. TorchDynamo does not 257actually execute your Python code. Instead, it symbolically analyzes it and 258builds a graph based on the results. This analysis allows torch.export to 259provide stronger guarantees about safety, but not all Python code is supported. 260 261An example of a case where one might want to use non-strict mode is if you run 262into a unsupported TorchDynamo feature that might not be easily solved, and you 263know the python code is not exactly needed for computation. For example: 264 265:: 266 267 import contextlib 268 import torch 269 270 class ContextManager(): 271 def __init__(self): 272 self.count = 0 273 def __enter__(self): 274 self.count += 1 275 def __exit__(self, exc_type, exc_value, traceback): 276 self.count -= 1 277 278 class M(torch.nn.Module): 279 def forward(self, x): 280 with ContextManager(): 281 return x.sin() + x.cos() 282 283 export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully 284 export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager 285 286In this example, the first call using non-strict mode (through the 287``strict=False`` flag) traces successfully whereas the second call using strict 288mode (default) results with a failure, where TorchDynamo is unable to support 289context managers. One option is to rewrite the code (see :ref:`Limitations of torch.export <Limitations of 290torch.export>`), but seeing as the context manager does not affect the tensor 291computations in the model, we can go with the non-strict mode's result. 292 293 294Expressing Dynamism 295^^^^^^^^^^^^^^^^^^^ 296 297By default ``torch.export`` will trace the program assuming all input shapes are 298**static**, and specializing the exported program to those dimensions. However, 299some dimensions, such as a batch dimension, can be dynamic and vary from run to 300run. Such dimensions must be specified by using the 301:func:`torch.export.Dim` API to create them and by passing them into 302:func:`torch.export.export` through the ``dynamic_shapes`` argument. An example: 303 304:: 305 306 import torch 307 from torch.export import Dim, export 308 309 class M(torch.nn.Module): 310 def __init__(self): 311 super().__init__() 312 313 self.branch1 = torch.nn.Sequential( 314 torch.nn.Linear(64, 32), torch.nn.ReLU() 315 ) 316 self.branch2 = torch.nn.Sequential( 317 torch.nn.Linear(128, 64), torch.nn.ReLU() 318 ) 319 self.buffer = torch.ones(32) 320 321 def forward(self, x1, x2): 322 out1 = self.branch1(x1) 323 out2 = self.branch2(x2) 324 return (out1 + self.buffer, out2) 325 326 example_args = (torch.randn(32, 64), torch.randn(32, 128)) 327 328 # Create a dynamic batch size 329 batch = Dim("batch") 330 # Specify that the first dimension of each input is that batch size 331 dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} 332 333 exported_program: torch.export.ExportedProgram = export( 334 M(), args=example_args, dynamic_shapes=dynamic_shapes 335 ) 336 print(exported_program) 337 338.. code-block:: 339 340 ExportedProgram: 341 class GraphModule(torch.nn.Module): 342 def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]): 343 344 # code: out1 = self.branch1(x1) 345 permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]); 346 addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute); 347 relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm); 348 349 # code: out2 = self.branch2(x2) 350 permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]); 351 addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1); 352 relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1); addmm_1 = None 353 354 # code: return (out1 + self.buffer, out2) 355 add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1); 356 return (add, relu_1) 357 358 Graph signature: ExportGraphSignature( 359 parameters=[ 360 'branch1.0.weight', 361 'branch1.0.bias', 362 'branch2.0.weight', 363 'branch2.0.bias', 364 ], 365 buffers=['L__self___buffer'], 366 user_inputs=['arg5_1', 'arg6_1'], 367 user_outputs=['add', 'relu_1'], 368 inputs_to_parameters={ 369 'arg0_1': 'branch1.0.weight', 370 'arg1_1': 'branch1.0.bias', 371 'arg2_1': 'branch2.0.weight', 372 'arg3_1': 'branch2.0.bias', 373 }, 374 inputs_to_buffers={'arg4_1': 'L__self___buffer'}, 375 buffers_to_mutate={}, 376 backward_signature=None, 377 assertion_dep_token=None, 378 ) 379 Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)} 380 381Some additional things to note: 382 383* Through the :func:`torch.export.Dim` API and the ``dynamic_shapes`` argument, we specified the first 384 dimension of each input to be dynamic. Looking at the inputs ``arg5_1`` and 385 ``arg6_1``, they have a symbolic shape of (s0, 64) and (s0, 128), instead of 386 the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs. 387 ``s0`` is a symbol representing that this dimension can be a range 388 of values. 389 390* ``exported_program.range_constraints`` describes the ranges of each symbol 391 appearing in the graph. In this case, we see that ``s0`` has the range 392 [2, inf]. For technical reasons that are difficult to explain here, they are 393 assumed to be not 0 or 1. This is not a bug, and does not necessarily mean 394 that the exported program will not work for dimensions 0 or 1. See 395 `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`_ 396 for an in-depth discussion of this topic. 397 398 399We can also specify more expressive relationships between input shapes, such as 400where a pair of shapes might differ by one, a shape might be double of 401another, or a shape is even. An example: 402 403:: 404 405 class M(torch.nn.Module): 406 def forward(self, x, y): 407 return x + y[1:] 408 409 x, y = torch.randn(5), torch.randn(6) 410 dimx = torch.export.Dim("dimx", min=3, max=6) 411 dimy = dimx + 1 412 413 exported_program = torch.export.export( 414 M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}), 415 ) 416 print(exported_program) 417 418.. code-block:: 419 420 ExportedProgram: 421 class GraphModule(torch.nn.Module): 422 def forward(self, arg0_1: "f32[s0]", arg1_1: "f32[s0 + 1]"): 423 # code: return x + y[1:] 424 slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(arg1_1, 0, 1, 9223372036854775807); arg1_1 = None 425 add: "f32[s0]" = torch.ops.aten.add.Tensor(arg0_1, slice_1); arg0_1 = slice_1 = None 426 return (add,) 427 428 Graph signature: ExportGraphSignature( 429 input_specs=[ 430 InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), 431 InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None) 432 ], 433 output_specs=[ 434 OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)] 435 ) 436 Range constraints: {s0: ValueRanges(lower=3, upper=6, is_bool=False), s0 + 1: ValueRanges(lower=4, upper=7, is_bool=False)} 437 438Some things to note: 439 440* By specifying ``{0: dimx}`` for the first input, we see that the resulting 441 shape of the first input is now dynamic, being ``[s0]``. And now by specifying 442 ``{0: dimy}`` for the second input, we see that the resulting shape of the 443 second input is also dynamic. However, because we expressed ``dimy = dimx + 1``, 444 instead of ``arg1_1``'s shape containing a new symbol, we see that it is 445 now being represented with the same symbol used in ``arg0_1``, ``s0``. We can 446 see that relationship of ``dimy = dimx + 1`` is being shown through ``s0 + 1``. 447 448* Looking at the range constraints, we see that ``s0`` has the range [3, 6], 449 which is specified initially, and we can see that ``s0 + 1`` has the solved 450 range of [4, 7]. 451 452 453Serialization 454^^^^^^^^^^^^^ 455 456To save the ``ExportedProgram``, users can use the :func:`torch.export.save` and 457:func:`torch.export.load` APIs. A convention is to save the ``ExportedProgram`` 458using a ``.pt2`` file extension. 459 460An example: 461 462:: 463 464 import torch 465 import io 466 467 class MyModule(torch.nn.Module): 468 def forward(self, x): 469 return x + 10 470 471 exported_program = torch.export.export(MyModule(), torch.randn(5)) 472 473 torch.export.save(exported_program, 'exported_program.pt2') 474 saved_exported_program = torch.export.load('exported_program.pt2') 475 476 477Specializations 478^^^^^^^^^^^^^^^ 479 480A key concept in understanding the behavior of ``torch.export`` is the 481difference between *static* and *dynamic* values. 482 483A *dynamic* value is one that can change from run to run. These behave like 484normal arguments to a Python function—you can pass different values for an 485argument and expect your function to do the right thing. Tensor *data* is 486treated as dynamic. 487 488 489A *static* value is a value that is fixed at export time and cannot change 490between executions of the exported program. When the value is encountered during 491tracing, the exporter will treat it as a constant and hard-code it into the 492graph. 493 494When an operation is performed (e.g. ``x + y``) and all inputs are static, then 495the output of the operation will be directly hard-coded into the graph, and the 496operation won’t show up (i.e. it will get constant-folded). 497 498When a value has been hard-coded into the graph, we say that the graph has been 499*specialized* to that value. 500 501The following values are static: 502 503Input Tensor Shapes 504~~~~~~~~~~~~~~~~~~~ 505 506By default, ``torch.export`` will trace the program specializing on the input 507tensors' shapes, unless a dimension is specified as dynamic via the 508``dynamic_shapes`` argument to ``torch.export``. This means that if there exists 509shape-dependent control flow, ``torch.export`` will specialize on the branch 510that is being taken with the given sample inputs. For example: 511 512:: 513 514 import torch 515 from torch.export import export 516 517 class Mod(torch.nn.Module): 518 def forward(self, x): 519 if x.shape[0] > 5: 520 return x + 1 521 else: 522 return x - 1 523 524 example_inputs = (torch.rand(10, 2),) 525 exported_program = export(Mod(), example_inputs) 526 print(exported_program) 527 528.. code-block:: 529 530 ExportedProgram: 531 class GraphModule(torch.nn.Module): 532 def forward(self, arg0_1: f32[10, 2]): 533 add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1); 534 return (add,) 535 536The conditional of (``x.shape[0] > 5``) does not appear in the 537``ExportedProgram`` because the example inputs have the static 538shape of (10, 2). Since ``torch.export`` specializes on the inputs' static 539shapes, the else branch (``x - 1``) will never be reached. To preserve the dynamic 540branching behavior based on the shape of a tensor in the traced graph, 541:func:`torch.export.Dim` will need to be used to specify the dimension 542of the input tensor (``x.shape[0]``) to be dynamic, and the source code will 543need to be :ref:`rewritten <Data/Shape-Dependent Control Flow>`. 544 545Note that tensors that are part of the module state (e.g. parameters and 546buffers) always have static shapes. 547 548Python Primitives 549~~~~~~~~~~~~~~~~~ 550 551``torch.export`` also specializes on Python primtivies, 552such as ``int``, ``float``, ``bool``, and ``str``. However they do have dynamic 553variants such as ``SymInt``, ``SymFloat``, and ``SymBool``. 554 555For example: 556 557:: 558 559 import torch 560 from torch.export import export 561 562 class Mod(torch.nn.Module): 563 def forward(self, x: torch.Tensor, const: int, times: int): 564 for i in range(times): 565 x = x + const 566 return x 567 568 example_inputs = (torch.rand(2, 2), 1, 3) 569 exported_program = export(Mod(), example_inputs) 570 print(exported_program) 571 572.. code-block:: 573 574 ExportedProgram: 575 class GraphModule(torch.nn.Module): 576 def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1): 577 add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1); 578 add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1); 579 add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1); 580 return (add_2,) 581 582Because integers are specialized, the ``torch.ops.aten.add.Tensor`` operations 583are all computed with the hard-coded constant ``1``, rather than ``arg1_1``. If 584a user passes a different value for ``arg1_1`` at runtime, like 2, than the one used 585during export time, 1, this will result in an error. 586Additionally, the ``times`` iterator used in the ``for`` loop is also "inlined" 587in the graph through the 3 repeated ``torch.ops.aten.add.Tensor`` calls, and the 588input ``arg2_1`` is never used. 589 590Python Containers 591~~~~~~~~~~~~~~~~~ 592 593Python containers (``List``, ``Dict``, ``NamedTuple``, etc.) are considered to 594have static structure. 595 596 597.. _Limitations of torch.export: 598 599Limitations of torch.export 600--------------------------- 601 602Graph Breaks 603^^^^^^^^^^^^ 604 605As ``torch.export`` is a one-shot process for capturing a computation graph from 606a PyTorch program, it might ultimately run into untraceable parts of programs as 607it is nearly impossible to support tracing all PyTorch and Python features. In 608the case of ``torch.compile``, an unsupported operation will cause a "graph 609break" and the unsupported operation will be run with default Python evaluation. 610In contrast, ``torch.export`` will require users to provide additional 611information or rewrite parts of their code to make it traceable. As the 612tracing is based on TorchDynamo, which evaluates at the Python 613bytecode level, there will be significantly fewer rewrites required compared to 614previous tracing frameworks. 615 616When a graph break is encountered, :ref:`ExportDB <torch.export_db>` is a great 617resource for learning about the kinds of programs that are supported and 618unsupported, along with ways to rewrite programs to make them traceable. 619 620An option to get past dealing with this graph breaks is by using 621:ref:`non-strict export <Non-Strict Export>` 622 623.. _Data/Shape-Dependent Control Flow: 624 625Data/Shape-Dependent Control Flow 626^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 627 628Graph breaks can also be encountered on data-dependent control flow (``if 629x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot 630possibly deal with without generating code for a combinatorially exploding 631number of paths. In such cases, users will need to rewrite their code using 632special control flow operators. Currently, we support :ref:`torch.cond <cond>` 633to express if-else like control flow (more coming soon!). 634 635Missing Fake/Meta/Abstract Kernels for Operators 636^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 637 638When tracing, a FakeTensor kernel (aka meta kernel, abstract impl) is 639required for all operators. This is used to reason about the input/output shapes 640for this operator. 641 642Please see :func:`torch.library.register_fake` for more details. 643 644In the unfortunate case where your model uses an ATen operator that is does not 645have a FakeTensor kernel implementation yet, please file an issue. 646 647 648Read More 649--------- 650 651.. toctree:: 652 :caption: Additional Links for Export Users 653 :maxdepth: 1 654 655 export.ir_spec 656 torch.compiler_transformations 657 torch.compiler_ir 658 generated/exportdb/index 659 cond 660 661.. toctree:: 662 :caption: Deep Dive for PyTorch Developers 663 :maxdepth: 1 664 665 torch.compiler_dynamo_overview 666 torch.compiler_dynamo_deepdive 667 torch.compiler_dynamic_shapes 668 torch.compiler_fake_tensor 669 670 671API Reference 672------------- 673 674.. automodule:: torch.export 675.. autofunction:: export 676.. autofunction:: save 677.. autofunction:: load 678.. autofunction:: register_dataclass 679.. autofunction:: torch.export.dynamic_shapes.Dim 680.. autofunction:: dims 681.. autoclass:: torch.export.dynamic_shapes.ShapesCollection 682 683 .. automethod:: dynamic_shapes 684 685.. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes 686.. autoclass:: Constraint 687.. autoclass:: ExportedProgram 688 689 .. automethod:: module 690 .. automethod:: buffers 691 .. automethod:: named_buffers 692 .. automethod:: parameters 693 .. automethod:: named_parameters 694 .. automethod:: run_decompositions 695 696.. autoclass:: ExportBackwardSignature 697.. autoclass:: ExportGraphSignature 698.. autoclass:: ModuleCallSignature 699.. autoclass:: ModuleCallEntry 700 701 702.. automodule:: torch.export.exported_program 703.. automodule:: torch.export.graph_signature 704.. autoclass:: InputKind 705.. autoclass:: InputSpec 706.. autoclass:: OutputKind 707.. autoclass:: OutputSpec 708.. autoclass:: ExportGraphSignature 709 710 .. automethod:: replace_all_uses 711 .. automethod:: get_replace_hook 712 713.. autoclass:: torch.export.graph_signature.CustomObjArgument 714 715.. py:module:: torch.export.dynamic_shapes 716 717.. automodule:: torch.export.unflatten 718 :members: 719 720.. automodule:: torch.export.custom_obj 721 722.. automodule:: torch.export.experimental 723.. automodule:: torch.export.passes 724.. autofunction:: torch.export.passes.move_to_device_pass 725