1Frequently Asked Questions 2========================== 3**Author**: `Mark Saroufim <https://github.com/msaroufim>`_ 4 5Does ``torch.compile`` support training? 6~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 7 8``torch.compile`` supports training, using AOTAutograd to capture backwards: 9 101. The ``.forward()`` graph and ``optimizer.step()`` is captured by 11 TorchDynamo’s python ``evalframe`` frontend. 122. For each segment of ``.forward()`` that torchdynamo captures, it uses 13 AOTAutograd to generate a backward graph segment. 143. Each pair of forward and backward graph are (optionally) min-cut 15 partitioned to save the minimal state between forward and backward. 164. The forward and backward pairs are wrapped in ``autograd.function`` modules. 175. Usercode calling\ ``.backward()`` still triggers eager’s autograd engine, 18 which runs each *compiled backward* graph as if it were one op, also running 19 any non-compiled eager ops’ ``.backward()`` functions. 20 21Do you support Distributed code? 22~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 23 24``torch.compile`` supports ``DistributedDataParallel`` (DDP). 25Support for other distributed training libraries is being considered. 26 27The main reason why Distributed code is challenging with dynamo is 28because AOTAutograd unrolls both the forward and backward pass and 29provides 2 graphs for backends to optimize. This is a problem for 30distributed code because we’d like to ideally overlap communication 31operations with computations. Eager pytorch accomplishes this in 32different ways for DDP/FSDP- using autograd hooks, module hooks, and 33modifications/mutations of module states. In a naive application of 34dynamo, hooks that should run directly after an operation during 35backwards may be delayed until after the entire compiled region of 36backwards ops, due to how AOTAutograd compiled functions interact with 37dispatcher hooks. 38 39The basic strategy for optimizing DDP with Dynamo is outlined in 40`distributed.py <https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/backends/distributed.py>`__ 41where the main idea will be to graph break on `DDP bucket 42boundaries <https://pytorch.org/docs/stable/notes/ddp.html#internal-design>`__. 43 44When each node in DDP needs to synchronize its weights with the other 45nodes it organizes its gradients and parameters into buckets which 46reduces communication times and allows a node to broadcast a fraction of 47its gradients to other waiting nodes. 48 49Graph breaks in distributed code mean you can expect dynamo and its 50backends to optimize the compute overhead of a distributed program but 51not its communication overhead. Graph-breaks may interfere with 52compilation speedups, if the reduced graph-size robs the compiler of 53fusion opportunities. However, there are diminishing returns with 54increasing graph size since most of the current compute optimizations 55are local fusions. So in practice this approach may be sufficient. 56 57Do I still need to export whole graphs? 58~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 59 60For the vast majority of models you probably don’t and you can use 61``torch.compile()`` as is but there are a few situations where 62full graphs are necessary and you can can ensure a full graph by simply 63running ``torch.compile(..., fullgraph=True)``. These situations include: 64 65* Large scale training runs, such as $250K+ that require pipeline parallelism 66 and other advanced sharding strategies. 67 68* Inference optimizers like `TensorRT <https://github.com/pytorch/TensorRT>`__ 69 or `AITemplate <https://github.com/facebookincubator/AITemplate>`__ that 70 rely on fusing much more aggressively than training optimizers. 71 72* Mobile training or inference. 73 74Future work will include tracing communication operations into graphs, 75coordinating these operations with compute optimizations, and optimizing 76the communication operations. 77 78Why is my code crashing? 79~~~~~~~~~~~~~~~~~~~~~~~~ 80 81If your code ran just fine without ``torch.compile`` and started to 82crash with it is enabled, then the most important first step is figuring 83out which part of the stack your failure occurred. To troubleshoot that, 84follow the steps below and only try the next step if the previous one 85succeeded. 86 871. ``torch.compile(..., backend="eager")`` which only runs TorchDynamo 88 forward graph capture and then runs the captured graph with PyTorch. 89 If this fails then there’s an issue with TorchDynamo. 90 912. ``torch.compile(..., backend="aot_eager")`` 92 which runs TorchDynamo to capture a forward graph, and then AOTAutograd 93 to trace the backward graph without any additional backend compiler 94 steps. PyTorch eager will then be used to run the forward and backward 95 graphs. If this fails then there’s an issue with AOTAutograd. 96 973. ``torch.compile(..., backend="inductor")`` which runs TorchDynamo to capture a 98 forward graph, and then AOTAutograd to trace the backward graph with the 99 TorchInductor compiler. If this fails then there’s an issue with TorchInductor 100 101Why is compilation slow? 102~~~~~~~~~~~~~~~~~~~~~~~~ 103 104* **Dynamo Compilation**– TorchDynamo has a builtin stats function for 105 collecting and displaying the time spent in each compilation phase. 106 These stats can be accessed by calling ``torch._dynamo.utils.compile_times()`` 107 after executing ``torch._dynamo``. By default, this returns a string 108 representation of the compile times spent in each TorchDynamo function by name. 109 110* **Inductor Compilation**– TorchInductor has a builtin stats and trace function 111 for displaying time spent in each compilation phase, output code, output 112 graph visualization and IR dump. ``env TORCH_COMPILE_DEBUG=1 python repro.py``. 113 This is a debugging tool designed to make it easier to debug/understand the 114 internals of TorchInductor with an output that will look something like 115 `this <https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396>`__ 116 Each file in that debug trace can be enabled/disabled via 117 ``torch._inductor.config.trace.*``. The profile and the diagram are both 118 disabled by default since they are expensive to generate. See the 119 `example debug directory 120 output <https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396>`__ 121 for more examples. 122 123* **Excessive Recompilation** 124 When TorchDynamo compiles a function (or part of one), it makes certain 125 assumptions about locals and globals in order to allow compiler 126 optimizations, and expresses these assumptions as guards that check 127 particular values at runtime. If any of these guards fail, Dynamo will 128 recompile that function (or part) up to 129 ``torch._dynamo.config.cache_size_limit`` times. If your program is 130 hitting the cache limit, you will first need to determine which guard is 131 failing and what part of your program is triggering it. The 132 `recompilation profiler <#recompilation-profiler>`__ automates the 133 process of setting TorchDynamo’s cache limit to 1 and running your 134 program under an observation-only ‘compiler’ that records the causes of 135 any guard failures. You should be sure to run your program for at least 136 as long (as many iterations) as you were running when you ran into 137 trouble, and the profiler will accumulate statistics over this duration. 138 139 140Why are you recompiling in production? 141~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 142 143In some cases, you may not want unexpected compiles after a program has 144warmed up. For example, if you are serving production traffic in a 145latency critical application. For this, TorchDynamo provides an 146alternate mode where prior compiled graphs are used, but no new ones are 147generated: 148 149.. code-block:: python 150 151 frozen_toy_example = dynamo.run(toy_example) 152 frozen_toy_example(torch.randn(10), torch.randn(10)) 153 154How are you speeding up my code? 155~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 156 157There are 3 major ways to accelerate PyTorch code: 158 1591. Kernel fusion via vertical fusions which fuse sequential operations to avoid 160 excessive read/writes. For example, fuse 2 subsequent cosines means you 161 can can do 1 read 1 write instead 2 reads 2 writes 2. Horizontal fusion: 162 the simplest example being batching where a single matrix is multiplied 163 with a batch of examples but the more general scenario is a grouped GEMM 164 where a group of matrix multiplications are scheduled together 165 1662. Out of order execution: A general optimization for compilers, by looking ahead 167 at the exact data dependencies within a graph we can decide on the most 168 opportune time to execute a node and which buffers can be reused 169 1703. Automatic work placement: Similar of the out of order execution point, 171 but by matching nodes of a graph to resources like physical hardware or 172 memory we can design an appropriate schedule 173 174The above are general principles for accelerating PyTorch code but 175different backends will each make different tradeoffs on what to 176optimize. For example Inductor first takes care of fusing whatever it 177can and only then generates `Triton <https://openai.com/blog/triton/>`__ 178kernels. 179 180Triton in addition offers speedups because of automatic memory 181coalescing, memory management and scheduling within each Streaming 182Multiprocessor and has been designed to handle tiled computations. 183 184However, regardless of the backend you use it’s best to use a benchmark 185and see approach so try out the PyTorch profiler, visually inspect the 186generated kernels and try to see what’s going on for yourself. 187 188Why am I not seeing speedups? 189~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 190 191.. _torch.compiler_graph_breaks: 192 193Graph Breaks 194------------ 195 196The main reason you won’t see the speedups you’d like to by using dynamo 197is excessive graph breaks. So what’s a graph break? 198 199Given a program like: 200 201.. code-block:: python 202 203 def some_fun(x): 204 ... 205 206 torch.compile(some_fun)(x) 207 ... 208 209Torchdynamo will attempt to compile all of the torch/tensor operations 210within ``some_fun()`` into a single FX graph, but it may fail to capture 211everything into one graph. 212 213Some graph break reasons are insurmountable to TorchDynamo like calling 214into a C extension other than PyTorch is invisible to TorchDynamo, and 215could do arbitrary things without TorchDynamo being able to introduce 216necessary guards to ensure that the compiled program would be safe to reuse. 217 218 To maximize performance, it’s important to have as few graph breaks 219 as possible. 220 221Identifying the cause of a graph break 222-------------------------------------- 223 224To identify all graph breaks in a program and the associated reasons for 225the breaks, ``torch._dynamo.explain`` can be used. This tool runs 226TorchDynamo on the supplied function and aggregates the graph breaks 227that are encountered. Here is an example usage: 228 229.. code-block:: python 230 231 import torch 232 import torch._dynamo as dynamo 233 def toy_example(a, b): 234 x = a / (torch.abs(a) + 1) 235 print("woo") 236 if b.sum() < 0: 237 b = b * -1 238 return x * b 239 explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10)) 240 print(explanation) 241 """ 242 Graph Count: 3 243 Graph Break Count: 2 244 Op Count: 5 245 Break Reasons: 246 Break Reason 1: 247 Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False 248 User Stack: 249 <FrameSummary file foo.py, line 5 in toy_example> 250 Break Reason 2: 251 Reason: generic_jump TensorVariable() 252 User Stack: 253 <FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5> 254 Ops per Graph: 255 ... 256 Out Guards: 257 ... 258 """ 259 260To throw an error on the first graph break encountered you can 261disable python fallbacks by using ``fullgraph=True``, this should be 262familiar if you’ve worked with export based compilers. 263 264.. code-block:: python 265 266 def toy_example(a, b): 267 ... 268 269 torch.compile(toy_example, fullgraph=True, backend=<compiler>)(a, b) 270 271Why didn’t my code recompile when I changed it? 272----------------------------------------------- 273 274If you enabled dynamic shapes by setting 275``env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py`` then your code 276won’t recompile on shape changes. We’ve added support for dynamic shapes 277which avoids recompilations in the case when shapes vary by less than a 278factor of 2. This is especially useful in scenarios like varying image 279sizes in CV or variable sequence length in NLP. In inference scenarios 280it’s often not possible to know what a batch size will be beforehand 281because you take what you can get from different client apps. 282 283In general, TorchDynamo tries very hard not to recompile things 284unnecessarily so if for example TorchDynamo finds 3 graphs and your 285change only modified one graph then only that graph will recompile. So 286another tip to avoid potentially slow compilation times is to warmup a 287model by compiling it once after which subsequent compilations will be 288much faster. Cold start compile times is still a metric we track 289visibly. 290 291Why am I getting incorrect results? 292~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 293 294Accuracy issues can also be minified if you set the environment variable 295``TORCHDYNAMO_REPRO_LEVEL=4``, it operates with a similar git bisect 296model and a full repro might be something like 297``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` the reason 298we need this is downstream compilers will codegen code whether it’s 299Triton code or the C++ backend, the numerics from those downstream 300compilers can be different in subtle ways yet have dramatic impact on 301your training stability. So the accuracy debugger is very useful for us 302to detect bugs in our codegen or with a backend compiler. 303 304If you'd like to ensure that random number generation is the same across both torch 305and triton then you can enable ``torch._inductor.config.fallback_random = True`` 306 307Why am I getting OOMs? 308~~~~~~~~~~~~~~~~~~~~~~ 309 310Dynamo is still an alpha product so there’s a few sources of OOMs and if 311you’re seeing an OOM try disabling the following configurations in this 312order and then open an issue on GitHub so we can solve the root problem 3131. If you’re using dynamic shapes try disabling them, we’ve disabled 314them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2. 315CUDA graphs with Triton are enabled by default in inductor but removing 316them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``. 317 318Does ``torch.func`` work with ``torch.compile`` (for `grad` and `vmap` transforms)? 319~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 320 321Applying a ``torch.func`` transform to a function that uses ``torch.compile`` 322does work: 323 324.. code-block:: python 325 326 import torch 327 328 @torch.compile 329 def f(x): 330 return torch.sin(x) 331 332 def g(x): 333 return torch.grad(f)(x) 334 335 x = torch.randn(2, 3) 336 g(x) 337 338Calling ``torch.func`` transform inside of a function handled with ``torch.compile`` 339------------------------------------------------------------------------------------ 340 341 342Compiling ``torch.func.grad`` with ``torch.compile`` 343---------------------------------------------------- 344 345.. code-block:: python 346 347 import torch 348 349 def wrapper_fn(x): 350 return torch.func.grad(lambda x: x.sin().sum())(x) 351 352 x = torch.randn(3, 3, 3) 353 grad_x = torch.compile(wrapper_fn)(x) 354 355Compiling ``torch.vmap`` with ``torch.compile`` 356----------------------------------------------- 357 358.. code-block:: python 359 360 import torch 361 362 def my_fn(x): 363 return torch.vmap(lambda x: x.sum(1))(x) 364 365 x = torch.randn(3, 3, 3) 366 output = torch.compile(my_fn)(x) 367 368 369Compiling functions besides the ones which are supported (escape hatch) 370----------------------------------------------------------------------- 371 372For other transforms, as a workaround, use ``torch._dynamo.allow_in_graph`` 373 374``allow_in_graph`` is an escape hatch. If your code does not work with 375``torch.compile``, which introspects Python bytecode, but you believe it 376will work via a symbolic tracing approach (like ``jax.jit``), then use 377``allow_in_graph``. 378 379By using ``allow_in_graph`` to annotate a function, you must make sure 380your code meets the following requirements: 381 382- All outputs in your function only depend on the inputs and 383 do not depend on any captured Tensors. 384- Your function is functional. That is, it does not mutate any state. This may 385 be relaxed; we actually support functions that appear to be functional from 386 the outside: they may have in-place PyTorch operations, but may not mutate 387 global state or inputs to the function. 388- Your function does not raise data-dependent errors. 389 390.. code-block:: python 391 392 import torch 393 394 @torch.compile 395 def f(x): 396 return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x) 397 398 x = torch.randn(2, 3) 399 f(x) 400 401A common pitfall is using ``allow_in_graph`` to annotate a function that 402invokes an ``nn.Module``. This is because the outputs now depend on the 403parameters of the ``nn.Module``. To get this to work, use 404``torch.func.functional_call`` to extract the module state. 405 406Does NumPy work with ``torch.compile``? 407~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 408 409Starting in 2.1, ``torch.compile`` understands native NumPy programs that 410work on NumPy arrays, and mixed PyTorch-NumPy programs that convert from PyTorch 411to NumPy and back via ``x.numpy()``, ``torch.from_numpy``, and related functions. 412 413.. _nonsupported-numpy-feats: 414 415Which NumPy features does ``torch.compile`` support? 416---------------------------------------------------- 417 418NumPy within ``torch.compile`` follows NumPy 2.0 pre-release. 419 420Generally, ``torch.compile`` is able to trace through most NumPy constructions, 421and when it cannot, it falls back to eager and lets NumPy execute that piece of 422code. Even then, there are a few features where ``torch.compile`` semantics 423slightly deviate from those of NumPy: 424 425- NumPy scalars: We model them as 0-D arrays. That is, ``np.float32(3)`` returns 426 a 0-D array under ``torch.compile``. To avoid a graph break, it is best to use this 0-D 427 array. If this breaks your code, you can workaround this by casting the NumPy scalar 428 to the relevant Python scalar type ``bool/int/float``. 429 430- Negative strides: ``np.flip`` and slicing with a negative step return a copy. 431 432- Type promotion: NumPy's type promotion will change in NumPy 2.0. The new rules 433 are described in `NEP 50 <https://numpy.org/neps/nep-0050-scalar-promotion.html)>`__. 434 ``torch.compile`` implements NEP 50 rather than the current soon-to-be deprecated rules. 435 436- ``{tril,triu}_indices_from/{tril,triu}_indices`` return arrays rather than a tuple of arrays. 437 438There are other features for which we do not support tracing and we gracefully 439fallback to NumPy for their execution: 440 441- Non-numeric dtypes like datetimes, strings, chars, void, structured dtypes and recarrays. 442 443- Long dtypes ``np.float128/np.complex256`` and some unsigned dtypes ``np.uint16/np.uint32/np.uint64``. 444 445- ``ndarray`` subclasses. 446 447- Masked arrays. 448 449- Esoteric ufunc machinery like ``axes=[(n,k),(k,m)->(n,m)]`` and ufunc methods (e.g., ``np.add.reduce``). 450 451- Sorting / ordering ``complex64/complex128`` arrays. 452 453- NumPy ``np.poly1d`` and ``np.polynomial``. 454 455- Positional ``out1, out2`` args in functions with 2 or more returns (``out=tuple`` does work). 456 457- ``__array_function__``, ``__array_interface__`` and ``__array_wrap__``. 458 459- ``ndarray.ctypes`` attribute. 460 461Can I compile NumPy code using ``torch.compile``? 462------------------------------------------------- 463 464Of course you do! ``torch.compile`` understands NumPy code natively, and treats it 465as if it were PyTorch code. To do so, simply wrap NumPy code with the ``torch.compile`` 466decorator. 467 468.. code-block:: python 469 470 import torch 471 import numpy as np 472 473 @torch.compile 474 def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray: 475 return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) 476 477 X = np.random.randn(1024, 64) 478 Y = np.random.randn(1024, 64) 479 Z = numpy_fn(X, Y) 480 assert isinstance(Z, np.ndarray) 481 482Executing this example with the environment variable ``TORCH_LOGS=output_code``, we can see 483that ``torch.compile`` was able to fuse the multiplication and the sum into one C++ kernel. 484It was also able to execute them in parallel using OpenMP (native NumPy is single-threaded). 485This can easily make your NumPy code ``n`` times faster, where ``n`` is the number of cores 486in your processor! 487 488Tracing NumPy code this way also supports graph breaks within the compiled code. 489 490Can I execute NumPy code on CUDA and compute gradients via ``torch.compile``? 491----------------------------------------------------------------------------- 492 493Yes you can! To do so, you may simply execute your code within a ``torch.device("cuda")`` 494context. Consider the example 495 496.. code-block:: python 497 498 import torch 499 import numpy as np 500 501 @torch.compile 502 def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray: 503 return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) 504 505 X = np.random.randn(1024, 64) 506 Y = np.random.randn(1024, 64) 507 with torch.device("cuda"): 508 Z = numpy_fn(X, Y) 509 assert isinstance(Z, np.ndarray) 510 511In this example, ``numpy_fn`` will be executed in CUDA. For this to be 512possible, ``torch.compile`` automatically moves ``X`` and ``Y`` from CPU 513to CUDA, and then it moves the result ``Z`` from CUDA to CPU. If we are 514executing this function several times in the same program run, we may want 515to avoid all these rather expensive memory copies. To do so, we just need 516to tweak our ``numpy_fn`` so that it accepts cuda Tensors and returns tensors. 517We can do so by using ``torch.compiler.wrap_numpy``: 518 519.. code-block:: python 520 521 @torch.compile(fullgraph=True) 522 @torch.compiler.wrap_numpy 523 def numpy_fn(X, Y): 524 return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) 525 526 X = torch.randn(1024, 64, device="cuda") 527 Y = torch.randn(1024, 64, device="cuda") 528 Z = numpy_fn(X, Y) 529 assert isinstance(Z, torch.Tensor) 530 assert Z.device.type == "cuda" 531 532Here, we explicitly create the tensors in CUDA memory, and pass them to the 533function, which performs all the computations on the CUDA device. 534``wrap_numpy`` is in charge of marking any ``torch.Tensor`` input as an input 535with ``np.ndarray`` semantics at a ``torch.compile`` level. Marking tensors 536inside the compiler is a very cheap operation, so no data copy or data movement 537happens during runtime. 538 539Using this decorator, we can also differentiate through NumPy code! 540 541.. code-block:: python 542 543 @torch.compile(fullgraph=True) 544 @torch.compiler.wrap_numpy 545 def numpy_fn(X, Y): 546 return np.mean(np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))) 547 548 X = torch.randn(1024, 64, device="cuda", requires_grad=True) 549 Y = torch.randn(1024, 64, device="cuda") 550 Z = numpy_fn(X, Y) 551 assert isinstance(Z, torch.Tensor) 552 Z.backward() 553 # X.grad now holds the gradient of the computation 554 print(X.grad) 555 556We have been using ``fullgraph=True`` as graph break are problematic in this context. 557When a graph break occurs, we need to materialize the NumPy arrays. Since NumPy arrays 558do not have a notion of ``device`` or ``requires_grad``, this information is lost during 559a graph break. 560 561We cannot propagate gradients through a graph break, as the graph break code may execute 562arbitrary code that don't know how to differentiate. On the other hand, in the case of 563the CUDA execution, we can work around this problem as we did in the first example, by 564using the ``torch.device("cuda")`` context manager: 565 566.. code-block:: python 567 568 @torch.compile 569 @torch.compiler.wrap_numpy 570 def numpy_fn(X, Y): 571 prod = X[:, :, None] * Y[:, None, :] 572 print("oops, a graph break!") 573 return np.sum(prod, axis=(-2, -1)) 574 575 X = torch.randn(1024, 64, device="cuda") 576 Y = torch.randn(1024, 64, device="cuda") 577 578 with torch.device("cuda"): 579 Z = numpy_fn(X, Y) 580 assert isinstance(Z, torch.Tensor) 581 assert Z.device.type == "cuda" 582 583During the graph break, the intermediary tensors still need to be moved to CPU, but when the 584tracing is resumed after the graph break, the rest of the graph is still traced on CUDA. 585Given this CUDA <> CPU and CPU <> CUDA movement, graph breaks are fairly costly in the NumPy 586context and should be avoided, but at least they allow tracing through complex pieces of code. 587 588 589How do I debug NumPy code under ``torch.compile``? 590-------------------------------------------------- 591 592Debugging JIT compiled code is challenging, given the complexity of modern 593compilers and the daunting errors that they raise. 594`The tutorial on how to diagnose runtime errors within torch.compile <https://pytorch.org/docs/main/torch.compiler_troubleshooting.html#diagnosing-runtime-errors>`__ 595contains a few tips and tricks on how to tackle this task. 596 597If the above is not enough to pinpoint the origin of the issue, there are still 598a few other NumPy-specific tools we can use. We can discern whether the bug 599is entirely in the PyTorch code by disabling tracing through NumPy functions: 600 601 602.. code-block:: python 603 604 from torch._dynamo import config 605 config.trace_numpy = False 606 607If the bug lies in the traced NumPy code, we can execute the NumPy code eagerly (without ``torch.compile``) 608using PyTorch as a backend by importing ``import torch._numpy as np``. 609This should just be used for **debugging purposes** and is in no way a 610replacement for the PyTorch API, as it is **much less performant** and, as a 611private API, **may change without notice**. At any rate, ``torch._numpy`` is a 612Python implementation of NumPy in terms of PyTorch and it is used internally by ``torch.compile`` to 613transform NumPy code into Pytorch code. It is rather easy to read and modify, 614so if you find any bug in it feel free to submit a PR fixing it or simply open 615an issue. 616 617If the program does work when importing ``torch._numpy as np``, chances are 618that the bug is in TorchDynamo. If this is the case, please feel open an issue 619with a `minimal reproducer <https://pytorch.org/docs/2.1/torch.compiler_troubleshooting.html>`__. 620 621I ``torch.compile`` some NumPy code and I did not see any speed-up. 622------------------------------------------------------------------- 623 624The best place to start is the 625`tutorial with general advice for how to debug these sort of torch.compile issues <https://pytorch.org/docs/main/torch.compiler_faq.html#why-am-i-not-seeing-speedups>`__. 626 627Some graph breaks may happen because of the use of unsupported features. See 628:ref:`nonsupported-numpy-feats`. More generally, it is useful to keep in mind 629that some widely used NumPy features do not play well with compilers. For 630example, in-place modifications make reasoning difficult within the compiler and 631often yield worse performance than their out-of-place counterparts.As such, it is best to avoid 632them. Same goes for the use of the ``out=`` parameter. Instead, prefer 633out-of-place ops and let ``torch.compile`` optimize the memory use. Same goes 634for data-dependent ops like masked indexing through boolean masks, or 635data-dependent control flow like ``if`` or ``while`` constructions. 636 637 638Which API to use for fine grain tracing? 639~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 640 641In some cases, you might need to exclude small parts of your code from the 642torch.compile compilations. This section provides some of the answers and 643you can find more information in :ref:`torchdynamo_fine_grain_tracing`. 644 645How do I graph break on a function? 646----------------------------------- 647 648Graph break on a function is not enough to sufficiently express what you want 649PyTorch to do. You need to be more specific about your use case. Some of the 650most common use cases you might want to consider: 651 652* If you want to disable compilation on this function frame and the recursively 653 invoked frames, use ``torch._dynamo.disable``. 654 655* If you want a particular operator, such as ``fbgemm`` to use the eager mode, 656 use ``torch._dynamo.disallow_in_graph``. 657 658Some of the uncommon use cases include: 659 660* If you want to disable TorchDynamo on the function frame but enable it back 661 on the recursively invoked frames – use ``torch._dynamo.disable(recursive=False)``. 662 663* If you want to prevent inlining of a function frame – use ``torch._dynamo.graph_break`` 664 at the beginning of the function you want to prevent inlining. 665 666What's the difference between ``torch._dynamo.disable`` and ``torch._dynamo.disallow_in_graph`` 667----------------------------------------------------------------------------------------------- 668 669Disallow-in-graph works at the level of operators, or more specifically, 670the operators that you see in the TorchDynamo extracted graphs. 671 672Disable works at the function frame level and decides if TorchDynamo 673should look into the function frame or not. 674 675What's the difference between ``torch._dynamo.disable`` and ``torch._dynamo_skip`` 676---------------------------------------------------------------------------------- 677 678.. note:: 679 ``torch._dynamo_skip`` is deprecated. 680 681You most likely need ``torch._dynamo.disable``. But in an unlikely scenario, you 682might need even finer control. Suppose you want to disable the tracing on just 683the ``a_fn`` function, but want to continue the tracing back in ``aa_fn`` and 684``ab_fn``. The image below demonstrates this use case: 685 686 687.. figure:: _static/img/fine_grained_apis/call_stack_diagram.png 688 :alt: diagram of torch.compile + disable(a_fn, recursive=False) 689 690In this case, you can use ``torch._dynamo.disable(recursive=False)``. 691In previous versions, this functionality was provided by ``torch._dynamo.skip``. 692This is now supported by the ``recursive`` flag inside ``torch._dynamo.disable``. 693