xref: /aosp_15_r20/external/pytorch/torch/_dynamo/symbolic_convert.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import collections.abc
4import contextlib
5import copy
6import dataclasses
7import dis
8import functools
9import importlib
10import inspect
11import itertools
12import linecache
13import logging
14import operator
15import re
16import sys
17import threading
18import traceback
19import types
20import typing
21import weakref
22from typing import (
23    Any,
24    Callable,
25    cast,
26    Deque,
27    Dict,
28    List,
29    Optional,
30    Set,
31    Tuple,
32    Type,
33    TYPE_CHECKING,
34    Union,
35)
36from unittest.mock import patch
37
38import torch
39import torch._logging
40from torch._guards import tracing, TracingContext
41
42from . import config, exc, logging as torchdynamo_logging, trace_rules, variables
43from .bytecode_analysis import (
44    get_indexof,
45    JUMP_OPNAMES,
46    livevars_analysis,
47    propagate_line_nums,
48)
49from .bytecode_transformation import (
50    cleaned_instructions,
51    create_call_function,
52    create_instruction,
53    create_jump_absolute,
54    create_swap,
55    get_code_keys,
56    Instruction,
57    is_generator,
58    unique_id,
59)
60from .code_context import code_context
61from .codegen import PyCodegen
62from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported
63from .funcname_cache import get_funcname
64from .guards import GuardBuilder, install_guard
65from .output_graph import GraphCompileReason, OutputGraph
66from .replay_record import DummyModule, ExecutionRecorder
67from .resume_execution import ContinueExecutionCache, ReenterWith
68from .source import (
69    AttrSource,
70    GetItemSource,
71    GlobalSource,
72    GlobalWeakRefSource,
73    LocalSource,
74    Source,
75    TorchFunctionModeStackSource,
76)
77from .trace_rules import is_builtin_constant, is_forbidden
78from .utils import (
79    counters,
80    get_fake_value,
81    get_instruction_source_311,
82    get_torch_function_mode_stack,
83    graph_break_dup_warning_checker,
84    istype,
85    LazyString,
86    proxy_args_kwargs,
87)
88from .variables.base import is_side_effect_safe, MutableLocal, typestr, VariableTracker
89from .variables.builder import VariableBuilder, wrap_fx_proxy
90from .variables.builtin import BuiltinVariable
91from .variables.constant import ConstantVariable
92from .variables.ctx_manager import (
93    ContextWrappingVariable,
94    GenericContextWrappingVariable,
95    WithExitFunctionVariable,
96)
97from .variables.dicts import ConstDictVariable, SetVariable
98from .variables.functions import (
99    BaseUserFunctionVariable,
100    NestedUserFunctionVariable,
101    SkipFunctionVariable,
102    UserFunctionVariable,
103    UserMethodVariable,
104)
105from .variables.iter import MAX_ITERATOR_LIMIT
106from .variables.lists import (
107    BaseListVariable,
108    ListIteratorVariable,
109    ListVariable,
110    SliceVariable,
111    TupleVariable,
112)
113from .variables.misc import (
114    ClosureVariable,
115    GetAttrVariable,
116    InlinedClosureVariable,
117    NullVariable,
118    PythonModuleVariable,
119    UnknownVariable,
120)
121from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
122from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
123
124
125if TYPE_CHECKING:
126    from .variables.torch_function import TorchFunctionModeVariable
127
128from .variables.user_defined import (
129    RemovableHandleVariable,
130    UserDefinedClassVariable,
131    UserDefinedObjectVariable,
132)
133
134
135log = logging.getLogger(__name__)
136graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
137trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
138trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
139trace_bytecode_log = torch._logging.getArtifactLogger(__name__, "trace_bytecode")
140tls = threading.local()
141compare_op_handlers: Dict[str, Any] = {
142    k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items()
143}
144handle_contains = BuiltinVariable(operator.contains).call_function
145handle_not = BuiltinVariable(operator.not_).call_function
146compare_op_handlers["in"] = lambda tx, args, _: handle_contains(
147    tx, [*reversed(args)], {}
148)
149compare_op_handlers["not in"] = lambda tx, args, _: handle_not(
150    tx, [handle_contains(tx, [*reversed(args)], {})], {}
151)
152
153
154PT2_ISSUE_TRACKER_URL = "https://github.com/pytorch/pytorch/issues/new?&labels=oncall%3A+pt2&projects=&template=pt2-bug-report.yml"
155
156
157@dataclasses.dataclass
158class SpeculationEntry:
159    filename: str
160    lineno: int
161    instruction_pointer: int
162    inst: Instruction  # for debugging only
163    failed: bool = False
164    reason: Optional[GraphCompileReason] = None
165
166    def fail_and_restart_analysis(self):
167        """
168        Start tracing of the current frame over again, and don't take this branch.
169        """
170        self.failed = True
171        if self.reason is not None:
172            restart_reason = self.reason.reason
173        else:
174            restart_reason = "Unknown fail_and_restart_analysis"
175        raise exc.SpeculationRestartAnalysis(restart_reason=restart_reason)
176
177
178@dataclasses.dataclass
179class SpeculationLog:
180    """
181    SpeculationLog replaces the prior copy_graphstate/restore_graphstate
182    checkpointing.  Rather than saving/restoring state, we restart the
183    dynamo conversion process over from the beginning -- but when we
184    hit the start of the speculation that failed, we instead generate
185    a graph break.
186    """
187
188    entries: List[SpeculationEntry] = dataclasses.field(default_factory=list)
189    index: int = 0
190
191    def restart(self):
192        self.index = 0
193
194    def clear(self):
195        self.entries.clear()
196        self.index = 0
197
198    def next(
199        self, filename: str, lineno: int, instruction_pointer, inst
200    ) -> SpeculationEntry:
201        """
202        Lookup or create a SpeculationEntry() that is shared across
203        RestartAnalysis calls.  Args are used only for debug checks.
204        """
205        if len(self.entries) == self.index:
206            self.entries.append(
207                SpeculationEntry(filename, lineno, instruction_pointer, inst)
208            )
209        entry = self.entries[self.index]
210        prev_entry_msg = ""
211        if self.index != 0:
212            prev_entry = self.entries[self.index - 1]
213            prev_entry_msg = (
214                f"Previous instruction: {prev_entry.filename}:{prev_entry.lineno}"
215                f"({prev_entry.inst.opname} @ {prev_entry.instruction_pointer})\n"
216            )
217        assert (
218            entry.instruction_pointer == instruction_pointer
219            and entry.filename == filename
220            and entry.lineno == lineno
221        ), f"""
222SpeculationLog diverged at index {self.index} (log had {len(self.entries)} entries):
223- Expected: {entry.filename}:{entry.lineno} ({entry.inst.opname} at ip={entry.instruction_pointer})
224- Actual: {filename}:{lineno} ({inst.opname} at ip={instruction_pointer})
225{prev_entry_msg}
226There are two usual reasons why this may have occured:
227- When Dynamo analysis restarted, the second run took a different path than
228  the first.  If this occurred, the previous instruction is the critical instruction that
229  behaved differently.
230- Speculation entries are only added under certain conditions (as seen in
231  step()), e.g., there must exist operators in the graph; those conditions may
232  have changed on restart.
233
234If this divergence was intentional, clear the speculation log before restarting (do NOT
235do this for graph breaks, you will infinite loop).
236
237Otherwise, please submit a bug report, ideally including the contents of TORCH_LOGS=+dynamo
238"""
239        self.index += 1
240        return entry
241
242
243@dataclasses.dataclass
244class LocalState:
245    input_sizes: Dict[str, List[int]] = dataclasses.field(default_factory=dict)
246    input_strides: Dict[str, List[int]] = dataclasses.field(default_factory=dict)
247
248
249# Mutable box that is shared across restarts
250@dataclasses.dataclass
251class DistributedState:
252    compile_pg: Any
253    local_state: LocalState
254    all_states: Optional[List[LocalState]] = None
255
256
257@functools.lru_cache(None)
258def _step_logger():
259    return torchdynamo_logging.get_step_logger(log)
260
261
262@dataclasses.dataclass
263class BlockStackEntry:
264    # Current instruction that pushes something to block_stack
265    inst: Instruction
266    target: Instruction
267    stack_index: Optional[int] = None
268    with_context: Optional[
269        Union[ContextWrappingVariable, GenericContextWrappingVariable]
270    ] = None
271
272    def can_restore(self):
273        return self.with_context is not None
274
275    def resume_fn(self):
276        assert self.stack_index is not None
277        if (
278            self.with_context
279            and hasattr(self.with_context, "target_values")
280            and self.with_context.target_values
281        ):
282            return ReenterWith(self.stack_index, tuple(self.with_context.target_values))
283        else:
284            return ReenterWith(self.stack_index)
285
286    def exit(self, tx):
287        assert self.with_context is not None
288        return self.with_context.exit(tx)
289
290
291class ReturnValueOp(Exception):
292    pass
293
294
295def stack_op(fn: typing.Callable[..., object]):
296    nargs = len(inspect.signature(fn).parameters)
297    fn_var = BuiltinVariable(fn)
298
299    @functools.wraps(fn)
300    def impl(self: "InstructionTranslator", inst: Instruction):
301        self.push(fn_var.call_function(self, self.popn(nargs), {}))
302
303    return impl
304
305
306def _detect_and_normalize_assert_statement(
307    self: "InstructionTranslatorBase",
308    truth_fn: typing.Callable[[object], bool],
309    push: bool,
310):
311    # Detect if this jump instruction is assert and normalize the assert
312    # by pushing dummy error message when nothing is given.
313    #
314    # Python 3.9 assertion is in following format:
315    # 18 POP_JUMP_IF_TRUE       28
316    # 20 LOAD_ASSERTION_ERROR
317    # 22 LOAD_CONST               3 ('Assert message') -> optional instruction
318    # 24 CALL_FUNCTION            1                    -> optional instruction
319    # 26 RAISE_VARARGS
320    #
321    # Python 3.8 assertion is in following format:
322    # 18 POP_JUMP_IF_TRUE       28
323    # 20 LOAD_GLOBAL              0 (Assertion type)
324    # 22 LOAD_CONST               3 ('Assert message') -> optional instruction
325    # 24 CALL_FUNCTION            1                    -> optional instruction
326    # 26 RAISE_VARARGS            1
327
328    if (truth_fn is not operator.truth) or push:
329        return False
330
331    assert isinstance(self.instruction_pointer, int)
332    current_instruction_pointer = self.instruction_pointer
333    inst = self.instructions[current_instruction_pointer]
334    # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0
335    if sys.version_info < (3, 9):
336        if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError":
337            return False
338    else:
339        if inst.opname != "LOAD_ASSERTION_ERROR":
340            return False
341
342    current_instruction_pointer += 1
343
344    # Use dummy error message if its hard to extract
345    error_msg = "assertion error"
346
347    inst = self.instructions[current_instruction_pointer]
348    # DETECT RAISE_VARARGS or LOAD CONST
349    if inst.opname == "LOAD_CONST":
350        if not isinstance(inst.argval, str):
351            return False
352        error_msg = inst.argval
353
354        # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION
355        # (PRECALL for Python 3.11, CALL for Python 3.12+)
356        current_instruction_pointer += 1
357        inst = self.instructions[current_instruction_pointer]
358        if inst.opname not in ("CALL_FUNCTION", "PRECALL", "CALL"):
359            return False
360
361        # for Python 3.11, PRECALL should be followed by CALL, then RAISE_VARARGS
362        # for Python != 3.11, CALL_FUNCTION/CALL should be followed by RAISE_VARARGS
363        current_instruction_pointer += 1
364        if inst.opname == "PRECALL":
365            current_instruction_pointer += 1
366        inst = self.instructions[current_instruction_pointer]
367
368    if inst.opname != "RAISE_VARARGS":
369        return False
370
371    self.push(ConstantVariable.create(error_msg))
372
373    return True
374
375
376def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
377    def jump_graph_break(self, inst, value, extra_msg=""):
378        if not self.should_compile_partial_graph():
379            unimplemented("should_compile_partial_graph=False")
380        # compile a partial subgraph prefix then jump into user code
381        if self.maybe_has_backedge():
382            msg = (
383                "Skipping frame because there is a graph break in a for/while loop\n"
384                f"{self.frame_summary()}"
385            )
386            log.info(msg)
387            raise exc.SkipFrame(msg)
388
389        self.push(value)
390        log.debug("generic_jump triggered compile")
391        self.output.compile_subgraph(
392            self,
393            reason=GraphCompileReason(
394                f"generic_jump {typestr(value)}{extra_msg}", [self.frame_summary()]
395            ),
396        )
397        self.pop()
398
399        if_next = self.create_call_resume_at(self.next_instruction)
400        if push:
401            self.push(value)
402        if_jump = self.create_call_resume_at(inst.target)
403
404        if sys.version_info >= (3, 13):
405            # 3.13 requires stack[-1] to be bool type
406            self.output.add_output_instructions([create_instruction("TO_BOOL")])
407
408        self.output.add_output_instructions(
409            [create_instruction(inst.opname, target=if_jump[0])] + if_next + if_jump
410        )
411
412    def inner(self: "InstructionTranslatorBase", inst: Instruction):
413        value: VariableTracker = self.pop()
414        if (
415            config.rewrite_assert_with_torch_assert
416            and _detect_and_normalize_assert_statement(self, truth_fn, push)
417        ):
418            error_msg: VariableTracker = self.pop()
419            # Skip over things like `assert True`
420            if value.is_python_constant():
421                if bool(value.as_python_constant()):
422                    return self.jump(inst)
423                else:
424                    jump_graph_break(self, inst, value)
425
426            # TODO maybe should respect DtoH sync intention of users later??
427            # Manually insert torch._assert_async instead of python assert and jump over
428            # assert related instructions as we don't need them anymore.
429
430            # if we see Tensor as assert statement, no need to call scalar_tensor
431            if isinstance(value, TensorVariable):
432                self.output.create_proxy(
433                    "call_function",
434                    torch._assert_async,
435                    *proxy_args_kwargs((value, error_msg), {}),
436                )
437                self.jump(inst)
438                return
439
440            if isinstance(value, SymNodeVariable):
441                # if the assertion is normal shape expression.
442                # just install guard and bail out.
443                sym_expr = value.sym_num
444                if not isinstance(sym_expr, torch.SymBool):
445                    sym_expr = sym_expr != 0
446
447                result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr)
448                if not result:
449                    unimplemented(
450                        "Assertion failed on symbolic shapes. Did you make sure eager mode succeeds?"
451                    )
452                self.jump(inst)
453                return
454
455            scalar_to_tensor_proxy = self.output.create_proxy(
456                "call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {})
457            )
458
459            scalar_to_tensor = wrap_fx_proxy(
460                self,
461                scalar_to_tensor_proxy,
462                example_value=get_fake_value(scalar_to_tensor_proxy.node, self),
463            )
464
465            self.output.create_proxy(
466                "call_function",
467                torch._assert_async,
468                *proxy_args_kwargs((scalar_to_tensor, error_msg), {}),
469            )
470            self.jump(inst)
471            return
472
473        if value.is_python_constant():
474            if truth_fn(value.as_python_constant()):
475                if push:
476                    self.push(value)
477                self.jump(inst)
478        elif (
479            isinstance(value, (TensorVariable)) and self.should_compile_partial_graph()
480        ):
481            jump_graph_break(self, inst, value)
482        elif isinstance(value, NNModuleVariable):
483            # Equivalent of "self.nn_module is not None"
484            mod = self.output.get_submodule(value.module_key)
485            if truth_fn(mod):
486                if push:
487                    self.push(value)
488                self.jump(inst)
489        elif isinstance(value, UnspecializedNNModuleVariable):
490            mod = value.value
491            if truth_fn(mod):
492                if push:
493                    self.push(value)
494                self.jump(inst)
495        elif isinstance(value, UserDefinedObjectVariable):
496            try:
497                x = value.var_getattr(self, "__bool__")  # type: ignore[arg-type]
498            except exc.ObservedAttributeError:
499                exc.handle_observed_exception(self)
500                # if __bool__ is missing, trying __len__ to infer a truth value.
501                try:
502                    x = value.var_getattr(self, "__len__")  # type: ignore[arg-type]
503                except exc.ObservedAttributeError:
504                    exc.handle_observed_exception(self)
505                    x = None
506
507            # __bool__ or __len__ is function
508            if isinstance(x, UserMethodVariable):
509                result = x.call_function(self, [], {})  # type: ignore[arg-type]
510                if isinstance(result, ConstantVariable) and isinstance(
511                    result.value, (bool, int)
512                ):
513                    if truth_fn(result.value):
514                        if push:
515                            self.push(value)
516                        self.jump(inst)
517                else:
518                    unimplemented(
519                        "generic_jump on UserDefined with __bool__ returning non-constant"
520                    )
521            # __bool__ or __len__ is non-function or not existed in the user defined object
522            else:
523                if truth_fn(True):
524                    if push:
525                        self.push(value)
526                    self.jump(inst)
527        elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence(
528            self
529        ):
530            if truth_fn(len(value.unpack_var_sequence(self))):
531                if push:
532                    self.push(value)
533                self.jump(inst)
534        elif isinstance(value, SymNodeVariable):
535            try:
536                eval_result = value.evaluate_expr(self.output)
537            except exc.UserError as e:
538                if self.should_compile_partial_graph():
539                    return jump_graph_break(self, inst, value, extra_msg=f"\n{e}")
540                raise
541            if truth_fn(eval_result):
542                if push:
543                    self.push(value)
544                self.jump(inst)
545        elif isinstance(value, variables.BackwardHookVariable):
546            if truth_fn(True):
547                if push:
548                    self.push(value)
549                self.jump(inst)
550        else:
551            from .source import is_constant_source
552
553            if value.source is not None and is_constant_source(value.source):
554                if truth_fn(value.get_real_value()):  # type: ignore[attr-defined]
555                    if push:
556                        self.push(value)
557                    self.jump(inst)
558            else:
559                # TODO link the torch.cond doc later
560                raise exc.UserError(
561                    exc.UserErrorType.DYNAMIC_CONTROL_FLOW,
562                    "Dynamic control flow is not supported at the moment. Please use "
563                    "functorch.experimental.control_flow.cond to explicitly capture the control flow.",
564                    case_name="cond_operands",
565                )
566
567    return inner
568
569
570explain = False
571
572
573def break_graph_if_unsupported(*, push):
574    def decorator(inner_fn):
575        @functools.wraps(inner_fn)
576        def wrapper(self: "InstructionTranslatorBase", inst: Instruction):
577            speculation = self.speculate()
578            if speculation.failed:
579                assert speculation.reason is not None
580                return handle_graph_break(self, inst, speculation.reason)
581            try:
582                return inner_fn(self, inst)
583            except Unsupported as excp:
584                if self.generic_context_manager_depth > 0:
585                    # We don't support graph break under GenericContextWrappingVariable,
586                    # If there is, we roll back to the checkpoint and fall back.
587                    excp.remove_from_stats()
588                    unimplemented("Graph break under GenericContextWrappingVariable")
589
590                if isinstance(excp, exc.UncapturedHigherOrderOpError):
591                    raise
592
593                if not self.should_compile_partial_graph():
594                    raise
595
596                user_stack = excp.real_stack
597                # TODO: Also report the traceback from the parent frame
598                try:
599                    frame_loc = (user_stack[-1].filename, user_stack[-1].lineno)
600                except IndexError:
601                    # first instruction
602                    code_options = self.code_options
603                    frame_loc = (
604                        code_options["co_filename"],
605                        code_options["co_firstlineno"],
606                    )
607                # torch._dynamo.explain() formats this a little nicer, and presents a slightly
608                # more actionable user code pointer
609                if (
610                    graph_break_log.isEnabledFor(logging.DEBUG)
611                    and not explain
612                    and graph_break_dup_warning_checker.add(frame_loc)
613                ):
614                    user_stack_formatted = "".join(traceback.format_list(user_stack))
615                    # This log line is exercised from
616                    #   python test/dynamo/test_exc.py -k test_graph_break_log
617                    graph_break_log.debug(
618                        "Graph break: from user code at:\n%s",
619                        user_stack_formatted,
620                        exc_info=True,
621                    )
622                else:
623                    # This log line MUST NOT contain the string "Graph break",
624                    # exercised by
625                    #   python test/dynamo/test_misc.py -k test_duplicate_graph_break_log
626                    log.debug(
627                        "Unsupported break in user code at %s:%s (details suppressed)",
628                        *frame_loc,
629                    )
630
631                if self.maybe_has_backedge():
632                    msg = (
633                        "Skipping frame because there is a graph break in a for/while loop\n"
634                        f"{self.frame_summary()}"
635                    )
636                    log.info(msg)
637                    raise exc.SkipFrame(msg) from excp
638
639                excp.remove_from_stats()
640                excp.add_to_stats("graph_break")
641                speculation.reason = GraphCompileReason(excp.msg, user_stack)
642            speculation.fail_and_restart_analysis()
643
644        def handle_graph_break(
645            self: "InstructionTranslatorBase",
646            inst: Instruction,
647            reason: GraphCompileReason,
648        ):
649            self.output.compile_subgraph(self, reason=reason)
650            cg = PyCodegen(self)
651            cleanup: List[Instruction] = []
652            # Reconstruct the context variable CLASS in the block stack
653            for b in self.block_stack:
654                assert b.with_context is not None
655                assert isinstance(b.with_context, ContextWrappingVariable)
656                b.with_context.reconstruct_type(cg)
657                cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup))
658            self.output.add_output_instructions(cg.get_instructions())
659            del cg
660
661            if sys.version_info >= (3, 11) and inst.opname == "CALL":
662                kw_names = (
663                    self.kw_names.as_python_constant()
664                    if self.kw_names is not None
665                    else ()
666                )
667                if len(kw_names) > 0:
668                    # KW_NAMES no longer used in 3.13
669                    assert sys.version_info < (3, 13)
670                    self.output.add_output_instructions(
671                        [create_instruction("KW_NAMES", argval=kw_names)]
672                    )
673                self.output.add_output_instructions(
674                    create_call_function(inst.arg, False)
675                )
676            else:
677                # copy instruction, but without exception table data
678                assert inst.target is None
679                inst_copy = copy.copy(inst)
680                inst_copy.exn_tab_entry = None
681                self.output.add_output_instructions([inst_copy])
682
683            self.output.add_output_instructions(cleanup)
684
685            if (
686                sys.version_info >= (3, 11)
687                and sys.version_info < (3, 12)
688                and inst.opname == "CALL"
689            ):
690                # stack effect for PRECALL + CALL is split between the two instructions
691                stack_effect = dis.stack_effect(
692                    dis.opmap["PRECALL"], inst.arg
693                ) + dis.stack_effect(dis.opmap["CALL"], inst.arg)
694            else:
695                stack_effect = dis.stack_effect(inst.opcode, inst.arg)
696            self.popn(push - stack_effect)
697
698            for _ in range(push):
699                self.push(UnknownVariable())
700            self.output.add_output_instructions(
701                self.create_call_resume_at(self.next_instruction)
702            )
703
704        return wrapper
705
706    return decorator
707
708
709class BytecodeDistpatchTableMeta(type):
710    """Installs a `cls.dispatch_table` on every subclass to speed up calls to self.OPCODE()"""
711
712    def __init__(cls, name, bases, dct) -> None:
713        super().__init__(name, bases, dct)
714
715        def _missing(opname, *args):
716            unimplemented(f"missing: {opname}")
717
718        dispatch_table = {
719            op: getattr(cls, opname, functools.partial(_missing, opname))
720            for opname, op in dis.opmap.items()
721        }
722        cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)]
723
724
725class InstructionTranslatorBase(
726    metaclass=BytecodeDistpatchTableMeta,
727):
728    output: OutputGraph
729    symbolic_locals: Dict[str, VariableTracker]
730    symbolic_globals: Dict[str, VariableTracker]
731    symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"]
732    stack: List[VariableTracker]
733    instruction_pointer: Optional[int]
734    current_instruction: Instruction
735    block_stack: List[BlockStackEntry]
736    lineno: int
737    kw_names: Optional[ConstantVariable]
738    accept_prefix_inst: bool
739    prefix_insts: List[Instruction]
740    inline_depth: int
741    inconsistent_side_effects: bool
742    current_speculation: Optional[SpeculationEntry]
743    dispatch_table: List[Any]
744    exn_vt_stack: List[VariableTracker]
745    exec_recorder: Optional[ExecutionRecorder]
746    strict_checks_fn: Optional[Callable[[VariableTracker], bool]]
747
748    def mark_inconsistent_side_effects(self):
749        """
750        InstructionTranslator has encountered instructions which may cause
751        dynamo to see a different version of history from eager
752        See: https://github.com/pytorch/pytorch/issues/110765
753        """
754        self.inconsistent_side_effects = True
755
756    def maybe_has_backedge(self):
757        # This function employs a heuristic. It does not reliably detect a backedge.
758        # The heuristic is straightforward: starting from the current instruction and
759        # continuing to the end, if any jump instruction targets an instruction before
760        # the current one, there might be a backedge.
761
762        # Python 3.12 introduced changes to bytecode that group common paths in
763        # blockstacks (with or try...else) and allow for early returns. Consequently,
764        # there can be multiple RETURN_VALUE instructions. Another heuristic is to
765        # halt detection upon encountering the first RETURN_VALUE or RETURN_CONST.
766
767        # These heuristics can result in both false positives and negatives, but
768        # in either case, the Dynamo code remains valid. For false positives
769        # (where an edge is incorrectly marked as a backedge), Dynamo will
770        # perform a SkipFrame instead of potentially applying optimizations. For
771        # false negatives (where an edge that should be marked as a backedge
772        # isn't), multiple graphs may be generated if there's a break in the
773        # graph during a for loop. In general, its better to have fewer false
774        # negatives so that Dynamo does not skip the whole frame.
775
776        cur_offset = self.current_instruction.offset
777        assert self.instruction_pointer is not None
778        for inst in self.instructions[self.instruction_pointer :]:
779            if inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
780                return False
781            if inst.opname in JUMP_OPNAMES:
782                jump_offset = inst.argval
783                if jump_offset < cur_offset:
784                    return True
785        return False
786
787    def cell_and_freevars(self):
788        if not hasattr(self, "_cell_and_freevars"):
789            self._cell_and_freevars = tuple(
790                self.code_options["co_cellvars"] or []
791            ) + tuple(self.code_options["co_freevars"] or [])
792
793            # An inlined function might depend on the freevar of the parent
794            # function. So, recursively obtain parent cell and freevars.
795            if isinstance(self, InliningInstructionTranslator):
796                self._cell_and_freevars += self.parent.cell_and_freevars()
797        return self._cell_and_freevars
798
799    def prune_dead_locals(self):
800        reads = livevars_analysis(self.instructions, self.current_instruction)
801        # implicit use by super()
802        # reads = reads | {"__class__"}
803        # output variables?
804        reads = reads | set(self.cell_and_freevars())
805        self.symbolic_locals = {
806            k: v for k, v in self.symbolic_locals.items() if k in reads
807        }
808        self.output.side_effects.prune_dead_object_new(self)
809
810    def call_function(
811        self,
812        fn: VariableTracker,
813        args: List[VariableTracker],
814        kwargs: Dict[str, VariableTracker],
815    ):
816        assert isinstance(fn, VariableTracker)
817        assert isinstance(args, list)
818        assert isinstance(kwargs, dict)
819        assert all(
820            isinstance(x, VariableTracker)
821            for x in itertools.chain(args, kwargs.values())
822        )
823        inner_fn = None
824        if hasattr(fn, "value"):
825            inner_fn = fn.value
826        if hasattr(fn, "fn"):
827            inner_fn = fn.fn
828        if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
829            raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
830        self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
831
832    def inline_user_function_return(self, fn, args, kwargs):
833        """
834        A call to some user defined function by inlining it.
835        """
836        return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
837
838    def get_line_of_code_header(self, lineno=None):
839        if lineno is None:
840            lineno = self.lineno
841        inline_depth_str = (
842            f" (inline depth: {self.inline_depth})" if self.inline_depth > 0 else ""
843        )
844        funcname = get_funcname(self.f_code.co_filename, lineno)
845        funcname_str = "" if funcname is None else f" ({funcname})"
846        return f"{self.f_code.co_filename}:{lineno} in {self.f_code.co_name}{funcname_str}{inline_depth_str}"
847
848    def get_log_starts_line_log_str(self):
849        log_str = f"TRACE starts_line {self.get_line_of_code_header()}\n"
850        line = linecache.getline(self.f_code.co_filename, self.lineno).rstrip()
851        log_str += f"    {line}"
852        return log_str
853
854    def starts_line(self, lineno):
855        if self.lineno == lineno:
856            return
857        self.lineno = lineno
858        TracingContext.set_current_loc(
859            self.f_code.co_filename, lineno, self.f_code.co_name
860        )
861        from torch._logging.structured import dump_file
862
863        dump_file(self.f_code.co_filename)
864        if trace_source_log.isEnabledFor(logging.DEBUG):
865            trace_source_log.debug("%s", LazyString(self.get_log_starts_line_log_str))
866
867    def step(self):
868        """Process exactly one instruction, return False we should exit"""
869        ip = self.instruction_pointer
870        if ip is None:
871            return False
872        self.current_instruction = inst = self.instructions[ip]
873        self.instruction_pointer = ip + 1
874
875        if inst.starts_line:
876            self.starts_line(inst.starts_line)
877
878        if (
879            not self.stack
880            and self.should_compile_partial_graph()
881            and self.is_non_empty_graph()
882        ):
883            self.current_speculation = self.speculate()
884            if self.current_speculation.failed:
885                return self.step_graph_break(inst)
886
887        if trace_bytecode_log.isEnabledFor(logging.DEBUG):
888            trace_bytecode_log.debug(
889                "TRACE %s %s %s", inst.opname, inst.argval, self.stack
890            )
891
892        self.update_block_stack(inst)
893
894        try:
895            self.dispatch_table[inst.opcode](self, inst)
896            return not self.output.should_exit
897        except exc.ObservedException as e:
898            self.exception_handler(e)
899            return True
900        except ReturnValueOp:
901            return False
902        except Unsupported:
903            if self.current_speculation is None:
904                log.debug("empty checkpoint")
905                raise
906            log.debug("step triggered compile", exc_info=True)
907
908        self.current_speculation.fail_and_restart_analysis()
909
910    if sys.version_info >= (3, 11):
911
912        def update_block_stack(self, inst):
913            # 3.11+ no longer uses a block stack, but we still keep track of one
914            # so that we know which contexts are currently active.
915            # For our purposes, all exception table entries with the same target
916            # are considered to be part of the same "block".
917            # NOTE: we only keep track of with blocks that are not contained in try blocks.
918            # This is because we will not create continuation functions on graph breaks in try blocks,
919            # but we may for with blocks. We do not push blocks here since
920            # with blocks are pushed when handling BEFORE_WITH.
921            entry = inst.exn_tab_entry
922            if entry:
923                # Detect when we have exited the top with block.
924                # The with blocks on the block stack are not enclosed in try
925                # blocks, so a with block's cleanup code should be in the
926                # previous with block (if any).
927                if (
928                    len(self.block_stack) >= 2
929                    and entry.target is not self.block_stack[-1].target
930                    and entry.target is self.block_stack[-2].target
931                ):
932                    # exit the current block
933                    self.block_stack.pop()
934            else:
935                # no longer in any block
936                # It is possible for NOPs to be between two instructions
937                # in the same block, but the NOPs are not covered by an
938                # exception table entry. In this case, assume that we
939                # are still in the same block.
940                # In 3.12+, JUMP_BACKWARD might also not be covered by
941                # an exception table entry, so we also assume that we
942                # are still in the same block. It is probably safe to do
943                # this in 3.11, even though we haven't encountered this case before.
944                if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"):
945                    # If we really escape from a block and the current
946                    # instruction is not in another block, then there
947                    # should be no other nested blocks that we are in.
948                    assert len(self.block_stack) == 1
949                    self.block_stack.pop()
950
951    else:
952
953        def update_block_stack(self, inst):
954            pass
955
956    @property
957    def next_instruction(self):
958        return self.instructions[self.instruction_pointer]  # type: ignore[index]
959
960    def step_graph_break(self, continue_inst):
961        # generate code from checkpoint
962        assert not self.output.output_instructions
963        assert self.current_speculation is not None
964        self.output.compile_subgraph(
965            self,
966            partial_convert=True,
967            reason=GraphCompileReason("step_unsupported", [self.frame_summary()]),
968        )
969        self.output.add_output_instructions(
970            [create_jump_absolute(continue_inst)] + self.instructions
971        )
972
973    def run_ctx_mgr(self):
974        # NB: Don't push the top level frame summary; set_current_loc will
975        # take care of it.  However, DO make sure we attach real_stack to
976        # exceptions
977        return TracingContext.current_frame(None)
978
979    def run(self):
980        with self.run_ctx_mgr():
981            try:
982                self.output.push_tx(self)
983                while self.step():
984                    pass
985            except BackendCompilerFailed:
986                raise
987            except Exception as e:
988                if self.exec_recorder:
989                    e.exec_record = self.exec_recorder.get_record()  # type: ignore[attr-defined]
990                raise
991            finally:
992                self.output.pop_tx()
993                # Cleanup the outputGraph to delete the held tensors. We perform the
994                # cleanup only for InstructionTranslator and not
995                # InliningInstructionTranslator. The InliningInstructionTranslator
996                # mutates the output object and is restored to original state if
997                # there was an exception.
998                if isinstance(self, InstructionTranslator):
999                    self.output.cleanup()
1000
1001    def push(self, val: Optional[VariableTracker], name: Any = None):
1002        assert val is None or isinstance(
1003            val, VariableTracker
1004        ), f"push expects VariableTracker, got {typestr(val)}"
1005        self.stack.append(val)  # type: ignore[arg-type]
1006        if sys.version_info >= (3, 13):
1007            self.name_stack.append(name)
1008            assert len(self.stack) == len(self.name_stack)
1009
1010    def push_many(self, vals: List[VariableTracker]):
1011        for val in vals:
1012            self.push(val)
1013
1014    def pop(self) -> VariableTracker:
1015        if sys.version_info >= (3, 13):
1016            assert len(self.stack) == len(self.name_stack)
1017            self.name_stack.pop()
1018        return self.stack.pop()
1019
1020    def popn(self, n: int) -> List[VariableTracker]:
1021        return [*reversed([self.pop() for _ in range(n)])]
1022
1023    def _load_closure(self, name):
1024        return ClosureVariable(name=name)
1025
1026    def _load_fast(self, name):
1027        if self.exec_recorder and name in self.f_locals:
1028            self.exec_recorder.add_local_var(name, self.f_locals[name])
1029
1030        try:
1031            self.push(self.symbolic_locals[name].unwrap(), name=name)
1032        except KeyError:
1033            if sys.version_info >= (3, 13) and name in self.cell_and_freevars():
1034                # 3.13 merged LOAD_CLOSURE into LOAD_FAST
1035                # If we fail to LOAD_FAST, then we probably should have done LOAD_CLOSURE.
1036                # Closure variable creation is actually done in SET_FUNCTION_ATTRIBUTE,
1037                # but we'll do it again here so that we don't need to push a dummy variable.
1038                # We shouldn't actually be doing anything with this variable anyway.
1039                self.push(self._load_closure(name), name=name)
1040            elif name.startswith("."):
1041                try:
1042                    # This happens in dict/list comprehensions
1043                    new_name = name.replace(".", "implicit")
1044                    self.push(self.symbolic_locals[new_name], name=new_name)
1045                except KeyError:
1046                    unimplemented("undefined LOAD_FAST (implicit)")
1047            else:
1048                unimplemented("undefined LOAD_FAST")
1049
1050        # for continuation functions
1051        if name.startswith("___stack"):
1052            self.symbolic_locals.pop(name)
1053
1054    def LOAD_FAST(self, inst):
1055        self._load_fast(inst.argval)
1056
1057    def LOAD_DEREF(self, inst):
1058        assert inst.argval in self.cell_and_freevars()
1059
1060        if self.exec_recorder and inst.argval in self.f_locals:
1061            self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval])
1062
1063        if inst.argval not in self.symbolic_locals:
1064            unimplemented(f"undefined LOAD_DEREF {inst.argval}")
1065        self.push(self.symbolic_locals[inst.argval])
1066
1067    def _store_fast(self, name):
1068        loaded_vt = self.pop()
1069        loaded_vt.set_name_hint(name)
1070        self.symbolic_locals[name] = loaded_vt
1071
1072    def STORE_FAST(self, inst):
1073        self._store_fast(inst.argval)
1074
1075    def DELETE_FAST(self, inst):
1076        del self.symbolic_locals[inst.argval]
1077
1078    STORE_DEREF = STORE_FAST
1079
1080    def LOAD_CLOSURE(self, inst):
1081        self.push(self._load_closure(inst.argval))
1082
1083    def _load_const(self, inst):
1084        i = inst.arg
1085        if i is None:
1086            return ConstantVariable.create(value=inst.argval)
1087        val = self._constants_cache[i]
1088        if not val:
1089            self._constants_cache[i] = val = ConstantVariable.create(value=inst.argval)
1090        return val
1091
1092    def LOAD_CONST(self, inst):
1093        self.push(self._load_const(inst))
1094
1095    def _load_global(self, inst):
1096        name = inst.argval
1097
1098        if self.exec_recorder:
1099            if name in self.f_globals:
1100                self.exec_recorder.add_global_var(name, self.f_globals[name])
1101            else:
1102                assert name in self.f_builtins
1103                self.exec_recorder.builtins[name] = self.f_builtins[name]
1104
1105        if name in self.symbolic_globals:
1106            variable = self.output.side_effects[self.symbolic_globals[name]]
1107            self.push(self.output.side_effects.load_global(variable, name))
1108            return
1109
1110        try:
1111            value = self.f_globals[name]
1112        except KeyError:
1113            return self.load_builtin(inst)
1114
1115        source = GlobalSource(name)
1116        self.push(VariableBuilder(self, source)(value))
1117
1118    @functools.cached_property
1119    def nn_modules_globals_vt(self):
1120        module_name = "torch.nn.modules.module"
1121        module_source = self.import_source(module_name)
1122        fglobals_value = importlib.import_module(module_name)  # type: ignore[assignment]
1123        return VariableBuilder(self, module_source)(fglobals_value)
1124
1125    def LOAD_GLOBAL(self, inst):
1126        if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2:
1127            self.PUSH_NULL(inst)
1128        self._load_global(inst)
1129        if sys.version_info >= (3, 13) and inst.arg % 2:
1130            self.PUSH_NULL(inst)
1131
1132    def STORE_GLOBAL(self, inst):
1133        value = self.pop()
1134        name = inst.argval
1135        source = GlobalSource(name)
1136        if name not in self.symbolic_globals:
1137            self.symbolic_globals[name] = object()  # type: ignore[assignment]  # sentinel object
1138        variable = self.output.side_effects.track_global_existing(
1139            source, self.symbolic_globals[name]
1140        )
1141        if isinstance(value, RemovableHandleVariable):
1142            unimplemented("Storing handles in globals - NYI")
1143        self.output.side_effects.store_global(variable, name, value)
1144
1145    def import_source(self, module_name):
1146        """Create an alias to a module for use in guards"""
1147        if "torch_package" in module_name:
1148            value = torch.package.package_importer._package_imported_modules[
1149                module_name
1150            ]
1151            alias = (
1152                module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_")
1153            )
1154        else:
1155            value = importlib.import_module(module_name)
1156            alias = f"__import_{module_name.replace('.', '_dot_')}"
1157        f_globals = self.output.global_scope
1158        assert alias not in f_globals or f_globals[alias] is value
1159        f_globals[alias] = value
1160        self.output.update_co_names(alias)
1161        return GlobalSource(alias)
1162
1163    def resolve_name(self, name, package, level):
1164        """
1165        Copied from the Cpython implementation of __import__
1166        Resolve a relative module name to an absolute one.
1167        https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902
1168        """
1169        bits = package.rsplit(".", level - 1)
1170        if len(bits) < level:
1171            raise ImportError("attempted relative import beyond top-level package")
1172        base = bits[0]
1173        return f"{base}.{name}" if name else base
1174
1175    def calc_package(self):
1176        """
1177        Copied from the Cpython implementation of __import__
1178        https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090
1179        """
1180        package = self.f_globals.get("__package__")
1181        spec = self.f_globals.get("__spec__")
1182        if package is not None:
1183            if spec is not None and package != spec.parent:
1184                log.warning(
1185                    "__package__ != __spec__.parent (%r != %r)",
1186                    package,
1187                    spec.parent,
1188                    stacklevel=3,
1189                )
1190            return package
1191        elif spec is not None:
1192            return spec.parent
1193        else:
1194            log.warning(
1195                "can't resolve package from __spec__ or __package__, "
1196                "falling back on __name__ and __path__",
1197                stacklevel=3,
1198            )
1199            package = self.f_globals["__name__"]
1200            if "__path__" not in self.f_globals:
1201                package = package.rpartition(".")[0]
1202        return package
1203
1204    def IMPORT_NAME(self, inst):
1205        level, fromlist = self.popn(2)
1206        level = level.as_python_constant()
1207        fromlist = fromlist.as_python_constant()
1208        module_name = inst.argval
1209
1210        # Are we replaying? if so, load recorded module
1211        recorded_name = (
1212            f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}"
1213        )
1214        if recorded_name in self.f_globals:
1215            value = self.f_globals[recorded_name]
1216            source = GlobalSource(recorded_name)
1217        else:
1218            try:
1219                value = __import__(
1220                    module_name,
1221                    fromlist=fromlist,
1222                    level=level,
1223                    globals=self.f_globals,
1224                )
1225            except ImportError:
1226                unimplemented("import a module that does not exist")
1227
1228            if level != 0:
1229                pkg = self.calc_package()
1230                module_name = self.resolve_name(module_name, pkg, level)
1231
1232            # For __import__, when the name variable is of the form package.module,
1233            # normally, the top-level package (the name up till the first dot) is
1234            # returned, not the module named by module_name. However, when a
1235            # non-empty fromlist argument is given, the module named by name is
1236            # returned. Therefore, we set the source correctly here.
1237            if not fromlist:
1238                top_level_module_name = module_name.partition(".")[0]
1239                source = self.import_source(top_level_module_name)
1240            else:
1241                source = self.import_source(module_name)
1242
1243        if self.exec_recorder:
1244            self.exec_recorder.add_local_mod(recorded_name, value)
1245
1246        if istype(value, (types.ModuleType, DummyModule)):
1247            self.push(PythonModuleVariable(value, source=source))
1248        else:
1249            unimplemented(f"IMPORT_NAME {typestr(value)}")
1250
1251    def IMPORT_FROM(self, inst):
1252        self.DUP_TOP(inst)
1253        self._load_attr(inst)
1254
1255    def load_builtin_from_argval(self, argval):
1256        if argval not in self.f_builtins:
1257            raise NameError(f"name '{argval}' is not defined")
1258        val = self.f_builtins[argval]
1259
1260        if callable(val):
1261            builtins_source = GlobalSource(
1262                self.output.name_of_builtins_dict_key_in_fglobals
1263            )
1264            var_source = GetItemSource(builtins_source, argval)
1265            self.push(VariableBuilder(self, var_source)(val))
1266        else:
1267            assert is_builtin_constant(val)
1268            self.push(ConstantVariable.create(value=val))
1269
1270    def load_builtin(self, inst):
1271        self.load_builtin_from_argval(inst.argval)
1272
1273    def jump(self, inst):
1274        self.instruction_pointer = self.indexof[inst.target]
1275
1276    JUMP_FORWARD = jump
1277    JUMP_ABSOLUTE = jump
1278
1279    POP_JUMP_IF_FALSE = generic_jump(operator.not_, False)
1280    POP_JUMP_IF_TRUE = generic_jump(operator.truth, False)
1281    JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True)
1282    JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True)
1283
1284    def SETUP_LOOP(self, inst):
1285        # only exists in python<=3.7
1286        self.block_stack.append(BlockStackEntry(inst, inst.target))
1287
1288    def SETUP_EXCEPT(self, inst):
1289        # only exists in python<=3.7
1290        self.block_stack.append(BlockStackEntry(inst, inst.target))
1291
1292    def POP_BLOCK(self, inst):
1293        self.block_stack.pop()
1294
1295    def SETUP_WITH(self, inst):
1296        self.setup_or_before_with(inst)
1297
1298    def SETUP_FINALLY(self, inst):
1299        self.block_stack.append(BlockStackEntry(inst, inst.target))
1300
1301    def BEGIN_FINALLY(self, inst):
1302        self.push(None)
1303
1304    def WITH_CLEANUP_START(self, inst):
1305        exit, exc = self.popn(2)
1306        assert exc is None
1307        self.push(exc)
1308        self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {}))
1309
1310    def WITH_CLEANUP_FINISH(self, inst):
1311        self.popn(2)
1312        self.push(None)
1313
1314    def CALL_FINALLY(self, inst):
1315        """
1316        pushes the address of the next instruction onto the stack and increments
1317        bytecode counter by delta
1318        """
1319        # Python 3.8 only
1320        addr = self.indexof[self.next_instruction]
1321        self.push(ConstantVariable.create(addr))
1322        self.jump(inst)
1323
1324    def END_FINALLY(self, inst):
1325        # Python 3.8 only
1326        # https://docs.python.org/3.8/library/dis.html#opcode-END_FINALLY
1327        tos = self.pop()
1328        if isinstance(tos, ConstantVariable):
1329            self.instruction_pointer = tos.as_python_constant()
1330        else:
1331            pass
1332
1333    def POP_FINALLY(self, inst):
1334        # Python 3.8 only
1335        preserve_tos = inst.argval
1336        if preserve_tos:
1337            tos = self.pop()
1338        _ = self.pop()
1339        if preserve_tos:
1340            self.push(tos)  # type: ignore[possibly-undefined]
1341
1342    def FOR_ITER(self, inst):
1343        it = self.pop().realize()
1344        try:
1345            val = it.next_variable(self)
1346            self.push(it)
1347            self.push(val)
1348        except (StopIteration, exc.ObservedUserStopIteration) as e:
1349            if isinstance(e, exc.ObservedUserStopIteration):
1350                exc.handle_observed_exception(self)
1351
1352            # leave iterator upon exhaustion in 3.12
1353            if sys.version_info >= (3, 12):
1354                # CPython 3.12 actually jumps to the instruction after the END_FOR
1355                # and performs the action of END_FOR as part of FOR_ITER. We jump
1356                # to the END_FOR and run it, so we need to make sure 2 values are
1357                # on the stack for it to pop.
1358                self.push(it)
1359                self.push(ConstantVariable.create(None))
1360            self.jump(inst)
1361
1362    def _raise_exception_variable(self, inst):
1363        val = self.pop()
1364        # User can raise exception in 2 ways
1365        #   1) raise exception type - raise NotImplementedError
1366        #   2) raise execption instance - raise NotImplemetedError("foo")
1367
1368        # 1) when user raises exception type
1369        if isinstance(val, variables.BuiltinVariable):
1370            # Create the instance of the exception type
1371            # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549
1372            val = val.call_function(self, [], {})  # type: ignore[arg-type]
1373
1374        # Save the exception in a global data structure
1375        self.exn_vt_stack.append(val)
1376
1377        # 2) when user raises exception instance
1378        if isinstance(val, variables.ExceptionVariable):
1379            if observed_exception_type := exc.observed_exception_map.get(val.exc_type):
1380                raise observed_exception_type(f"raised exception {val}")
1381            raise exc.ObservedException(f"raised exception {val}")
1382        unimplemented(f"raise {exc}")
1383
1384    def RAISE_VARARGS(self, inst):
1385        if inst.arg == 0:
1386            unimplemented("re-raise")
1387        elif inst.arg == 1:
1388            self._raise_exception_variable(inst)
1389        else:
1390            # Support raise .. from None ... Dynamo does not track __cause__ and other attributes of exception. So we
1391            # ignore `from None` part.
1392            from_vt = self.pop()
1393            if isinstance(from_vt, ConstantVariable) and from_vt.value is None:
1394                self._raise_exception_variable(inst)
1395            unimplemented("raise ... from ...")
1396
1397    def RERAISE(self, inst):
1398        if sys.version_info >= (3, 11):
1399            # RERAISE is currently supported in a narrow case of `raise ... from None`
1400            self._raise_exception_variable(inst)
1401        unimplemented("RERAISE")
1402
1403    def exception_handler(self, raised_exception):
1404        if sys.version_info >= (3, 11):
1405            exn_tab_entry = self.current_instruction.exn_tab_entry
1406            if exn_tab_entry:
1407                # Implementation is based on https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt
1408
1409                # 1) pop values from the stack until it matches the stack depth
1410                # for the handler
1411                while len(self.stack) > exn_tab_entry.depth:
1412                    self.pop()
1413
1414                # 2) if 'lasti' is true, then push the offset that the exception was raised at
1415                if exn_tab_entry.lasti:
1416                    self.push(
1417                        variables.ConstantVariable(self.current_instruction.offset)
1418                    )
1419
1420                # 3) push the exception to the stack
1421                assert len(self.exn_vt_stack)
1422                self.push(self.exn_vt_stack[-1])
1423
1424                # 4) jump to the handler
1425                self.jump(exn_tab_entry)
1426            else:
1427                # No handler found. Bubble the exception to the parent
1428                # instruction translater. We use special exception for this.
1429                self.stack.clear()
1430                if type(self) is InstructionTranslator:
1431                    raise Unsupported("Observed exception")
1432                raise raised_exception
1433        else:
1434            if len(self.block_stack):
1435                # base implementation - https://github.com/python/cpython/blob/3.10/Python/ceval.c#L4455
1436
1437                assert len(self.exn_vt_stack)
1438                exception_var = self.exn_vt_stack[-1]
1439
1440                block_stack_entry = self.block_stack.pop()
1441
1442                while block_stack_entry.inst.opname == "EXCEPT_HANDLER":
1443                    # TODO(anijain2305) - This is not tested .. unable to create a testcase
1444                    # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456
1445                    self.popn(3)
1446                    if len(self.block_stack) == 0:
1447                        # No handler found in this frame. Bubble the exception to the parent
1448                        # instruction translater.
1449                        self.stack.clear()
1450                        if type(self) is InstructionTranslator:
1451                            raise Unsupported("Observed exception")
1452                        raise raised_exception
1453                    block_stack_entry = self.block_stack.pop()
1454
1455                if block_stack_entry.inst.opname != "SETUP_FINALLY":
1456                    unimplemented(
1457                        "exception is raised when top of the block stack "
1458                        "is not exception handler (e.g. try .. with .. except). "
1459                        f"Current TOS is {block_stack_entry.inst}"
1460                    )
1461
1462                # Push a dummy block stack entry of EXCEPT_HANDLER
1463                # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456
1464                except_handler_inst = Instruction(1e6, "EXCEPT_HANDLER", None, 0)
1465                self.block_stack.append(BlockStackEntry(except_handler_inst, None))
1466
1467                # Push old exception
1468                if len(self.exn_vt_stack) >= 2:
1469                    old_exception = self.exn_vt_stack[-2]
1470
1471                    # Push the old exception on to stack - tb, value, type
1472                    # Traceback is currently mapped to UnknownVariable
1473                    self.push(variables.UnknownVariable())
1474                    self.push(old_exception)
1475                    self.push(variables.BuiltinVariable(old_exception.exc_type))
1476                else:
1477                    # Push empty exception tb, value, type
1478                    self.push(variables.ConstantVariable(None))
1479                    self.push(variables.ConstantVariable(None))
1480                    self.push(variables.ConstantVariable(None))
1481
1482                # Push new exception - tb, val, type
1483                # Traceback is currently mapped to UnknownVariable
1484                self.push(variables.UnknownVariable())
1485                self.push(exception_var)
1486                self.push(variables.BuiltinVariable(exception_var.exc_type))
1487
1488                # Jump to target
1489                self.jump(block_stack_entry)
1490            else:
1491                # No handler found. Bubble the exception to the parent
1492                # instruction translater. We use special exception for this.
1493                self.stack.clear()
1494                if type(self) is InstructionTranslator:
1495                    raise Unsupported("Observed exception")
1496                raise raised_exception
1497
1498    def PUSH_EXC_INFO(self, inst):
1499        val = self.pop()
1500        assert len(self.exn_vt_stack)
1501        self.push(self.exn_vt_stack[-1])
1502        self.push(val)
1503
1504    def POP_EXCEPT(self, inst):
1505        if sys.version_info >= (3, 11):
1506            val = self.pop()
1507            assert isinstance(val, variables.ExceptionVariable)
1508
1509            # This exception is handled and therefore we can clear the error indicator
1510            assert len(self.exn_vt_stack)
1511            self.exn_vt_stack.pop()
1512        else:
1513            assert len(self.block_stack) > 0
1514            if self.block_stack[-1].inst.opname != "EXCEPT_HANDLER":
1515                raise AssertionError(
1516                    "Bug in Dynamo tracing of exception handling."
1517                    "Top of the block stack is not EXCEPT_HANDLER."
1518                )
1519            self.block_stack.pop()
1520
1521            self.popn(3)
1522
1523            # This exception is handled and therefore we can clear the error indicator
1524            assert len(self.exn_vt_stack)
1525            self.exn_vt_stack.pop()
1526
1527    def check_if_exc_matches(self):
1528        assert len(self.stack) >= 2
1529        expected_exc_types = self.pop()
1530        if sys.version_info >= (3, 11):
1531            # CHECK_EXC_MATCH (which is used from 3.11 onwards) does not pop.
1532            # This is the description from the disassembly doc
1533            #
1534            # Performs exception matching for ``except``. Tests whether the ``STACK[-2]``
1535            # is an exception matching ``STACK[-1]``. Pops ``STACK[-1]`` and pushes the boolean
1536            # result of the test.
1537            exc_instance = self.stack[-1]
1538        else:
1539            # This is used prior to 3.11 via opcode JUMP_IF_NOT_EXC_MATCH
1540            # There is no documentation but here is the code pointer that does 2 pops
1541            # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L3650-L3665
1542            exc_instance = self.stack.pop()
1543
1544        # Users can check exception in 2 ways
1545        # 1) except NotImplementedError --> BuilinVariable
1546        # 2) except (NotImplemetedError, AttributeError) -> TupleVariable
1547
1548        if not isinstance(expected_exc_types, (BuiltinVariable, TupleVariable)):
1549            unimplemented(
1550                f"except has an unsupported types of objects {expected_exc_types}"
1551            )
1552
1553        if sys.version_info >= (3, 11):
1554            if not isinstance(exc_instance, variables.ExceptionVariable):
1555                unimplemented(
1556                    f"except expects to recieve an object of exception type but received {exc_instance}"
1557                )
1558
1559        if isinstance(expected_exc_types, TupleVariable):
1560            expected_types = expected_exc_types.items
1561        else:
1562            expected_types = [
1563                expected_exc_types,
1564            ]
1565
1566        for expected_type in expected_types:
1567            if not isinstance(expected_type, BuiltinVariable):
1568                unimplemented(
1569                    f"except has an unsupported types of object {expected_type}"
1570                )
1571            if isinstance(exc_instance, variables.ExceptionVariable) and issubclass(
1572                exc_instance.exc_type, expected_type.fn
1573            ):
1574                return True
1575            elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass(
1576                exc_instance.fn, expected_type.fn
1577            ):
1578                return True
1579
1580        return False
1581
1582    def CHECK_EXC_MATCH(self, inst):
1583        self.push(variables.ConstantVariable(self.check_if_exc_matches()))
1584
1585    def JUMP_IF_NOT_EXC_MATCH(self, inst):
1586        if not self.check_if_exc_matches():
1587            self.jump(inst)
1588
1589    def COMPARE_OP(self, inst):
1590        if inst.argval == "exception match":
1591            self.CHECK_EXC_MATCH(inst)
1592        else:
1593            self.push(compare_op_handlers[inst.argval](self, self.popn(2), {}))
1594
1595    def GET_ITER(self, inst):
1596        self.call_function(BuiltinVariable(iter), [self.pop()], {})
1597
1598    @break_graph_if_unsupported(push=1)
1599    def CALL_FUNCTION(self, inst):
1600        args = self.popn(inst.argval)
1601        fn = self.pop()
1602        self.call_function(fn, args, {})
1603
1604    @break_graph_if_unsupported(push=1)
1605    def CALL_FUNCTION_EX(self, inst):
1606        kwargsvars: VariableTracker
1607        if inst.argval == 0:
1608            kwargsvars = ConstDictVariable({})
1609            argsvars = self.pop()
1610        elif inst.argval == 1:
1611            kwargsvars = self.pop()
1612            argsvars = self.pop()
1613        else:
1614            unimplemented("CALL_FUNCTION_EX")
1615
1616        if sys.version_info >= (3, 13):
1617            # 3.13 swapped null and callable
1618            null = self.pop()
1619            assert isinstance(null, NullVariable)
1620
1621        fn = self.pop()
1622
1623        if sys.version_info >= (3, 11) and sys.version_info < (3, 13):
1624            null = self.pop()
1625            assert isinstance(null, NullVariable)
1626
1627        if isinstance(fn, GetAttrVariable) and isinstance(fn.obj, TensorVariable):
1628            # realize is requires for Python 3.8
1629            kwargsvars = kwargsvars.realize()
1630            if fn.name == "view" and isinstance(
1631                argsvars, (ConstantVariable, TensorVariable)
1632            ):
1633                # Hack to handle special case in some bert models.  Converts
1634                # x.view(*shape) into x.view(shape), which is correct for view()
1635                # but not generally.  See test_transpose_for_scores().
1636                argsvars = TupleVariable([argsvars])
1637            elif (
1638                fn.name == "random_"
1639                and isinstance(argsvars, TupleVariable)
1640                and len(argsvars.items) == 0
1641                and isinstance(kwargsvars, ConstDictVariable)
1642                and ConstantVariable.create("from") in kwargsvars
1643            ):
1644                # `from`` is python keyword. Adding random_ with `from` in the
1645                # Fx graph causes syntax error. Even if we convert the kwargs to
1646                # args, aot_autograd/inductor while lowering generates
1647                # aten.random.from, again causing syntax errors. Since this
1648                # usecase is uncommon, graph break.
1649                unimplemented("random_ op is called with from keyword")
1650            elif (
1651                fn.name == "uniform_"
1652                and isinstance(argsvars, TupleVariable)
1653                and len(argsvars.items) == 0
1654                and isinstance(kwargsvars, ConstDictVariable)
1655                and ConstantVariable.create("from") in kwargsvars
1656            ):
1657                # `from`` is python keyword. Adding uniform_ with `from` in the
1658                # Fx graph causes syntax error. Even if we convert the kwargs to
1659                # args, aot_autograd/inductor while lowering generates
1660                # aten.uniform.from, again causing syntax errors. Since this
1661                # usecase is uncommon, graph break.
1662                unimplemented("uniform_ op is called with from keyword")
1663
1664        if not isinstance(
1665            argsvars, BaseListVariable
1666        ) and argsvars.has_force_unpack_var_sequence(self):
1667            argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
1668
1669        # Unpack for cases like fn(**obj) where obj is a map
1670        if isinstance(kwargsvars, UserDefinedObjectVariable):
1671            kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars)  # type: ignore[arg-type]
1672
1673        if not isinstance(argsvars, BaseListVariable) or not isinstance(
1674            kwargsvars, ConstDictVariable
1675        ):
1676            unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}")
1677
1678        # Map to a dictionary of str -> VariableTracker
1679        kwargsvars = kwargsvars.keys_as_python_constant()
1680        self.call_function(fn, argsvars.items, kwargsvars)
1681
1682    @break_graph_if_unsupported(push=1)
1683    def CALL_FUNCTION_KW(self, inst):
1684        argnames = self.pop()
1685        args = self.popn(inst.argval)
1686        fn = self.pop()
1687        assert isinstance(argnames, TupleVariable) and argnames.is_python_constant()
1688        argnames = argnames.as_python_constant()
1689        args, kwargs_list = args[: -len(argnames)], args[-len(argnames) :]
1690        kwargs = dict(zip(argnames, kwargs_list))
1691        assert len(kwargs) == len(argnames)
1692        self.call_function(fn, args, kwargs)
1693
1694    def LOAD_METHOD_SUPER(self, inst):
1695        self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
1696        arg = inst.argval[0]
1697        argval = self.code_options["co_names"][arg]
1698        if sys.version_info < (3, 11):
1699            self._load_attr(dataclasses.replace(inst, argval=argval))
1700        else:
1701            self.LOAD_METHOD(dataclasses.replace(inst, argval=argval))
1702
1703    def LOAD_ATTR_SUPER(self, inst):
1704        self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
1705        arg = inst.argval[0]
1706        argval = self.code_options["co_names"][arg]
1707        self._load_attr(dataclasses.replace(inst, argval=argval))
1708
1709    def LOAD_METHOD(self, inst):
1710        self._load_attr(inst)
1711        obj = self.pop()
1712        if sys.version_info >= (3, 13):
1713            self.push(obj)
1714            self.PUSH_NULL(inst)
1715        elif sys.version_info >= (3, 11):
1716            # always follow the NULL + fn convention, since if obj
1717            # is actually a method, self is already bound to it, so it
1718            # doesn't need to be passed in as an arg.
1719            self.PUSH_NULL(inst)
1720            self.push(obj)
1721        else:
1722            self.push(obj)
1723            self.push(None)
1724
1725    def CALL_METHOD(self, inst):
1726        args = self.popn(inst.argval)
1727        dummy = self.pop()
1728        assert dummy is None
1729        fn = self.pop()
1730        self.call_function(fn, args, {})
1731
1732    def _load_attr(self, inst):
1733        obj = self.pop()
1734        result = BuiltinVariable(getattr).call_function(
1735            self, [obj, ConstantVariable.create(inst.argval)], {}  # type: ignore[arg-type]
1736        )
1737        self.push(result)
1738
1739    def LOAD_ATTR(self, inst):
1740        if sys.version_info >= (3, 12):
1741            if inst.arg % 2:
1742                self.LOAD_METHOD(inst)
1743                return
1744        self._load_attr(inst)
1745
1746    def STORE_ATTR(self, inst):
1747        speculation = self.speculate()
1748        if speculation.failed:
1749            return self.store_attr_graph_break(inst)
1750        val, obj = self.popn(2)
1751
1752        if isinstance(obj, NNModuleVariable) and not isinstance(val, ConstantVariable):
1753            # We don't allow side effects during export on non-constant values
1754            # https://github.com/pytorch/torchdynamo/issues/1475
1755            assert (
1756                not self.export
1757            ), f"Mutating module attribute {inst.argval} during export."
1758
1759        try:
1760            BuiltinVariable(setattr).call_function(
1761                self, [obj, ConstantVariable.create(inst.argval), val], {}  # type: ignore[arg-type]
1762            )
1763            return
1764        except Unsupported as e:
1765            if not self.should_compile_partial_graph():
1766                raise
1767            log.debug("STORE_ATTR triggered compile", exc_info=True)
1768            e.remove_from_stats()
1769            e.add_to_stats("graph_break")
1770        speculation.fail_and_restart_analysis()
1771
1772    def store_attr_graph_break(self, inst):
1773        if not self.should_compile_partial_graph():
1774            unimplemented("should_compile_partial_graph=False")
1775        self.output.compile_subgraph(
1776            self, reason=GraphCompileReason("store_attr", [self.frame_summary()])
1777        )
1778        self.output.add_output_instructions([copy.copy(inst)])
1779        self.popn(2)
1780        self.output.add_output_instructions(
1781            self.create_call_resume_at(self.next_instruction)
1782        )
1783
1784    def DELETE_ATTR(self, inst):
1785        obj = self.pop()
1786        BuiltinVariable(delattr).call_function(
1787            self, [obj, ConstantVariable.create(inst.argval)], {}  # type: ignore[arg-type]
1788        )
1789
1790    def create_call_resume_at(self, offset):
1791        raise AssertionError(
1792            f"create_call_resume_at not overridden by subclass {type(self)}"
1793        )
1794
1795    def should_compile_partial_graph(self) -> bool:
1796        raise AssertionError(
1797            f"should_compile_partial_graph not overridden by subclass {type(self)}"
1798        )
1799
1800    @break_graph_if_unsupported(push=0)
1801    def STORE_SUBSCR(self, inst):
1802        val, obj, key = self.popn(3)
1803        result = obj.call_method(self, "__setitem__", [key, val], {})
1804
1805    def DELETE_SUBSCR(self, inst):
1806        obj, key = self.popn(2)
1807        obj.call_method(self, "__delitem__", [key], {})
1808
1809    def BUILD_TUPLE(self, inst):
1810        name_tuple = None
1811        if sys.version_info >= (3, 13):
1812            name_tuple = tuple(self.name_stack[-inst.argval :])
1813        items = self.popn(inst.argval)
1814        self.push(TupleVariable(items), name=name_tuple)
1815
1816    def BUILD_SLICE(self, inst):
1817        items = self.popn(inst.argval)
1818        self.push(SliceVariable(items))
1819
1820    def BUILD_LIST(self, inst):
1821        items = self.popn(inst.argval)
1822        self.push(ListVariable(items, mutable_local=MutableLocal()))
1823
1824    def BUILD_SET(self, inst):
1825        if config.inject_BUILD_SET_unimplemented_TESTING_ONLY:
1826            unimplemented("missing: BUILD_SET")
1827        items = self.popn(inst.argval)
1828        new_set = SetVariable(items, mutable_local=MutableLocal())
1829        self.push(new_set)
1830
1831    def BUILD_LIST_UNPACK(self, inst, cls=ListVariable):
1832        seqs = self.popn(inst.argval)
1833        items = []
1834        for seq in seqs:
1835            try:
1836                items.extend(seq.force_unpack_var_sequence(self))
1837            except NotImplementedError:
1838                unimplemented(f"BUILD_LIST_UNPACK {seq}")
1839        self.push(cls(items, mutable_local=MutableLocal()))
1840
1841    def BUILD_TUPLE_UNPACK(self, inst):
1842        self.BUILD_LIST_UNPACK(inst, cls=TupleVariable)
1843
1844    BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK
1845
1846    def BUILD_MAP(self, inst):
1847        items = self.popn(inst.argval * 2)
1848        d = dict(zip(items[::2], items[1::2]))
1849        self.push(ConstDictVariable(d, mutable_local=MutableLocal()))
1850
1851    def BUILD_MAP_UNPACK(self, inst):
1852        items = self.popn(inst.argval)
1853        # ensure everything is a dict
1854        items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items]  # type: ignore[arg-type]
1855        result = {}
1856        for x in items:
1857            assert isinstance(x, ConstDictVariable)
1858            result.update(x.items)
1859        self.push(
1860            ConstDictVariable(
1861                result,
1862                mutable_local=MutableLocal(),
1863            )
1864        )
1865
1866    BUILD_MAP_UNPACK_WITH_CALL = BUILD_MAP_UNPACK
1867
1868    def BUILD_CONST_KEY_MAP(self, inst):
1869        keys = self.pop()
1870        values = self.popn(inst.argval)
1871        assert isinstance(keys, TupleVariable)
1872        assert keys.is_python_constant()
1873
1874        keys = keys.force_unpack_var_sequence(self)
1875        assert len(keys) == len(values)
1876
1877        self.push(
1878            ConstDictVariable(
1879                dict(zip(keys, values)),
1880                mutable_local=MutableLocal(),
1881            )
1882        )
1883
1884    def MAP_ADD(self, inst):
1885        k, v = self.popn(2)
1886        assert inst.argval > 0
1887        obj = self.stack[-inst.arg].realize()
1888        assert isinstance(obj, ConstDictVariable)
1889        obj.call_method(self, "__setitem__", (k, v), {})  # type: ignore[arg-type]
1890
1891    def SET_ADD(self, inst):
1892        v = self.pop()
1893        assert inst.argval > 0
1894        obj = self.stack[-inst.arg]
1895        assert isinstance(obj, SetVariable)
1896        assert obj.mutable_local
1897        return obj.call_method(self, "add", [v], {})
1898
1899    def SET_UPDATE(self, inst):
1900        v = self.pop()
1901        assert inst.argval > 0
1902        obj = self.stack[-inst.arg]
1903        assert isinstance(obj, SetVariable)
1904        assert obj.mutable_local
1905        obj.call_method(self, "update", [v], {})
1906
1907    def LIST_APPEND(self, inst):
1908        v = self.pop()
1909        assert inst.argval > 0
1910        obj = self.stack[-inst.arg].realize()
1911        assert isinstance(obj, ListVariable)
1912        assert obj.mutable_local
1913        self.output.side_effects.mutation(obj)
1914        obj.items.append(v)
1915
1916    def MAKE_FUNCTION(self, inst):
1917        flags = inst.arg
1918        old_stack = list(self.stack)
1919        if sys.version_info < (3, 11):
1920            fn_name = self.pop()
1921        code = self.pop()
1922        if sys.version_info >= (3, 11):
1923            # MAKE_FUNCTION behavior actually changed in 3.11, see
1924            # https://github.com/python/cpython/pull/93189/
1925            assert hasattr(code.value, "co_qualname")  # type: ignore[attr-defined]
1926            fn_name = ConstantVariable.create(value=code.value.co_qualname)  # type: ignore[attr-defined]
1927        defaults = None
1928        closure = None
1929        annotations = None
1930        kwdefaults = None
1931
1932        if sys.version_info < (3, 13):
1933            # in 3.13, this is handled in SET_FUNCTION_ATTRIBUTE
1934            if flags & 0x08:
1935                closure = self.pop()
1936            if flags & 0x04:
1937                annotations = self.pop()
1938            if flags & 0x02:
1939                kwdefaults = self.pop()
1940            if flags & 0x01:
1941                defaults = self.pop()
1942
1943        self.push(
1944            NestedUserFunctionVariable(
1945                fn_name,
1946                code,
1947                self.f_globals,
1948                defaults,
1949                kwdefaults,
1950                annotations,
1951                closure,
1952                closure_scope=self,
1953            )
1954        )
1955
1956    def UNPACK_SEQUENCE(self, inst):
1957        seq = self.pop()
1958        if isinstance(seq, TensorVariable):
1959            val = seq.unpack_var_sequence(self, idxes=range(inst.argval))  # type: ignore[arg-type]
1960        elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable):
1961            # x, y = a.shape
1962            proxy = getattr(seq.obj.as_proxy(), seq.name)
1963            val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)]
1964        elif seq.has_force_unpack_var_sequence(self):
1965            val = seq.force_unpack_var_sequence(self)
1966        else:
1967            unimplemented(f"UNPACK_SEQUENCE {seq}")
1968        if len(val) != inst.argval:
1969            unimplemented("UNPACK_SEQUENCE length mismatch")
1970        for i in reversed(val):
1971            self.push(i)
1972
1973    def UNPACK_EX(self, inst):
1974        assert 0 <= inst.argval <= 0xFFFF
1975        prefix = inst.argval & 0xFF  # low byte
1976        suffix = inst.argval >> 8  # high byte
1977        seq = self.pop()
1978        if seq.has_force_unpack_var_sequence(self):
1979            vals = list(seq.force_unpack_var_sequence(self))
1980            assert len(vals) >= prefix + suffix
1981            vals_prefix = vals[:prefix]
1982            vals_list = vals[prefix : len(vals) - suffix]
1983            vals_suffix = vals[len(vals) - suffix :]
1984            for item in reversed(vals_suffix):
1985                self.push(item)
1986            self.push(TupleVariable(vals_list))
1987            for item in reversed(vals_prefix):
1988                self.push(item)
1989        else:
1990            unimplemented(f"UNPACK_EX {seq}")
1991
1992    def NOP(self, inst):
1993        pass
1994
1995    def POP_TOP(self, inst):
1996        self.pop()
1997
1998    def ROT_TWO(self, inst):
1999        a = self.pop()
2000        b = self.pop()
2001        self.push(a)
2002        self.push(b)
2003
2004    def ROT_THREE(self, inst):
2005        a = self.pop()
2006        b = self.pop()
2007        c = self.pop()
2008        self.push(a)
2009        self.push(c)
2010        self.push(b)
2011
2012    def ROT_FOUR(self, inst):
2013        a = self.pop()
2014        b = self.pop()
2015        c = self.pop()
2016        d = self.pop()
2017        self.push(a)
2018        self.push(d)
2019        self.push(c)
2020        self.push(b)
2021
2022    def DUP_TOP(self, inst):
2023        a = self.pop()
2024        self.push(a)
2025        self.push(a)
2026
2027    def DUP_TOP_TWO(self, inst):
2028        a = self.pop()
2029        b = self.pop()
2030        self.push(b)
2031        self.push(a)
2032        self.push(b)
2033        self.push(a)
2034
2035    def FORMAT_VALUE(self, inst):
2036        flags = inst.arg
2037        if (flags & 0x04) == 0x04:
2038            fmt_spec = self.pop()
2039        else:
2040            fmt_spec = ConstantVariable.create("")
2041
2042        value = self.pop()
2043        if isinstance(value, SymNodeVariable):
2044            from torch._dynamo.variables.lazy import (
2045                LazySymNodeFormatString,
2046                LazyVariableTracker,
2047            )
2048
2049            value = LazyVariableTracker.create(
2050                LazySymNodeFormatString(value, fmt_spec), source=value.source
2051            )
2052            self.push(value)
2053            return
2054        if (flags & 0x03) == 0x01:
2055            value = BuiltinVariable(str).call_function(self, [value], {})  # type: ignore[arg-type]
2056        elif (flags & 0x03) == 0x02:
2057            value = BuiltinVariable(repr).call_function(self, [value], {})  # type: ignore[arg-type]
2058        elif (flags & 0x03) == 0x03:
2059            value = BuiltinVariable(ascii).call_function(self, [value], {})  # type: ignore[arg-type]
2060
2061        fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}")
2062
2063        self.call_function(BuiltinVariable(str.format), [fmt_var, value], {})
2064
2065    def BUILD_STRING(self, inst):
2066        format_string_parts: List[str] = []
2067        args: List[VariableTracker] = []
2068        kwargs: Dict[str, VariableTracker] = {}
2069        for part in self.popn(inst.arg):
2070            if isinstance(part, ConstantVariable):
2071                format_string_parts.append("{}")
2072                args.append(part)
2073            elif isinstance(part, variables.StringFormatVariable):
2074                format_string_parts.append(part.format_string)
2075                args.extend(part.sym_args)
2076                if set(kwargs.keys()) & set(part.sym_kwargs.keys()):
2077                    unimplemented(
2078                        f"BUILD_STRING key conflict {kwargs} & {part.sym_kwargs}"
2079                    )
2080                kwargs.update(part.sym_kwargs)
2081            else:
2082                unimplemented(f"BUILD_STRING {part}")
2083        self.push(
2084            variables.StringFormatVariable.create(
2085                "".join(format_string_parts), args, kwargs
2086            )
2087        )
2088
2089    def IS_OP(self, inst):
2090        assert inst.argval == 0 or inst.argval == 1
2091        if inst.argval == 0:
2092            new_argval = "is"
2093        else:
2094            new_argval = "is not"
2095        new_inst = create_instruction("COMPARE_OP", argval=new_argval)
2096        self.COMPARE_OP(new_inst)
2097
2098    def CONTAINS_OP(self, inst):
2099        assert inst.argval == 0 or inst.argval == 1
2100        left, right = self.popn(2)
2101        op = inst.argval
2102        self.push(right.call_method(self, "__contains__", [left], {}))
2103        if op == 1:
2104            self.UNARY_NOT(inst)
2105
2106    def LIST_EXTEND(self, inst):
2107        v = self.pop()
2108        assert inst.argval > 0
2109        obj = self.stack[-inst.arg]
2110        assert isinstance(obj, ListVariable)
2111        assert obj.mutable_local
2112        obj.call_method(self, "extend", [v], {})
2113
2114    def LIST_TO_TUPLE(self, inst):
2115        self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {}))  # type: ignore[arg-type]
2116
2117    def DICT_MERGE(self, inst):
2118        v = self.pop()
2119        assert inst.argval > 0
2120        obj = self.stack[-inst.arg].realize()
2121        assert isinstance(obj, ConstDictVariable)
2122        assert obj.mutable_local
2123        obj.call_method(self, "update", [v], {})
2124
2125    DICT_UPDATE = DICT_MERGE
2126
2127    def GEN_START(self, inst):
2128        self.pop()
2129
2130    def GET_LEN(self, inst):
2131        tos = self.stack[-1]
2132        if tos.is_python_constant():
2133            self.push(ConstantVariable.create(len(tos.as_python_constant())))
2134        else:
2135            self.push(tos.call_method(self, "__len__", [], {}))
2136
2137    def MATCH_MAPPING(self, inst):
2138        tos = self.stack[-1]
2139        assert isinstance(tos, ConstDictVariable)
2140        if isinstance(tos.items, collections.abc.Mapping):
2141            self.push(ConstantVariable.create(True))
2142        else:
2143            self.push(ConstantVariable.create(False))
2144
2145    def MATCH_SEQUENCE(self, inst):
2146        tos = self.stack[-1]
2147        assert tos.is_python_constant()
2148        tos_value = tos.as_python_constant()
2149        if isinstance(tos_value, collections.abc.Sequence) and not isinstance(
2150            tos_value, (str, bytes, bytearray)
2151        ):
2152            self.push(ConstantVariable.create(True))
2153        else:
2154            self.push(ConstantVariable.create(False))
2155
2156    def MATCH_KEYS(self, inst):
2157        tos = self.stack[-1]
2158        tos1 = self.stack[-2]
2159        assert isinstance(tos1, ConstDictVariable)
2160
2161        if all(k in tos1 for k in tos):  # type: ignore[attr-defined]
2162            self.push(TupleVariable([tos1.getitem_const(self, k) for k in tos]))  # type: ignore[attr-defined,arg-type]
2163            if sys.version_info < (3, 11):
2164                self.push(ConstantVariable.create(True))
2165        else:
2166            self.push(ConstantVariable.create(None))
2167            if sys.version_info < (3, 11):
2168                self.push(ConstantVariable.create(False))
2169
2170    def LOAD_ASSERTION_ERROR(self, inst):
2171        self.load_builtin_from_argval("AssertionError")
2172
2173    UNARY_POSITIVE = stack_op(operator.pos)
2174    UNARY_NEGATIVE = stack_op(operator.neg)
2175    UNARY_NOT = stack_op(operator.not_)
2176    UNARY_INVERT = stack_op(operator.invert)
2177
2178    BINARY_POWER = stack_op(operator.pow)
2179    BINARY_MULTIPLY = stack_op(operator.mul)
2180    BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul)
2181    BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv)
2182    BINARY_TRUE_DIVIDE = stack_op(operator.truediv)
2183    BINARY_MODULO = stack_op(operator.mod)
2184    BINARY_REMAINDER = stack_op(operator.mod)
2185    BINARY_ADD = stack_op(operator.add)
2186    BINARY_SUBTRACT = stack_op(operator.sub)
2187    BINARY_SUBSCR = break_graph_if_unsupported(push=1)(stack_op(operator.getitem))
2188    BINARY_LSHIFT = stack_op(operator.lshift)
2189    BINARY_RSHIFT = stack_op(operator.rshift)
2190    BINARY_AND = stack_op(operator.and_)
2191    BINARY_OR = stack_op(operator.or_)
2192    BINARY_XOR = stack_op(operator.xor)
2193
2194    INPLACE_POWER = stack_op(operator.ipow)
2195    INPLACE_MULTIPLY = stack_op(operator.imul)
2196    INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul)
2197    INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv)
2198    INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv)
2199    INPLACE_MODULO = stack_op(operator.imod)
2200    INPLACE_REMAINDER = stack_op(operator.imod)
2201    INPLACE_ADD = stack_op(operator.iadd)
2202    INPLACE_SUBTRACT = stack_op(operator.isub)
2203    INPLACE_LSHIFT = stack_op(operator.ilshift)
2204    INPLACE_RSHIFT = stack_op(operator.irshift)
2205    INPLACE_AND = stack_op(operator.iand)
2206    INPLACE_XOR = stack_op(operator.ixor)
2207    INPLACE_OR = stack_op(operator.ior)
2208
2209    # 3.11 opcodes
2210    def RESUME(self, inst):
2211        if inst.arg == 0:
2212            self.append_prefix_inst(inst)
2213            self.accept_prefix_inst = False
2214        else:
2215            assert not self.accept_prefix_inst
2216
2217    if sys.version_info >= (3, 11):
2218
2219        def BINARY_OP(self, inst):
2220            return _binary_op_lookup[inst.arg](self, inst)
2221
2222    def PRECALL(self, inst):
2223        pass
2224
2225    def KW_NAMES(self, inst):
2226        kw_names = self.code_options["co_consts"][inst.arg]
2227        assert isinstance(kw_names, tuple)
2228        for name in kw_names:
2229            assert isinstance(name, str)
2230        assert self.kw_names is None
2231        self.kw_names = ConstantVariable.create(value=kw_names)  # type: ignore[assignment]
2232
2233    def PUSH_NULL(self, inst):
2234        self.push(NullVariable())
2235
2236    def _call(self, inst, call_kw=False):
2237        # see https://docs.python.org/3.11/library/dis.html#opcode-CALL
2238        # for convention
2239        if call_kw:
2240            # TOS is kw_names for CALL_KW instruction
2241            assert sys.version_info >= (3, 13)
2242            kw_names = self.pop()
2243            assert isinstance(kw_names, TupleVariable) and kw_names.is_python_constant()
2244            kw_names = kw_names.as_python_constant()
2245        else:
2246            kw_names = self.kw_names.value if self.kw_names else ()
2247
2248        contents = self.popn(inst.arg + 2)
2249        if sys.version_info >= (3, 13):
2250            # NULL and callable swapped
2251            fn = contents[0]
2252            args = [] if isinstance(contents[1], NullVariable) else [contents[1]]
2253        else:
2254            if isinstance(contents[0], NullVariable):
2255                fn = contents[1]
2256                args = []
2257            else:
2258                fn = contents[0]
2259                args = [contents[1]]
2260
2261        if kw_names:
2262            args = args + contents[2 : -len(kw_names)]
2263            kwargs_list = contents[-len(kw_names) :]
2264            kwargs = dict(zip(kw_names, kwargs_list))
2265            assert len(kwargs) == len(kw_names)
2266        else:
2267            args = args + contents[2:]
2268            kwargs = {}
2269
2270        try:
2271            # if call_function fails, need to set kw_names to None, otherwise
2272            # a subsequent call may have self.kw_names set to an old value
2273            self.call_function(fn, args, kwargs)
2274        finally:
2275            self.kw_names = None
2276
2277    @break_graph_if_unsupported(push=1)
2278    def CALL(self, inst):
2279        self._call(inst)
2280
2281    def COPY(self, inst):
2282        self.push(self.stack[-inst.arg])
2283
2284    def SWAP(self, inst):
2285        self.stack[-1], self.stack[-inst.arg] = self.stack[-inst.arg], self.stack[-1]
2286
2287    JUMP_BACKWARD = jump
2288    JUMP_BACKWARD_NO_INTERRUPT = jump
2289
2290    POP_JUMP_FORWARD_IF_TRUE = generic_jump(operator.truth, False)
2291    POP_JUMP_BACKWARD_IF_TRUE = generic_jump(operator.truth, False)
2292    POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False)
2293    POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False)
2294
2295    def CACHE(self, inst):
2296        pass
2297
2298    def BEFORE_WITH(self, inst):
2299        self.setup_or_before_with(inst)
2300
2301    def setup_or_before_with(self, inst):
2302        ctx = self.pop()
2303        if not isinstance(
2304            ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
2305        ):
2306            unimplemented(f"{inst.opname} {ctx}")
2307
2308        if isinstance(ctx, GenericContextWrappingVariable):
2309            self.generic_context_manager_depth += 1
2310
2311        # Need this redundant check for mypy
2312        assert isinstance(
2313            ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
2314        )
2315
2316        exit = WithExitFunctionVariable(
2317            ctx,
2318            inst.target,
2319        )
2320
2321        if sys.version_info >= (3, 11):
2322            # See create_call_resume_at for block stack details.
2323            # Only push a block if the current instruction's block is a
2324            # with block that is not nested in a try block - that is, the current
2325            # instruction's block target is the same as the top block's target.
2326            if inst.exn_tab_entry and (
2327                not self.block_stack
2328                or inst.exn_tab_entry.target is not self.block_stack[-1].target
2329            ):
2330                target = None
2331            else:
2332                target = self.next_instruction.exn_tab_entry.target
2333        else:
2334            target = inst.target
2335
2336        if target:
2337            if isinstance(self, InstructionTranslator):
2338                self.block_stack.append(
2339                    BlockStackEntry(inst, target, len(self.stack), ctx)
2340                )
2341            else:
2342                self.block_stack.append(BlockStackEntry(inst, target))
2343
2344        self.push(exit)
2345        self.push(ctx.enter(self))
2346
2347    def append_prefix_inst(self, inst):
2348        assert self.accept_prefix_inst
2349        self.prefix_insts.append(inst)
2350
2351    def MAKE_CELL(self, inst):
2352        if sys.version_info >= (3, 12) and not self.accept_prefix_inst:
2353            # In 3.12+, MAKE_CELL is not longer necessarily a prefix instruction.
2354            # It can be generated by inlined comprehensions.
2355            assert isinstance(self.symbolic_locals[inst.argval], NullVariable)
2356            self.symbolic_locals[
2357                inst.argval
2358            ] = self.output.side_effects.track_cell_new()
2359        else:
2360            self.append_prefix_inst(inst)
2361
2362    def COPY_FREE_VARS(self, inst):
2363        self.append_prefix_inst(inst)
2364
2365    def RETURN_GENERATOR(self, inst):
2366        self.append_prefix_inst(inst)
2367
2368    # 3.12 opcodes
2369    # BINARY/STORE_SLICE opcodes are broken down into
2370    # BUILD_SLICE 2 and BINARY/STORE_SUBSCR
2371
2372    def END_FOR(self, inst):
2373        if sys.version_info >= (3, 13):
2374            self.pop()
2375        else:
2376            self.popn(2)
2377
2378    def LOAD_FAST_CHECK(self, inst):
2379        if isinstance(self.symbolic_locals[inst.argval], NullVariable):
2380            unimplemented("LOAD_FAST_CHECK on uninitialized variable")
2381        self.LOAD_FAST(inst)
2382
2383    def LOAD_FAST_AND_CLEAR(self, inst):
2384        if inst.argval not in self.symbolic_locals:
2385            self.push(NullVariable())
2386        else:
2387            self.LOAD_FAST(inst)
2388        self.symbolic_locals[inst.argval] = NullVariable()
2389
2390    def LOAD_SUPER_ATTR(self, inst):
2391        self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
2392        if inst.arg & 1:
2393            self.LOAD_METHOD(inst)
2394        else:
2395            self._load_attr(inst)
2396
2397    def CALL_INTRINSIC_1(self, inst):
2398        if inst.argval == 5:
2399            # INTRINSIC_UNARY_POSITIVE
2400            self.UNARY_POSITIVE(inst)
2401        elif inst.argval == 6:
2402            # INTRINSIC_LIST_TO_TUPLE
2403            self.push(TupleVariable(self.pop().force_unpack_var_sequence(self)))
2404        else:
2405            unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}")
2406
2407    def END_SEND(self, inst):
2408        tos = self.pop()
2409        self.pop()
2410        self.push(tos)
2411
2412    # 3.13 opcodes
2413    # fused instructions LOAD_FAST_LOAD_FAST, STORE_FAST_STORE_FAST, STORE_FAST_LOAD_FAST
2414    # are broken down.
2415    @break_graph_if_unsupported(push=1)
2416    def CALL_KW(self, inst):
2417        self._call(inst, call_kw=True)
2418
2419    def TO_BOOL(self, inst):
2420        # TO_BOOL only precedes a conditional jump or UNARY_NOT (see compile.c in CPython)
2421        # So we can skip this instruction as long as we remember to codegen a TO_BOOL
2422        # before conditional jumps/UNARY_NOT.
2423        assert self.next_instruction.opname in (
2424            "POP_JUMP_IF_TRUE",
2425            "POP_JUMP_IF_FALSE",
2426            "UNARY_NOT",
2427        )
2428
2429    def SET_FUNCTION_ATTRIBUTE(self, inst):
2430        flags = inst.arg
2431        fn = self.pop()
2432        assert isinstance(fn, NestedUserFunctionVariable)
2433        attr_names = self.name_stack[-1]
2434        attr = self.pop()
2435
2436        if flags & 0x08:
2437            # 3.13 merged LOAD_CLOSURE into LOAD_FAST, so we won't know if a given LOAD_FAST
2438            # is meant to load a closure variable or not. Our workaround is to maintain a stack
2439            # of LOAD_FAST variable names and tuples (self.name_stack). So if we are indeed
2440            # constructing a closure tuple, we can use self.name_stack to construct the closure
2441            # variables here.
2442            assert isinstance(attr_names, tuple) and all(
2443                isinstance(name, str) for name in attr_names
2444            )
2445            fn.closure = TupleVariable(
2446                [self._load_closure(name) for name in attr_names]
2447            )
2448            fn.closure_scope = self
2449        elif flags & 0x04:
2450            fn.annotations = attr
2451        elif flags & 0x02:
2452            fn.kwdefaults = attr
2453        elif flags & 0x01:
2454            fn.defaults = attr
2455
2456        self.push(fn)
2457
2458    def _format_value_313(self, fmt_spec):
2459        value = self.pop()
2460        if isinstance(value, SymNodeVariable):
2461            value = ConstantVariable.create(str(value.sym_num))
2462
2463        fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}")
2464
2465        self.call_function(BuiltinVariable(str.format), [fmt_var, value], {})
2466
2467    def FORMAT_SIMPLE(self, inst):
2468        self._format_value_313(ConstantVariable.create(""))
2469
2470    def FORMAT_WITH_SPEC(self, inst):
2471        self._format_value_313(self.pop())
2472
2473    def is_non_empty_graph(self):
2474        if self.output.count_calls() > 1:
2475            # perf optimization only
2476            self.is_non_empty_graph = lambda: True  # type: ignore[method-assign]
2477            return True
2478        return False
2479
2480    def format_frame_summary(self, additional_stack_frames=None):
2481        if additional_stack_frames is None:
2482            additional_stack_frames = []
2483        return "".join(
2484            traceback.format_list(
2485                [self.frame_summary()] + list(reversed(additional_stack_frames))
2486            )
2487        )
2488
2489    def frame_summary(self):
2490        return traceback.FrameSummary(
2491            getattr(self.f_code, "co_filename", "<unknown>"),
2492            self.lineno,
2493            getattr(self.f_code, "co_name", "<unknown>"),
2494            lookup_line=False,
2495        )
2496
2497    def is_co_filename_from_nn_modules(self):
2498        filename = getattr(self.f_code, "co_filename", "<unknown>")
2499        nn_modules_pattern = re.compile(r".*torch/nn/modules.*")
2500        return nn_modules_pattern.match(filename) is not None
2501
2502    def store_global_weakref_by_id(self, prefix, value):
2503        global_name = self.output.install_global_by_id(prefix, weakref.ref(value))
2504        install_guard(
2505            GlobalWeakRefSource(global_name).make_guard(GuardBuilder.WEAKREF_ALIVE)
2506        )
2507        return global_name
2508
2509    @property
2510    def fake_mode(self):
2511        return self.output.tracing_context.fake_mode
2512
2513    def find_symbolic_locals_name(self, tensor_variable):
2514        for key, value in self.symbolic_locals.items():
2515            if value is tensor_variable:
2516                return key
2517        return None
2518
2519    @contextlib.contextmanager
2520    def strict_translation_mode(self, check_fn: Callable[[VariableTracker], bool]):
2521        """
2522        Strict mode is enabled on a per-VariableTracker level depending on the return value of check_fn(node).
2523        """
2524        prior = self.strict_checks_fn
2525        self.strict_checks_fn = check_fn
2526        try:
2527            yield
2528        finally:
2529            self.strict_checks_fn = prior
2530
2531    def speculate(self) -> SpeculationEntry:
2532        assert self.instruction_pointer is not None
2533        assert self.instruction_pointer > 0
2534        return self.speculation_log.next(
2535            self.f_code.co_filename,
2536            self.lineno,
2537            self.instruction_pointer - 1,
2538            self.instructions[self.instruction_pointer - 1],
2539        )
2540
2541    def __init__(
2542        self,
2543        output: OutputGraph,
2544        instructions: List[Instruction],
2545        f_locals: Dict[str, Any],
2546        f_globals: Dict[str, Any],
2547        f_builtins: Dict[str, Any],
2548        code_options: Dict[str, Any],
2549        symbolic_locals: Dict[str, VariableTracker],
2550        symbolic_globals: Dict[str, VariableTracker],
2551        symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
2552        f_code: types.CodeType,
2553        export: bool,
2554        inline_depth: int,
2555        speculation_log: SpeculationLog,
2556        distributed_state: Optional[DistributedState],
2557    ) -> None:
2558        super().__init__()
2559        self.speculation_log = speculation_log
2560        self.distributed_state = distributed_state
2561
2562        # Mutable state checkpointed by copy_graphstate()
2563        self.output = output
2564        self.symbolic_locals = symbolic_locals
2565        self.symbolic_globals = symbolic_globals
2566        self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack
2567        self.stack = []
2568        # stack of variable names for tracking 3.13 closures
2569        self.name_stack: list[Any] = []
2570        self.instruction_pointer = 0
2571        self.current_instruction = create_instruction("NOP")
2572        self.block_stack = []
2573        # states before SETUP_WITH for checkpointing and fallback
2574        self.generic_context_manager_depth = 0
2575        self.lineno = -1
2576        self.kw_names = None
2577        self.accept_prefix_inst = True
2578        self.prefix_insts = []
2579        self.exn_vt_stack = []
2580
2581        # Properties of the input/output code
2582        self.instructions: List[Instruction] = instructions
2583        self.indexof: Dict[Instruction, int] = get_indexof(self.instructions)
2584        self.f_locals: Dict[
2585            str, Any
2586        ] = f_locals  # needed for recording accessed locals for replay
2587        self.f_globals: Dict[str, Any] = f_globals
2588        self.f_builtins: Dict[str, Any] = f_builtins
2589        self.code_options: Dict[str, Any] = code_options
2590        self.f_code: types.CodeType = f_code
2591
2592        # Execution record for replaying errors
2593        if config.replay_record_enabled:
2594            self.exec_recorder = ExecutionRecorder(
2595                code=f_code, code_options=code_options
2596            )
2597        else:
2598            self.exec_recorder = None
2599        # Stack of module being parsed, current nn.module is at the end of ordered dict.
2600        # The first field of tuple is the fully qualified name of current module
2601        # in original hierarchy.  The second field is the type of current nn.module
2602        self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {}
2603        # Flag to indicate whether tracing is used for export.
2604        self.export = export
2605        self.one_graph = False
2606
2607        self.current_speculation = None
2608
2609        self.strict_checks_fn = None
2610
2611        if sys.version_info >= (3, 10):
2612            from .resume_execution import (
2613                CO_ASYNC_GENERATOR,
2614                CO_COROUTINE,
2615                CO_GENERATOR,
2616                CO_ITERABLE_COROUTINE,
2617            )
2618
2619            if f_code.co_flags & (
2620                CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR
2621            ):
2622                self.push(BuiltinVariable(None))
2623
2624        self.inline_depth = inline_depth
2625        self.inconsistent_side_effects = False
2626        self._constants_cache: List[Optional[VariableTracker]] = [None] * len(
2627            f_code.co_consts
2628        )
2629        linecache.lazycache(f_code.co_filename, f_globals)
2630
2631
2632class InstructionTranslator(InstructionTranslatorBase):
2633    mutated_closure_cell_contents: Set[str]
2634
2635    @staticmethod
2636    def current_tx() -> "InstructionTranslator":
2637        return tls.current_tx
2638
2639    @contextlib.contextmanager
2640    def set_current_tx(self):
2641        prior = getattr(tls, "current_tx", None)
2642        tls.current_tx = self
2643        try:
2644            yield
2645        finally:
2646            tls.current_tx = prior
2647
2648    def __init__(
2649        self,
2650        instructions: List[Instruction],
2651        f_code,
2652        f_locals,
2653        f_globals,
2654        f_builtins,
2655        code_options,
2656        compiler_fn,
2657        one_graph,
2658        export,
2659        export_constraints,
2660        mutated_closure_cell_contents: Set[str],
2661        frame_state,
2662        speculation_log: SpeculationLog,
2663        distributed_state: Optional[DistributedState],
2664    ) -> None:
2665        _step_logger()(
2666            logging.INFO,
2667            f"torchdynamo start tracing {f_code.co_name} {code_options['co_filename']}:{code_options['co_firstlineno']}",
2668        )
2669        super().__init__(
2670            output=OutputGraph(
2671                code_options,
2672                compiler_fn,
2673                self,
2674                export,
2675                export_constraints,
2676                frame_state,
2677                local_scope=f_locals,
2678                global_scope=f_globals,
2679                f_code=f_code,
2680            ),
2681            instructions=instructions,
2682            f_locals=f_locals,
2683            f_globals=f_globals,
2684            f_builtins=f_builtins,
2685            code_options=code_options,
2686            symbolic_locals={},  # set below
2687            # A global var is inserted only after a STORE_GLOBAL happens to it
2688            symbolic_globals={},
2689            symbolic_torch_function_mode_stack=collections.deque(),
2690            f_code=f_code,
2691            export=export,
2692            inline_depth=0,
2693            speculation_log=speculation_log,
2694            distributed_state=distributed_state,
2695        )
2696
2697        self._throw_if_in_functorch()
2698
2699        # as soon as we create the tracing context we should keep it active, so any calls
2700        # into dynamo apis can rely on finding it
2701        with tracing(self.output.tracing_context), self.set_current_tx():
2702            self.one_graph: bool = one_graph
2703            self.export = export
2704            self.mutated_closure_cell_contents = mutated_closure_cell_contents
2705            if self.export:
2706                assert (
2707                    self.one_graph
2708                ), "Export without one graph - something has gone wrong."
2709
2710            vars = list(code_options["co_varnames"])
2711            cells_and_freevars = [x for x in self.cell_and_freevars() if x not in vars]
2712            vars.extend(cells_and_freevars)
2713            cells_and_freevars_set = set(cells_and_freevars)
2714
2715            self.symbolic_locals = {
2716                k: variables.LazyVariableTracker.create(
2717                    f_locals[k],
2718                    source=LocalSource(k, cell_or_freevar=k in cells_and_freevars_set),
2719                )
2720                for k in vars
2721                if k in f_locals
2722            }
2723
2724            self._init_torch_function_mode_stack()
2725
2726            self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = []
2727            if export:
2728                # export gets confused if we never realize unused inputs
2729                # in export mode just eagerly realize everything
2730                self.symbolic_locals = variables.LazyVariableTracker.realize_all(
2731                    self.symbolic_locals
2732                )
2733
2734            self._freevars_ids = {}
2735            for name in self.code_options["co_freevars"]:
2736                if name in f_locals:
2737                    self._freevars_ids[name] = id(f_locals[name])
2738
2739    def _throw_if_in_functorch(self):
2740        # Fallback to eager in case of a graph break inside vmap
2741        eager = torch._dynamo.lookup_backend("eager")
2742        compiler_fn = inspect.getattr_static(
2743            self.output.compiler_fn, "compiler_fn", self.output.compiler_fn
2744        )
2745        ci = torch._C._functorch.peek_interpreter_stack()
2746        forbidden_keys = (
2747            torch._C._functorch.TransformType.Vmap,
2748            torch._C._functorch.TransformType.Grad,
2749            torch._C._functorch.TransformType.Jvp,
2750        )
2751
2752        if ci is not None and ci.key() in forbidden_keys and compiler_fn is not eager:
2753            name = ci.key().name.lower()
2754            msg = (
2755                "If you are reaching here, it means dynamo failed for one of the following reasons:\n"
2756                # Calling a torch.compiled function
2757                f"- Calling torch.func.{name}(compiled_fn) function from eager mode is not supported. "
2758                f"Ensure that torch.func.{name} is also wrapped within a torch.compile function. "
2759                "For more information, see PyTorch issue #128711.\n"
2760                # if it reaches here, it means Dynamo failed to inline a functorch function
2761                f"- torch.func.{name}(fn) requires the function to be inlined by dynamo"
2762            )
2763            unimplemented(msg)
2764
2765    def _init_torch_function_mode_stack(self):
2766        from .variables.torch_function import TorchFunctionModeStackVariable
2767
2768        TorchFunctionModeStackVariable.reset()
2769
2770        self.symbolic_torch_function_mode_stack: Deque[
2771            TorchFunctionModeVariable
2772        ] = collections.deque()
2773        # We want to retrieve all modes to properly reconstruct the stack if needed
2774        py_stack = get_torch_function_mode_stack(filter_ignored=False)
2775
2776        if py_stack:
2777            has_device_context = isinstance(
2778                py_stack[0], torch.utils._device.DeviceContext
2779            )
2780
2781        for i, val in enumerate(py_stack):
2782            self.symbolic_torch_function_mode_stack.append(
2783                variables.LazyVariableTracker.create(
2784                    val, source=TorchFunctionModeStackSource(i)
2785                )
2786            )
2787
2788    def get_example_value(self, source: Source):
2789        if isinstance(source, LocalSource):
2790            return self.f_locals[source.local_name]
2791        if isinstance(source, GlobalSource):
2792            return self.f_globals[source.global_name]
2793        raise KeyError
2794
2795    def run(self):
2796        super().run()
2797
2798    def match_nested_cell(self, name, cell):
2799        """Match a cell in this method to one in a function we are inlining"""
2800        try:
2801            value = cell.cell_contents
2802        except ValueError:
2803            return None
2804        # TODO(jansel): check the id of the cell rather than the contents
2805        if id(value) != self._freevars_ids.get(name):
2806            return None
2807        return self.symbolic_locals[name]
2808
2809    def should_compile_partial_graph(self):
2810        if sys.version_info >= (3, 11):
2811            # Do not compile if current instruction's block is not the top with block
2812            entry = self.current_instruction.exn_tab_entry
2813            if entry and (
2814                not self.block_stack or entry.target is not self.block_stack[-1].target
2815            ):
2816                return False
2817        return (
2818            all(b.can_restore() for b in self.block_stack)
2819            and not self.one_graph
2820            and self.generic_context_manager_depth == 0
2821        )
2822
2823    def create_call_resume_at(self, inst):
2824        self.instruction_pointer = None
2825
2826        if inst.opname == "RETURN_VALUE":
2827            return [create_instruction("RETURN_VALUE")]
2828        elif inst.opname == "RETURN_CONST":
2829            return [create_instruction("RETURN_CONST", argval=inst.argval)]
2830
2831        reads = livevars_analysis(self.instructions, inst)
2832        all_argnames = tuple(
2833            k
2834            for k in self.symbolic_locals.keys()
2835            if k in reads and k not in self.cell_and_freevars()
2836        )
2837        # NOTE: do not use isinstance, since it realizes lazy VT's
2838        argnames = tuple(
2839            k
2840            for k in all_argnames
2841            if not type.__instancecheck__(NullVariable, self.symbolic_locals[k])
2842        )
2843        argnames_null = tuple(
2844            k
2845            for k in all_argnames
2846            if type.__instancecheck__(NullVariable, self.symbolic_locals[k])
2847        )
2848        if sys.version_info < (3, 12):
2849            assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
2850
2851        cg = PyCodegen(self)
2852
2853        # Handle inactive context variables.
2854        # The resume function assumes that context variables are the class, NOT the object.
2855        # e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
2856        stack_ctx_vars = []
2857        for i, var in enumerate(self.stack):
2858            if type.__instancecheck__(ContextWrappingVariable, var):
2859                ctx = cast(ContextWrappingVariable, var)
2860                target_values = (
2861                    () if ctx.target_values is None else tuple(ctx.target_values)
2862                )
2863                stack_ctx_vars.append((i, target_values))
2864                # Replace the current stack var with the context class
2865                ctx.reconstruct_type(cg)
2866                cg.extend_output(create_swap(len(self.stack) - i + 1))
2867                cg.append_output(create_instruction("POP_TOP"))
2868
2869        argnames_ctx_vars = []
2870        for name in argnames:
2871            if type.__instancecheck__(
2872                ContextWrappingVariable, var := self.symbolic_locals[name]
2873            ):
2874                ctx = cast(ContextWrappingVariable, var)
2875                target_values = (
2876                    () if ctx.target_values is None else tuple(ctx.target_values)
2877                )
2878                argnames_ctx_vars.append((name, target_values))
2879                # Replace the local with the context class
2880                ctx.reconstruct_type(cg)
2881                cg.append_output(create_instruction("STORE_FAST", argval=name))
2882
2883        # Python does not allow null to be an arg to a function, so
2884        # we remove nulls from the stack and restore them in the
2885        # prologue of the resume function
2886
2887        # sorted list of indices of nulls on the stack
2888        null_idxes: List[int] = []
2889        if sys.version_info >= (3, 11):
2890            # find indices of NullVariables
2891            for i, var in enumerate(self.stack):
2892                if type.__instancecheck__(NullVariable, var):
2893                    null_idxes.append(i)
2894            # generate bytecode to pop the nulls
2895            null_cnt = 0
2896            for i, var in enumerate(reversed(self.stack)):
2897                if type.__instancecheck__(NullVariable, var):
2898                    for j in range(2, i + 2 - null_cnt):
2899                        cg.append_output(create_instruction("SWAP", arg=j))
2900                    cg.extend_output(cg.pop_null())
2901                    null_cnt += 1
2902
2903        # we popped all nulls from the stack at runtime,
2904        # so we should not count NullVariables
2905        stack_len = len(self.stack) - len(null_idxes)
2906        nargs = stack_len + len(argnames)
2907
2908        name = unique_id(f"__resume_at_{inst.offset}")
2909
2910        new_code: types.CodeType = ContinueExecutionCache.lookup(
2911            self.f_code,
2912            self.lineno,
2913            inst.offset,
2914            tuple(b.target.offset for b in self.block_stack),
2915            stack_len,
2916            argnames,
2917            argnames_null,
2918            tuple(b.resume_fn() for b in self.block_stack),
2919            tuple(stack_ctx_vars),
2920            tuple(argnames_ctx_vars),
2921            tuple(null_idxes),
2922        )
2923
2924        # Add original GraphModule context to the resume function to handle
2925        # the case of a graph break while tracing a GraphModule
2926        orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
2927            "orig_graphmodule", lambda: None
2928        )()
2929        if orig_graphmodule_maybe is not None:
2930            code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
2931                orig_graphmodule_maybe
2932            )
2933
2934        if new_code.co_freevars:
2935            # expose code object for debugging purposes
2936            self.output.install_global_unsafe(name, new_code)
2937            cg.make_function_with_closure(name, new_code, True, stack_len)
2938        else:
2939            # This is safe: we pre-generate a unique name
2940            self.output.install_global_unsafe(
2941                name, types.FunctionType(new_code, self.f_globals, name)
2942            )
2943            cg.extend_output(cg.load_function_name(name, True, stack_len))
2944
2945        cg.extend_output([cg.create_load(k) for k in argnames])
2946        cg.extend_output(create_call_function(nargs, False))
2947        cg.append_output(create_instruction("RETURN_VALUE"))
2948        return cg.get_instructions()
2949
2950    def symbolic_locals_contain_module_class(self):
2951        for v in self.symbolic_locals.values():
2952            if isinstance(v, UserDefinedClassVariable) and issubclass(
2953                v.as_python_constant(), torch.nn.Module
2954            ):
2955                return True
2956        return False
2957
2958    def _return(self, inst):
2959        if (
2960            self.output.count_calls() == 0
2961            and not self.inconsistent_side_effects
2962            and not self.symbolic_locals_contain_module_class()
2963            and not self.export
2964        ):
2965            raise exc.SkipFrame("because no content in function call")
2966        self.instruction_pointer = None
2967        _step_logger()(
2968            logging.INFO,
2969            f"torchdynamo done tracing {self.f_code.co_name} ({inst.opname})",
2970        )
2971        log.debug("%s triggered compile", inst.opname)
2972        self.output.compile_subgraph(
2973            self,
2974            reason=GraphCompileReason(
2975                "return_value", [self.frame_summary()], graph_break=False
2976            ),
2977        )
2978        return_inst = (
2979            create_instruction("RETURN_VALUE")
2980            if inst.opname == "RETURN_VALUE"
2981            else create_instruction("RETURN_CONST", argval=inst.argval)
2982        )
2983        self.output.add_output_instructions([return_inst])
2984        raise ReturnValueOp
2985
2986    def RETURN_VALUE(self, inst):
2987        self._return(inst)
2988
2989    def RETURN_CONST(self, inst):
2990        self._return(inst)
2991
2992
2993if sys.version_info >= (3, 11):
2994    _binary_op_lookup = [
2995        getattr(
2996            InstructionTranslator,
2997            opname[3:] if "INPLACE" in opname else f"BINARY_{opname[3:]}",
2998        )
2999        for opname, _ in dis._nb_ops  # type: ignore[attr-defined]
3000    ]
3001
3002
3003class InliningInstructionTranslator(InstructionTranslatorBase):
3004    """Trace and inline a called method"""
3005
3006    symbolic_result: Optional[TensorVariable]
3007
3008    @classmethod
3009    def inline_call(cls, parent, func, args, kwargs):
3010        with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
3011            return cls.inline_call_(parent, func, args, kwargs)
3012
3013    @staticmethod
3014    def check_inlineable(func):
3015        if func.has_self():
3016            unimplemented("inline with __self__")
3017
3018        result = trace_rules.check_verbose(func, is_inlined_call=True)
3019        if result.skipped:
3020            from torch._dynamo.variables.misc import produce_trampoline_autograd_apply
3021
3022            # _origin marks this as coming from an internal dynamo known function that is safe to
3023            # trace through.
3024            if hasattr(getattr(func, "fn", None), "_origin") and func.fn._origin in [
3025                produce_trampoline_autograd_apply,
3026            ]:
3027                # Known sound
3028                return trace_rules.SkipResult(
3029                    False, "allowlist in dynamo known function"
3030                )
3031            fn_qualname = func.fn.__qualname__ if hasattr(func, "fn") else ""
3032            unimplemented(
3033                f"'inline in skipfiles: {fn_qualname} | {func.get_name()} {func.get_filename()}, {result.reason}'"
3034            )
3035
3036        if isinstance(func, UserFunctionVariable) and inspect.getattr_static(
3037            func.get_function(), "_torchdynamo_disable", False
3038        ):
3039            unimplemented(
3040                f"call torch._dynamo.disable() wrapped function {func.get_function()}"
3041            )
3042        else:
3043            return result
3044
3045    @staticmethod
3046    def inline_call_(
3047        parent, func: VariableTracker, args: List[VariableTracker], kwargs
3048    ):
3049        if isinstance(func, SkipFunctionVariable):
3050            unimplemented("inline with functions in skip files")
3051        assert isinstance(
3052            func,
3053            (UserFunctionVariable, NestedUserFunctionVariable),
3054        )
3055        result = InliningInstructionTranslator.check_inlineable(func)
3056        assert result.skipped is False
3057        try:
3058            sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
3059        except TypeError as e:
3060            # Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info
3061            raise ArgsMismatchError(  # noqa: B904
3062                "{reason}.\n  func = {func}, args = {args}, kwargs = {kwargs}".format(
3063                    reason=str(e),
3064                    func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}",
3065                    args=[arg.python_type() for arg in args],
3066                    kwargs=kwargs,
3067                ),
3068            )
3069
3070        for v in itertools.chain(sub_locals.values(), closure_cells.values()):
3071            if not isinstance(v, VariableTracker):
3072                unimplemented(f"unconverted arg {v}")
3073
3074        code: types.CodeType = func.get_code()
3075        if code.co_name in ("__setitem__", "__setattr__") and not (
3076            args
3077            and isinstance(
3078                args[0],
3079                (variables.CustomizedDictVariable, variables.UserDefinedObjectVariable),
3080            )
3081        ):
3082            unimplemented(f"inline {code.co_name}")
3083
3084        suffix = ""
3085        # TODO: mlazos, add support for enabling multiple artifact logs
3086        # with a single alias
3087        if torch._logging._internal.log_state.is_artifact_enabled("bytecode"):
3088            suffix = f"\n{dis.Bytecode(code).dis()}"
3089        if sys.version_info >= (3, 11):
3090            cur_inst = parent.current_instruction
3091            parent_code = parent.f_code
3092            header = parent.get_line_of_code_header(lineno=cur_inst.positions.lineno)
3093
3094            def get_trace_call_log_str():
3095                line = get_instruction_source_311(parent_code, cur_inst).rstrip()
3096                return f"TRACE inlined call {code.co_name} from {header}\n{line}"
3097
3098            trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
3099        log.debug("INLINING %s%s, %s", code, suffix, result.reason)
3100
3101        # Detect inline GraphModule calls in order to propagate node metadata,
3102        # by checking if the first argument (self) is a variable tracking a GraphModule.
3103        if args and isinstance(args[0], NNModuleVariable):
3104            module = parent.output.get_submodule(args[0].module_key)
3105            if isinstance(module, torch.fx.GraphModule):
3106                # The inline call might not actually be a call to `forward`,
3107                # but it is enough to add a context for `forward` in case it is called.
3108                code_context.get_context(module.forward.__code__)[
3109                    "orig_graphmodule"
3110                ] = weakref.ref(module)
3111
3112        tracer: InliningInstructionTranslator
3113        if is_generator(code):
3114            tracer = InliningGeneratorInstructionTranslator(
3115                parent,
3116                code,
3117                sub_locals,
3118                parent.symbolic_globals,
3119                parent.symbolic_torch_function_mode_stack,
3120                closure_cells,
3121                func,
3122            )
3123        else:
3124            tracer = InliningInstructionTranslator(
3125                parent,
3126                code,
3127                sub_locals,
3128                parent.symbolic_globals,
3129                parent.symbolic_torch_function_mode_stack,
3130                closure_cells,
3131                func,
3132            )
3133
3134        strict_ctx: Any = contextlib.nullcontext()
3135        if parent.strict_checks_fn:
3136            strict_ctx = tracer.strict_translation_mode(parent.strict_checks_fn)
3137        try:
3138            with strict_ctx:
3139                tracer.run()
3140        except exc.ObservedException as e:
3141            msg = f"Observed exception DURING INLING {code} : {e}"
3142            # TODO(anijain2305) - This works but we should probably have a
3143            # global/central data structure for the exception stack.
3144            parent.exn_vt_stack.extend(tracer.exn_vt_stack)
3145            log.debug(msg)
3146            # bubble up the exception to the parent frame.
3147            raise
3148        except exc.SkipFrame as e:
3149            msg = f"SKIPPED INLINING {code}: {e}"
3150            log.debug(msg)
3151            raise Unsupported(msg) from e
3152        except Exception as e:
3153            log.debug("FAILED INLINING %s", code)
3154            raise
3155        assert tracer.symbolic_result is not None
3156        func.export_freevars(parent, tracer)
3157
3158        if tracer.f_globals is parent.f_globals:
3159            # Merge symbolic_globals back if parent and child are in the same namespace
3160            parent.symbolic_globals.update(tracer.symbolic_globals)
3161
3162        parent.inconsistent_side_effects |= tracer.inconsistent_side_effects
3163
3164        log.debug("DONE INLINING %s", code)
3165
3166        if is_generator(code):
3167            assert isinstance(tracer, InliningGeneratorInstructionTranslator)
3168            assert tracer.symbolic_result.as_python_constant() is None
3169            return ListIteratorVariable(
3170                tracer.generated_items,
3171                mutable_local=MutableLocal(),
3172            )
3173        else:
3174            return tracer.symbolic_result
3175
3176    def __init__(
3177        self,
3178        parent: InstructionTranslatorBase,
3179        code: types.CodeType,
3180        symbolic_locals: Dict[str, VariableTracker],
3181        symbolic_globals: Dict[str, VariableTracker],
3182        symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
3183        closure_cells: Dict[str, VariableTracker],
3184        funcvar: BaseUserFunctionVariable,
3185    ) -> None:
3186        f_globals = funcvar.get_globals()  # type: ignore[attr-defined]
3187        f_builtins = f_globals["__builtins__"]
3188        if not isinstance(f_builtins, dict):
3189            f_builtins = f_builtins.__dict__
3190        instructions = cleaned_instructions(code)
3191        propagate_line_nums(instructions)
3192        super().__init__(
3193            output=parent.output,
3194            f_locals={},
3195            f_globals=f_globals,
3196            f_builtins=f_builtins,
3197            symbolic_locals=symbolic_locals,
3198            symbolic_globals=symbolic_globals,
3199            symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack,
3200            instructions=instructions,
3201            code_options={k: getattr(code, k) for k in get_code_keys()},
3202            f_code=code,
3203            export=parent.export,
3204            inline_depth=parent.inline_depth + 1,
3205            speculation_log=parent.speculation_log,
3206            distributed_state=parent.distributed_state,
3207        )
3208        self.parent = parent
3209        self.symbolic_result = None
3210        self.closure_cells = closure_cells
3211        self.nn_module_stack = parent.nn_module_stack.copy()
3212        self.one_graph = parent.one_graph
3213
3214    @property
3215    def fake_mode(self):
3216        return self.parent.fake_mode
3217
3218    def run_ctx_mgr(self):
3219        return TracingContext.current_frame(self.parent.frame_summary())
3220
3221    def STORE_DEREF(self, inst):
3222        if inst.argval in self.closure_cells:
3223            cell = self.closure_cells[inst.argval]
3224            val = self.pop()
3225            if isinstance(cell, ClosureVariable):
3226                if not self.output.is_root_tracer():
3227                    unimplemented(
3228                        "HigherOrderOperator: Mutating a variable not in the current scope (ClosureVariable)"
3229                    )
3230                self.output.root_tx.symbolic_locals[cell.name] = val
3231            else:
3232                self.output.side_effects.store_cell(cell, val)
3233        else:
3234            maybe_cell = self.symbolic_locals.get(inst.argval)
3235            if isinstance(
3236                maybe_cell,
3237                variables.NewCellVariable,
3238            ):
3239                self.output.side_effects.store_cell(
3240                    self.symbolic_locals[inst.argval], self.pop()
3241                )
3242            else:
3243                if (
3244                    maybe_cell is not None
3245                    and maybe_cell.source.name()
3246                    not in self.output.root_tx.mutated_closure_cell_contents
3247                ):
3248                    # Why is the source name here unique?
3249                    # mutated_closure_cell_contents is a per-frame
3250                    # concept, and sources identify, e.g., particular
3251                    # locals from the frame.  If you had two locals,
3252                    # they'll get different source names, and therefore
3253                    # differ here.
3254                    self.output.root_tx.mutated_closure_cell_contents.add(
3255                        maybe_cell.source.name()
3256                    )
3257                    raise exc.UnspecializeRestartAnalysis
3258                unimplemented("write to __closure__ while inlining")
3259
3260    def LOAD_DEREF(self, inst):
3261        if inst.argval in self.closure_cells:
3262            cell = self.closure_cells[inst.argval]
3263            if isinstance(cell, ClosureVariable):
3264                self.push(self.output.root_tx.symbolic_locals[cell.name])
3265            else:
3266                self.push(self.output.side_effects.load_cell(cell))
3267        else:
3268            maybe_sym_local = self.symbolic_locals.get(inst.argval, None)
3269            if isinstance(maybe_sym_local, variables.NewCellVariable):
3270                self.push(self.output.side_effects.load_cell(maybe_sym_local))
3271            else:
3272                super().LOAD_DEREF(inst)
3273
3274    def _load_closure(self, name):
3275        assert name in self.cell_and_freevars()
3276        if name in self.closure_cells:
3277            return self.closure_cells[name]
3278        else:
3279            return InlinedClosureVariable(name=name)
3280
3281    def check_replace_is_safe(self, oldvar):
3282        if not is_side_effect_safe(oldvar.mutable_local):
3283            unimplemented(
3284                "HigherOrderOperator: Mutating a variable not in the current scope (replace_all)"
3285            )
3286
3287    def should_compile_partial_graph(self):
3288        return False  # inlining functions is all-or-nothing
3289
3290    def create_call_resume_at(self, offset):
3291        unimplemented("cant resume while inlining")
3292
3293    def RETURN_VALUE(self, inst):
3294        self.symbolic_result = self.pop()  # type: ignore[assignment]
3295        self.instruction_pointer = None
3296        raise ReturnValueOp
3297
3298    def RETURN_CONST(self, inst):
3299        self.symbolic_result = self._load_const(inst)
3300        self.instruction_pointer = None
3301        raise ReturnValueOp
3302
3303    def get_globals_source_and_value(self, name):
3304        if "__name__" in self.f_globals:
3305            module_name = self.f_globals["__name__"]
3306            module_source = self.import_source(module_name)
3307            if "torch_package" in module_name:
3308                fglobals_value = torch.package.package_importer._package_imported_modules[module_name]  # type: ignore[assignment]
3309            else:
3310                fglobals_value = importlib.import_module(module_name)  # type: ignore[assignment]
3311            fglobals_vt = VariableBuilder(self, module_source)(fglobals_value)
3312            global_source = AttrSource(module_source, name)
3313        else:
3314            globals_name = self.output.install_global_by_id(
3315                "___unnamed_scope", self.f_globals
3316            )
3317            globals_source = GlobalSource(globals_name)
3318            fglobals_value = self.f_globals  # type: ignore[assignment]
3319            fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value)
3320            global_source = GetItemSource(globals_source, name)  # type: ignore[assignment]
3321        return fglobals_value, fglobals_vt, global_source
3322
3323    def _load_global(self, inst):
3324        if self.output.global_scope is self.f_globals:
3325            super()._load_global(inst)
3326        else:
3327            name = inst.argval
3328
3329            _, fglobals_vt, global_source = self.get_globals_source_and_value(name)
3330            if self.output.side_effects.has_pending_mutation_of_attr(fglobals_vt, name):
3331                self.push(self.output.side_effects.load_attr(fglobals_vt, name))
3332            else:
3333                try:
3334                    value = self.f_globals[name]
3335                except KeyError:
3336                    return self.load_builtin(inst)
3337
3338                self.push(VariableBuilder(self, global_source)(value))
3339
3340    def STORE_GLOBAL(self, inst):
3341        if self.f_globals is self.parent.f_globals:
3342            super().STORE_GLOBAL(inst)
3343        else:
3344            value = self.pop()
3345            if isinstance(value, RemovableHandleVariable):
3346                unimplemented("Storing handles in globals - NYI")
3347            name = inst.argval
3348            fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name)
3349            self.output.side_effects.store_attr(fglobals_vt, name, value)
3350
3351
3352class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
3353    generated_items: List[VariableTracker]
3354
3355    def __init__(self, *args, **kwargs) -> None:
3356        super().__init__(*args, **kwargs)
3357        self.generated_items = []
3358
3359    def YIELD_VALUE(self, inst: Instruction):
3360        self.generated_items.append(self.pop())
3361        if len(self.generated_items) > MAX_ITERATOR_LIMIT:
3362            unimplemented(
3363                "Too many yield values in generator. Maybe you are inlining an infinite generator. "
3364                f"If not, please report a bug at {PT2_ISSUE_TRACKER_URL}",
3365            )
3366        self.push(ConstantVariable.create(None))
3367
3368    def GET_YIELD_FROM_ITER(self, inst):
3369        tos = self.stack[-1]
3370        if not isinstance(tos, ListIteratorVariable):
3371            self.pop()
3372            res = BuiltinVariable(iter).call_function(self, [tos], {})  # type: ignore[arg-type]
3373            self.push(res)
3374
3375    def YIELD_FROM(self, inst):
3376        assert len(self.stack) >= 2
3377        val = self.pop()
3378        tos = self.stack[-1]
3379        if not (isinstance(val, ConstantVariable) and val.value is None):
3380            # invoke send
3381            # Unreachable code - if you hit this, you are implementing generator support and have
3382            # lifted the `unimplemented("generator")` in frame conversion. This codepath handles
3383            # subgenerator and lines up with this line in Python 3.10
3384            # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L2599
3385            unimplemented("Unreachable sub-generator code")
3386
3387        try:
3388            val = tos.next_variable(self)
3389        except (StopIteration, exc.ObservedUserStopIteration) as ex:
3390            if isinstance(ex, exc.ObservedUserStopIteration):
3391                exc.handle_observed_exception(self)
3392
3393            # The iterator is exhausted. Stop the loop and return.
3394            self.pop()
3395            self.push(ConstantVariable.create(ex.value))
3396        else:
3397            self.push(val)
3398            # Add the value to yield into generated_items and replace the top of the stack with None
3399            self.YIELD_VALUE(inst)
3400
3401            # Repeat the YIELD_FROM instruction in the next eval loop
3402            assert (
3403                isinstance(self.instruction_pointer, int)
3404                and self.instruction_pointer > 0
3405            )
3406            self.instruction_pointer -= 1
3407
3408    def SEND(self, inst):
3409        assert len(self.stack) >= 2
3410        val = self.pop()
3411        tos = self.stack[-1]
3412        if isinstance(tos, ListIteratorVariable) or (
3413            isinstance(tos, UserDefinedObjectVariable)
3414            and isinstance(tos.value, collections.abc.Iterator)
3415        ):
3416            if isinstance(val, ConstantVariable) and val.value is None:
3417                try:
3418                    val = tos.next_variable(self)
3419                except (StopIteration, exc.ObservedUserStopIteration) as ex:
3420                    # To implement SEND, we have to look at the implementation
3421                    # when the iterator returns StopIteration. This translates to this code
3422                    # 3.11: https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2613-L2619
3423                    # 3.12: https://github.com/python/cpython/blob/3.12/Python/bytecodes.c#L863-L866
3424                    # The implementation is different in 3.11 and 3.12. In 3.12, we rely
3425                    # on END_SEND to clean up. In 3.11, SEND does the cleanup as well.
3426                    if sys.version_info < (3, 12):
3427                        self.pop()  # Python 3.12 uses new opcode END_SEND
3428                    self.push(ConstantVariable.create(ex.value))
3429                    self.jump(inst)
3430                else:
3431                    self.push(val)
3432            else:
3433                # invoke send
3434                # Unreachable code - if you hit this, you are implementing generator support and have
3435                # lifted the `unimplemented("generator")` in frame conversion. This codepath handles
3436                # subgenerator and lines up with this line in Python 3.11
3437                # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597
3438                unimplemented("Unreachable sub-generator code")
3439        else:
3440            unimplemented(f"SEND {typestr(tos)}")
3441