xref: /aosp_15_r20/external/pytorch/torch/_dynamo/eval_frame.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# mypy: disable-error-code="method-assign"
3
4"""
5Functions in this file are responsible for modifying the eval frame
6handler at RUNTIME.  Therefore, all functions in this file are hot.
7Functions that only execute at compile time should be placed
8in torch._dynamo.convert_frame.
9"""
10
11from __future__ import annotations
12
13import contextlib
14import functools
15import inspect
16import logging
17import os
18import sys
19import textwrap
20import traceback
21import types
22import warnings
23import weakref
24from enum import Enum
25from os.path import dirname, join
26from typing import (
27    Any,
28    Callable,
29    Dict,
30    List,
31    NamedTuple,
32    Optional,
33    Set,
34    Tuple,
35    TYPE_CHECKING,
36    Union,
37)
38from unittest.mock import patch
39
40import sympy
41
42import torch
43import torch.fx
44import torch.utils._pytree as pytree
45import torch.utils.checkpoint
46from torch import _guards
47
48# see discussion at https://github.com/pytorch/pytorch/issues/120699
49from torch._C._dynamo.eval_frame import (  # noqa: F401
50    reset_code,
51    set_guard_error_hook,
52    skip_code,
53    unsupported,
54)
55from torch._dispatch.python import enable_python_dispatcher
56from torch._subclasses.fake_tensor import unset_fake_temporarily
57from torch._utils_internal import justknobs_check, log_export_usage
58from torch.export.dynamic_shapes import _combine_args, _process_dynamic_shapes
59from torch.fx import GraphModule
60from torch.fx.experimental.proxy_tensor import make_fx
61from torch.fx.experimental.symbolic_shapes import (
62    ConstraintViolationError,
63    DimDynamic,
64    ShapeEnv,
65    StatelessSymbolicContext,
66)
67from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
68
69from . import config, convert_frame, external_utils, trace_rules, utils
70from .backends.registry import CompilerFn, lookup_backend
71from .code_context import code_context
72from .exc import CondOpArgsMismatchError, UserError, UserErrorType
73from .hooks import Hooks
74from .mutation_guard import install_generation_tagging_init
75from .utils import common_constant_types, compile_times
76
77
78if TYPE_CHECKING:
79    from torch._subclasses import fake_tensor
80
81    from .types import CacheEntry, DynamoCallback
82
83
84log = logging.getLogger(__name__)
85
86
87always_optimize_code_objects = utils.ExactWeakKeyDictionary()
88null_context = contextlib.nullcontext
89
90
91# See https://github.com/python/typing/pull/240
92class Unset(Enum):
93    token = 0
94
95
96cached_backends: Dict[int, CompilerFn] = {}
97
98unset = Unset.token
99
100
101def _maybe_set_eval_frame(callback: DynamoCallback):
102    # A wrapper on set_eval_frame that is guarded by a Justknob.
103    # Users can disable torchDynamo by setting the JK to False.
104    from torch._C._dynamo.eval_frame import set_eval_frame
105
106    if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"):
107        torch._dynamo.utils.warn_once(
108            "Dynamo disabled by Justknob: enable_compiler_set_eval_frame, skipping set_eval_frame"
109        )
110        return callback
111    else:
112        return set_eval_frame(callback)
113
114
115def _reset_guarded_backend_cache():
116    global cached_backends
117    for backend in cached_backends.values():
118        if hasattr(backend, "reset"):
119            backend.reset()
120    cached_backends.clear()
121
122
123DONT_WRAP_FILES = {
124    # For tracing into fx modules
125    inspect.getsourcefile(GraphModule),
126    join(dirname(dirname(__file__)), "onnx/_internal/fx/dynamo_graph_extractor.py"),
127}
128
129
130def _debug_get_cache_entry_list(
131    code: Union[types.CodeType, Callable[..., Any]]
132) -> List[CacheEntry]:
133    """
134    Given a code object or a callable object, retrieve the cache entries
135     stored in this code.
136    """
137    if callable(code):
138        code = code.__code__
139    return torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code)
140
141
142class OptimizedModule(torch.nn.Module):
143    """
144    Wraps the original nn.Module object and later patches its
145    forward method to optimized self.forward method.
146    """
147
148    _torchdynamo_orig_callable: Callable[..., Any]
149    get_compiler_config: Callable[[], Any]
150
151    _opt_mod_attributes = {
152        "_orig_mod",
153        "dynamo_ctx",
154        "_torchdynamo_orig_callable",
155        "get_compiler_config",
156        "forward",
157        "_forward",
158        "__dict__",
159        "named_children_walk",
160    }
161
162    def __init__(self, mod: torch.nn.Module, dynamo_ctx) -> None:
163        super().__init__()
164        # Installs the params/buffer
165        self._orig_mod = mod
166        self.dynamo_ctx = dynamo_ctx
167        self._initialize()
168        self.training = self._orig_mod.training
169
170    def _initialize(self):
171        # Do this stuff in constructor to lower overhead slightly
172        if isinstance(self.dynamo_ctx, DisableContext):
173            # No need to check trace rules
174            self.forward = self.dynamo_ctx(self._orig_mod.__call__)
175        elif isinstance(self._orig_mod.forward, types.MethodType) and (
176            trace_rules.check(self._orig_mod.forward)
177            or getattr(self._orig_mod, "_is_fsdp_managed_module", False)
178        ):
179            # This may be a torch.nn.* instance in trace_rules.py which
180            # won't trigger a frame evaluation workaround to add an extra
181            # frame we can capture
182            self.forward = self.dynamo_ctx(external_utils.wrap_inline(self._orig_mod))
183        else:
184            # Invoke hooks outside of dynamo then pickup the inner frame
185            self.forward = self.dynamo_ctx(self._orig_mod.__call__)
186
187        if hasattr(self._orig_mod, "_initialize_hook"):
188            self._forward = self.forward
189            self.forward = self._call_lazy_check
190
191    def __reduce__(self):
192        return (self.__class__, (self._orig_mod, self.dynamo_ctx))
193
194    def __getstate__(self):
195        state = dict(self.__dict__)
196        state.pop("forward", None)
197        state.pop("__call__", None)
198        return state
199
200    def __setstate__(self, state):
201        self.__dict__ = state
202        self._initialize()
203
204    @property
205    def training(self):
206        return self._orig_mod.training
207
208    @training.setter
209    def training(self, value):
210        try:
211            super().__getattr__("_orig_mod")
212            self._orig_mod.training = value
213        except AttributeError:
214            # still initializing
215            pass
216
217    def __getattr__(self, name):
218        if name == "_orig_mod":
219            return self._modules["_orig_mod"]
220        return getattr(self._orig_mod, name)
221
222    def __setattr__(self, name, val) -> None:
223        # Allow patching over class attributes
224        if hasattr(type(self), name):
225            return super().__setattr__(name, val)
226
227        if name in OptimizedModule._opt_mod_attributes:
228            return super().__setattr__(name, val)
229        return setattr(self._orig_mod, name, val)
230
231    def _call_lazy_check(self, *args, **kwargs):
232        if hasattr(self._orig_mod, "_initialize_hook"):
233            # In the case of a lazy module, we want to run
234            # the pre-hooks which initialize it.
235            # Afterwards, lazy module deletes its pre-hooks
236            # to avoid treating it as lazy on subsequent recompile.
237            self._orig_mod._infer_parameters(self._orig_mod, args, kwargs)
238        return self._forward(*args, **kwargs)
239
240    def __dir__(self):
241        orig_mod_attrs = self._orig_mod.__dir__()
242        return orig_mod_attrs + [
243            attr for attr in super().__dir__() if attr not in orig_mod_attrs
244        ]
245
246
247def remove_from_cache(f):
248    """
249    Make sure f.__code__ is not cached to force a recompile
250    """
251    if isinstance(f, types.CodeType):
252        reset_code(f)
253    elif hasattr(f, "__code__"):
254        reset_code(f.__code__)
255    elif hasattr(getattr(f, "forward", None), "__code__"):
256        reset_code(f.forward.__code__)
257    else:
258        from . import reset  # type: ignore[attr-defined]
259
260        reset()
261        log.warning("could not determine __code__ for %s", f)
262
263
264def nothing():
265    pass
266
267
268def always_false():
269    return False
270
271
272def innermost_fn(fn):
273    """
274    In case of nesting of _TorchDynamoContext calls, find the innermost
275    function. TorchDynamo caches on fn.__code__ object, so its necessary to find
276    the innermost function to pass on the optimize, run, disable etc.
277    """
278    unaltered_fn = fn
279    while hasattr(unaltered_fn, "_torchdynamo_orig_callable"):
280        unaltered_fn = unaltered_fn._torchdynamo_orig_callable
281        assert callable(unaltered_fn)
282    return unaltered_fn
283
284
285def make_set_enable_dynamic(enable: bool):
286    assert isinstance(enable, bool)
287    if enable:
288        # Assume everything is dynamic by default
289        return config._make_closure_patcher(assume_static_by_default=False)
290    else:
291        return config._make_closure_patcher(
292            automatic_dynamic_shapes=False, assume_static_by_default=True
293        )
294
295
296class _TorchDynamoContext:
297    def __init__(
298        self,
299        callback: DynamoCallback,
300        on_enter=nothing,
301        backend_ctx_ctor=null_context,
302        patch_fn=nothing,
303        first_ctx=False,
304        *,
305        export=False,
306        dynamic=None,
307        compiler_config=None,
308    ) -> None:
309        super().__init__()
310        assert callable(callback) or callback is False or callback is None
311        self.callback: DynamoCallback = callback
312        self._backend_ctx_ctor = backend_ctx_ctor
313        self.prior: Union[Unset, DynamoCallback] = unset
314        self.first_ctx = first_ctx
315        self.export = export
316        self._dynamic = dynamic
317        self.compiler_config = compiler_config
318        self.cleanup_fns: List[Callable[[], Any]] = []
319        self.enter_exit_hooks = []
320        patch_fn()
321
322        # Save the backends so that we can reset them during torch._dynamo.reset
323        backend = innermost_fn(callback)
324        cached_backends.setdefault(id(backend), backend)
325
326        if dynamic is not None:
327            self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic))
328
329        if on_enter is not nothing:
330            # this case is not common
331            def call_on_enter():
332                on_enter()
333                return nothing
334
335            self.enter_exit_hooks.append(call_on_enter)
336
337        if backend_ctx_ctor is not contextlib.nullcontext:
338            # this case is not common
339            def call_backend_ctx():
340                ctx = backend_ctx_ctor()
341                ctx.__enter__()
342                return functools.partial(ctx.__exit__, None, None, None)
343
344            self.enter_exit_hooks.append(call_backend_ctx)
345
346    def __enter__(self):
347        if config.raise_on_ctx_manager_usage:
348            raise RuntimeError(
349                "torch._dynamo.optimize(...) is used with a context manager. "
350                "Please refer to https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html "
351                "to use torch._dynamo.optimize(...) as an annotation/decorator. "
352            )
353        self.cleanup_fns = [enter() for enter in self.enter_exit_hooks]
354        self.prior = _maybe_set_eval_frame(self.callback)
355
356    def __exit__(self, exc_type, exc_val, exc_tb):
357        assert self.prior is not unset
358        _maybe_set_eval_frame(self.prior)
359        self.prior = unset
360        for cleanup in self.cleanup_fns:
361            cleanup()
362        self.cleanup_fns.clear()
363
364    def __call__(self, fn):
365        # public api for compiler config/options
366        def get_compiler_config():
367            return self.compiler_config
368
369        fn = innermost_fn(fn)
370
371        # add context containing GraphModule to any GraphModule forward functions
372        if isinstance(fn, GraphModule):
373            # add context containing GraphModule to any GraphModule forward functions
374            code_context.get_context(fn.forward.__code__)[
375                "orig_graphmodule"
376            ] = weakref.ref(fn)
377
378        # Optimize the forward method of torch.nn.Module object
379        if isinstance(fn, torch.nn.Module):
380            mod = fn
381            new_mod = OptimizedModule(mod, self)
382            # Save the function pointer to find the original callable while nesting
383            # of decorators.
384            new_mod._torchdynamo_orig_callable = mod.forward
385
386            # when compiling torch.nn.Module,
387            # provide public api OptimizedModule.get_compiler_config()
388            assert not hasattr(new_mod, "get_compiler_config")
389            new_mod.get_compiler_config = get_compiler_config
390
391            return new_mod
392
393        if inspect.isclass(fn):
394            # User has wrapped the class with compile/disable decorator. Apply
395            # disable to init/call method.
396            cls_obj = fn
397            cls_obj.__call__ = self(cls_obj.__call__)
398            if issubclass(cls_obj, torch.nn.Module):
399                # NN module variable tracker directly inlines the _call_impl.
400                cls_obj._call_impl = self(cls_obj._call_impl)
401            return cls_obj
402
403        assert callable(fn)
404
405        try:
406            filename = inspect.getsourcefile(fn)
407        except TypeError:
408            filename = None
409        if (
410            (filename is None or trace_rules.check(fn))
411            and (
412                getattr(fn, "__name__", "")
413                not in ["_call_impl", "_wrapped_call_impl", "_lazy_forward"]
414            )
415            and filename not in DONT_WRAP_FILES
416        ):
417            # call to a builtin without a frame for us to capture
418            fn = external_utils.wrap_inline(fn)
419
420        def do_nothing(*arg, **kwargs):
421            pass
422
423        if hasattr(self, "callback"):
424            callback = self.callback
425        else:
426            callback = do_nothing
427
428        is_jit_tracing = torch._C._is_tracing
429        is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing
430
431        @functools.wraps(fn)
432        def _fn(*args, **kwargs):
433            if is_fx_tracing():
434                if config.error_on_nested_fx_trace:
435                    raise RuntimeError(
436                        "Detected that you are using FX to symbolically trace "
437                        "a dynamo-optimized function. This is not supported at the moment."
438                    )
439                else:
440                    return fn(*args, **kwargs)
441
442            if is_jit_tracing():
443                if config.error_on_nested_jit_trace:
444                    raise RuntimeError(
445                        "Detected that you are using FX to torch.jit.trace "
446                        "a dynamo-optimized function. This is not supported at the moment."
447                    )
448                else:
449                    return fn(*args, **kwargs)
450
451            cleanups = [enter() for enter in self.enter_exit_hooks]
452            prior = _maybe_set_eval_frame(callback)
453
454            # Ensure that if an assertion occurs after graph pushes
455            # something onto the DynamicLayerStack then we pop it off (the
456            # constructed graph code isn't guarded with try/finally).
457            #
458            # This used to be a context but putting a `with` here is a noticible
459            # perf regression (#126293)
460            saved_dynamic_layer_stack_depth = (
461                torch._C._functorch.get_dynamic_layer_stack_depth()
462            )
463
464            try:
465                return fn(*args, **kwargs)
466            finally:
467                # Restore the dynamic layer stack depth if necessary.
468                torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
469                    saved_dynamic_layer_stack_depth
470                )
471
472                _maybe_set_eval_frame(prior)
473                for cleanup in cleanups:
474                    cleanup()
475
476        # hooks to properly handle inlining
477        _fn._torchdynamo_inline = fn  # type: ignore[attr-defined]
478
479        # Save the function pointer to find the original callable while nesting
480        # of decorators.
481        _fn._torchdynamo_orig_callable = fn  # type: ignore[attr-defined]
482
483        # when compiling user function instead of nn.Module
484        # provide public api _fn.get_compiler_config()
485        assert not hasattr(_fn, "get_compiler_config")
486        _fn.get_compiler_config = get_compiler_config  # type: ignore[attr-defined]
487
488        # If the function is called using torch._dynamo.optimize decorator, we
489        # should prevent any type of skipping.
490        if callback not in (None, False):
491            if not hasattr(fn, "__code__"):
492                raise RuntimeError(
493                    textwrap.dedent(
494                        """
495
496                        torch._dynamo.optimize is called on a non function object.
497                        If this is a callable class, please wrap the relevant code into a function and optimize the
498                        wrapper function.
499
500                        >> class CallableClass:
501                        >>     def __init__(self) -> None:
502                        >>         super().__init__()
503                        >>         self.relu = torch.nn.ReLU()
504                        >>
505                        >>     def __call__(self, x):
506                        >>         return self.relu(torch.sin(x))
507                        >>
508                        >>     def print_hello(self):
509                        >>         print("Hello world")
510                        >>
511                        >> mod = CallableClass()
512
513                        If you want to optimize the __call__ function and other code, wrap that up in a function
514
515                        >> def wrapper_fn(x):
516                        >>     y = mod(x)
517                        >>     return y.sum()
518
519                        and then optimize the wrapper_fn
520
521                        >> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn)
522                        """
523                    )
524                )
525            always_optimize_code_objects[fn.__code__] = True
526
527        return _fn
528
529
530class OptimizeContext(_TorchDynamoContext):
531    def __init__(
532        self,
533        callback,
534        backend_ctx_ctor,
535        first_ctx=False,
536        *,
537        export=False,
538        dynamic=None,
539        compiler_config=None,
540        rebuild_ctx: Optional[
541            Callable[[], Union[OptimizeContext, _NullDecorator]]
542        ] = None,
543    ) -> None:
544        def on_enter():
545            install_generation_tagging_init()
546
547        super().__init__(
548            callback=callback,
549            on_enter=on_enter,
550            backend_ctx_ctor=backend_ctx_ctor,
551            patch_fn=TorchPatcher.patch,
552            first_ctx=first_ctx,
553            export=export,
554            dynamic=dynamic,
555            compiler_config=compiler_config,
556        )
557
558        if config.compiled_autograd:
559
560            def call_compiled_autograd():
561                assert rebuild_ctx is not None
562                compiler_fn = rebuild_ctx()
563                ctx = torch._dynamo.compiled_autograd.enable(compiler_fn)
564                ctx.__enter__()
565                return functools.partial(ctx.__exit__, None, None, None)
566
567            self.enter_exit_hooks.append(call_compiled_autograd)
568
569    def __reduce__(self):
570        return (
571            self.__class__,
572            (self.callback, self._backend_ctx_ctor, self.first_ctx),
573            {
574                "export": self.export,
575                "dynamic": self._dynamic,
576                "compiler_config": self.compiler_config,
577            },
578        )
579
580
581class RunOnlyContext(_TorchDynamoContext):
582    def __init__(self) -> None:
583        # cudagraph trees relies on generation increment
584        def on_enter():
585            torch._dynamo.mutation_guard.GenerationTracker.generation += 1
586
587        super().__init__(callback=False, on_enter=on_enter)
588
589    def __reduce__(self):
590        return (self.__class__, ())
591
592
593class DisableContext(_TorchDynamoContext):
594    def __init__(self) -> None:
595        super().__init__(callback=None)
596
597    def __call__(self, fn):
598        # Earlier this code was in the base class _TorchDynamoContext. But we
599        # moved it here to have better code organization. For disable, we just
600        # want the callback to be None. We don't have to check trace_rules or
601        # create any wrapper.
602        fn = innermost_fn(fn)
603
604        if isinstance(fn, torch.nn.Module):
605            mod = fn
606            new_mod = OptimizedModule(mod, self)
607            new_mod._torchdynamo_orig_callable = mod.forward
608            return new_mod
609
610        if inspect.isclass(fn):
611            # User has wrapped the class with compile/disable decorator. Apply
612            # disable to init/call method.
613            cls_obj = fn
614            # Disable on init is useful for reconstruction of bytecodes where we
615            # want to prevent Dynamo from tracing into the init function. Check
616            # test_reconstruction in test_model_output.py.
617            cls_obj.__init__ = self(cls_obj.__init__)
618            cls_obj.__call__ = self(cls_obj.__call__)
619            if issubclass(cls_obj, torch.nn.Module):
620                # NN module variable tracker directly inlines the _call_impl. Disable it.
621                cls_obj._call_impl = self(cls_obj._call_impl)
622            return cls_obj
623
624        assert callable(fn)
625
626        callback = self.callback
627
628        @functools.wraps(fn)
629        def _fn(*args, **kwargs):
630            prior = _maybe_set_eval_frame(callback)
631            try:
632                return fn(*args, **kwargs)
633            finally:
634                _maybe_set_eval_frame(prior)
635
636        _fn._torchdynamo_disable = True  # type: ignore[attr-defined]
637
638        # Save the function pointer to find the original callable while nesting
639        # of decorators.
640        _fn._torchdynamo_orig_callable = fn  # type: ignore[attr-defined]
641
642        return _fn
643
644    def __reduce__(self):
645        return (self.__class__, ())
646
647
648def _optimize_catch_errors(
649    compile_fn,
650    hooks: Hooks,
651    backend_ctx_ctor=null_context,
652    export=False,
653    dynamic=None,
654    compiler_config=None,
655    rebuild_ctx=None,
656):
657    return OptimizeContext(
658        convert_frame.catch_errors_wrapper(compile_fn, hooks),
659        backend_ctx_ctor=backend_ctx_ctor,
660        first_ctx=True,
661        export=export,
662        dynamic=dynamic,
663        compiler_config=compiler_config,
664        rebuild_ctx=rebuild_ctx,
665    )
666
667
668def get_compiler_fn(compiler_fn):
669    from .repro.after_dynamo import wrap_backend_debug
670
671    if hasattr(compiler_fn, "compiler_name"):
672        compiler_str = compiler_fn.compiler_name
673    elif isinstance(compiler_fn, str):
674        compiler_str = compiler_fn
675    else:
676        compiler_str = None
677    compiler_fn = lookup_backend(compiler_fn)
678    return wrap_backend_debug(compiler_fn, compiler_str)
679
680
681class _NullDecorator(contextlib.nullcontext):  # type: ignore[type-arg]
682    def __call__(self, fn):
683        assert callable(fn)
684        return fn
685
686
687def check_if_dynamo_supported():
688    if sys.version_info >= (3, 13):
689        raise RuntimeError("Python 3.13+ not yet supported for torch.compile")
690
691
692def is_dynamo_supported():
693    try:
694        check_if_dynamo_supported()
695        return True
696    except Exception:
697        return False
698
699
700def check_if_inductor_supported():
701    check_if_dynamo_supported()
702
703
704def is_inductor_supported():
705    try:
706        check_if_inductor_supported()
707        return True
708    except Exception:
709        return False
710
711
712def optimize(*args, **kwargs):
713    def rebuild_ctx():
714        return optimize(*args, **kwargs)
715
716    return _optimize(rebuild_ctx, *args, **kwargs)
717
718
719def _optimize(
720    rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]],
721    backend="inductor",
722    *,
723    nopython=False,
724    guard_export_fn=None,
725    guard_fail_fn=None,
726    disable=False,
727    dynamic=None,
728) -> Union[OptimizeContext, _NullDecorator]:
729    """
730    The main entrypoint of TorchDynamo.  Do graph capture and call
731    backend() to optimize extracted graphs.
732
733    Args:
734        backend: One of the two things:
735            - Either, a function/callable taking a torch.fx.GraphModule and
736            example_inputs and returning a python callable that runs the
737            graph faster.
738            One can also provide additional context for the backend, like
739            torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute.
740            See AOTAutogradMemoryEfficientFusionWithContext for the usage.
741            - Or, a string backend name in `torch._dynamo.list_backends()`
742        nopython: If True, graph breaks will be errors and there will
743            be a single whole-program graph.
744        disable: If True, turn this decorator into a no-op
745        dynamic: If True, upfront compile as dynamic a kernel as possible.  If False,
746            disable all dynamic shapes support (always specialize).  If None, automatically
747            detect when sizes vary and generate dynamic kernels upon recompile.
748
749    Example Usage::
750
751        @torch._dynamo.optimize()
752        def toy_example(a, b):
753            ...
754    """
755    check_if_dynamo_supported()
756    # Note: The hooks object could be global instead of passed around, *however* that would make
757    # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.
758    # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same
759    # compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an
760    # easier to understand UX at the cost of a little more plumbing on our end.
761    hooks = Hooks(guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn)
762    torch._C._log_api_usage_once("torch._dynamo.optimize")
763    if (
764        disable
765        or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1"
766        or (not justknobs_check("pytorch/compiler:enable_dynamo"))
767    ):
768        return _NullDecorator()
769
770    backend = get_compiler_fn(backend)
771
772    # Find if backend has any extra context manager
773    backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
774
775    if nopython:
776        return optimize_assert(
777            backend,
778            dynamic=dynamic,
779            hooks=hooks,
780            rebuild_ctx=rebuild_ctx,
781        )
782    # The backend function is stashed in the callable returned by
783    # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can
784    # be used by eval_frame.c to insert a guard on the backend.
785    return _optimize_catch_errors(
786        convert_frame.convert_frame(backend, hooks=hooks),
787        hooks,
788        backend_ctx_ctor,
789        dynamic=dynamic,
790        compiler_config=backend.get_compiler_config()
791        if hasattr(backend, "get_compiler_config")
792        else None,
793        rebuild_ctx=rebuild_ctx,
794    )
795
796
797# TODO(voz): Consider making "explain" output alongside a run / part of a run
798@patch("torch._dynamo.symbolic_convert.explain", True)
799def explain(f, *extra_args, **extra_kwargs):
800    def inner(*args, **kwargs):
801        # TODO(voz): Do we want a decorator for this?
802        from . import reset  # type: ignore[attr-defined]
803
804        reset()
805
806        graphs: List[torch.fx.GraphModule] = []
807        break_reasons: List[Any] = []
808        op_count: int = 0
809        ops_per_graph: List[torch.fx.Node] = []
810        out_guards: List[_guards.Guard] = []
811
812        def dynamo_graph_accumulating_compiler(
813            gm: torch.fx.GraphModule, example_inputs
814        ):
815            from .backends.debugging import _explain_graph_detail
816
817            nonlocal graphs
818            nonlocal op_count
819            nonlocal ops_per_graph
820            nonlocal break_reasons
821
822            gm, graphs, op_count, ops_per_graph, break_reasons = _explain_graph_detail(
823                gm, graphs, op_count, ops_per_graph, break_reasons
824            )
825
826            return gm.forward
827
828        def guard_export_print(guards):
829            nonlocal out_guards
830            out_guards.extend(guards)
831
832        opt_f = optimize(
833            dynamo_graph_accumulating_compiler,
834            nopython=False,
835            guard_export_fn=guard_export_print,
836        )(f)
837        # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
838        opt_f(*args, **kwargs)
839
840        graph_count = len(graphs)
841        graph_break_count = graph_count - 1
842        compile_time = compile_times(repr="str")
843
844        # TODO(voz): Do we want a decorator for this?
845        reset()
846        from .backends.debugging import ExplainOutput
847
848        return ExplainOutput(
849            graphs,
850            graph_count,
851            graph_break_count,
852            break_reasons,
853            op_count,
854            ops_per_graph,
855            out_guards,
856            compile_time,
857        )
858
859    if extra_args or extra_kwargs:
860        warnings.warn(
861            "explain(f, *args, **kwargs) is deprecated, use explain(f)(*args, **kwargs) instead.  "
862            "If you don't migrate, we may break your explain call in the future if your user defined kwargs "
863            "conflict with future kwargs added to explain(f).",
864            FutureWarning,
865            stacklevel=2,
866        )
867        return inner(*extra_args, **extra_kwargs)
868    else:
869        return inner
870
871
872class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
873    def __init__(
874        self,
875        m: torch.fx.GraphModule,
876        flat_args: Tuple[Any],
877        matched_input_elements_positions: List[int],
878        flat_results: List[Any],
879        matched_output_elements_positions: List[int],
880        example_fake_inputs: List[torch.Tensor],
881        flat_args_dynamic_dims: List[Set[int]],
882        fake_mode: Optional[fake_tensor.FakeTensorMode] = None,
883    ) -> None:
884        super().__init__(m)
885
886        assert len(flat_args_dynamic_dims) == len(flat_args)
887        matched_input_elements_to_fake = {
888            val: example_fake_inputs[ix]
889            for ix, val in enumerate(matched_input_elements_positions)
890        }
891
892        self.new_args = []
893        for i in range(0, len(flat_args)):
894            arg = super().placeholder(f"arg{i}", (), {})
895            if i in matched_input_elements_to_fake:
896                arg.node.meta["val"] = matched_input_elements_to_fake[i]
897            else:
898                # Fill node.mata["val"] with faketensor from the input,
899                # if it's not found in matched_input_elements_positions
900                if fake_mode is not None and isinstance(flat_args[i], torch.Tensor):
901                    # TODO(zhxchen17) Also preserve all the user constraints here.
902                    arg.node.meta["val"] = fake_mode.from_tensor(
903                        flat_args[i],
904                        symbolic_context=StatelessSymbolicContext(
905                            dynamic_sizes=[
906                                DimDynamic.DYNAMIC
907                                if d in flat_args_dynamic_dims[i]
908                                else DimDynamic.STATIC
909                                for d in range(len(flat_args[i].shape))
910                            ],
911                            constraint_sizes=[None] * len(flat_args[i].shape),
912                        ),
913                    )
914            self.new_args.append(arg)
915        self.old_args_gen = (self.new_args[i] for i in matched_input_elements_positions)
916        self.matched_output_elements_positions = matched_output_elements_positions
917        self.flat_results = flat_results
918
919    def placeholder(self, target, args, kwargs):
920        arg = next(self.old_args_gen)
921        if "val" in self.current_node.meta:
922            arg.node.meta["val"] = self.current_node.meta["val"]
923        if "tensor_dict" in self.current_node.meta:
924            arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"]
925        if "example_value" in self.current_node.meta:
926            # NB: intentionally do not use set_example_value
927            arg.node.meta["example_value"] = self.current_node.meta["example_value"]
928        if "unbacked_bindings" in self.current_node.meta:
929            arg.node.meta["unbacked_bindings"] = self.current_node.meta[
930                "unbacked_bindings"
931            ]
932        return arg
933
934    def output(self, target, args, kwargs):
935        dynamo_result_flat = args[0]
936        lookup = [*dynamo_result_flat, *self.new_args]
937        new_results_flat = []
938        for i in range(len(self.flat_results)):
939            if self.matched_output_elements_positions[i] is not None:
940                new_results_flat.append(
941                    lookup[self.matched_output_elements_positions[i]]
942                )
943            else:
944                const_val = self.flat_results[i]
945                assert isinstance(const_val, tuple(common_constant_types))
946                new_results_flat.append(const_val)
947        return super().output(target, (new_results_flat,), {})
948
949    def run_node(self, n):
950        self.current_node = n
951        result_proxy = super().run_node(n)
952        if "val" in self.current_node.meta:
953            result_proxy.node.meta["val"] = self.current_node.meta["val"]
954        if "example_value" in self.current_node.meta:
955            # NB: intentionally do not use set_example_value
956            result_proxy.node.meta["example_value"] = self.current_node.meta[
957                "example_value"
958            ]
959        if "unbacked_bindings" in self.current_node.meta:
960            result_proxy.node.meta["unbacked_bindings"] = self.current_node.meta[
961                "unbacked_bindings"
962            ]
963        if self.current_node.op != "output":
964            result_proxy.node._rename(
965                getattr(self.current_node, "name", result_proxy.node.name)
966            )
967        return result_proxy
968
969    def transform(self):
970        result_gm = super().transform()
971        if "dynamo_flat_name_to_original_fqn" in self.module.meta:
972            result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[
973                "dynamo_flat_name_to_original_fqn"
974            ]
975        return result_gm
976
977
978class ExportResult(NamedTuple):
979    graph_module: torch.fx.GraphModule
980    guards: _guards.GuardsSet
981    # NB: Do not add new fields without overriding __iter__; people are
982    # destructuring so it is BC-breaking
983
984
985def check_signature_rewritable(graph):
986    input_errors = []
987    for node in graph.graph.find_nodes(op="placeholder"):
988        assert hasattr(node, "_dynamo_source")
989        source = node._dynamo_source
990        user_stacks = graph._source_to_user_stacks.get(source)
991        if user_stacks is None:
992            continue
993        assert len(user_stacks) > 0
994        # In some cases we may not have a useful stack.  Look for a
995        # useful stack
996        stack = None
997        for s in user_stacks:
998            if len(s) == 0:
999                continue
1000            stack = s
1001            break
1002        if stack is None:
1003            msg = f"{source.name()}, a closed over free variable"
1004        else:
1005            tb = "".join(traceback.format_list(stack))
1006            extra = ""
1007            if len(user_stacks) > 1:
1008                extra = f"(elided {len(user_stacks) - 1} more accesses)"
1009            msg = f"{source.name()}, accessed at:\n{tb}{extra}"
1010        # TODO: option to print ALL of the stack traces at once
1011        input_errors.append(msg)
1012
1013    if input_errors:
1014        raise UserError(
1015            UserErrorType.INVALID_INPUT,
1016            "Cannot export model which references tensors that are neither "
1017            "buffers/parameters/constants nor are direct inputs.  For each tensor, if you'd "
1018            "like this tensor to be an explicit input, add it as a dummy argument "
1019            "to the top-level model definition you are exporting; if you would "
1020            "like its value to be embedded as an exported constant, wrap its access "
1021            "in a function marked with @assume_constant_result.\n\n"
1022            + "\n\n".join(input_errors),
1023        )
1024
1025
1026def rewrite_signature(
1027    f_sig,
1028    graph,
1029    fake_mode,
1030    flat_args,
1031    in_spec,
1032    example_fake_inputs,
1033    graph_captured_input,
1034    graph_captured_output,
1035    dynamo_traced_result,
1036    flat_args_dynamic_dims,
1037):
1038    orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec)
1039
1040    def check_user_input_output(flat_values, error_type):
1041        supported_types = [
1042            torch.Tensor,
1043            torch.SymInt,
1044            torch.SymFloat,
1045            torch.SymBool,
1046            torch._C.ScriptObject,
1047        ] + list(common_constant_types)
1048
1049        def is_supported_type(val):
1050            return isinstance(val, tuple(supported_types))
1051
1052        value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output"
1053        # We only check that the outputs are not None. Inputs can be None.
1054        for v in flat_values:
1055            if not is_supported_type(v):
1056                if error_type == UserErrorType.INVALID_INPUT and v is None:
1057                    continue
1058
1059                raise UserError(
1060                    error_type,
1061                    f"It looks like one of the {value_type}s with type `{type(v)}` "
1062                    "is not supported or pytree-flattenable. \n"
1063                    f"Exported graphs {value_type}s can only contain the "
1064                    f"following supported types: {supported_types}. \n"
1065                    "If you are using a custom class object, "
1066                    "please register a pytree_flatten/unflatten function "
1067                    "using `torch.utils._pytree.register_pytree_node` or "
1068                    "`torch.export.register_dataclass`.",
1069                )
1070
1071    check_user_input_output(flat_args, UserErrorType.INVALID_INPUT)
1072    flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result)
1073    check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT)
1074
1075    def check_optional_input_and_error(f_sig: inspect.Signature):
1076        # Check if function has optional input.
1077        for name, param in f_sig.parameters.items():
1078            if param.default is not inspect.Parameter.empty:
1079                from torch._dynamo.exc import Unsupported
1080
1081                log.error(
1082                    "Parameter %s is optional with a default value of %s",
1083                    name,
1084                    param.default,
1085                )
1086                raise Unsupported(
1087                    "Tracing through optional input is not supported yet",
1088                    case_name="optional_input",
1089                )
1090
1091    def produce_matching(debug_type, sources, candidates):
1092        matched_elements_positions: List[Optional[int]] = []
1093        dict_of_source_vals = {}
1094        for i, val in enumerate(sources):
1095            dict_of_source_vals[id(val)] = i
1096
1097        for i, val in enumerate(candidates):
1098            if isinstance(val, tuple(common_constant_types)):
1099                matched_elements_positions.append(None)
1100            elif id(val) not in dict_of_source_vals:
1101                if debug_type == "inputs":
1102                    check_optional_input_and_error(f_sig)
1103                raise AssertionError(
1104                    f"Unexpectedly found a {type(val)} in the {debug_type}.\n"
1105                    'Please file an issue along with a paste of the logs from TORCH_LOGS="+export"',
1106                )
1107            else:
1108                matched_elements_positions.append(dict_of_source_vals[id(val)])
1109
1110        return matched_elements_positions
1111
1112    matched_input_elements_positions = produce_matching(
1113        "inputs", flat_args, graph_captured_input
1114    )
1115
1116    assert graph_captured_output is not None
1117    matched_output_elements_positions = produce_matching(
1118        "outputs", list(graph_captured_output) + flat_args, flat_results_traced
1119    )
1120
1121    new_graph = FlattenInputOutputSignature(
1122        graph,
1123        flat_args,
1124        matched_input_elements_positions,
1125        flat_results_traced,
1126        matched_output_elements_positions,
1127        example_fake_inputs,
1128        flat_args_dynamic_dims,
1129        fake_mode,
1130    ).transform()
1131
1132    # Make dynamo graph to have same input/output spec as user code
1133    def argument_names(f_sig, args, kwargs) -> List[str]:
1134        def signature_to_fullargspec(sig: inspect.Signature):
1135            # Get a list of Parameter objects from the Signature object
1136            params = list(sig.parameters.values())
1137            # Separate positional arguments, keyword-only arguments and varargs/varkw
1138            args = [
1139                p.name
1140                for p in params
1141                if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
1142            ]
1143            kwonlyargs = [
1144                p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
1145            ]
1146            varargs = next(
1147                (p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL),
1148                None,
1149            )
1150            varkw = next(
1151                (p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD),
1152                None,
1153            )
1154            # Get default values for positional arguments and keyword-only arguments
1155            defaults = tuple(
1156                p.default
1157                for p in params
1158                if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
1159                and p.default is not inspect.Parameter.empty
1160            )
1161            kwonlydefaults = {
1162                p.name: p.default
1163                for p in params
1164                if p.kind == inspect.Parameter.KEYWORD_ONLY
1165                and p.default is not inspect.Parameter.empty
1166            }
1167            # Get annotations for parameters and return value
1168            annotations = {}
1169            if sig.return_annotation:
1170                annotations = {"return": sig.return_annotation}
1171            for parameter in params:
1172                annotations[parameter.name] = parameter.annotation
1173            # Return a FullArgSpec object with the extracted attributes
1174            return inspect.FullArgSpec(
1175                args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations
1176            )
1177
1178        fullargspec = signature_to_fullargspec(f_sig)
1179
1180        # 1. Map `args` 1-to-1 to positional arguments in original signature.
1181        input_strs = fullargspec.args[: len(args)]
1182
1183        if len(args) > len(fullargspec.args):
1184            # 2. If there are more arguments left in `args`, they map to varargs in original
1185            # signature. Assign names as {varargs}_0, {varargs}_1, ...
1186            assert fullargspec.varargs is not None, "More arguments than expected"
1187            input_strs += [
1188                f"{fullargspec.varargs}_{i}"
1189                for i in range(0, len(args) - len(input_strs))
1190            ]
1191        elif len(args) < len(fullargspec.args):
1192            # 3. If there are fewer arguments in `args` than `fullargspec.args`,
1193            # it implies these are arguments either with default values, or provided in
1194            # `kwargs`. The former can be safely ignored. Because Dynamo.export does not
1195            # export them as part of the function signature. The latter will be handled
1196            # in the next step.
1197            for unprovided_arg in fullargspec.args[
1198                len(args) : -len(fullargspec.defaults or [])
1199            ]:
1200                assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}"
1201
1202        # 4. Keyword arguments provided in `kwargs`.
1203        input_strs += list(kwargs.keys())
1204
1205        # 5. Keyword-only arguments with default values if not provided are not exported
1206        # as part of the function signature.
1207        for kwonly_arg in fullargspec.kwonlyargs:
1208            kwonlydefaults = fullargspec.kwonlydefaults or {}
1209            assert (
1210                kwonly_arg in kwargs or kwonly_arg in kwonlydefaults
1211            ), f"Missing keyword only argument {kwonly_arg}"
1212
1213        return input_strs
1214
1215    new_graph.graph._codegen = _PyTreeCodeGen(
1216        _PyTreeInfo(
1217            argument_names(f_sig, orig_args, orig_kwargs),
1218            in_spec,
1219            out_spec_traced,
1220        )
1221    )
1222    new_graph.recompile()
1223    return new_graph
1224
1225
1226def export(
1227    f: Callable[..., Any],
1228    *extra_args,
1229    aten_graph: bool = False,
1230    pre_dispatch: bool = False,
1231    decomposition_table: Optional[
1232        Dict[torch._ops.OpOverload, Callable[..., Any]]
1233    ] = None,
1234    tracing_mode: str = "symbolic",
1235    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
1236    assume_static_by_default: bool = False,
1237    same_signature: bool = True,
1238    disable_constraint_solver: bool = False,
1239    prefer_deferred_runtime_asserts_over_guards: bool = False,
1240    allow_complex_guards_as_runtime_asserts: bool = False,
1241    _log_export_usage: bool = True,
1242    **extra_kwargs,
1243) -> Callable[..., ExportResult]:
1244    """
1245    Export an input function f to a format that can be executed outside of PyTorch using the FX graph.
1246
1247    Args:
1248        f (callable): A PyTorch function to be exported.
1249
1250        aten_graph (bool): If True, exports a graph with ATen operators.
1251        If False, exports a graph with Python operators. Default is False.
1252
1253        pre_dispatch (bool): If True, exports a graph with ATen operators,
1254        but before any logic in the PyTorch dispatcher has run.
1255        This can be useful if you want to apply further transformations on a graph before running it
1256        through autograd, autocast, or any other functionalities that are integrated into the dispatcher.
1257        This flag is only valid if aten_graph=True is set.
1258        Default is False.
1259
1260        decomposition_table (dict): A dictionary that maps operators to their decomposition functions.
1261        Required if aten_graph or tracing_mode is specified. Default is None.
1262
1263        tracing_mode (str): If "symbolic", turn on dynamic shapes support. Default is "symbolic".
1264
1265        dynamic_shapes:
1266         An optional argument where the type should either be:
1267         1) a dict from argument names of ``f`` to their dynamic shape specifications,
1268         2) a tuple that specifies dynamic shape specifications for each input in original order.
1269         If you are specifying dynamism on keyword args, you will need to pass them in the order that
1270         is defined in the original function signature.
1271
1272         The dynamic shape of a tensor argument can be specified as either
1273         (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
1274         not required to include static dimension indices in this dict, but when they are,
1275         they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
1276         where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
1277         are denoted by None. Arguments that are dicts or tuples / lists of tensors are
1278         recursively specified by using mappings or sequences of contained specifications.
1279
1280        same_signature (bool): If True, rewrite the returned graph's signature to be the same as f.
1281
1282        disable_constraint_solver (bool): Whether the dim constraint solver must be disabled.
1283
1284    Returns:
1285        A function that given args and kwargs, returns a tuple of (graph, guards)
1286        Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options.
1287        Guards: The guards we accumulated during tracing f above
1288
1289    Raises:
1290        AssertionError: If decomposition_table is specified without setting aten_graph=True,
1291        or if graph breaks during tracing in export.
1292
1293        AssertionError: If Dynamo input and output is not consistent with traced input/output.
1294
1295    Note - this headerdoc was authored by ChatGPT, with slight modifications by the author.
1296    """
1297    if _log_export_usage:
1298        log_export_usage(event="export.private_api", flags={"_dynamo"})
1299
1300    # Deal with "local variable referenced before assignment"
1301    _f = f
1302    _assume_static_by_default = assume_static_by_default
1303
1304    def inner(*args, **kwargs):
1305        combined_args = _combine_args(_f, args, kwargs)
1306        constraints = _process_dynamic_shapes(combined_args, dynamic_shapes)
1307        f = _f
1308        assume_static_by_default = _assume_static_by_default
1309        check_if_dynamo_supported()
1310        torch._C._log_api_usage_once("torch._dynamo.export")
1311        if decomposition_table is not None:
1312            assert (
1313                aten_graph
1314            ), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
1315        if pre_dispatch:
1316            assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True"
1317        f = innermost_fn(f)
1318        call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
1319        original_signature = inspect.signature(call_to_inspect)
1320        graph = None
1321        out_guards = None
1322        graph_captured_input = None
1323        graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None
1324        fake_mode = None
1325        result_traced = None
1326
1327        def guard_export_print(guards: _guards.GuardsSet):
1328            nonlocal out_guards
1329            assert (
1330                out_guards is None
1331            ), "whole graph export entails exactly one guard export"
1332            out_guards = guards
1333
1334        example_inputs = []
1335
1336        def dynamo_normalization_capturing_compiler(
1337            gm: torch.fx.GraphModule, inner_example_inputs
1338        ):
1339            nonlocal graph
1340            assert (
1341                graph is None
1342            ), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
1343            graph = gm
1344
1345            nonlocal fake_mode, example_inputs
1346            # NB: do NOT pass inner_example_inputs here, we are detecting the
1347            # Dynamo allocated fake mode, which should be DISTINCT from a
1348            # potential outer ambient fake mode which the user provided.
1349            # example_inputs is always the user specified inputs, so they
1350            # would have the wrong fake mode attached to them
1351            fake_mode = _guards.detect_fake_mode()
1352            example_inputs = inner_example_inputs
1353
1354            def result_capturing_wrapper(*graph_inputs):
1355                nonlocal graph_captured_result
1356                nonlocal graph_captured_input
1357
1358                graph_captured_input = graph_inputs
1359                assert graph is not None
1360
1361                named_parameters = dict(graph.named_parameters(remove_duplicate=False))
1362                named_buffers = dict(graph.named_buffers(remove_duplicate=False))
1363
1364                ambient_fake_mode = (
1365                    _guards.detect_fake_mode(graph_inputs)
1366                    if _guards.detect_fake_mode(graph_inputs) is not None
1367                    else fake_mode
1368                )
1369
1370                # We reran fake tensor propagation, but we didn't do
1371                # anything with the resulting unbacked SymInts.  Drop them
1372                # from the pending list.
1373                # NB: this is wrong if graph_captured_result has
1374                # data-dependent output size!
1375                ignore_fresh_unbacked = null_context()
1376                if shape_env := ambient_fake_mode.shape_env:
1377                    ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols()
1378
1379                with (
1380                    ambient_fake_mode
1381                ), enable_python_dispatcher(), ignore_fresh_unbacked:
1382                    params_and_buffers = {
1383                        **named_parameters,
1384                        **named_buffers,
1385                    }
1386                    fake_params_buffers = {}
1387
1388                    for name, value in params_and_buffers.items():
1389                        fake_params_buffers[name] = ambient_fake_mode.from_tensor(
1390                            value, static_shapes=True
1391                        )
1392
1393                    fake_graph_inputs = pytree.tree_map(
1394                        ambient_fake_mode.from_tensor, graph_inputs
1395                    )
1396                    graph_captured_result = torch.func.functional_call(
1397                        graph, fake_params_buffers, fake_graph_inputs
1398                    )
1399
1400                return graph_captured_result
1401
1402            return result_capturing_wrapper
1403
1404        # Note: This is needed by rewrite_signature. We need to put it before
1405        # optimize_assert since user program may mutate the inputs.
1406        flat_args, in_spec = pytree.tree_flatten((args, kwargs))
1407
1408        remove_from_cache(f)
1409        constraint_violation_error = None
1410        if tracing_mode != "symbolic":
1411            assume_static_by_default = True
1412        with config.patch(
1413            specialize_int=True,
1414            assume_static_by_default=assume_static_by_default,
1415            automatic_dynamic_shapes=False,
1416            capture_dynamic_output_shape_ops=True,
1417            capture_scalar_outputs=True,
1418            prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
1419            allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
1420        ):
1421            opt_f = optimize_assert(
1422                dynamo_normalization_capturing_compiler,
1423                hooks=Hooks(
1424                    guard_export_fn=guard_export_print,
1425                    guard_fail_fn=None,
1426                ),
1427                export=True,
1428                export_constraints=constraints,
1429            )(f)
1430            # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
1431            try:
1432                result_traced = opt_f(*args, **kwargs)
1433            except ConstraintViolationError as e:
1434                constraint_violation_error = e
1435        remove_from_cache(f)
1436
1437        if (
1438            not disable_constraint_solver
1439            and (shape_env := getattr(fake_mode, "shape_env", None)) is not None
1440            and (dim_constraints := shape_env.dim_constraints) is not None
1441            and not isinstance(
1442                call_to_inspect, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
1443            )
1444            and not trace_rules.check(call_to_inspect)
1445        ):
1446            dim_constraints.solve()
1447            forced_specializations = dim_constraints.forced_specializations()
1448            msg = dim_constraints.prettify_results(
1449                original_signature,
1450                dynamic_shapes,
1451                constraint_violation_error,
1452                forced_specializations,
1453            )
1454            if constraint_violation_error:
1455                constraint_violation_error.args = (
1456                    constraint_violation_error.args[0] + msg,
1457                )
1458            else:
1459                if forced_specializations:
1460                    constraint_violation_error = ConstraintViolationError(msg)
1461                else:
1462                    log.info(
1463                        "Summary of dimension constraints:%s",
1464                        msg,
1465                    )
1466
1467            # Error if we have any constraints on static values
1468            for k in shape_env.var_to_range.keys():
1469                if isinstance(k, sympy.Integer):
1470                    constraint_violation_error = ConstraintViolationError(
1471                        f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
1472                        "It appears that you're trying to set a constraint on a "
1473                        f"value which we evaluated to have a static value of {k}. "
1474                        'Set TORCH_LOGS="+export" for more information.'
1475                    )
1476        if constraint_violation_error:
1477            raise constraint_violation_error
1478
1479        if graph is None:
1480            assert (
1481                same_signature
1482            ), "Failed to produce a graph during tracing as no tensor operations were found and same_signature is False."
1483            # If the module does not contain any tensor computation, we would create a graph with inputs and outputs.
1484            # To be consitant with the graph traced by dynano, `graph` will have only tensor inputs as placeholders
1485            # and tensor outputs as output nodes. non-tensor inputs and outputs will be added when rewriting signature.
1486            # We will also construct the `example_inputs`, `graph_captured_input`, and `graph_captured_result` corresponding
1487            # to `graph`.
1488            example_inputs = []
1489            graph_captured_input = ()
1490            graph_captured_result = ()
1491            fake_mode = torch._subclasses.FakeTensorMode(
1492                shape_env=ShapeEnv(), export=True
1493            )
1494            if out_guards is None:
1495                out_guards = _guards.GuardsSet()
1496            assert out_guards is not None  # suppress mypy error
1497            parameter_names = list(original_signature.parameters.keys())
1498            fx_graph = torch.fx.Graph()
1499            for i, name in enumerate(parameter_names):
1500                if torch.is_tensor(flat_args[i]):
1501                    node = fx_graph.placeholder(name)
1502                    node.meta["val"] = fake_mode.from_tensor(
1503                        flat_args[i], static_shapes=True
1504                    )
1505                    graph_captured_input = graph_captured_input + (flat_args[i],)
1506                    example_inputs.append(flat_args[i])
1507            fx_graph.output(graph_captured_result)
1508            module = torch.nn.Module()
1509            graph = torch.fx.GraphModule(module, fx_graph)
1510            log.info(
1511                "Failed to capture a graph during tracing as no tensor operations were found.:\n\n%s",
1512                graph.print_readable(print_output=False, colored=True),
1513            )
1514        else:
1515            assert hasattr(graph, "_source_to_user_stacks")
1516            assert out_guards is not None, "Failed to produce guards during tracing"
1517            assert fake_mode is not None
1518
1519            log.info(
1520                "Dynamo captured graph:\n\n%s",
1521                graph.print_readable(print_output=False, colored=True),
1522            )
1523
1524            # This check need to happened before aten_graph
1525            # because placeholder's _source_node attribute is not preserved by make_fx
1526            if same_signature:
1527                check_signature_rewritable(graph)
1528
1529        # NB: This is mostly hitting the cache; Dynamo already converted these
1530        example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs]
1531
1532        if aten_graph:
1533            # Running graph with interpreter is needed for propagating the stack_trace
1534            def graph_with_interpreter(*args):
1535                with torch.fx.traceback.preserve_node_meta():
1536                    return torch.fx.Interpreter(graph).run(*args)  # type: ignore[arg-type]
1537
1538            with unset_fake_temporarily(), enable_python_dispatcher(), fake_mode:
1539                try:
1540                    graph = make_fx(
1541                        graph_with_interpreter,
1542                        decomposition_table=decomposition_table,
1543                        tracing_mode="real",
1544                        _allow_non_fake_inputs=True,
1545                        pre_dispatch=pre_dispatch,
1546                        _allow_fake_constant=False,
1547                    )(*example_fake_inputs)
1548                except CondOpArgsMismatchError as e:
1549                    # Wrap the internal error to the user-facing error
1550                    raise UserError(  # noqa: B904
1551                        UserErrorType.DYNAMIC_CONTROL_FLOW,
1552                        str(e),
1553                        case_name="cond_operands",
1554                    )
1555
1556            assert graph is not None
1557            for node in graph.graph.find_nodes(op="get_attr"):
1558                if isinstance(getattr(graph, node.target), torch.Tensor):  # type: ignore[arg-type]
1559                    node.meta["val"] = fake_mode.from_tensor(
1560                        getattr(graph, node.target), static_shapes=True  # type: ignore[arg-type]
1561                    )
1562
1563        if same_signature:
1564            flat_args_dynamic_dims = [
1565                {
1566                    c.dim
1567                    for c in (constraints or ())
1568                    if (
1569                        c.t_id == id(x)
1570                        and c.constraint_range.vr.lower != c.constraint_range.vr.upper
1571                    )
1572                }
1573                for x in flat_args
1574            ]
1575            graph = rewrite_signature(
1576                original_signature,
1577                graph,
1578                fake_mode,
1579                flat_args,
1580                in_spec,
1581                example_fake_inputs,
1582                graph_captured_input,
1583                graph_captured_result,
1584                result_traced,  # type: ignore[possibly-undefined]
1585                flat_args_dynamic_dims,
1586            )
1587        return ExportResult(graph, out_guards)  # type: ignore[arg-type]
1588
1589    if extra_args or extra_kwargs:
1590        warnings.warn(
1591            "export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead.  "
1592            "If you don't migrate, we may break your export call in the future if your user defined kwargs "
1593            "conflict with future kwargs added to export(f).",
1594            FutureWarning,
1595            stacklevel=2,
1596        )
1597        return inner(*extra_args, **extra_kwargs)
1598    else:
1599        return inner
1600
1601
1602def optimize_assert(
1603    backend,
1604    *,
1605    hooks=Hooks(None, None),
1606    export=False,
1607    export_constraints=None,
1608    dynamic=None,
1609    rebuild_ctx=None,
1610):
1611    """
1612    The same as `torch._dynamo.optimize(backend, nopython=True)`
1613    """
1614    backend = get_compiler_fn(backend)
1615
1616    # Find if backend has any extra context manager
1617    backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
1618
1619    return _optimize_catch_errors(
1620        convert_frame.convert_frame_assert(
1621            backend, export=export, export_constraints=export_constraints
1622        ),
1623        hooks,
1624        backend_ctx_ctor,
1625        export=export,
1626        dynamic=dynamic,
1627        rebuild_ctx=rebuild_ctx,
1628    )
1629
1630
1631class TorchPatcher:
1632    @staticmethod
1633    @functools.lru_cache(None)
1634    def patch():
1635        # A better way to disable the following would be decorate the source
1636        # functions with @torch._disable_dynamo. However, this causes issues
1637        # with torch.deploy internally.
1638        from .decorators import disable
1639
1640        torch.jit.trace = disable(torch.jit.trace)
1641        torch.jit.trace_module = disable(torch.jit.trace_module)
1642        torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph)
1643        torch.fx._symbolic_trace.Tracer.trace = disable(
1644            torch.fx._symbolic_trace.Tracer.trace
1645        )
1646        torch.distributions.Distribution.set_default_validate_args(False)
1647
1648        from torch.optim import (
1649            adadelta,
1650            adagrad,
1651            adam,
1652            adamax,
1653            adamw,
1654            asgd,
1655            lbfgs,
1656            nadam,
1657            radam,
1658            rmsprop,
1659            rprop,
1660            sgd,
1661            sparse_adam,
1662        )
1663
1664        optimizer_modules = {
1665            adadelta,
1666            adagrad,
1667            adam,
1668            adamax,
1669            adamw,
1670            asgd,
1671            lbfgs,
1672            nadam,
1673            radam,
1674            rmsprop,
1675            rprop,
1676            sgd,
1677            sparse_adam,
1678        }
1679
1680        for opt_mod in optimizer_modules:
1681            opt_name = opt_mod.__name__.split(".")[-1]
1682            fused_fn_name = f"_fused_{opt_name}"
1683            single_tensor_fn_name = f"_single_tensor_{opt_name}"
1684
1685            if hasattr(opt_mod, fused_fn_name):
1686                setattr(
1687                    opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name))
1688                )
1689
1690        optimizer_classes = [
1691            opt
1692            for opt in torch.optim.__dict__.values()
1693            if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
1694        ]
1695
1696        # Note: we don't support sparsity or tracing through backwards
1697        excluded_optimizer_classes = {
1698            torch.optim.SparseAdam,
1699            torch.optim.LBFGS,
1700        }
1701
1702        for opt in optimizer_classes:
1703            if opt in excluded_optimizer_classes:
1704                opt.step = disable(opt.step)
1705
1706            if hasattr(opt, "_init_group"):
1707                opt._init_group = disable(opt._init_group)
1708
1709    @staticmethod
1710    def suppress_torch_distributed_warnings(fn):
1711        def inner_fn(*args, **kwargs):
1712            warnings.filterwarnings(
1713                "ignore", category=UserWarning, module="torch.distributed"
1714            )
1715            return fn(*args, **kwargs)
1716
1717        return inner_fn
1718