xref: /aosp_15_r20/external/pytorch/docs/source/torch.compiler_faq.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1Frequently Asked Questions
2==========================
3**Author**: `Mark Saroufim <https://github.com/msaroufim>`_
4
5Does ``torch.compile`` support training?
6~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7
8``torch.compile`` supports training, using AOTAutograd to capture backwards:
9
101. The ``.forward()`` graph and ``optimizer.step()`` is captured by
11   TorchDynamo’s python ``evalframe`` frontend.
122. For each segment of ``.forward()`` that torchdynamo captures, it uses
13   AOTAutograd to generate a backward graph segment.
143. Each pair of forward and backward graph are (optionally) min-cut
15   partitioned to save the minimal state between forward and backward.
164. The forward and backward pairs are wrapped in ``autograd.function`` modules.
175. Usercode calling\ ``.backward()`` still triggers eager’s autograd engine,
18   which runs each *compiled backward* graph as if it were one op, also running
19   any non-compiled eager ops’ ``.backward()`` functions.
20
21Do you support Distributed code?
22~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
23
24``torch.compile`` supports ``DistributedDataParallel`` (DDP).
25Support for other distributed training libraries is being considered.
26
27The main reason why Distributed code is challenging with dynamo is
28because AOTAutograd unrolls both the forward and backward pass and
29provides 2 graphs for backends to optimize. This is a problem for
30distributed code because we’d like to ideally overlap communication
31operations with computations. Eager pytorch accomplishes this in
32different ways for DDP/FSDP- using autograd hooks, module hooks, and
33modifications/mutations of module states. In a naive application of
34dynamo, hooks that should run directly after an operation during
35backwards may be delayed until after the entire compiled region of
36backwards ops, due to how AOTAutograd compiled functions interact with
37dispatcher hooks.
38
39The basic strategy for optimizing DDP with Dynamo is outlined in
40`distributed.py <https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/backends/distributed.py>`__
41where the main idea will be to graph break on `DDP bucket
42boundaries <https://pytorch.org/docs/stable/notes/ddp.html#internal-design>`__.
43
44When each node in DDP needs to synchronize its weights with the other
45nodes it organizes its gradients and parameters into buckets which
46reduces communication times and allows a node to broadcast a fraction of
47its gradients to other waiting nodes.
48
49Graph breaks in distributed code mean you can expect dynamo and its
50backends to optimize the compute overhead of a distributed program but
51not its communication overhead. Graph-breaks may interfere with
52compilation speedups, if the reduced graph-size robs the compiler of
53fusion opportunities. However, there are diminishing returns with
54increasing graph size since most of the current compute optimizations
55are local fusions. So in practice this approach may be sufficient.
56
57Do I still need to export whole graphs?
58~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
59
60For the vast majority of models you probably don’t and you can use
61``torch.compile()`` as is but there are a few situations where
62full graphs are necessary and you can can ensure a full graph by simply
63running ``torch.compile(..., fullgraph=True)``. These situations include:
64
65* Large scale training runs, such as $250K+ that require pipeline parallelism
66  and other advanced sharding strategies.
67
68* Inference optimizers like `TensorRT <https://github.com/pytorch/TensorRT>`__
69  or `AITemplate <https://github.com/facebookincubator/AITemplate>`__ that
70  rely on fusing much more aggressively than training optimizers.
71
72* Mobile training or inference.
73
74Future work will include tracing communication operations into graphs,
75coordinating these operations with compute optimizations, and optimizing
76the communication operations.
77
78Why is my code crashing?
79~~~~~~~~~~~~~~~~~~~~~~~~
80
81If your code ran just fine without ``torch.compile`` and started to
82crash with it is enabled, then the most important first step is figuring
83out which part of the stack your failure occurred. To troubleshoot that,
84follow the steps below and only try the next step if the previous one
85succeeded.
86
871. ``torch.compile(..., backend="eager")`` which only runs TorchDynamo
88   forward graph capture and then runs the captured graph with PyTorch.
89   If this fails then there’s an issue with TorchDynamo.
90
912. ``torch.compile(..., backend="aot_eager")``
92   which runs TorchDynamo to capture a forward graph, and then AOTAutograd
93   to trace the backward graph without any additional backend compiler
94   steps. PyTorch eager will then be used to run the forward and backward
95   graphs. If this fails then there’s an issue with AOTAutograd.
96
973. ``torch.compile(..., backend="inductor")`` which runs TorchDynamo to capture a
98   forward graph, and then AOTAutograd to trace the backward graph with the
99   TorchInductor compiler. If this fails then there’s an issue with TorchInductor
100
101Why is compilation slow?
102~~~~~~~~~~~~~~~~~~~~~~~~
103
104* **Dynamo Compilation**– TorchDynamo has a builtin stats function for
105  collecting and displaying the time spent in each compilation phase.
106  These stats can be accessed by calling ``torch._dynamo.utils.compile_times()``
107  after executing ``torch._dynamo``. By default, this returns a string
108  representation of the compile times spent in each TorchDynamo function by name.
109
110* **Inductor Compilation**– TorchInductor has a builtin stats and trace function
111  for displaying time spent in each compilation phase, output code, output
112  graph visualization and IR dump. ``env TORCH_COMPILE_DEBUG=1 python repro.py``.
113  This is a debugging tool designed to make it easier to debug/understand the
114  internals of TorchInductor with an output that will look something like
115  `this <https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396>`__
116  Each file in that debug trace can be enabled/disabled via
117  ``torch._inductor.config.trace.*``. The profile and the diagram are both
118  disabled by default since they are expensive to generate. See the
119  `example debug directory
120  output <https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396>`__
121  for more examples.
122
123* **Excessive Recompilation**
124  When TorchDynamo compiles a function (or part of one), it makes certain
125  assumptions about locals and globals in order to allow compiler
126  optimizations, and expresses these assumptions as guards that check
127  particular values at runtime. If any of these guards fail, Dynamo will
128  recompile that function (or part) up to
129  ``torch._dynamo.config.cache_size_limit`` times. If your program is
130  hitting the cache limit, you will first need to determine which guard is
131  failing and what part of your program is triggering it. The
132  `recompilation profiler <#recompilation-profiler>`__ automates the
133  process of setting TorchDynamo’s cache limit to 1 and running your
134  program under an observation-only ‘compiler’ that records the causes of
135  any guard failures. You should be sure to run your program for at least
136  as long (as many iterations) as you were running when you ran into
137  trouble, and the profiler will accumulate statistics over this duration.
138
139
140Why are you recompiling in production?
141~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
142
143In some cases, you may not want unexpected compiles after a program has
144warmed up. For example, if you are serving production traffic in a
145latency critical application. For this, TorchDynamo provides an
146alternate mode where prior compiled graphs are used, but no new ones are
147generated:
148
149.. code-block:: python
150
151   frozen_toy_example = dynamo.run(toy_example)
152   frozen_toy_example(torch.randn(10), torch.randn(10))
153
154How are you speeding up my code?
155~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
156
157There are 3 major ways to accelerate PyTorch code:
158
1591. Kernel fusion via vertical fusions which fuse sequential operations to avoid
160   excessive read/writes. For example, fuse 2 subsequent cosines means you
161   can can do 1 read 1 write instead 2 reads 2 writes 2. Horizontal fusion:
162   the simplest example being batching where a single matrix is multiplied
163   with a batch of examples but the more general scenario is a grouped GEMM
164   where a group of matrix multiplications are scheduled together
165
1662. Out of order execution: A general optimization for compilers, by looking ahead
167   at the exact data dependencies within a graph we can decide on the most
168   opportune time to execute a node and which buffers can be reused
169
1703. Automatic work placement: Similar of the out of order execution point,
171   but by matching nodes of a graph to resources like physical hardware or
172   memory we can design an appropriate schedule
173
174The above are general principles for accelerating PyTorch code but
175different backends will each make different tradeoffs on what to
176optimize. For example Inductor first takes care of fusing whatever it
177can and only then generates `Triton <https://openai.com/blog/triton/>`__
178kernels.
179
180Triton in addition offers speedups because of automatic memory
181coalescing, memory management and scheduling within each Streaming
182Multiprocessor and has been designed to handle tiled computations.
183
184However, regardless of the backend you use it’s best to use a benchmark
185and see approach so try out the PyTorch profiler, visually inspect the
186generated kernels and try to see what’s going on for yourself.
187
188Why am I not seeing speedups?
189~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
190
191.. _torch.compiler_graph_breaks:
192
193Graph Breaks
194------------
195
196The main reason you won’t see the speedups you’d like to by using dynamo
197is excessive graph breaks. So what’s a graph break?
198
199Given a program like:
200
201.. code-block:: python
202
203   def some_fun(x):
204       ...
205
206   torch.compile(some_fun)(x)
207   ...
208
209Torchdynamo will attempt to compile all of the torch/tensor operations
210within ``some_fun()`` into a single FX graph, but it may fail to capture
211everything into one graph.
212
213Some graph break reasons are insurmountable to TorchDynamo like calling
214into a C extension other than PyTorch is invisible to TorchDynamo, and
215could do arbitrary things without TorchDynamo being able to introduce
216necessary guards to ensure that the compiled program would be safe to reuse.
217
218   To maximize performance, it’s important to have as few graph breaks
219   as possible.
220
221Identifying the cause of a graph break
222--------------------------------------
223
224To identify all graph breaks in a program and the associated reasons for
225the breaks, ``torch._dynamo.explain`` can be used. This tool runs
226TorchDynamo on the supplied function and aggregates the graph breaks
227that are encountered. Here is an example usage:
228
229.. code-block:: python
230
231   import torch
232   import torch._dynamo as dynamo
233   def toy_example(a, b):
234       x = a / (torch.abs(a) + 1)
235       print("woo")
236       if b.sum() < 0:
237           b = b * -1
238       return x * b
239   explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10))
240   print(explanation)
241   """
242   Graph Count: 3
243   Graph Break Count: 2
244   Op Count: 5
245   Break Reasons:
246     Break Reason 1:
247       Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
248       User Stack:
249         <FrameSummary file foo.py, line 5 in toy_example>
250     Break Reason 2:
251       Reason: generic_jump TensorVariable()
252       User Stack:
253         <FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5>
254   Ops per Graph:
255     ...
256   Out Guards:
257     ...
258   """
259
260To throw an error on the first graph break encountered you can
261disable python fallbacks by using ``fullgraph=True``, this should be
262familiar if you’ve worked with export based compilers.
263
264.. code-block:: python
265
266   def toy_example(a, b):
267      ...
268
269   torch.compile(toy_example, fullgraph=True, backend=<compiler>)(a, b)
270
271Why didn’t my code recompile when I changed it?
272-----------------------------------------------
273
274If you enabled dynamic shapes by setting
275``env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py`` then your code
276won’t recompile on shape changes. We’ve added support for dynamic shapes
277which avoids recompilations in the case when shapes vary by less than a
278factor of 2. This is especially useful in scenarios like varying image
279sizes in CV or variable sequence length in NLP. In inference scenarios
280it’s often not possible to know what a batch size will be beforehand
281because you take what you can get from different client apps.
282
283In general, TorchDynamo tries very hard not to recompile things
284unnecessarily so if for example TorchDynamo finds 3 graphs and your
285change only modified one graph then only that graph will recompile. So
286another tip to avoid potentially slow compilation times is to warmup a
287model by compiling it once after which subsequent compilations will be
288much faster. Cold start compile times is still a metric we track
289visibly.
290
291Why am I getting incorrect results?
292~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
293
294Accuracy issues can also be minified if you set the environment variable
295``TORCHDYNAMO_REPRO_LEVEL=4``, it operates with a similar git bisect
296model and a full repro might be something like
297``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` the reason
298we need this is downstream compilers will codegen code whether it’s
299Triton code or the C++ backend, the numerics from those downstream
300compilers can be different in subtle ways yet have dramatic impact on
301your training stability. So the accuracy debugger is very useful for us
302to detect bugs in our codegen or with a backend compiler.
303
304If you'd like to ensure that random number generation is the same across both torch
305and triton then you can enable ``torch._inductor.config.fallback_random = True``
306
307Why am I getting OOMs?
308~~~~~~~~~~~~~~~~~~~~~~
309
310Dynamo is still an alpha product so there’s a few sources of OOMs and if
311you’re seeing an OOM try disabling the following configurations in this
312order and then open an issue on GitHub so we can solve the root problem
3131. If you’re using dynamic shapes try disabling them, we’ve disabled
314them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2.
315CUDA graphs with Triton are enabled by default in inductor but removing
316them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``.
317
318Does ``torch.func`` work with ``torch.compile`` (for `grad` and `vmap` transforms)?
319~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
320
321Applying a ``torch.func`` transform to a function that uses ``torch.compile``
322does work:
323
324.. code-block:: python
325
326    import torch
327
328    @torch.compile
329    def f(x):
330        return torch.sin(x)
331
332    def g(x):
333        return torch.grad(f)(x)
334
335    x = torch.randn(2, 3)
336    g(x)
337
338Calling ``torch.func`` transform inside of a function handled with ``torch.compile``
339------------------------------------------------------------------------------------
340
341
342Compiling ``torch.func.grad`` with ``torch.compile``
343----------------------------------------------------
344
345.. code-block:: python
346
347    import torch
348
349    def wrapper_fn(x):
350        return torch.func.grad(lambda x: x.sin().sum())(x)
351
352    x = torch.randn(3, 3, 3)
353    grad_x = torch.compile(wrapper_fn)(x)
354
355Compiling ``torch.vmap`` with ``torch.compile``
356-----------------------------------------------
357
358.. code-block:: python
359
360    import torch
361
362    def my_fn(x):
363        return torch.vmap(lambda x: x.sum(1))(x)
364
365    x = torch.randn(3, 3, 3)
366    output = torch.compile(my_fn)(x)
367
368
369Compiling functions besides the ones which are supported (escape hatch)
370-----------------------------------------------------------------------
371
372For other transforms, as a workaround, use ``torch._dynamo.allow_in_graph``
373
374``allow_in_graph`` is an escape hatch. If your code does not work with
375``torch.compile``, which introspects Python bytecode, but you believe it
376will work via a symbolic tracing approach (like ``jax.jit``), then use
377``allow_in_graph``.
378
379By using ``allow_in_graph`` to annotate a function, you must make sure
380your code meets the following requirements:
381
382- All outputs in your function only depend on the inputs and
383  do not depend on any captured Tensors.
384- Your function is functional. That is, it does not mutate any state. This may
385  be relaxed; we actually support functions that appear to be functional from
386  the outside: they may have in-place PyTorch operations, but may not mutate
387  global state or inputs to the function.
388- Your function does not raise data-dependent errors.
389
390.. code-block:: python
391
392    import torch
393
394    @torch.compile
395    def f(x):
396        return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x)
397
398    x = torch.randn(2, 3)
399    f(x)
400
401A common pitfall is using ``allow_in_graph`` to annotate a function that
402invokes an ``nn.Module``. This is because the outputs now depend on the
403parameters of the ``nn.Module``. To get this to work, use
404``torch.func.functional_call`` to extract the module state.
405
406Does NumPy work with ``torch.compile``?
407~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
408
409Starting in 2.1, ``torch.compile`` understands native NumPy programs that
410work on NumPy arrays, and mixed PyTorch-NumPy programs that convert from PyTorch
411to NumPy and back via ``x.numpy()``, ``torch.from_numpy``, and related functions.
412
413.. _nonsupported-numpy-feats:
414
415Which NumPy features does ``torch.compile`` support?
416----------------------------------------------------
417
418NumPy within ``torch.compile`` follows NumPy 2.0 pre-release.
419
420Generally, ``torch.compile`` is able to trace through most NumPy constructions,
421and when it cannot, it falls back to eager and lets NumPy execute that piece of
422code. Even then, there are a few features where ``torch.compile`` semantics
423slightly deviate from those of NumPy:
424
425- NumPy scalars: We model them as 0-D arrays. That is, ``np.float32(3)`` returns
426  a 0-D array under ``torch.compile``. To avoid a graph break, it is best to use this 0-D
427  array. If this breaks your code, you can workaround this by casting the NumPy scalar
428  to the relevant Python scalar type ``bool/int/float``.
429
430- Negative strides: ``np.flip`` and slicing with a negative step return a copy.
431
432- Type promotion: NumPy's type promotion will change in NumPy 2.0. The new rules
433  are described in `NEP 50 <https://numpy.org/neps/nep-0050-scalar-promotion.html)>`__.
434  ``torch.compile`` implements NEP 50 rather than the current soon-to-be deprecated rules.
435
436- ``{tril,triu}_indices_from/{tril,triu}_indices`` return arrays rather than a tuple of arrays.
437
438There are other features for which we do not support tracing and we gracefully
439fallback to NumPy for their execution:
440
441- Non-numeric dtypes like datetimes, strings, chars, void, structured dtypes and recarrays.
442
443- Long dtypes ``np.float128/np.complex256`` and some unsigned dtypes ``np.uint16/np.uint32/np.uint64``.
444
445- ``ndarray`` subclasses.
446
447- Masked arrays.
448
449- Esoteric ufunc machinery like ``axes=[(n,k),(k,m)->(n,m)]`` and ufunc methods (e.g., ``np.add.reduce``).
450
451- Sorting / ordering ``complex64/complex128`` arrays.
452
453- NumPy ``np.poly1d`` and ``np.polynomial``.
454
455- Positional ``out1, out2`` args in functions with 2 or more returns (``out=tuple`` does work).
456
457- ``__array_function__``, ``__array_interface__`` and ``__array_wrap__``.
458
459- ``ndarray.ctypes`` attribute.
460
461Can I compile NumPy code using ``torch.compile``?
462-------------------------------------------------
463
464Of course you do! ``torch.compile`` understands NumPy code natively, and treats it
465as if it were PyTorch code. To do so, simply wrap NumPy code with the ``torch.compile``
466decorator.
467
468.. code-block:: python
469
470   import torch
471   import numpy as np
472
473   @torch.compile
474   def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
475       return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
476
477   X = np.random.randn(1024, 64)
478   Y = np.random.randn(1024, 64)
479   Z = numpy_fn(X, Y)
480   assert isinstance(Z, np.ndarray)
481
482Executing this example with the environment variable ``TORCH_LOGS=output_code``, we can see
483that ``torch.compile`` was able to fuse the multiplication and the sum into one C++ kernel.
484It was also able to execute them in parallel using OpenMP (native NumPy is single-threaded).
485This can easily make your NumPy code ``n`` times faster, where ``n`` is the number of cores
486in your processor!
487
488Tracing NumPy code this way also supports graph breaks within the compiled code.
489
490Can I execute NumPy code on CUDA and compute gradients via ``torch.compile``?
491-----------------------------------------------------------------------------
492
493Yes you can! To do so, you may simply execute your code within a ``torch.device("cuda")``
494context. Consider the example
495
496.. code-block:: python
497
498   import torch
499   import numpy as np
500
501   @torch.compile
502   def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
503       return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
504
505   X = np.random.randn(1024, 64)
506   Y = np.random.randn(1024, 64)
507   with torch.device("cuda"):
508       Z = numpy_fn(X, Y)
509   assert isinstance(Z, np.ndarray)
510
511In this example, ``numpy_fn`` will be executed in CUDA. For this to be
512possible, ``torch.compile`` automatically moves ``X`` and ``Y`` from CPU
513to CUDA, and then it moves the result ``Z`` from CUDA to CPU. If we are
514executing this function several times in the same program run, we may want
515to avoid all these rather expensive memory copies. To do so, we just need
516to tweak our ``numpy_fn`` so that it accepts cuda Tensors and returns tensors.
517We can do so by using ``torch.compiler.wrap_numpy``:
518
519.. code-block:: python
520
521   @torch.compile(fullgraph=True)
522   @torch.compiler.wrap_numpy
523   def numpy_fn(X, Y):
524       return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
525
526   X = torch.randn(1024, 64, device="cuda")
527   Y = torch.randn(1024, 64, device="cuda")
528   Z = numpy_fn(X, Y)
529   assert isinstance(Z, torch.Tensor)
530   assert Z.device.type == "cuda"
531
532Here, we explicitly create the tensors in CUDA memory, and pass them to the
533function, which performs all the computations on the CUDA device.
534``wrap_numpy`` is in charge of marking any ``torch.Tensor`` input as an input
535with ``np.ndarray`` semantics at a ``torch.compile`` level. Marking tensors
536inside the compiler is a very cheap operation, so no data copy or data movement
537happens during runtime.
538
539Using this decorator, we can also differentiate through NumPy code!
540
541.. code-block:: python
542
543   @torch.compile(fullgraph=True)
544   @torch.compiler.wrap_numpy
545   def numpy_fn(X, Y):
546       return np.mean(np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)))
547
548   X = torch.randn(1024, 64, device="cuda", requires_grad=True)
549   Y = torch.randn(1024, 64, device="cuda")
550   Z = numpy_fn(X, Y)
551   assert isinstance(Z, torch.Tensor)
552   Z.backward()
553   # X.grad now holds the gradient of the computation
554   print(X.grad)
555
556We have been using ``fullgraph=True`` as graph break are problematic in this context.
557When a graph break occurs, we need to materialize the NumPy arrays. Since NumPy arrays
558do not have a notion of ``device`` or ``requires_grad``, this information is lost during
559a graph break.
560
561We cannot propagate gradients through a graph break, as the graph break code may execute
562arbitrary code that don't know how to differentiate. On the other hand, in the case of
563the CUDA execution, we can work around this problem as we did in the first example, by
564using the ``torch.device("cuda")`` context manager:
565
566.. code-block:: python
567
568   @torch.compile
569   @torch.compiler.wrap_numpy
570   def numpy_fn(X, Y):
571       prod = X[:, :, None] * Y[:, None, :]
572       print("oops, a graph break!")
573       return np.sum(prod, axis=(-2, -1))
574
575   X = torch.randn(1024, 64, device="cuda")
576   Y = torch.randn(1024, 64, device="cuda")
577
578   with torch.device("cuda"):
579       Z = numpy_fn(X, Y)
580   assert isinstance(Z, torch.Tensor)
581   assert Z.device.type == "cuda"
582
583During the graph break, the intermediary tensors still need to be moved to CPU, but when the
584tracing is resumed after the graph break, the rest of the graph is still traced on CUDA.
585Given this CUDA <> CPU and CPU <> CUDA movement, graph breaks are fairly costly in the NumPy
586context and should be avoided, but at least they allow tracing through complex pieces of code.
587
588
589How do I debug NumPy code under ``torch.compile``?
590--------------------------------------------------
591
592Debugging JIT compiled code is challenging, given the complexity of modern
593compilers and the daunting errors that they raise.
594`The tutorial on how to diagnose runtime errors within torch.compile <https://pytorch.org/docs/main/torch.compiler_troubleshooting.html#diagnosing-runtime-errors>`__
595contains a few tips and tricks on how to tackle this task.
596
597If the above is not enough to pinpoint the origin of the issue, there are still
598a few other NumPy-specific tools we can use. We can discern whether the bug
599is entirely in the PyTorch code by disabling tracing through NumPy functions:
600
601
602.. code-block:: python
603
604   from torch._dynamo import config
605   config.trace_numpy = False
606
607If the bug lies in the traced NumPy code, we can execute the NumPy code eagerly (without ``torch.compile``)
608using PyTorch as a backend by importing ``import torch._numpy as np``.
609This should just be used for **debugging purposes** and is in no way a
610replacement for the PyTorch API, as it is **much less performant** and, as a
611private API, **may change without notice**. At any rate, ``torch._numpy`` is a
612Python implementation of NumPy in terms of PyTorch and it is used internally by ``torch.compile`` to
613transform NumPy code into Pytorch code. It is rather easy to read and modify,
614so if you find any bug in it feel free to submit a PR fixing it or simply open
615an issue.
616
617If the program does work when importing ``torch._numpy as np``, chances are
618that the bug is in TorchDynamo. If this is the case, please feel open an issue
619with a `minimal reproducer <https://pytorch.org/docs/2.1/torch.compiler_troubleshooting.html>`__.
620
621I ``torch.compile`` some NumPy code and I did not see any speed-up.
622-------------------------------------------------------------------
623
624The best place to start is the
625`tutorial with general advice for how to debug these sort of torch.compile issues <https://pytorch.org/docs/main/torch.compiler_faq.html#why-am-i-not-seeing-speedups>`__.
626
627Some graph breaks may happen because of the use of unsupported features. See
628:ref:`nonsupported-numpy-feats`. More generally, it is useful to keep in mind
629that some widely used NumPy features do not play well with compilers. For
630example, in-place modifications make reasoning difficult within the compiler and
631often yield worse performance than their out-of-place counterparts.As such, it is best to avoid
632them. Same goes for the use of the ``out=`` parameter. Instead, prefer
633out-of-place ops and let ``torch.compile`` optimize the memory use. Same goes
634for data-dependent ops like masked indexing through boolean masks, or
635data-dependent control flow like ``if`` or ``while`` constructions.
636
637
638Which API to use for fine grain tracing?
639~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
640
641In some cases, you might need to exclude small parts of your code from the
642torch.compile compilations. This section provides some of the answers and
643you can find more information in :ref:`torchdynamo_fine_grain_tracing`.
644
645How do I graph break on a function?
646-----------------------------------
647
648Graph break on a function is not enough to sufficiently express what you  want
649PyTorch to do. You need to be more specific about your use case. Some of the
650most common use cases you might want to consider:
651
652* If you want to disable compilation on this function frame and the recursively
653  invoked frames, use ``torch._dynamo.disable``.
654
655* If you want a particular operator, such as ``fbgemm`` to use the  eager mode,
656  use ``torch._dynamo.disallow_in_graph``.
657
658Some of the uncommon use cases include:
659
660* If you want to disable TorchDynamo on the function frame but enable it back
661  on the recursively invoked frames – use ``torch._dynamo.disable(recursive=False)``.
662
663* If you want to prevent inlining of a function frame – use ``torch._dynamo.graph_break``
664  at the beginning of the function you want to prevent inlining.
665
666What's the difference between ``torch._dynamo.disable`` and ``torch._dynamo.disallow_in_graph``
667-----------------------------------------------------------------------------------------------
668
669Disallow-in-graph works at the level of operators, or more specifically,
670the operators that you see in the TorchDynamo extracted graphs.
671
672Disable works at the function frame level and decides if TorchDynamo
673should look into the function frame or not.
674
675What's the difference between ``torch._dynamo.disable`` and ``torch._dynamo_skip``
676----------------------------------------------------------------------------------
677
678.. note::
679   ``torch._dynamo_skip`` is deprecated.
680
681You most likely need ``torch._dynamo.disable``. But in an unlikely scenario, you
682might need even finer control. Suppose you want to disable the tracing on just
683the ``a_fn`` function, but want to continue the tracing back in ``aa_fn`` and
684``ab_fn``. The image below demonstrates this use case:
685
686
687.. figure:: _static/img/fine_grained_apis/call_stack_diagram.png
688   :alt: diagram of torch.compile + disable(a_fn, recursive=False)
689
690In this case, you can use ``torch._dynamo.disable(recursive=False)``.
691In previous versions, this functionality was provided by ``torch._dynamo.skip``.
692This is now supported by the ``recursive`` flag inside ``torch._dynamo.disable``.
693