xref: /aosp_15_r20/external/pytorch/torch/_dynamo/convert_frame.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2from __future__ import annotations
3
4import collections
5import contextlib
6import cProfile
7import dis
8import functools
9import itertools
10import logging
11import os
12import pstats
13import random
14import subprocess
15import sys
16import threading
17import time
18import traceback
19import typing
20import weakref
21from pathlib import Path
22from types import CodeType, FrameType, FunctionType, ModuleType
23from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union
24from typing_extensions import ParamSpec
25from weakref import ReferenceType
26
27import torch
28import torch._logging
29from torch._C._dynamo.guards import GlobalStateGuard
30from torch._dynamo.distributed import get_compile_pg
31from torch._dynamo.utils import CompileTimeInstructionCounter
32from torch._guards import compile_context, CompileContext, CompileId, tracing
33from torch._logging import structured
34from torch._utils_internal import (
35    compile_time_strobelight_meta,
36    justknobs_check,
37    maybe_upload_prof_stats_to_manifold,
38    signpost_event,
39)
40from torch.fx._lazy_graph_module import _use_lazy_graph_module
41from torch.fx.experimental.symbolic_shapes import (
42    ConstraintViolationError,
43    GuardOnDataDependentSymNode,
44)
45from torch.fx.graph_module import _forward_from_src as original_forward_from_src
46from torch.nn.parallel.distributed import DistributedDataParallel
47from torch.utils._python_dispatch import (
48    _disable_current_modes,
49    is_in_torch_dispatch_mode,
50)
51from torch.utils._traceback import CapturedTraceback, format_traceback_short
52
53from . import config, exc, trace_rules
54from .bytecode_analysis import remove_dead_code, remove_pointless_jumps
55from .bytecode_transformation import (
56    check_inst_exn_tab_entries_valid,
57    Instruction,
58    is_generator,
59    propagate_inst_exn_table_entries,
60    transform_code_object,
61)
62from .cache_size import (
63    CacheSizeRelevantForFrame,
64    compute_cache_size,
65    exceeds_cache_size_limit,
66    is_recompilation,
67)
68from .eval_frame import always_optimize_code_objects, skip_code, TorchPatcher
69from .exc import (
70    augment_exc_message,
71    BackendCompilerFailed,
72    CacheLimitExceeded,
73    format_error_msg,
74    InternalTorchDynamoError,
75    SkipCodeRecursiveException,
76    TorchRuntimeError,
77    UncapturedHigherOrderOpError,
78    unimplemented,
79    Unsupported,
80)
81from .guards import (
82    CheckFunctionManager,
83    get_and_maybe_log_recompilation_reason,
84    GuardedCode,
85)
86from .hooks import Hooks
87from .replay_record import ExecutionRecord
88from .symbolic_convert import (
89    DistributedState,
90    InstructionTranslator,
91    LocalState,
92    SpeculationLog,
93)
94from .trace_rules import is_numpy
95from .utils import (
96    CleanupManager,
97    CompilationMetrics,
98    counters,
99    dynamo_timed,
100    format_bytecode,
101    frame_phase_timing,
102    gen_record_file_name,
103    get_chromium_event_logger,
104    increment_frame,
105    is_namedtuple,
106    istype,
107    LazyString,
108    orig_code_map,
109    record_compilation_metrics,
110    reset_graph_break_dup_checker,
111    setup_compile_debug,
112    troubleshooting_url,
113    write_record_to_file,
114)
115
116
117np: Optional[ModuleType]
118try:
119    import numpy as np
120except ModuleNotFoundError:
121    np = None
122
123
124if typing.TYPE_CHECKING:
125    from .backends.registry import CompilerFn
126    from .repro.after_dynamo import WrapBackendDebug
127    from .types import BytecodeHook, CacheEntry
128    from .variables.builder import FrameStateSizeEntry
129
130
131log = logging.getLogger(__name__)
132bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode")
133graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
134
135
136compile_lock = threading.RLock()
137
138_T = TypeVar("_T")
139_P = ParamSpec("_P")
140
141
142class TODO_UNKNOWN:
143    pass
144
145
146class Tracker:
147    def __init__(self) -> None:
148        self.seen: List[ReferenceType[CodeType]] = []
149        self.seen_ids: Set[int] = set()
150
151    def add(self, strong_obj: CodeType) -> None:
152        idx = id(strong_obj)
153        if idx not in self.seen_ids:
154            obj = weakref.ref(strong_obj, lambda _: self.seen_ids.remove(idx))
155            self.seen.append(obj)
156            self.seen_ids.add(idx)
157
158    def __contains__(self, item: CodeType) -> bool:
159        return id(item) in self.seen_ids
160
161    def clear(self) -> None:
162        self.seen.clear()
163        self.seen_ids.clear()
164
165
166input_codes = Tracker()
167output_codes = Tracker()
168
169initial_global_state: Optional[GlobalStateGuard] = None
170
171
172@functools.wraps(original_forward_from_src)
173def fx_forward_from_src_skip_result(
174    src: str, globals: Dict[str, Any], co_fields: Optional[Dict[str, str]] = None
175) -> FunctionType:
176    # we monkey patch FX to prevent infinite loop of trying to convert
177    # our generated code
178    result = original_forward_from_src(src, globals, co_fields)
179    skip_code(result.__code__)
180    return result
181
182
183def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
184    """
185    Context manager to:
186        1) Save/restore torch.is_grad_enabled() state
187        2) Save/restore python random state
188        3) Save/restore torch random state
189        4) Monkey patch torch.fx.graph_module._forward_from_src
190    """
191
192    @functools.wraps(fn)
193    def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
194        guards = GlobalStateGuard()
195        prior_grad_mode = torch.is_grad_enabled()
196        # Just in case we get left in a bad dispatch state we want to restore
197        # it. This can happen because the dispatch bits aren't a true
198        # stack/counter - so we can't just increment/decrement them as we enter
199        # and leave.
200        with torch._C._PreserveDispatchKeyGuard():
201            prior_inference_mode = torch.is_inference_mode_enabled()
202            prior_deterministic = torch.are_deterministic_algorithms_enabled()
203            prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
204            py_rng_state = random.getstate()
205            torch_rng_state = torch.random.get_rng_state()
206            cuda_rng_state = None
207            if torch.cuda.is_available():
208                cuda_rng_state = torch.cuda.get_rng_state()
209            allow_tf32 = torch._C._get_cublas_allow_tf32()
210            prior_fwd_from_src = torch.fx.graph_module._forward_from_src
211            torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
212            cleanup = setup_compile_debug()
213
214            exit_stack = contextlib.ExitStack()
215            exit_stack.enter_context(
216                torch.fx._symbolic_trace._maybe_revert_all_patches()
217            )
218            try:
219                return fn(*args, **kwargs)
220            finally:
221                cleanup.close()
222                exit_stack.close()
223                torch._C._set_grad_enabled(prior_grad_mode)
224                torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
225                torch.use_deterministic_algorithms(
226                    prior_deterministic, warn_only=prior_warn_only
227                )
228                random.setstate(py_rng_state)
229                torch.random.set_rng_state(torch_rng_state)
230                if cuda_rng_state is not None:
231                    torch.cuda.set_rng_state(cuda_rng_state)
232                torch._C._set_cublas_allow_tf32(allow_tf32)
233                torch.fx.graph_module._forward_from_src = prior_fwd_from_src
234                assert (
235                    guards.check()
236                ), f"Global {guards.reason()}state changed while dynamo tracing, please report a bug"
237
238    _fn._torchdynamo_orig_callable = fn  # type: ignore[attr-defined]
239    return _fn
240
241
242@TorchPatcher.suppress_torch_distributed_warnings
243def has_tensor_in_frame(frame: FrameType) -> bool:
244    """Check if the frame has torch.* related bits"""
245    # Check if the function was decorated using torch._dynamo.optimize
246    if frame.f_code in always_optimize_code_objects:
247        return True
248
249    # Check if there is global import of torch.*
250    for co_name in frame.f_code.co_names:
251        if co_name in frame.f_globals:
252            obj = frame.f_globals[co_name]
253            if isinstance(obj, ModuleType) and (
254                obj.__name__.startswith("torch.") or obj is torch
255            ):
256                return True
257            # ... or a global import of numpy.*
258            if np and config.trace_numpy and (obj is np or is_numpy(obj)):
259                return True
260
261    seen_ids: Dict[int, bool] = {}
262
263    def has_tensor(obj: object) -> bool:
264        """Recursively check if the obj has a tensor"""
265        obj_id = id(obj)
266        if obj_id in seen_ids:
267            return seen_ids[obj_id]
268        seen_ids[obj_id] = False
269
270        if isinstance(obj, (torch.Tensor, torch.nn.Module)) or (
271            istype(obj, type) and issubclass(obj, torch.nn.Module)
272        ):
273            seen_ids[obj_id] = True
274            return seen_ids[obj_id]
275        elif (
276            config.trace_numpy
277            and np
278            and (istype(obj, np.ndarray) or isinstance(obj, np.generic))
279        ):
280            seen_ids[obj_id] = True
281            return seen_ids[obj_id]
282        elif istype(obj, (list, tuple)):
283            seen_ids[obj_id] = any(has_tensor(v) for v in obj)
284            return seen_ids[obj_id]
285        elif istype(obj, dict):
286            # Some packages like pytest can be updated during runtime. So, make a
287            # copy of values to avoid issues like "RuntimeError: dictionary
288            # changed size during iteration"
289            values = list(obj.values())
290            seen_ids[obj_id] = any(has_tensor(v) for v in values)
291            return seen_ids[obj_id]
292        elif istype(obj, (str, int, float, type(None), bool)):
293            seen_ids[obj_id] = False
294            return seen_ids[obj_id]
295        elif is_namedtuple(obj) and hasattr(obj, "_fields"):
296            seen_ids[obj_id] = any(has_tensor(getattr(obj, v)) for v in obj._fields)
297            return seen_ids[obj_id]
298        else:
299            # if config.debug:
300            #     print(
301            #         f"Assuming that object of type {type(obj)} does not have a tensor"
302            #     )
303            return False
304
305    # Check if the passed arguments are of type Tensor
306    for value in frame.f_locals.values():
307        if has_tensor(value):
308            return True
309
310    log.debug(
311        "skipping because no torch.* %s \
312            %s %s",
313        frame.f_code.co_name,
314        frame.f_code.co_filename,
315        frame.f_code.co_firstlineno,
316    )
317
318    return False
319
320
321def exception_handler(
322    e: Exception,
323    code: CodeType,
324    frame: Optional[FrameType] = None,
325    export: bool = False,
326) -> None:
327    record_filename = None
328    if hasattr(e, "exec_record"):
329        record_filename = gen_record_file_name(e, code)
330        write_record_to_file(record_filename, e.exec_record)
331        e.record_filename = record_filename  # type: ignore[attr-defined]
332
333    augment_exc_message(e, export=export)
334
335
336FRAME_COUNTER = 0
337FRAME_COMPILE_COUNTER: typing.Counter[
338    Union[int, FrameStateSizeEntry]
339] = collections.Counter()
340
341
342def maybe_cprofile(func: Callable[_P, _T]) -> Callable[_P, _T]:
343    if config.cprofile:
344        return cprofile_wrapper(func)
345    return func
346
347
348def cprofile_wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]:
349    @functools.wraps(func)
350    def profile_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
351        trace_id = CompileContext.current_trace_id()
352        assert trace_id, "Trace id is None"
353        profile_path = Path(
354            f"/tmp/{func.__name__}_{str(trace_id).replace('/', '_')}.profile"
355        )
356        prof = cProfile.Profile()
357        prof.enable()
358        start_ts = time.time()
359        retval = prof.runcall(func, *args, **kwargs)
360        profile_latency = time.time() - start_ts
361        prof.disable()
362        log.warning(
363            "### Cprofile for %s trace id [%s] took %.3f seconds ###",
364            func.__name__,
365            trace_id,
366            profile_latency,
367        )
368        ps = pstats.Stats(prof)
369        try:
370            prof.dump_stats(profile_path)
371        except PermissionError:
372            log.exception("Cannot write to %s", profile_path)
373        log.warning("Raw profile at %s", profile_path)
374        svg_path = profile_path.with_suffix(".svg")
375        try:
376            gprof2dot_process = subprocess.Popen(
377                [
378                    "gprof2dot",
379                    "-f",
380                    "pstats",
381                    "--node-label=total-time-percentage",
382                    "--node-label=self-time-percentage",
383                    "--node-label=total-time",
384                    str(profile_path),
385                ],
386                stdout=subprocess.PIPE,
387            )
388            subprocess.check_call(
389                ["dot", "-Tsvg", "-o", str(svg_path)],
390                stdin=gprof2dot_process.stdout,
391            )
392            log.warning("Generated SVG from profile at %s", svg_path)
393        except FileNotFoundError:
394            log.warning(
395                "Failed to generate SVG from profile -- dumping stats instead."
396                "Try installing gprof2dot and dot for a better visualization"
397            )
398            ps.sort_stats(pstats.SortKey.TIME).print_stats(20)
399            ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20)
400
401        if manifold_link := maybe_upload_prof_stats_to_manifold(
402            str(profile_path)
403        ):  # fb-only
404            torch._logging.trace_structured(
405                "link",
406                lambda: {"name": "cprofile_manifold_url", "url": manifold_link},
407            )
408        return retval
409
410    return profile_wrapper
411
412
413class ConvertFrameAssert:
414    def __init__(
415        self,
416        compiler_fn: CompilerFn,
417        one_graph: bool = True,
418        export: bool = False,
419        export_constraints: Optional[typing.Never] = None,
420    ) -> None:
421        # assert export_constraints is None
422        reset_graph_break_dup_checker()
423        self._torchdynamo_orig_callable = compiler_fn
424        self._one_graph = one_graph
425        self._export = export
426        self._export_constraints = export_constraints
427
428    @property
429    def _clone_with_backend(self) -> Callable[[CompilerFn], ConvertFrameAssert]:
430        return lambda backend: convert_frame_assert(
431            backend, self._one_graph, self._export, self._export_constraints
432        )
433
434    def __call__(
435        self,
436        frame: FrameType,
437        cache_entry: Optional[CacheEntry],
438        hooks: Hooks,
439        frame_state: Dict[str, Union[int, FrameStateSizeEntry]],
440        *,
441        skip: int = 0,
442    ) -> Optional[GuardedCode]:
443        increment_frame()
444
445        code = frame.f_code
446
447        cache_size = compute_cache_size(frame, cache_entry)
448        input_codes.add(code)
449        if code in output_codes:
450            return None
451        if (
452            os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION")
453            and os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") != code.co_name
454        ):
455            return None
456        if code.co_name == "<genexpr>" and code.co_filename.endswith(
457            (
458                "transformers/file_utils.py",
459                "transformers/utils/generic.py",
460                "diffusers/utils/outputs.py",
461            )
462        ):
463            # not needed, but cleans up torchbench error stats
464            return None
465        if code.co_name == "__setattr__":
466            # setattr could be tricky to handle generally,
467            # but also not likely useful to compile- skip the whole frame
468            return None
469        if code.co_name == "__init__" and code.co_filename.startswith(
470            os.path.dirname(torch.optim.__file__)
471        ):
472            # optimizer support is still incomplete see
473            # test_state_dict in test/dynamo/test_optimizers.py
474            return None
475
476        # Check if the frame is generated by an exec builtin call
477        # TODO - Running exec generated frame seems propagates f_globals to the
478        # next frames.
479        if code.co_name == "<module>" and code.co_filename == "<string>":
480            return None
481
482        if (
483            code.co_name == "<lambda>"
484            and code.co_filename == "<string>"
485            and not bool(frame.f_builtins)
486        ):
487            # namedtuple subclass constructor. Empty builtins cause issue with
488            # len keyword in LIST_LEN guard.
489            return None
490
491        if is_generator(code):
492            unimplemented("generator")
493
494        if not has_tensor_in_frame(frame):
495            return None
496
497        global initial_global_state
498        initial_global_state = GlobalStateGuard()
499
500        global FRAME_COUNTER
501        if "_id" not in frame_state:
502            frame_state["_id"] = FRAME_COUNTER
503            FRAME_COUNTER += 1
504        frame_id = frame_state["_id"]
505        assert isinstance(frame_id, int)
506
507        frame_compile_id = FRAME_COMPILE_COUNTER[frame_id]
508        FRAME_COMPILE_COUNTER[frame_id] += 1
509
510        compile_id = CompileId(frame_id, frame_compile_id)
511
512        signpost_event(
513            "dynamo",
514            "_convert_frame_assert._compile",
515            {
516                "co_name": code.co_name,
517                "frame_id": frame_id,
518                "compile_id": str(compile_id),
519                "co_filename": code.co_filename,
520                "co_firstlineno": code.co_firstlineno,
521                "cache_size": cache_size.num_cache_entries_with_same_id_matched_objs,
522                "accumulated_cache_size": cache_size.num_cache_entries,
523            },
524        )
525
526        return _compile(
527            frame.f_code,
528            frame.f_globals,
529            frame.f_locals,
530            frame.f_builtins,
531            self._torchdynamo_orig_callable,
532            self._one_graph,
533            self._export,
534            self._export_constraints,
535            hooks,
536            cache_entry,
537            cache_size,
538            frame,
539            frame_state=frame_state,
540            compile_id=compile_id,
541            skip=skip + 1,
542        )
543
544
545def convert_frame_assert(
546    compiler_fn: CompilerFn,
547    one_graph: bool = True,
548    export: bool = False,
549    export_constraints: Optional[typing.Never] = None,
550) -> ConvertFrameAssert:
551    """Fully convert a frame into an FX graph"""
552    return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints)
553
554
555from collections import OrderedDict
556
557from torch.utils.hooks import RemovableHandle
558
559
560if typing.TYPE_CHECKING:
561    from .output_graph import OutputGraph
562
563# we have to use `OrderedDict` to make `RemovableHandle` work.
564_bytecode_hooks: Dict[int, BytecodeHook] = OrderedDict()
565
566
567def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle:
568    """Register hooks for bytecode generated by Dynamo. The hook can do some
569    logging, as well as return a new code object to be used. Please refer
570    to `BytecodeHook` for the hook signature.
571    """
572    handle = RemovableHandle(_bytecode_hooks)
573    _bytecode_hooks[handle.id] = hook
574    return handle
575
576
577def _compile(
578    code: CodeType,
579    globals: Dict[str, object],
580    locals: Dict[str, object],
581    builtins: Dict[str, object],
582    compiler_fn: CompilerFn,
583    one_graph: bool,
584    export: bool,
585    export_constraints: Optional[typing.Never],
586    hooks: Hooks,
587    cache_entry: Optional[CacheEntry],
588    cache_size: CacheSizeRelevantForFrame,
589    frame: Optional[FrameType] = None,
590    frame_state: Optional[Dict[str, Union[int, FrameStateSizeEntry]]] = None,
591    *,
592    compile_id: CompileId,
593    skip: int = 0,
594) -> Optional[GuardedCode]:
595    from torch.fx.experimental.validator import (
596        bisect,
597        BisectValidationException,
598        translation_validation_enabled,
599        ValidationException,
600    )
601
602    # Only nonlocal defs here please!
603    # Time spent compiling this frame before restarting or failing analysis
604    dynamo_time_before_restart: float = 0.0
605    output: Optional[OutputGraph] = None
606    tracer: Optional[InstructionTranslator] = None
607
608    @preserve_global_state
609    def transform(
610        instructions: List[Instruction], code_options: Dict[str, object]
611    ) -> None:
612        nonlocal output
613        nonlocal tracer
614        speculation_log.restart()
615        tracer = InstructionTranslator(
616            instructions,
617            code,
618            locals,
619            globals,
620            builtins,
621            code_options,
622            compiler_fn,
623            one_graph,
624            export,
625            export_constraints,
626            mutated_closure_cell_contents,
627            frame_state=frame_state,
628            speculation_log=speculation_log,
629            distributed_state=distributed_state,
630        )
631
632        try:
633            with tracing(tracer.output.tracing_context), tracer.set_current_tx():
634                tracer.run()
635        except exc.UnspecializeRestartAnalysis:
636            speculation_log.clear()
637            raise
638        except (exc.SpeculationRestartAnalysis, exc.SkipFrame):
639            raise
640        except Exception:
641            if translation_validation_enabled():
642                bisect(tracer.output.shape_env)
643            raise
644        finally:
645            tracer.output.call_cleanup_hooks()
646
647        output = tracer.output
648        assert output is not None
649        assert output.output_instructions
650        instructions[:] = output.output_instructions
651        code_options.update(output.code_options)
652
653        if config.dead_code_elimination:
654            propagate_inst_exn_table_entries(instructions)
655            check_inst_exn_tab_entries_valid(instructions)
656            instructions[:] = remove_pointless_jumps(remove_dead_code(instructions))
657
658    def compile_inner(
659        code: CodeType,
660        one_graph: bool,
661        hooks: Hooks,
662        transform: Callable[[List[Instruction], Dict[str, Any]], Any],
663    ) -> Optional[GuardedCode]:
664        with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"):
665            with CompileTimeInstructionCounter.record():
666                return _compile_inner(code, one_graph, hooks, transform)
667
668    @compile_time_strobelight_meta(phase_name="compile_inner")
669    @maybe_cprofile
670    def _compile_inner(
671        code: CodeType,
672        one_graph: bool,
673        hooks: Hooks,
674        transform: Callable[[List[Instruction], Dict[str, Any]], Any],
675    ) -> Optional[GuardedCode]:
676        nonlocal dynamo_time_before_restart
677        last_attempt_start_time = start_time = time.time()
678
679        def log_bytecode(
680            prefix: str, name: str, filename: str, line_no: int, code: CodeType
681        ) -> None:
682            if bytecode_log.isEnabledFor(logging.DEBUG):
683                bytecode_log.debug(
684                    format_bytecode(prefix, name, filename, line_no, code)
685                )
686
687        log_bytecode(
688            "ORIGINAL BYTECODE",
689            code.co_name,
690            code.co_filename,
691            code.co_firstlineno,
692            code,
693        )
694
695        out_code = None
696        for attempt in itertools.count():
697            CompileContext.get().attempt = attempt
698            try:
699                out_code = transform_code_object(code, transform)
700                break
701            except exc.RestartAnalysis as e:
702                log.info(
703                    "Restarting analysis due to %s",
704                    LazyString(format_traceback_short, e.__traceback__),
705                )
706                # If restart reason is None just log the type of the exception
707                restart_reasons.add(e.restart_reason or str(type(e)))
708                # We now have a new "last attempt", reset the clock
709                last_attempt_start_time = time.time()
710                if attempt > 100:
711                    unimplemented("100+ RestartAnalysis() calls")
712            except exc.SkipFrame as e:
713                log.debug(
714                    "Skipping frame %s %s \
715                    %s %s",
716                    e,
717                    code.co_name,
718                    code.co_filename,
719                    code.co_firstlineno,
720                )
721                if one_graph:
722                    log.debug("No graph captured with one_graph=True")
723                return None
724
725        assert (
726            distributed_state is None or distributed_state.all_states is not None
727        ), "compiler collective wasn't run before compilation completed"
728
729        assert out_code is not None
730        log_bytecode(
731            "MODIFIED BYTECODE",
732            code.co_name,
733            code.co_filename,
734            code.co_firstlineno,
735            out_code,
736        )
737
738        for hook in _bytecode_hooks.values():
739            hook_output = hook(code, out_code)
740            if hook_output is not None:
741                out_code = hook_output
742
743        orig_code_map[out_code] = code
744        output_codes.add(out_code)
745        dynamo_time_before_restart = last_attempt_start_time - start_time
746        assert output is not None
747
748        # Tests for new code objects.
749        # The rationale for these tests can be found in torch/csrc/dynamo/eval_frame.c
750        # Only test once the code object is created.
751        # They are not tested during runtime.
752
753        def count_args(code: CodeType) -> int:
754            import inspect
755
756            return (
757                code.co_argcount
758                + code.co_kwonlyargcount
759                + bool(code.co_flags & inspect.CO_VARARGS)
760                + bool(code.co_flags & inspect.CO_VARKEYWORDS)
761            )
762
763        assert out_code is not None
764
765        total_argcount_old = count_args(code)
766        total_argcount_new = count_args(out_code)
767        msg = "arg mismatch: "
768        msg += f"old code object has args {code.co_varnames[:total_argcount_old]}, "
769        msg += f"new code object has args {out_code.co_varnames[:total_argcount_new]}"
770        assert (
771            code.co_varnames[:total_argcount_old]
772            == out_code.co_varnames[:total_argcount_new]
773        ), msg
774
775        msg = "free var mismatch: "
776        msg += f"old code object has free var {code.co_freevars}, "
777        msg += f"new code object has free var {out_code.co_freevars}"
778        assert code.co_freevars == out_code.co_freevars, msg
779
780        msg = "cell var mismatch: "
781        msg += f"old code object has cell var {code.co_cellvars}, "
782        msg += f"new code object has cell var {out_code.co_cellvars}"
783        assert code.co_cellvars == out_code.co_cellvars, msg
784
785        # Skipping Dynamo on a frame without any extracted graph.
786        # This does not affect eager functionality. But this is necessary
787        # for export for cases where Dynamo-reconstructed bytecode can create
788        # new function frames, confusing export in thinking that there
789        # are extra graphs now.
790
791        if output.export and output.is_empty_graph():
792            return None
793
794        assert output.guards is not None
795        CleanupManager.instance[out_code] = output.cleanups
796        check_fn = CheckFunctionManager(
797            output,
798            hooks.guard_fail_fn if hooks else None,
799        )
800
801        guarded_code = GuardedCode(out_code, check_fn.check_fn, compile_id)
802
803        if not output.is_empty_graph() and hooks.guard_export_fn is not None:
804            # We should not run the guard_export_fn when Dynamo does not
805            # generate any graph. This can happen in export when TorchDynamo
806            # generated bytecode has some reconstruction logic for mutated
807            # variables which can trigger TorchDynamo on the children frames but
808            # they are benign and do not generate any new graphs.
809            hooks.guard_export_fn(output.guards)
810
811        return guarded_code
812
813    with _use_lazy_graph_module(config.use_lazy_graph_module), compile_context(
814        CompileContext(compile_id)
815    ):
816        restart_reasons: set[str] = set()
817        # This is shared across restarts
818        mutated_closure_cell_contents: Set[str] = set()
819        speculation_log = SpeculationLog()
820        if compile_pg := get_compile_pg():
821            distributed_state = DistributedState(compile_pg, LocalState())
822        else:
823            distributed_state = None
824        torch._dynamo.callback_handler.run_start_callbacks()
825
826        # Check recompilations
827        recompile_reasons = None
828        if is_recompilation(cache_size) and frame:
829            recompile_reasons = get_and_maybe_log_recompilation_reason(
830                cache_entry, frame
831            )
832
833        exceeded, limit_type = exceeds_cache_size_limit(cache_size, compile_id)
834        if exceeded:
835
836            def format_func_info(code: CodeType) -> str:
837                return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})"
838
839            def format_guard_failures() -> str:
840                if not recompile_reasons:
841                    return "Unable to find recompilation reasons"
842                return recompile_reasons[-1]
843
844            log.warning(
845                "torch._dynamo hit config.%s (%s)\n"
846                "   function: %s\n"
847                "   last reason: %s\n"
848                'To log all recompilation reasons, use TORCH_LOGS="recompiles".\n'
849                "To diagnose recompilation issues, see %s.",
850                limit_type,
851                getattr(config, limit_type),
852                format_func_info(code),
853                format_guard_failures(),
854                troubleshooting_url,
855            )
856            if config.skip_code_recursive_on_cache_limit_hit and justknobs_check(
857                "pytorch/compiler:skip_code_recursive_on_cache_limit_hit"
858            ):
859                raise CacheLimitExceeded(f"{limit_type} reached")
860            else:
861                # do not recursively skip frames
862                unimplemented(f"{limit_type} reached")
863
864        log.debug(
865            "torchdynamo start compiling %s %s:%s, stack (elided %s frames):\n%s",
866            code.co_name,
867            code.co_filename,
868            code.co_firstlineno,
869            skip + 2,
870            # -2: omit current frame, omit contextlib decorator
871            "".join(CapturedTraceback.extract(skip=2 + skip).format()),
872        )
873        # -4: -2 as above, plus trace_structured frames
874        #
875        # NB: the frame looks like this:
876        #
877        # # handled by skip argument
878        # torch/_dynamo/convert_frame.py:1069 in catch_errors
879        # torch/_dynamo/convert_frame.py:910 in _convert_frame
880        # torch/_dynamo/convert_frame.py:464 in _convert_frame_assert
881        # torch/_utils_internal.py:70 in wrapper_function
882        #
883        # # 2 current frame and context lib
884        # env/lib/python3.10/contextlib.py:79 in inner
885        # torch/_dynamo/convert_frame.py:776 in _compile
886        #
887        # # 2 extra here
888        # torch/_logging/_internal.py:1064 in trace_structured
889        # torch/_dynamo/convert_frame.py:780 in <lambda>
890        convert_frame_intern = structured.intern_string(__file__)
891        # Initialize the ChromiumEventLogger on start
892        chromium_event_log = get_chromium_event_logger()
893        chromium_event_log.reset()
894        torch._logging.trace_structured(
895            "dynamo_start",
896            lambda: {
897                "stack": list(
898                    itertools.takewhile(
899                        lambda f: f["filename"] != convert_frame_intern,
900                        structured.from_traceback(
901                            CapturedTraceback.extract(skip=4 + skip).summary()
902                        ),
903                    )
904                )
905                + [
906                    {
907                        "line": code.co_firstlineno,
908                        "name": code.co_name,
909                        "filename": structured.intern_string(code.co_filename),
910                    }
911                ]
912            },
913        )
914        start_time = time.time()
915        fail_type: Optional[str] = None
916        fail_reason: Optional[str] = None
917        fail_user_frame_filename: Optional[str] = None
918        fail_user_frame_lineno: Optional[int] = None
919        start_possibly_missed_reinplacing_opportunities = torch._dynamo.utils.counters[
920            "inductor"
921        ]["possibly_missed_reinplacing_opportunities"]
922        guarded_code = None
923        try:
924            guarded_code = compile_inner(code, one_graph, hooks, transform)
925            return guarded_code
926        except Exception as e:
927            fail_type = type(e).__qualname__
928            fail_reason = str(e)
929            # NB: e's msg is mutated here to add user stack, but we DON'T want
930            # that stack in the Scuba logged fail_reason
931            exception_handler(e, code, frame, export=export)
932            fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message(
933                e, compile_id
934            )
935            if isinstance(
936                e,
937                (
938                    Unsupported,
939                    TorchRuntimeError,
940                    BackendCompilerFailed,
941                    AssertionError,
942                    ConstraintViolationError,
943                    GuardOnDataDependentSymNode,
944                    ValidationException,
945                    UncapturedHigherOrderOpError,
946                    BisectValidationException,
947                ),
948            ):
949                raise
950            else:
951                # Rewrap for clarity
952                raise InternalTorchDynamoError(
953                    f"{type(e).__qualname__}: {str(e)}"
954                ).with_traceback(e.__traceback__) from None
955        finally:
956            if tracer:
957                tracer.output.local_scope = {}
958
959            from .utils import curr_frame
960
961            frame_key = str(curr_frame)
962            if (
963                fail_reason is None
964                and output is not None
965                and frame_key in frame_phase_timing
966            ):
967                guard_count = len(output.guards)
968                shape_env_guard_count = len(output.shape_env.guards)
969                graph_op_count = output.count_calls()
970                graph_node_count = len(output.graph.nodes)
971                graph_input_count = len(output.placeholders)
972                entire_frame_compile_time = frame_phase_timing[frame_key].get(
973                    "entire_frame_compile", None
974                )
975                backend_compile_time = frame_phase_timing[frame_key].get(
976                    "backend_compile", None
977                )
978                inductor_compile_time = frame_phase_timing[frame_key].get(
979                    "inductor_compile", None
980                )
981                code_gen_time = frame_phase_timing[frame_key].get("code_gen", None)
982                non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops}
983                compliant_custom_ops = {
984                    op.__qualname__ for op in output.compliant_custom_ops
985                }
986                possibly_missed_reinplacing_opportunities = (
987                    torch._dynamo.utils.counters["inductor"][
988                        "possibly_missed_reinplacing_opportunities"
989                    ]
990                    - start_possibly_missed_reinplacing_opportunities
991                )
992            else:
993                guard_count = None
994                shape_env_guard_count = None
995                graph_op_count = None
996                graph_node_count = None
997                graph_input_count = None
998                entire_frame_compile_time = None
999                backend_compile_time = None
1000                inductor_compile_time = None
1001                code_gen_time = None
1002                non_compliant_ops = set({})
1003                compliant_custom_ops = set({})
1004                restart_reasons = set()
1005                # If compilation failed, the entire time is wasted
1006                dynamo_time_before_restart = time.time() - start_time
1007                possibly_missed_reinplacing_opportunities = None
1008
1009            metrics = CompilationMetrics(
1010                str(compile_id),
1011                frame_key,
1012                code.co_name,
1013                code.co_filename,
1014                code.co_firstlineno,
1015                cache_size.num_cache_entries_with_same_id_matched_objs,
1016                cache_size.num_cache_entries,
1017                guard_count,
1018                shape_env_guard_count,
1019                graph_op_count,
1020                graph_node_count,
1021                graph_input_count,
1022                start_time,
1023                entire_frame_compile_time,
1024                backend_compile_time,
1025                inductor_compile_time,
1026                code_gen_time,
1027                fail_type,
1028                fail_reason,
1029                fail_user_frame_filename,
1030                fail_user_frame_lineno,
1031                non_compliant_ops,
1032                compliant_custom_ops,
1033                restart_reasons,
1034                dynamo_time_before_restart,
1035                guarded_code is not None,
1036                possibly_missed_reinplacing_opportunities,
1037            )
1038            record_compilation_metrics(metrics)
1039            torch._dynamo.callback_handler.run_end_callbacks()
1040
1041
1042class ConvertFrame:
1043    def __init__(self, compiler_fn: CompilerFn, hooks: Hooks) -> None:
1044        self._torchdynamo_orig_callable = compiler_fn
1045        self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
1046        self._hooks = hooks
1047
1048    @property
1049    def _clone_with_backend(self) -> Callable[[WrapBackendDebug], ConvertFrame]:
1050        return lambda backend: convert_frame(backend, self._hooks)
1051
1052    def __call__(
1053        self,
1054        frame: FrameType,
1055        cache_entry: Optional[CacheEntry],
1056        hooks: Hooks,
1057        frame_state: Dict[str, Union[int, FrameStateSizeEntry]],
1058        skip: int = 0,
1059    ) -> Optional[
1060        Union[GuardedCode, torch._C._dynamo.eval_frame.SkipCodeRecursiveFlag]
1061    ]:
1062        counters["frames"]["total"] += 1
1063        try:
1064            result = self._inner_convert(
1065                frame, cache_entry, hooks, frame_state, skip=skip + 1
1066            )
1067            counters["frames"]["ok"] += 1
1068            return result
1069        except Exception as e:
1070            # These two exception types are "soft" failure, in the sense that
1071            # we know this is due to something we didn't implement all the
1072            # way, scare the user less about it.  That being said, if you
1073            # are trying to understand why a graph break happened, it's still
1074            # important to have this information, so offer it.
1075            #
1076            # NB: NotImplementedError used to be on this list, but actually
1077            # it is impossible for it to reach here, as it is converted into
1078            # InternalTorchDynamoError.  This behavior seemed reasonable
1079            # to me (ezyang, Aug 2023) so I kept it, but maybe at some point
1080            # someone wanted these to also get suppressed.  If so, you'll
1081            # need to make these exceptions not get wrapped
1082
1083            # We intentionally don't want to suppress error here.
1084            if isinstance(e, UncapturedHigherOrderOpError):
1085                raise
1086
1087            soft_fail = isinstance(e, Unsupported)
1088
1089            # This is a soft failure. In the sense, the code path reaches here
1090            # when we do not support graph breaks on bytecodes like LOAD_ATTR,
1091            # BUILD_SET etc. In such case, we can fallback to eager without
1092            # scaring users.
1093            if isinstance(e, Unsupported) and graph_break_log.isEnabledFor(
1094                logging.DEBUG
1095            ):
1096                # Log this message in the graph break. Also use the string
1097                # "skip: " to tell that the whole frame is falling back to
1098                # eager.
1099                if hasattr(e, "compile_id"):
1100                    with compile_context(CompileContext(e.compile_id)):  # type: ignore[attr-defined]
1101                        user_stack = e.real_stack
1102                        user_stack_formatted = "".join(
1103                            traceback.format_list(user_stack)
1104                        )
1105                        graph_break_log.debug(
1106                            "Graph break: skip: from user code at:\n%s",
1107                            user_stack_formatted,
1108                            exc_info=True,
1109                        )
1110
1111            if not config.suppress_errors and not soft_fail:
1112                raise
1113
1114            # Suppress the error.  NB: It's very important to do the
1115            # suppression logging HERE, where the actual suppression
1116            # happens. Previously it was somewhere else and so it was
1117            # possible to accidentally not log at all.
1118            record_filename = getattr(e, "record_filename", None)
1119            code = frame.f_code
1120            error_msg = format_error_msg(e, code, record_filename, frame)
1121
1122            if soft_fail:
1123                log.info(error_msg, exc_info=True)
1124            else:
1125                log.warning(error_msg, exc_info=True)
1126
1127            # If we encounter SkipCodeRecursiveException, return skip_code_recursive_flag
1128            # to signal to Dynamo eval frame to skip the current frame and any recursive calls.
1129            if isinstance(e, SkipCodeRecursiveException):
1130                return torch._C._dynamo.eval_frame.skip_code_recursive_flag
1131
1132        return None
1133
1134
1135def convert_frame(compiler_fn: CompilerFn, hooks: Hooks) -> ConvertFrame:
1136    """Try to convert a frame into an FX graph, if error leave frame unmodified"""
1137    return ConvertFrame(compiler_fn, hooks)
1138
1139
1140# TODO mlazos: add support for same args, or record them
1141def replay(filename: str) -> None:
1142    from .backends.debugging import eager
1143
1144    original_replay_val = config.replay_record_enabled
1145    config.replay_record_enabled = False
1146    with open(filename, "rb") as in_file:
1147        record = ExecutionRecord.load(in_file)
1148    record.globals = dict(itertools.chain(record.globals.items(), globals().items()))
1149
1150    try:
1151        _compile(
1152            record.code,
1153            record.globals,
1154            record.locals,
1155            record.builtins,
1156            compiler_fn=eager,
1157            one_graph=False,
1158            export=False,
1159            export_constraints=None,
1160            hooks=Hooks(),
1161            cache_size=CacheSizeRelevantForFrame(0, 0),
1162            cache_entry=None,
1163            frame=None,
1164            frame_state={},
1165            compile_id=CompileId(42, 999),
1166        )
1167    finally:
1168        config.replay_record_enabled = original_replay_val
1169
1170
1171def first_real_inst_idx(code: CodeType) -> int:
1172    if sys.version_info < (3, 11):
1173        return 0
1174    for inst in dis.get_instructions(code):
1175        if inst.opname == "RESUME":
1176            return inst.offset // 2
1177    raise RuntimeError("RESUME instruction not found in code")
1178
1179
1180class ConvertFrameProtocol(typing.Protocol):
1181    def __call__(
1182        self,
1183        frame: FrameType,
1184        cache_entry: Optional[CacheEntry],
1185        hooks: Hooks,
1186        frame_state: Dict[str, Union[int, FrameStateSizeEntry]],
1187        *,
1188        skip: int = 0,
1189    ) -> Optional[GuardedCode]:
1190        ...
1191
1192
1193class CatchErrorsWrapper:
1194    def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None:
1195        functools.wraps(callback)(self)
1196        self._torchdynamo_orig_callable = callback
1197        self.hooks = hooks
1198
1199    def __call__(
1200        self,
1201        frame: FrameType,
1202        cache_entry: Optional[CacheEntry],
1203        frame_state: Dict[str, Union[int, FrameStateSizeEntry]],
1204    ) -> Optional[GuardedCode]:
1205        assert frame_state is not None
1206
1207        is_skipfile = trace_rules.check(frame.f_code)
1208        if sys.version_info >= (3, 13):
1209            has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code)
1210        else:
1211            has_started_execution = frame.f_lasti >= first_real_inst_idx(frame.f_code)
1212        if (
1213            # TODO: the first condition is not covered by any test
1214            has_started_execution
1215            or is_skipfile
1216            or config.disable
1217            or (
1218                is_in_torch_dispatch_mode(include_infra_modes=False)
1219                and not getattr(self._torchdynamo_orig_callable, "_export", False)
1220            )
1221        ):
1222            if log.isEnabledFor(logging.DEBUG):
1223                print(frame.f_lasti, first_real_inst_idx(frame.f_code))
1224
1225                if has_started_execution:
1226                    skip_reason = "traced frame already"
1227                elif trace_rules.check(frame.f_code):
1228                    skip_reason = "in skipfiles"
1229                elif is_in_torch_dispatch_mode(include_infra_modes=False):
1230                    skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile"
1231                else:
1232                    skip_reason = "dynamo tracing is disabled"
1233
1234                log.debug(
1235                    "skipping: %s (reason: %s, file: %s)",
1236                    frame.f_code.co_name,
1237                    skip_reason,
1238                    frame.f_code.co_filename,
1239                )
1240            return None
1241
1242        if frame.f_code.co_filename == "<string>" and frame.f_code.co_name == "__new__":
1243            # nametuple constructor
1244            return None
1245        if config._get_optimize_ddp_mode() == "ddp_optimizer":
1246            ddp_module = DistributedDataParallel._get_active_ddp_module()
1247            if ddp_module:
1248                with compile_lock:
1249                    from torch._dynamo.backends.distributed import DDPOptimizer
1250
1251                    ddp_optimizer = DDPOptimizer(
1252                        bucket_bytes_cap=ddp_module.bucket_bytes_cap,
1253                        backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable,  # type: ignore[attr-defined]
1254                    )
1255                    assert hasattr(
1256                        self._torchdynamo_orig_callable, "_clone_with_backend"
1257                    ), "DDPOptimizer only supports callback fns that know how to clone themselves."
1258                    hijacked_callback = (
1259                        self._torchdynamo_orig_callable._clone_with_backend(
1260                            ddp_optimizer.compile_fn,
1261                        )
1262                    )
1263                    return hijacked_callback(
1264                        frame, cache_entry, self.hooks, frame_state
1265                    )
1266
1267        with compile_lock, _disable_current_modes():
1268            # skip=1: skip this frame
1269            return self._torchdynamo_orig_callable(
1270                frame, cache_entry, self.hooks, frame_state, skip=1
1271            )
1272
1273
1274def catch_errors_wrapper(
1275    callback: ConvertFrameProtocol, hooks: Hooks
1276) -> CatchErrorsWrapper:
1277    return CatchErrorsWrapper(callback, hooks)
1278