1# mypy: allow-untyped-defs 2import collections 3import collections.abc 4import contextlib 5import copy 6import dataclasses 7import dis 8import functools 9import importlib 10import inspect 11import itertools 12import linecache 13import logging 14import operator 15import re 16import sys 17import threading 18import traceback 19import types 20import typing 21import weakref 22from typing import ( 23 Any, 24 Callable, 25 cast, 26 Deque, 27 Dict, 28 List, 29 Optional, 30 Set, 31 Tuple, 32 Type, 33 TYPE_CHECKING, 34 Union, 35) 36from unittest.mock import patch 37 38import torch 39import torch._logging 40from torch._guards import tracing, TracingContext 41 42from . import config, exc, logging as torchdynamo_logging, trace_rules, variables 43from .bytecode_analysis import ( 44 get_indexof, 45 JUMP_OPNAMES, 46 livevars_analysis, 47 propagate_line_nums, 48) 49from .bytecode_transformation import ( 50 cleaned_instructions, 51 create_call_function, 52 create_instruction, 53 create_jump_absolute, 54 create_swap, 55 get_code_keys, 56 Instruction, 57 is_generator, 58 unique_id, 59) 60from .code_context import code_context 61from .codegen import PyCodegen 62from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported 63from .funcname_cache import get_funcname 64from .guards import GuardBuilder, install_guard 65from .output_graph import GraphCompileReason, OutputGraph 66from .replay_record import DummyModule, ExecutionRecorder 67from .resume_execution import ContinueExecutionCache, ReenterWith 68from .source import ( 69 AttrSource, 70 GetItemSource, 71 GlobalSource, 72 GlobalWeakRefSource, 73 LocalSource, 74 Source, 75 TorchFunctionModeStackSource, 76) 77from .trace_rules import is_builtin_constant, is_forbidden 78from .utils import ( 79 counters, 80 get_fake_value, 81 get_instruction_source_311, 82 get_torch_function_mode_stack, 83 graph_break_dup_warning_checker, 84 istype, 85 LazyString, 86 proxy_args_kwargs, 87) 88from .variables.base import is_side_effect_safe, MutableLocal, typestr, VariableTracker 89from .variables.builder import VariableBuilder, wrap_fx_proxy 90from .variables.builtin import BuiltinVariable 91from .variables.constant import ConstantVariable 92from .variables.ctx_manager import ( 93 ContextWrappingVariable, 94 GenericContextWrappingVariable, 95 WithExitFunctionVariable, 96) 97from .variables.dicts import ConstDictVariable, SetVariable 98from .variables.functions import ( 99 BaseUserFunctionVariable, 100 NestedUserFunctionVariable, 101 SkipFunctionVariable, 102 UserFunctionVariable, 103 UserMethodVariable, 104) 105from .variables.iter import MAX_ITERATOR_LIMIT 106from .variables.lists import ( 107 BaseListVariable, 108 ListIteratorVariable, 109 ListVariable, 110 SliceVariable, 111 TupleVariable, 112) 113from .variables.misc import ( 114 ClosureVariable, 115 GetAttrVariable, 116 InlinedClosureVariable, 117 NullVariable, 118 PythonModuleVariable, 119 UnknownVariable, 120) 121from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable 122from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable 123 124 125if TYPE_CHECKING: 126 from .variables.torch_function import TorchFunctionModeVariable 127 128from .variables.user_defined import ( 129 RemovableHandleVariable, 130 UserDefinedClassVariable, 131 UserDefinedObjectVariable, 132) 133 134 135log = logging.getLogger(__name__) 136graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") 137trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") 138trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source") 139trace_bytecode_log = torch._logging.getArtifactLogger(__name__, "trace_bytecode") 140tls = threading.local() 141compare_op_handlers: Dict[str, Any] = { 142 k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items() 143} 144handle_contains = BuiltinVariable(operator.contains).call_function 145handle_not = BuiltinVariable(operator.not_).call_function 146compare_op_handlers["in"] = lambda tx, args, _: handle_contains( 147 tx, [*reversed(args)], {} 148) 149compare_op_handlers["not in"] = lambda tx, args, _: handle_not( 150 tx, [handle_contains(tx, [*reversed(args)], {})], {} 151) 152 153 154PT2_ISSUE_TRACKER_URL = "https://github.com/pytorch/pytorch/issues/new?&labels=oncall%3A+pt2&projects=&template=pt2-bug-report.yml" 155 156 157@dataclasses.dataclass 158class SpeculationEntry: 159 filename: str 160 lineno: int 161 instruction_pointer: int 162 inst: Instruction # for debugging only 163 failed: bool = False 164 reason: Optional[GraphCompileReason] = None 165 166 def fail_and_restart_analysis(self): 167 """ 168 Start tracing of the current frame over again, and don't take this branch. 169 """ 170 self.failed = True 171 if self.reason is not None: 172 restart_reason = self.reason.reason 173 else: 174 restart_reason = "Unknown fail_and_restart_analysis" 175 raise exc.SpeculationRestartAnalysis(restart_reason=restart_reason) 176 177 178@dataclasses.dataclass 179class SpeculationLog: 180 """ 181 SpeculationLog replaces the prior copy_graphstate/restore_graphstate 182 checkpointing. Rather than saving/restoring state, we restart the 183 dynamo conversion process over from the beginning -- but when we 184 hit the start of the speculation that failed, we instead generate 185 a graph break. 186 """ 187 188 entries: List[SpeculationEntry] = dataclasses.field(default_factory=list) 189 index: int = 0 190 191 def restart(self): 192 self.index = 0 193 194 def clear(self): 195 self.entries.clear() 196 self.index = 0 197 198 def next( 199 self, filename: str, lineno: int, instruction_pointer, inst 200 ) -> SpeculationEntry: 201 """ 202 Lookup or create a SpeculationEntry() that is shared across 203 RestartAnalysis calls. Args are used only for debug checks. 204 """ 205 if len(self.entries) == self.index: 206 self.entries.append( 207 SpeculationEntry(filename, lineno, instruction_pointer, inst) 208 ) 209 entry = self.entries[self.index] 210 prev_entry_msg = "" 211 if self.index != 0: 212 prev_entry = self.entries[self.index - 1] 213 prev_entry_msg = ( 214 f"Previous instruction: {prev_entry.filename}:{prev_entry.lineno}" 215 f"({prev_entry.inst.opname} @ {prev_entry.instruction_pointer})\n" 216 ) 217 assert ( 218 entry.instruction_pointer == instruction_pointer 219 and entry.filename == filename 220 and entry.lineno == lineno 221 ), f""" 222SpeculationLog diverged at index {self.index} (log had {len(self.entries)} entries): 223- Expected: {entry.filename}:{entry.lineno} ({entry.inst.opname} at ip={entry.instruction_pointer}) 224- Actual: {filename}:{lineno} ({inst.opname} at ip={instruction_pointer}) 225{prev_entry_msg} 226There are two usual reasons why this may have occured: 227- When Dynamo analysis restarted, the second run took a different path than 228 the first. If this occurred, the previous instruction is the critical instruction that 229 behaved differently. 230- Speculation entries are only added under certain conditions (as seen in 231 step()), e.g., there must exist operators in the graph; those conditions may 232 have changed on restart. 233 234If this divergence was intentional, clear the speculation log before restarting (do NOT 235do this for graph breaks, you will infinite loop). 236 237Otherwise, please submit a bug report, ideally including the contents of TORCH_LOGS=+dynamo 238""" 239 self.index += 1 240 return entry 241 242 243@dataclasses.dataclass 244class LocalState: 245 input_sizes: Dict[str, List[int]] = dataclasses.field(default_factory=dict) 246 input_strides: Dict[str, List[int]] = dataclasses.field(default_factory=dict) 247 248 249# Mutable box that is shared across restarts 250@dataclasses.dataclass 251class DistributedState: 252 compile_pg: Any 253 local_state: LocalState 254 all_states: Optional[List[LocalState]] = None 255 256 257@functools.lru_cache(None) 258def _step_logger(): 259 return torchdynamo_logging.get_step_logger(log) 260 261 262@dataclasses.dataclass 263class BlockStackEntry: 264 # Current instruction that pushes something to block_stack 265 inst: Instruction 266 target: Instruction 267 stack_index: Optional[int] = None 268 with_context: Optional[ 269 Union[ContextWrappingVariable, GenericContextWrappingVariable] 270 ] = None 271 272 def can_restore(self): 273 return self.with_context is not None 274 275 def resume_fn(self): 276 assert self.stack_index is not None 277 if ( 278 self.with_context 279 and hasattr(self.with_context, "target_values") 280 and self.with_context.target_values 281 ): 282 return ReenterWith(self.stack_index, tuple(self.with_context.target_values)) 283 else: 284 return ReenterWith(self.stack_index) 285 286 def exit(self, tx): 287 assert self.with_context is not None 288 return self.with_context.exit(tx) 289 290 291class ReturnValueOp(Exception): 292 pass 293 294 295def stack_op(fn: typing.Callable[..., object]): 296 nargs = len(inspect.signature(fn).parameters) 297 fn_var = BuiltinVariable(fn) 298 299 @functools.wraps(fn) 300 def impl(self: "InstructionTranslator", inst: Instruction): 301 self.push(fn_var.call_function(self, self.popn(nargs), {})) 302 303 return impl 304 305 306def _detect_and_normalize_assert_statement( 307 self: "InstructionTranslatorBase", 308 truth_fn: typing.Callable[[object], bool], 309 push: bool, 310): 311 # Detect if this jump instruction is assert and normalize the assert 312 # by pushing dummy error message when nothing is given. 313 # 314 # Python 3.9 assertion is in following format: 315 # 18 POP_JUMP_IF_TRUE 28 316 # 20 LOAD_ASSERTION_ERROR 317 # 22 LOAD_CONST 3 ('Assert message') -> optional instruction 318 # 24 CALL_FUNCTION 1 -> optional instruction 319 # 26 RAISE_VARARGS 320 # 321 # Python 3.8 assertion is in following format: 322 # 18 POP_JUMP_IF_TRUE 28 323 # 20 LOAD_GLOBAL 0 (Assertion type) 324 # 22 LOAD_CONST 3 ('Assert message') -> optional instruction 325 # 24 CALL_FUNCTION 1 -> optional instruction 326 # 26 RAISE_VARARGS 1 327 328 if (truth_fn is not operator.truth) or push: 329 return False 330 331 assert isinstance(self.instruction_pointer, int) 332 current_instruction_pointer = self.instruction_pointer 333 inst = self.instructions[current_instruction_pointer] 334 # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0 335 if sys.version_info < (3, 9): 336 if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError": 337 return False 338 else: 339 if inst.opname != "LOAD_ASSERTION_ERROR": 340 return False 341 342 current_instruction_pointer += 1 343 344 # Use dummy error message if its hard to extract 345 error_msg = "assertion error" 346 347 inst = self.instructions[current_instruction_pointer] 348 # DETECT RAISE_VARARGS or LOAD CONST 349 if inst.opname == "LOAD_CONST": 350 if not isinstance(inst.argval, str): 351 return False 352 error_msg = inst.argval 353 354 # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION 355 # (PRECALL for Python 3.11, CALL for Python 3.12+) 356 current_instruction_pointer += 1 357 inst = self.instructions[current_instruction_pointer] 358 if inst.opname not in ("CALL_FUNCTION", "PRECALL", "CALL"): 359 return False 360 361 # for Python 3.11, PRECALL should be followed by CALL, then RAISE_VARARGS 362 # for Python != 3.11, CALL_FUNCTION/CALL should be followed by RAISE_VARARGS 363 current_instruction_pointer += 1 364 if inst.opname == "PRECALL": 365 current_instruction_pointer += 1 366 inst = self.instructions[current_instruction_pointer] 367 368 if inst.opname != "RAISE_VARARGS": 369 return False 370 371 self.push(ConstantVariable.create(error_msg)) 372 373 return True 374 375 376def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool): 377 def jump_graph_break(self, inst, value, extra_msg=""): 378 if not self.should_compile_partial_graph(): 379 unimplemented("should_compile_partial_graph=False") 380 # compile a partial subgraph prefix then jump into user code 381 if self.maybe_has_backedge(): 382 msg = ( 383 "Skipping frame because there is a graph break in a for/while loop\n" 384 f"{self.frame_summary()}" 385 ) 386 log.info(msg) 387 raise exc.SkipFrame(msg) 388 389 self.push(value) 390 log.debug("generic_jump triggered compile") 391 self.output.compile_subgraph( 392 self, 393 reason=GraphCompileReason( 394 f"generic_jump {typestr(value)}{extra_msg}", [self.frame_summary()] 395 ), 396 ) 397 self.pop() 398 399 if_next = self.create_call_resume_at(self.next_instruction) 400 if push: 401 self.push(value) 402 if_jump = self.create_call_resume_at(inst.target) 403 404 if sys.version_info >= (3, 13): 405 # 3.13 requires stack[-1] to be bool type 406 self.output.add_output_instructions([create_instruction("TO_BOOL")]) 407 408 self.output.add_output_instructions( 409 [create_instruction(inst.opname, target=if_jump[0])] + if_next + if_jump 410 ) 411 412 def inner(self: "InstructionTranslatorBase", inst: Instruction): 413 value: VariableTracker = self.pop() 414 if ( 415 config.rewrite_assert_with_torch_assert 416 and _detect_and_normalize_assert_statement(self, truth_fn, push) 417 ): 418 error_msg: VariableTracker = self.pop() 419 # Skip over things like `assert True` 420 if value.is_python_constant(): 421 if bool(value.as_python_constant()): 422 return self.jump(inst) 423 else: 424 jump_graph_break(self, inst, value) 425 426 # TODO maybe should respect DtoH sync intention of users later?? 427 # Manually insert torch._assert_async instead of python assert and jump over 428 # assert related instructions as we don't need them anymore. 429 430 # if we see Tensor as assert statement, no need to call scalar_tensor 431 if isinstance(value, TensorVariable): 432 self.output.create_proxy( 433 "call_function", 434 torch._assert_async, 435 *proxy_args_kwargs((value, error_msg), {}), 436 ) 437 self.jump(inst) 438 return 439 440 if isinstance(value, SymNodeVariable): 441 # if the assertion is normal shape expression. 442 # just install guard and bail out. 443 sym_expr = value.sym_num 444 if not isinstance(sym_expr, torch.SymBool): 445 sym_expr = sym_expr != 0 446 447 result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr) 448 if not result: 449 unimplemented( 450 "Assertion failed on symbolic shapes. Did you make sure eager mode succeeds?" 451 ) 452 self.jump(inst) 453 return 454 455 scalar_to_tensor_proxy = self.output.create_proxy( 456 "call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {}) 457 ) 458 459 scalar_to_tensor = wrap_fx_proxy( 460 self, 461 scalar_to_tensor_proxy, 462 example_value=get_fake_value(scalar_to_tensor_proxy.node, self), 463 ) 464 465 self.output.create_proxy( 466 "call_function", 467 torch._assert_async, 468 *proxy_args_kwargs((scalar_to_tensor, error_msg), {}), 469 ) 470 self.jump(inst) 471 return 472 473 if value.is_python_constant(): 474 if truth_fn(value.as_python_constant()): 475 if push: 476 self.push(value) 477 self.jump(inst) 478 elif ( 479 isinstance(value, (TensorVariable)) and self.should_compile_partial_graph() 480 ): 481 jump_graph_break(self, inst, value) 482 elif isinstance(value, NNModuleVariable): 483 # Equivalent of "self.nn_module is not None" 484 mod = self.output.get_submodule(value.module_key) 485 if truth_fn(mod): 486 if push: 487 self.push(value) 488 self.jump(inst) 489 elif isinstance(value, UnspecializedNNModuleVariable): 490 mod = value.value 491 if truth_fn(mod): 492 if push: 493 self.push(value) 494 self.jump(inst) 495 elif isinstance(value, UserDefinedObjectVariable): 496 try: 497 x = value.var_getattr(self, "__bool__") # type: ignore[arg-type] 498 except exc.ObservedAttributeError: 499 exc.handle_observed_exception(self) 500 # if __bool__ is missing, trying __len__ to infer a truth value. 501 try: 502 x = value.var_getattr(self, "__len__") # type: ignore[arg-type] 503 except exc.ObservedAttributeError: 504 exc.handle_observed_exception(self) 505 x = None 506 507 # __bool__ or __len__ is function 508 if isinstance(x, UserMethodVariable): 509 result = x.call_function(self, [], {}) # type: ignore[arg-type] 510 if isinstance(result, ConstantVariable) and isinstance( 511 result.value, (bool, int) 512 ): 513 if truth_fn(result.value): 514 if push: 515 self.push(value) 516 self.jump(inst) 517 else: 518 unimplemented( 519 "generic_jump on UserDefined with __bool__ returning non-constant" 520 ) 521 # __bool__ or __len__ is non-function or not existed in the user defined object 522 else: 523 if truth_fn(True): 524 if push: 525 self.push(value) 526 self.jump(inst) 527 elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence( 528 self 529 ): 530 if truth_fn(len(value.unpack_var_sequence(self))): 531 if push: 532 self.push(value) 533 self.jump(inst) 534 elif isinstance(value, SymNodeVariable): 535 try: 536 eval_result = value.evaluate_expr(self.output) 537 except exc.UserError as e: 538 if self.should_compile_partial_graph(): 539 return jump_graph_break(self, inst, value, extra_msg=f"\n{e}") 540 raise 541 if truth_fn(eval_result): 542 if push: 543 self.push(value) 544 self.jump(inst) 545 elif isinstance(value, variables.BackwardHookVariable): 546 if truth_fn(True): 547 if push: 548 self.push(value) 549 self.jump(inst) 550 else: 551 from .source import is_constant_source 552 553 if value.source is not None and is_constant_source(value.source): 554 if truth_fn(value.get_real_value()): # type: ignore[attr-defined] 555 if push: 556 self.push(value) 557 self.jump(inst) 558 else: 559 # TODO link the torch.cond doc later 560 raise exc.UserError( 561 exc.UserErrorType.DYNAMIC_CONTROL_FLOW, 562 "Dynamic control flow is not supported at the moment. Please use " 563 "functorch.experimental.control_flow.cond to explicitly capture the control flow.", 564 case_name="cond_operands", 565 ) 566 567 return inner 568 569 570explain = False 571 572 573def break_graph_if_unsupported(*, push): 574 def decorator(inner_fn): 575 @functools.wraps(inner_fn) 576 def wrapper(self: "InstructionTranslatorBase", inst: Instruction): 577 speculation = self.speculate() 578 if speculation.failed: 579 assert speculation.reason is not None 580 return handle_graph_break(self, inst, speculation.reason) 581 try: 582 return inner_fn(self, inst) 583 except Unsupported as excp: 584 if self.generic_context_manager_depth > 0: 585 # We don't support graph break under GenericContextWrappingVariable, 586 # If there is, we roll back to the checkpoint and fall back. 587 excp.remove_from_stats() 588 unimplemented("Graph break under GenericContextWrappingVariable") 589 590 if isinstance(excp, exc.UncapturedHigherOrderOpError): 591 raise 592 593 if not self.should_compile_partial_graph(): 594 raise 595 596 user_stack = excp.real_stack 597 # TODO: Also report the traceback from the parent frame 598 try: 599 frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) 600 except IndexError: 601 # first instruction 602 code_options = self.code_options 603 frame_loc = ( 604 code_options["co_filename"], 605 code_options["co_firstlineno"], 606 ) 607 # torch._dynamo.explain() formats this a little nicer, and presents a slightly 608 # more actionable user code pointer 609 if ( 610 graph_break_log.isEnabledFor(logging.DEBUG) 611 and not explain 612 and graph_break_dup_warning_checker.add(frame_loc) 613 ): 614 user_stack_formatted = "".join(traceback.format_list(user_stack)) 615 # This log line is exercised from 616 # python test/dynamo/test_exc.py -k test_graph_break_log 617 graph_break_log.debug( 618 "Graph break: from user code at:\n%s", 619 user_stack_formatted, 620 exc_info=True, 621 ) 622 else: 623 # This log line MUST NOT contain the string "Graph break", 624 # exercised by 625 # python test/dynamo/test_misc.py -k test_duplicate_graph_break_log 626 log.debug( 627 "Unsupported break in user code at %s:%s (details suppressed)", 628 *frame_loc, 629 ) 630 631 if self.maybe_has_backedge(): 632 msg = ( 633 "Skipping frame because there is a graph break in a for/while loop\n" 634 f"{self.frame_summary()}" 635 ) 636 log.info(msg) 637 raise exc.SkipFrame(msg) from excp 638 639 excp.remove_from_stats() 640 excp.add_to_stats("graph_break") 641 speculation.reason = GraphCompileReason(excp.msg, user_stack) 642 speculation.fail_and_restart_analysis() 643 644 def handle_graph_break( 645 self: "InstructionTranslatorBase", 646 inst: Instruction, 647 reason: GraphCompileReason, 648 ): 649 self.output.compile_subgraph(self, reason=reason) 650 cg = PyCodegen(self) 651 cleanup: List[Instruction] = [] 652 # Reconstruct the context variable CLASS in the block stack 653 for b in self.block_stack: 654 assert b.with_context is not None 655 assert isinstance(b.with_context, ContextWrappingVariable) 656 b.with_context.reconstruct_type(cg) 657 cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) 658 self.output.add_output_instructions(cg.get_instructions()) 659 del cg 660 661 if sys.version_info >= (3, 11) and inst.opname == "CALL": 662 kw_names = ( 663 self.kw_names.as_python_constant() 664 if self.kw_names is not None 665 else () 666 ) 667 if len(kw_names) > 0: 668 # KW_NAMES no longer used in 3.13 669 assert sys.version_info < (3, 13) 670 self.output.add_output_instructions( 671 [create_instruction("KW_NAMES", argval=kw_names)] 672 ) 673 self.output.add_output_instructions( 674 create_call_function(inst.arg, False) 675 ) 676 else: 677 # copy instruction, but without exception table data 678 assert inst.target is None 679 inst_copy = copy.copy(inst) 680 inst_copy.exn_tab_entry = None 681 self.output.add_output_instructions([inst_copy]) 682 683 self.output.add_output_instructions(cleanup) 684 685 if ( 686 sys.version_info >= (3, 11) 687 and sys.version_info < (3, 12) 688 and inst.opname == "CALL" 689 ): 690 # stack effect for PRECALL + CALL is split between the two instructions 691 stack_effect = dis.stack_effect( 692 dis.opmap["PRECALL"], inst.arg 693 ) + dis.stack_effect(dis.opmap["CALL"], inst.arg) 694 else: 695 stack_effect = dis.stack_effect(inst.opcode, inst.arg) 696 self.popn(push - stack_effect) 697 698 for _ in range(push): 699 self.push(UnknownVariable()) 700 self.output.add_output_instructions( 701 self.create_call_resume_at(self.next_instruction) 702 ) 703 704 return wrapper 705 706 return decorator 707 708 709class BytecodeDistpatchTableMeta(type): 710 """Installs a `cls.dispatch_table` on every subclass to speed up calls to self.OPCODE()""" 711 712 def __init__(cls, name, bases, dct) -> None: 713 super().__init__(name, bases, dct) 714 715 def _missing(opname, *args): 716 unimplemented(f"missing: {opname}") 717 718 dispatch_table = { 719 op: getattr(cls, opname, functools.partial(_missing, opname)) 720 for opname, op in dis.opmap.items() 721 } 722 cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)] 723 724 725class InstructionTranslatorBase( 726 metaclass=BytecodeDistpatchTableMeta, 727): 728 output: OutputGraph 729 symbolic_locals: Dict[str, VariableTracker] 730 symbolic_globals: Dict[str, VariableTracker] 731 symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"] 732 stack: List[VariableTracker] 733 instruction_pointer: Optional[int] 734 current_instruction: Instruction 735 block_stack: List[BlockStackEntry] 736 lineno: int 737 kw_names: Optional[ConstantVariable] 738 accept_prefix_inst: bool 739 prefix_insts: List[Instruction] 740 inline_depth: int 741 inconsistent_side_effects: bool 742 current_speculation: Optional[SpeculationEntry] 743 dispatch_table: List[Any] 744 exn_vt_stack: List[VariableTracker] 745 exec_recorder: Optional[ExecutionRecorder] 746 strict_checks_fn: Optional[Callable[[VariableTracker], bool]] 747 748 def mark_inconsistent_side_effects(self): 749 """ 750 InstructionTranslator has encountered instructions which may cause 751 dynamo to see a different version of history from eager 752 See: https://github.com/pytorch/pytorch/issues/110765 753 """ 754 self.inconsistent_side_effects = True 755 756 def maybe_has_backedge(self): 757 # This function employs a heuristic. It does not reliably detect a backedge. 758 # The heuristic is straightforward: starting from the current instruction and 759 # continuing to the end, if any jump instruction targets an instruction before 760 # the current one, there might be a backedge. 761 762 # Python 3.12 introduced changes to bytecode that group common paths in 763 # blockstacks (with or try...else) and allow for early returns. Consequently, 764 # there can be multiple RETURN_VALUE instructions. Another heuristic is to 765 # halt detection upon encountering the first RETURN_VALUE or RETURN_CONST. 766 767 # These heuristics can result in both false positives and negatives, but 768 # in either case, the Dynamo code remains valid. For false positives 769 # (where an edge is incorrectly marked as a backedge), Dynamo will 770 # perform a SkipFrame instead of potentially applying optimizations. For 771 # false negatives (where an edge that should be marked as a backedge 772 # isn't), multiple graphs may be generated if there's a break in the 773 # graph during a for loop. In general, its better to have fewer false 774 # negatives so that Dynamo does not skip the whole frame. 775 776 cur_offset = self.current_instruction.offset 777 assert self.instruction_pointer is not None 778 for inst in self.instructions[self.instruction_pointer :]: 779 if inst.opname in ("RETURN_VALUE", "RETURN_CONST"): 780 return False 781 if inst.opname in JUMP_OPNAMES: 782 jump_offset = inst.argval 783 if jump_offset < cur_offset: 784 return True 785 return False 786 787 def cell_and_freevars(self): 788 if not hasattr(self, "_cell_and_freevars"): 789 self._cell_and_freevars = tuple( 790 self.code_options["co_cellvars"] or [] 791 ) + tuple(self.code_options["co_freevars"] or []) 792 793 # An inlined function might depend on the freevar of the parent 794 # function. So, recursively obtain parent cell and freevars. 795 if isinstance(self, InliningInstructionTranslator): 796 self._cell_and_freevars += self.parent.cell_and_freevars() 797 return self._cell_and_freevars 798 799 def prune_dead_locals(self): 800 reads = livevars_analysis(self.instructions, self.current_instruction) 801 # implicit use by super() 802 # reads = reads | {"__class__"} 803 # output variables? 804 reads = reads | set(self.cell_and_freevars()) 805 self.symbolic_locals = { 806 k: v for k, v in self.symbolic_locals.items() if k in reads 807 } 808 self.output.side_effects.prune_dead_object_new(self) 809 810 def call_function( 811 self, 812 fn: VariableTracker, 813 args: List[VariableTracker], 814 kwargs: Dict[str, VariableTracker], 815 ): 816 assert isinstance(fn, VariableTracker) 817 assert isinstance(args, list) 818 assert isinstance(kwargs, dict) 819 assert all( 820 isinstance(x, VariableTracker) 821 for x in itertools.chain(args, kwargs.values()) 822 ) 823 inner_fn = None 824 if hasattr(fn, "value"): 825 inner_fn = fn.value 826 if hasattr(fn, "fn"): 827 inner_fn = fn.fn 828 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn): 829 raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}") 830 self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] 831 832 def inline_user_function_return(self, fn, args, kwargs): 833 """ 834 A call to some user defined function by inlining it. 835 """ 836 return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) 837 838 def get_line_of_code_header(self, lineno=None): 839 if lineno is None: 840 lineno = self.lineno 841 inline_depth_str = ( 842 f" (inline depth: {self.inline_depth})" if self.inline_depth > 0 else "" 843 ) 844 funcname = get_funcname(self.f_code.co_filename, lineno) 845 funcname_str = "" if funcname is None else f" ({funcname})" 846 return f"{self.f_code.co_filename}:{lineno} in {self.f_code.co_name}{funcname_str}{inline_depth_str}" 847 848 def get_log_starts_line_log_str(self): 849 log_str = f"TRACE starts_line {self.get_line_of_code_header()}\n" 850 line = linecache.getline(self.f_code.co_filename, self.lineno).rstrip() 851 log_str += f" {line}" 852 return log_str 853 854 def starts_line(self, lineno): 855 if self.lineno == lineno: 856 return 857 self.lineno = lineno 858 TracingContext.set_current_loc( 859 self.f_code.co_filename, lineno, self.f_code.co_name 860 ) 861 from torch._logging.structured import dump_file 862 863 dump_file(self.f_code.co_filename) 864 if trace_source_log.isEnabledFor(logging.DEBUG): 865 trace_source_log.debug("%s", LazyString(self.get_log_starts_line_log_str)) 866 867 def step(self): 868 """Process exactly one instruction, return False we should exit""" 869 ip = self.instruction_pointer 870 if ip is None: 871 return False 872 self.current_instruction = inst = self.instructions[ip] 873 self.instruction_pointer = ip + 1 874 875 if inst.starts_line: 876 self.starts_line(inst.starts_line) 877 878 if ( 879 not self.stack 880 and self.should_compile_partial_graph() 881 and self.is_non_empty_graph() 882 ): 883 self.current_speculation = self.speculate() 884 if self.current_speculation.failed: 885 return self.step_graph_break(inst) 886 887 if trace_bytecode_log.isEnabledFor(logging.DEBUG): 888 trace_bytecode_log.debug( 889 "TRACE %s %s %s", inst.opname, inst.argval, self.stack 890 ) 891 892 self.update_block_stack(inst) 893 894 try: 895 self.dispatch_table[inst.opcode](self, inst) 896 return not self.output.should_exit 897 except exc.ObservedException as e: 898 self.exception_handler(e) 899 return True 900 except ReturnValueOp: 901 return False 902 except Unsupported: 903 if self.current_speculation is None: 904 log.debug("empty checkpoint") 905 raise 906 log.debug("step triggered compile", exc_info=True) 907 908 self.current_speculation.fail_and_restart_analysis() 909 910 if sys.version_info >= (3, 11): 911 912 def update_block_stack(self, inst): 913 # 3.11+ no longer uses a block stack, but we still keep track of one 914 # so that we know which contexts are currently active. 915 # For our purposes, all exception table entries with the same target 916 # are considered to be part of the same "block". 917 # NOTE: we only keep track of with blocks that are not contained in try blocks. 918 # This is because we will not create continuation functions on graph breaks in try blocks, 919 # but we may for with blocks. We do not push blocks here since 920 # with blocks are pushed when handling BEFORE_WITH. 921 entry = inst.exn_tab_entry 922 if entry: 923 # Detect when we have exited the top with block. 924 # The with blocks on the block stack are not enclosed in try 925 # blocks, so a with block's cleanup code should be in the 926 # previous with block (if any). 927 if ( 928 len(self.block_stack) >= 2 929 and entry.target is not self.block_stack[-1].target 930 and entry.target is self.block_stack[-2].target 931 ): 932 # exit the current block 933 self.block_stack.pop() 934 else: 935 # no longer in any block 936 # It is possible for NOPs to be between two instructions 937 # in the same block, but the NOPs are not covered by an 938 # exception table entry. In this case, assume that we 939 # are still in the same block. 940 # In 3.12+, JUMP_BACKWARD might also not be covered by 941 # an exception table entry, so we also assume that we 942 # are still in the same block. It is probably safe to do 943 # this in 3.11, even though we haven't encountered this case before. 944 if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"): 945 # If we really escape from a block and the current 946 # instruction is not in another block, then there 947 # should be no other nested blocks that we are in. 948 assert len(self.block_stack) == 1 949 self.block_stack.pop() 950 951 else: 952 953 def update_block_stack(self, inst): 954 pass 955 956 @property 957 def next_instruction(self): 958 return self.instructions[self.instruction_pointer] # type: ignore[index] 959 960 def step_graph_break(self, continue_inst): 961 # generate code from checkpoint 962 assert not self.output.output_instructions 963 assert self.current_speculation is not None 964 self.output.compile_subgraph( 965 self, 966 partial_convert=True, 967 reason=GraphCompileReason("step_unsupported", [self.frame_summary()]), 968 ) 969 self.output.add_output_instructions( 970 [create_jump_absolute(continue_inst)] + self.instructions 971 ) 972 973 def run_ctx_mgr(self): 974 # NB: Don't push the top level frame summary; set_current_loc will 975 # take care of it. However, DO make sure we attach real_stack to 976 # exceptions 977 return TracingContext.current_frame(None) 978 979 def run(self): 980 with self.run_ctx_mgr(): 981 try: 982 self.output.push_tx(self) 983 while self.step(): 984 pass 985 except BackendCompilerFailed: 986 raise 987 except Exception as e: 988 if self.exec_recorder: 989 e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined] 990 raise 991 finally: 992 self.output.pop_tx() 993 # Cleanup the outputGraph to delete the held tensors. We perform the 994 # cleanup only for InstructionTranslator and not 995 # InliningInstructionTranslator. The InliningInstructionTranslator 996 # mutates the output object and is restored to original state if 997 # there was an exception. 998 if isinstance(self, InstructionTranslator): 999 self.output.cleanup() 1000 1001 def push(self, val: Optional[VariableTracker], name: Any = None): 1002 assert val is None or isinstance( 1003 val, VariableTracker 1004 ), f"push expects VariableTracker, got {typestr(val)}" 1005 self.stack.append(val) # type: ignore[arg-type] 1006 if sys.version_info >= (3, 13): 1007 self.name_stack.append(name) 1008 assert len(self.stack) == len(self.name_stack) 1009 1010 def push_many(self, vals: List[VariableTracker]): 1011 for val in vals: 1012 self.push(val) 1013 1014 def pop(self) -> VariableTracker: 1015 if sys.version_info >= (3, 13): 1016 assert len(self.stack) == len(self.name_stack) 1017 self.name_stack.pop() 1018 return self.stack.pop() 1019 1020 def popn(self, n: int) -> List[VariableTracker]: 1021 return [*reversed([self.pop() for _ in range(n)])] 1022 1023 def _load_closure(self, name): 1024 return ClosureVariable(name=name) 1025 1026 def _load_fast(self, name): 1027 if self.exec_recorder and name in self.f_locals: 1028 self.exec_recorder.add_local_var(name, self.f_locals[name]) 1029 1030 try: 1031 self.push(self.symbolic_locals[name].unwrap(), name=name) 1032 except KeyError: 1033 if sys.version_info >= (3, 13) and name in self.cell_and_freevars(): 1034 # 3.13 merged LOAD_CLOSURE into LOAD_FAST 1035 # If we fail to LOAD_FAST, then we probably should have done LOAD_CLOSURE. 1036 # Closure variable creation is actually done in SET_FUNCTION_ATTRIBUTE, 1037 # but we'll do it again here so that we don't need to push a dummy variable. 1038 # We shouldn't actually be doing anything with this variable anyway. 1039 self.push(self._load_closure(name), name=name) 1040 elif name.startswith("."): 1041 try: 1042 # This happens in dict/list comprehensions 1043 new_name = name.replace(".", "implicit") 1044 self.push(self.symbolic_locals[new_name], name=new_name) 1045 except KeyError: 1046 unimplemented("undefined LOAD_FAST (implicit)") 1047 else: 1048 unimplemented("undefined LOAD_FAST") 1049 1050 # for continuation functions 1051 if name.startswith("___stack"): 1052 self.symbolic_locals.pop(name) 1053 1054 def LOAD_FAST(self, inst): 1055 self._load_fast(inst.argval) 1056 1057 def LOAD_DEREF(self, inst): 1058 assert inst.argval in self.cell_and_freevars() 1059 1060 if self.exec_recorder and inst.argval in self.f_locals: 1061 self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval]) 1062 1063 if inst.argval not in self.symbolic_locals: 1064 unimplemented(f"undefined LOAD_DEREF {inst.argval}") 1065 self.push(self.symbolic_locals[inst.argval]) 1066 1067 def _store_fast(self, name): 1068 loaded_vt = self.pop() 1069 loaded_vt.set_name_hint(name) 1070 self.symbolic_locals[name] = loaded_vt 1071 1072 def STORE_FAST(self, inst): 1073 self._store_fast(inst.argval) 1074 1075 def DELETE_FAST(self, inst): 1076 del self.symbolic_locals[inst.argval] 1077 1078 STORE_DEREF = STORE_FAST 1079 1080 def LOAD_CLOSURE(self, inst): 1081 self.push(self._load_closure(inst.argval)) 1082 1083 def _load_const(self, inst): 1084 i = inst.arg 1085 if i is None: 1086 return ConstantVariable.create(value=inst.argval) 1087 val = self._constants_cache[i] 1088 if not val: 1089 self._constants_cache[i] = val = ConstantVariable.create(value=inst.argval) 1090 return val 1091 1092 def LOAD_CONST(self, inst): 1093 self.push(self._load_const(inst)) 1094 1095 def _load_global(self, inst): 1096 name = inst.argval 1097 1098 if self.exec_recorder: 1099 if name in self.f_globals: 1100 self.exec_recorder.add_global_var(name, self.f_globals[name]) 1101 else: 1102 assert name in self.f_builtins 1103 self.exec_recorder.builtins[name] = self.f_builtins[name] 1104 1105 if name in self.symbolic_globals: 1106 variable = self.output.side_effects[self.symbolic_globals[name]] 1107 self.push(self.output.side_effects.load_global(variable, name)) 1108 return 1109 1110 try: 1111 value = self.f_globals[name] 1112 except KeyError: 1113 return self.load_builtin(inst) 1114 1115 source = GlobalSource(name) 1116 self.push(VariableBuilder(self, source)(value)) 1117 1118 @functools.cached_property 1119 def nn_modules_globals_vt(self): 1120 module_name = "torch.nn.modules.module" 1121 module_source = self.import_source(module_name) 1122 fglobals_value = importlib.import_module(module_name) # type: ignore[assignment] 1123 return VariableBuilder(self, module_source)(fglobals_value) 1124 1125 def LOAD_GLOBAL(self, inst): 1126 if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2: 1127 self.PUSH_NULL(inst) 1128 self._load_global(inst) 1129 if sys.version_info >= (3, 13) and inst.arg % 2: 1130 self.PUSH_NULL(inst) 1131 1132 def STORE_GLOBAL(self, inst): 1133 value = self.pop() 1134 name = inst.argval 1135 source = GlobalSource(name) 1136 if name not in self.symbolic_globals: 1137 self.symbolic_globals[name] = object() # type: ignore[assignment] # sentinel object 1138 variable = self.output.side_effects.track_global_existing( 1139 source, self.symbolic_globals[name] 1140 ) 1141 if isinstance(value, RemovableHandleVariable): 1142 unimplemented("Storing handles in globals - NYI") 1143 self.output.side_effects.store_global(variable, name, value) 1144 1145 def import_source(self, module_name): 1146 """Create an alias to a module for use in guards""" 1147 if "torch_package" in module_name: 1148 value = torch.package.package_importer._package_imported_modules[ 1149 module_name 1150 ] 1151 alias = ( 1152 module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_") 1153 ) 1154 else: 1155 value = importlib.import_module(module_name) 1156 alias = f"__import_{module_name.replace('.', '_dot_')}" 1157 f_globals = self.output.global_scope 1158 assert alias not in f_globals or f_globals[alias] is value 1159 f_globals[alias] = value 1160 self.output.update_co_names(alias) 1161 return GlobalSource(alias) 1162 1163 def resolve_name(self, name, package, level): 1164 """ 1165 Copied from the Cpython implementation of __import__ 1166 Resolve a relative module name to an absolute one. 1167 https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902 1168 """ 1169 bits = package.rsplit(".", level - 1) 1170 if len(bits) < level: 1171 raise ImportError("attempted relative import beyond top-level package") 1172 base = bits[0] 1173 return f"{base}.{name}" if name else base 1174 1175 def calc_package(self): 1176 """ 1177 Copied from the Cpython implementation of __import__ 1178 https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090 1179 """ 1180 package = self.f_globals.get("__package__") 1181 spec = self.f_globals.get("__spec__") 1182 if package is not None: 1183 if spec is not None and package != spec.parent: 1184 log.warning( 1185 "__package__ != __spec__.parent (%r != %r)", 1186 package, 1187 spec.parent, 1188 stacklevel=3, 1189 ) 1190 return package 1191 elif spec is not None: 1192 return spec.parent 1193 else: 1194 log.warning( 1195 "can't resolve package from __spec__ or __package__, " 1196 "falling back on __name__ and __path__", 1197 stacklevel=3, 1198 ) 1199 package = self.f_globals["__name__"] 1200 if "__path__" not in self.f_globals: 1201 package = package.rpartition(".")[0] 1202 return package 1203 1204 def IMPORT_NAME(self, inst): 1205 level, fromlist = self.popn(2) 1206 level = level.as_python_constant() 1207 fromlist = fromlist.as_python_constant() 1208 module_name = inst.argval 1209 1210 # Are we replaying? if so, load recorded module 1211 recorded_name = ( 1212 f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}" 1213 ) 1214 if recorded_name in self.f_globals: 1215 value = self.f_globals[recorded_name] 1216 source = GlobalSource(recorded_name) 1217 else: 1218 try: 1219 value = __import__( 1220 module_name, 1221 fromlist=fromlist, 1222 level=level, 1223 globals=self.f_globals, 1224 ) 1225 except ImportError: 1226 unimplemented("import a module that does not exist") 1227 1228 if level != 0: 1229 pkg = self.calc_package() 1230 module_name = self.resolve_name(module_name, pkg, level) 1231 1232 # For __import__, when the name variable is of the form package.module, 1233 # normally, the top-level package (the name up till the first dot) is 1234 # returned, not the module named by module_name. However, when a 1235 # non-empty fromlist argument is given, the module named by name is 1236 # returned. Therefore, we set the source correctly here. 1237 if not fromlist: 1238 top_level_module_name = module_name.partition(".")[0] 1239 source = self.import_source(top_level_module_name) 1240 else: 1241 source = self.import_source(module_name) 1242 1243 if self.exec_recorder: 1244 self.exec_recorder.add_local_mod(recorded_name, value) 1245 1246 if istype(value, (types.ModuleType, DummyModule)): 1247 self.push(PythonModuleVariable(value, source=source)) 1248 else: 1249 unimplemented(f"IMPORT_NAME {typestr(value)}") 1250 1251 def IMPORT_FROM(self, inst): 1252 self.DUP_TOP(inst) 1253 self._load_attr(inst) 1254 1255 def load_builtin_from_argval(self, argval): 1256 if argval not in self.f_builtins: 1257 raise NameError(f"name '{argval}' is not defined") 1258 val = self.f_builtins[argval] 1259 1260 if callable(val): 1261 builtins_source = GlobalSource( 1262 self.output.name_of_builtins_dict_key_in_fglobals 1263 ) 1264 var_source = GetItemSource(builtins_source, argval) 1265 self.push(VariableBuilder(self, var_source)(val)) 1266 else: 1267 assert is_builtin_constant(val) 1268 self.push(ConstantVariable.create(value=val)) 1269 1270 def load_builtin(self, inst): 1271 self.load_builtin_from_argval(inst.argval) 1272 1273 def jump(self, inst): 1274 self.instruction_pointer = self.indexof[inst.target] 1275 1276 JUMP_FORWARD = jump 1277 JUMP_ABSOLUTE = jump 1278 1279 POP_JUMP_IF_FALSE = generic_jump(operator.not_, False) 1280 POP_JUMP_IF_TRUE = generic_jump(operator.truth, False) 1281 JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True) 1282 JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True) 1283 1284 def SETUP_LOOP(self, inst): 1285 # only exists in python<=3.7 1286 self.block_stack.append(BlockStackEntry(inst, inst.target)) 1287 1288 def SETUP_EXCEPT(self, inst): 1289 # only exists in python<=3.7 1290 self.block_stack.append(BlockStackEntry(inst, inst.target)) 1291 1292 def POP_BLOCK(self, inst): 1293 self.block_stack.pop() 1294 1295 def SETUP_WITH(self, inst): 1296 self.setup_or_before_with(inst) 1297 1298 def SETUP_FINALLY(self, inst): 1299 self.block_stack.append(BlockStackEntry(inst, inst.target)) 1300 1301 def BEGIN_FINALLY(self, inst): 1302 self.push(None) 1303 1304 def WITH_CLEANUP_START(self, inst): 1305 exit, exc = self.popn(2) 1306 assert exc is None 1307 self.push(exc) 1308 self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {})) 1309 1310 def WITH_CLEANUP_FINISH(self, inst): 1311 self.popn(2) 1312 self.push(None) 1313 1314 def CALL_FINALLY(self, inst): 1315 """ 1316 pushes the address of the next instruction onto the stack and increments 1317 bytecode counter by delta 1318 """ 1319 # Python 3.8 only 1320 addr = self.indexof[self.next_instruction] 1321 self.push(ConstantVariable.create(addr)) 1322 self.jump(inst) 1323 1324 def END_FINALLY(self, inst): 1325 # Python 3.8 only 1326 # https://docs.python.org/3.8/library/dis.html#opcode-END_FINALLY 1327 tos = self.pop() 1328 if isinstance(tos, ConstantVariable): 1329 self.instruction_pointer = tos.as_python_constant() 1330 else: 1331 pass 1332 1333 def POP_FINALLY(self, inst): 1334 # Python 3.8 only 1335 preserve_tos = inst.argval 1336 if preserve_tos: 1337 tos = self.pop() 1338 _ = self.pop() 1339 if preserve_tos: 1340 self.push(tos) # type: ignore[possibly-undefined] 1341 1342 def FOR_ITER(self, inst): 1343 it = self.pop().realize() 1344 try: 1345 val = it.next_variable(self) 1346 self.push(it) 1347 self.push(val) 1348 except (StopIteration, exc.ObservedUserStopIteration) as e: 1349 if isinstance(e, exc.ObservedUserStopIteration): 1350 exc.handle_observed_exception(self) 1351 1352 # leave iterator upon exhaustion in 3.12 1353 if sys.version_info >= (3, 12): 1354 # CPython 3.12 actually jumps to the instruction after the END_FOR 1355 # and performs the action of END_FOR as part of FOR_ITER. We jump 1356 # to the END_FOR and run it, so we need to make sure 2 values are 1357 # on the stack for it to pop. 1358 self.push(it) 1359 self.push(ConstantVariable.create(None)) 1360 self.jump(inst) 1361 1362 def _raise_exception_variable(self, inst): 1363 val = self.pop() 1364 # User can raise exception in 2 ways 1365 # 1) raise exception type - raise NotImplementedError 1366 # 2) raise execption instance - raise NotImplemetedError("foo") 1367 1368 # 1) when user raises exception type 1369 if isinstance(val, variables.BuiltinVariable): 1370 # Create the instance of the exception type 1371 # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 1372 val = val.call_function(self, [], {}) # type: ignore[arg-type] 1373 1374 # Save the exception in a global data structure 1375 self.exn_vt_stack.append(val) 1376 1377 # 2) when user raises exception instance 1378 if isinstance(val, variables.ExceptionVariable): 1379 if observed_exception_type := exc.observed_exception_map.get(val.exc_type): 1380 raise observed_exception_type(f"raised exception {val}") 1381 raise exc.ObservedException(f"raised exception {val}") 1382 unimplemented(f"raise {exc}") 1383 1384 def RAISE_VARARGS(self, inst): 1385 if inst.arg == 0: 1386 unimplemented("re-raise") 1387 elif inst.arg == 1: 1388 self._raise_exception_variable(inst) 1389 else: 1390 # Support raise .. from None ... Dynamo does not track __cause__ and other attributes of exception. So we 1391 # ignore `from None` part. 1392 from_vt = self.pop() 1393 if isinstance(from_vt, ConstantVariable) and from_vt.value is None: 1394 self._raise_exception_variable(inst) 1395 unimplemented("raise ... from ...") 1396 1397 def RERAISE(self, inst): 1398 if sys.version_info >= (3, 11): 1399 # RERAISE is currently supported in a narrow case of `raise ... from None` 1400 self._raise_exception_variable(inst) 1401 unimplemented("RERAISE") 1402 1403 def exception_handler(self, raised_exception): 1404 if sys.version_info >= (3, 11): 1405 exn_tab_entry = self.current_instruction.exn_tab_entry 1406 if exn_tab_entry: 1407 # Implementation is based on https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt 1408 1409 # 1) pop values from the stack until it matches the stack depth 1410 # for the handler 1411 while len(self.stack) > exn_tab_entry.depth: 1412 self.pop() 1413 1414 # 2) if 'lasti' is true, then push the offset that the exception was raised at 1415 if exn_tab_entry.lasti: 1416 self.push( 1417 variables.ConstantVariable(self.current_instruction.offset) 1418 ) 1419 1420 # 3) push the exception to the stack 1421 assert len(self.exn_vt_stack) 1422 self.push(self.exn_vt_stack[-1]) 1423 1424 # 4) jump to the handler 1425 self.jump(exn_tab_entry) 1426 else: 1427 # No handler found. Bubble the exception to the parent 1428 # instruction translater. We use special exception for this. 1429 self.stack.clear() 1430 if type(self) is InstructionTranslator: 1431 raise Unsupported("Observed exception") 1432 raise raised_exception 1433 else: 1434 if len(self.block_stack): 1435 # base implementation - https://github.com/python/cpython/blob/3.10/Python/ceval.c#L4455 1436 1437 assert len(self.exn_vt_stack) 1438 exception_var = self.exn_vt_stack[-1] 1439 1440 block_stack_entry = self.block_stack.pop() 1441 1442 while block_stack_entry.inst.opname == "EXCEPT_HANDLER": 1443 # TODO(anijain2305) - This is not tested .. unable to create a testcase 1444 # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456 1445 self.popn(3) 1446 if len(self.block_stack) == 0: 1447 # No handler found in this frame. Bubble the exception to the parent 1448 # instruction translater. 1449 self.stack.clear() 1450 if type(self) is InstructionTranslator: 1451 raise Unsupported("Observed exception") 1452 raise raised_exception 1453 block_stack_entry = self.block_stack.pop() 1454 1455 if block_stack_entry.inst.opname != "SETUP_FINALLY": 1456 unimplemented( 1457 "exception is raised when top of the block stack " 1458 "is not exception handler (e.g. try .. with .. except). " 1459 f"Current TOS is {block_stack_entry.inst}" 1460 ) 1461 1462 # Push a dummy block stack entry of EXCEPT_HANDLER 1463 # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456 1464 except_handler_inst = Instruction(1e6, "EXCEPT_HANDLER", None, 0) 1465 self.block_stack.append(BlockStackEntry(except_handler_inst, None)) 1466 1467 # Push old exception 1468 if len(self.exn_vt_stack) >= 2: 1469 old_exception = self.exn_vt_stack[-2] 1470 1471 # Push the old exception on to stack - tb, value, type 1472 # Traceback is currently mapped to UnknownVariable 1473 self.push(variables.UnknownVariable()) 1474 self.push(old_exception) 1475 self.push(variables.BuiltinVariable(old_exception.exc_type)) 1476 else: 1477 # Push empty exception tb, value, type 1478 self.push(variables.ConstantVariable(None)) 1479 self.push(variables.ConstantVariable(None)) 1480 self.push(variables.ConstantVariable(None)) 1481 1482 # Push new exception - tb, val, type 1483 # Traceback is currently mapped to UnknownVariable 1484 self.push(variables.UnknownVariable()) 1485 self.push(exception_var) 1486 self.push(variables.BuiltinVariable(exception_var.exc_type)) 1487 1488 # Jump to target 1489 self.jump(block_stack_entry) 1490 else: 1491 # No handler found. Bubble the exception to the parent 1492 # instruction translater. We use special exception for this. 1493 self.stack.clear() 1494 if type(self) is InstructionTranslator: 1495 raise Unsupported("Observed exception") 1496 raise raised_exception 1497 1498 def PUSH_EXC_INFO(self, inst): 1499 val = self.pop() 1500 assert len(self.exn_vt_stack) 1501 self.push(self.exn_vt_stack[-1]) 1502 self.push(val) 1503 1504 def POP_EXCEPT(self, inst): 1505 if sys.version_info >= (3, 11): 1506 val = self.pop() 1507 assert isinstance(val, variables.ExceptionVariable) 1508 1509 # This exception is handled and therefore we can clear the error indicator 1510 assert len(self.exn_vt_stack) 1511 self.exn_vt_stack.pop() 1512 else: 1513 assert len(self.block_stack) > 0 1514 if self.block_stack[-1].inst.opname != "EXCEPT_HANDLER": 1515 raise AssertionError( 1516 "Bug in Dynamo tracing of exception handling." 1517 "Top of the block stack is not EXCEPT_HANDLER." 1518 ) 1519 self.block_stack.pop() 1520 1521 self.popn(3) 1522 1523 # This exception is handled and therefore we can clear the error indicator 1524 assert len(self.exn_vt_stack) 1525 self.exn_vt_stack.pop() 1526 1527 def check_if_exc_matches(self): 1528 assert len(self.stack) >= 2 1529 expected_exc_types = self.pop() 1530 if sys.version_info >= (3, 11): 1531 # CHECK_EXC_MATCH (which is used from 3.11 onwards) does not pop. 1532 # This is the description from the disassembly doc 1533 # 1534 # Performs exception matching for ``except``. Tests whether the ``STACK[-2]`` 1535 # is an exception matching ``STACK[-1]``. Pops ``STACK[-1]`` and pushes the boolean 1536 # result of the test. 1537 exc_instance = self.stack[-1] 1538 else: 1539 # This is used prior to 3.11 via opcode JUMP_IF_NOT_EXC_MATCH 1540 # There is no documentation but here is the code pointer that does 2 pops 1541 # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L3650-L3665 1542 exc_instance = self.stack.pop() 1543 1544 # Users can check exception in 2 ways 1545 # 1) except NotImplementedError --> BuilinVariable 1546 # 2) except (NotImplemetedError, AttributeError) -> TupleVariable 1547 1548 if not isinstance(expected_exc_types, (BuiltinVariable, TupleVariable)): 1549 unimplemented( 1550 f"except has an unsupported types of objects {expected_exc_types}" 1551 ) 1552 1553 if sys.version_info >= (3, 11): 1554 if not isinstance(exc_instance, variables.ExceptionVariable): 1555 unimplemented( 1556 f"except expects to recieve an object of exception type but received {exc_instance}" 1557 ) 1558 1559 if isinstance(expected_exc_types, TupleVariable): 1560 expected_types = expected_exc_types.items 1561 else: 1562 expected_types = [ 1563 expected_exc_types, 1564 ] 1565 1566 for expected_type in expected_types: 1567 if not isinstance(expected_type, BuiltinVariable): 1568 unimplemented( 1569 f"except has an unsupported types of object {expected_type}" 1570 ) 1571 if isinstance(exc_instance, variables.ExceptionVariable) and issubclass( 1572 exc_instance.exc_type, expected_type.fn 1573 ): 1574 return True 1575 elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass( 1576 exc_instance.fn, expected_type.fn 1577 ): 1578 return True 1579 1580 return False 1581 1582 def CHECK_EXC_MATCH(self, inst): 1583 self.push(variables.ConstantVariable(self.check_if_exc_matches())) 1584 1585 def JUMP_IF_NOT_EXC_MATCH(self, inst): 1586 if not self.check_if_exc_matches(): 1587 self.jump(inst) 1588 1589 def COMPARE_OP(self, inst): 1590 if inst.argval == "exception match": 1591 self.CHECK_EXC_MATCH(inst) 1592 else: 1593 self.push(compare_op_handlers[inst.argval](self, self.popn(2), {})) 1594 1595 def GET_ITER(self, inst): 1596 self.call_function(BuiltinVariable(iter), [self.pop()], {}) 1597 1598 @break_graph_if_unsupported(push=1) 1599 def CALL_FUNCTION(self, inst): 1600 args = self.popn(inst.argval) 1601 fn = self.pop() 1602 self.call_function(fn, args, {}) 1603 1604 @break_graph_if_unsupported(push=1) 1605 def CALL_FUNCTION_EX(self, inst): 1606 kwargsvars: VariableTracker 1607 if inst.argval == 0: 1608 kwargsvars = ConstDictVariable({}) 1609 argsvars = self.pop() 1610 elif inst.argval == 1: 1611 kwargsvars = self.pop() 1612 argsvars = self.pop() 1613 else: 1614 unimplemented("CALL_FUNCTION_EX") 1615 1616 if sys.version_info >= (3, 13): 1617 # 3.13 swapped null and callable 1618 null = self.pop() 1619 assert isinstance(null, NullVariable) 1620 1621 fn = self.pop() 1622 1623 if sys.version_info >= (3, 11) and sys.version_info < (3, 13): 1624 null = self.pop() 1625 assert isinstance(null, NullVariable) 1626 1627 if isinstance(fn, GetAttrVariable) and isinstance(fn.obj, TensorVariable): 1628 # realize is requires for Python 3.8 1629 kwargsvars = kwargsvars.realize() 1630 if fn.name == "view" and isinstance( 1631 argsvars, (ConstantVariable, TensorVariable) 1632 ): 1633 # Hack to handle special case in some bert models. Converts 1634 # x.view(*shape) into x.view(shape), which is correct for view() 1635 # but not generally. See test_transpose_for_scores(). 1636 argsvars = TupleVariable([argsvars]) 1637 elif ( 1638 fn.name == "random_" 1639 and isinstance(argsvars, TupleVariable) 1640 and len(argsvars.items) == 0 1641 and isinstance(kwargsvars, ConstDictVariable) 1642 and ConstantVariable.create("from") in kwargsvars 1643 ): 1644 # `from`` is python keyword. Adding random_ with `from` in the 1645 # Fx graph causes syntax error. Even if we convert the kwargs to 1646 # args, aot_autograd/inductor while lowering generates 1647 # aten.random.from, again causing syntax errors. Since this 1648 # usecase is uncommon, graph break. 1649 unimplemented("random_ op is called with from keyword") 1650 elif ( 1651 fn.name == "uniform_" 1652 and isinstance(argsvars, TupleVariable) 1653 and len(argsvars.items) == 0 1654 and isinstance(kwargsvars, ConstDictVariable) 1655 and ConstantVariable.create("from") in kwargsvars 1656 ): 1657 # `from`` is python keyword. Adding uniform_ with `from` in the 1658 # Fx graph causes syntax error. Even if we convert the kwargs to 1659 # args, aot_autograd/inductor while lowering generates 1660 # aten.uniform.from, again causing syntax errors. Since this 1661 # usecase is uncommon, graph break. 1662 unimplemented("uniform_ op is called with from keyword") 1663 1664 if not isinstance( 1665 argsvars, BaseListVariable 1666 ) and argsvars.has_force_unpack_var_sequence(self): 1667 argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self)) 1668 1669 # Unpack for cases like fn(**obj) where obj is a map 1670 if isinstance(kwargsvars, UserDefinedObjectVariable): 1671 kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type] 1672 1673 if not isinstance(argsvars, BaseListVariable) or not isinstance( 1674 kwargsvars, ConstDictVariable 1675 ): 1676 unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}") 1677 1678 # Map to a dictionary of str -> VariableTracker 1679 kwargsvars = kwargsvars.keys_as_python_constant() 1680 self.call_function(fn, argsvars.items, kwargsvars) 1681 1682 @break_graph_if_unsupported(push=1) 1683 def CALL_FUNCTION_KW(self, inst): 1684 argnames = self.pop() 1685 args = self.popn(inst.argval) 1686 fn = self.pop() 1687 assert isinstance(argnames, TupleVariable) and argnames.is_python_constant() 1688 argnames = argnames.as_python_constant() 1689 args, kwargs_list = args[: -len(argnames)], args[-len(argnames) :] 1690 kwargs = dict(zip(argnames, kwargs_list)) 1691 assert len(kwargs) == len(argnames) 1692 self.call_function(fn, args, kwargs) 1693 1694 def LOAD_METHOD_SUPER(self, inst): 1695 self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) 1696 arg = inst.argval[0] 1697 argval = self.code_options["co_names"][arg] 1698 if sys.version_info < (3, 11): 1699 self._load_attr(dataclasses.replace(inst, argval=argval)) 1700 else: 1701 self.LOAD_METHOD(dataclasses.replace(inst, argval=argval)) 1702 1703 def LOAD_ATTR_SUPER(self, inst): 1704 self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) 1705 arg = inst.argval[0] 1706 argval = self.code_options["co_names"][arg] 1707 self._load_attr(dataclasses.replace(inst, argval=argval)) 1708 1709 def LOAD_METHOD(self, inst): 1710 self._load_attr(inst) 1711 obj = self.pop() 1712 if sys.version_info >= (3, 13): 1713 self.push(obj) 1714 self.PUSH_NULL(inst) 1715 elif sys.version_info >= (3, 11): 1716 # always follow the NULL + fn convention, since if obj 1717 # is actually a method, self is already bound to it, so it 1718 # doesn't need to be passed in as an arg. 1719 self.PUSH_NULL(inst) 1720 self.push(obj) 1721 else: 1722 self.push(obj) 1723 self.push(None) 1724 1725 def CALL_METHOD(self, inst): 1726 args = self.popn(inst.argval) 1727 dummy = self.pop() 1728 assert dummy is None 1729 fn = self.pop() 1730 self.call_function(fn, args, {}) 1731 1732 def _load_attr(self, inst): 1733 obj = self.pop() 1734 result = BuiltinVariable(getattr).call_function( 1735 self, [obj, ConstantVariable.create(inst.argval)], {} # type: ignore[arg-type] 1736 ) 1737 self.push(result) 1738 1739 def LOAD_ATTR(self, inst): 1740 if sys.version_info >= (3, 12): 1741 if inst.arg % 2: 1742 self.LOAD_METHOD(inst) 1743 return 1744 self._load_attr(inst) 1745 1746 def STORE_ATTR(self, inst): 1747 speculation = self.speculate() 1748 if speculation.failed: 1749 return self.store_attr_graph_break(inst) 1750 val, obj = self.popn(2) 1751 1752 if isinstance(obj, NNModuleVariable) and not isinstance(val, ConstantVariable): 1753 # We don't allow side effects during export on non-constant values 1754 # https://github.com/pytorch/torchdynamo/issues/1475 1755 assert ( 1756 not self.export 1757 ), f"Mutating module attribute {inst.argval} during export." 1758 1759 try: 1760 BuiltinVariable(setattr).call_function( 1761 self, [obj, ConstantVariable.create(inst.argval), val], {} # type: ignore[arg-type] 1762 ) 1763 return 1764 except Unsupported as e: 1765 if not self.should_compile_partial_graph(): 1766 raise 1767 log.debug("STORE_ATTR triggered compile", exc_info=True) 1768 e.remove_from_stats() 1769 e.add_to_stats("graph_break") 1770 speculation.fail_and_restart_analysis() 1771 1772 def store_attr_graph_break(self, inst): 1773 if not self.should_compile_partial_graph(): 1774 unimplemented("should_compile_partial_graph=False") 1775 self.output.compile_subgraph( 1776 self, reason=GraphCompileReason("store_attr", [self.frame_summary()]) 1777 ) 1778 self.output.add_output_instructions([copy.copy(inst)]) 1779 self.popn(2) 1780 self.output.add_output_instructions( 1781 self.create_call_resume_at(self.next_instruction) 1782 ) 1783 1784 def DELETE_ATTR(self, inst): 1785 obj = self.pop() 1786 BuiltinVariable(delattr).call_function( 1787 self, [obj, ConstantVariable.create(inst.argval)], {} # type: ignore[arg-type] 1788 ) 1789 1790 def create_call_resume_at(self, offset): 1791 raise AssertionError( 1792 f"create_call_resume_at not overridden by subclass {type(self)}" 1793 ) 1794 1795 def should_compile_partial_graph(self) -> bool: 1796 raise AssertionError( 1797 f"should_compile_partial_graph not overridden by subclass {type(self)}" 1798 ) 1799 1800 @break_graph_if_unsupported(push=0) 1801 def STORE_SUBSCR(self, inst): 1802 val, obj, key = self.popn(3) 1803 result = obj.call_method(self, "__setitem__", [key, val], {}) 1804 1805 def DELETE_SUBSCR(self, inst): 1806 obj, key = self.popn(2) 1807 obj.call_method(self, "__delitem__", [key], {}) 1808 1809 def BUILD_TUPLE(self, inst): 1810 name_tuple = None 1811 if sys.version_info >= (3, 13): 1812 name_tuple = tuple(self.name_stack[-inst.argval :]) 1813 items = self.popn(inst.argval) 1814 self.push(TupleVariable(items), name=name_tuple) 1815 1816 def BUILD_SLICE(self, inst): 1817 items = self.popn(inst.argval) 1818 self.push(SliceVariable(items)) 1819 1820 def BUILD_LIST(self, inst): 1821 items = self.popn(inst.argval) 1822 self.push(ListVariable(items, mutable_local=MutableLocal())) 1823 1824 def BUILD_SET(self, inst): 1825 if config.inject_BUILD_SET_unimplemented_TESTING_ONLY: 1826 unimplemented("missing: BUILD_SET") 1827 items = self.popn(inst.argval) 1828 new_set = SetVariable(items, mutable_local=MutableLocal()) 1829 self.push(new_set) 1830 1831 def BUILD_LIST_UNPACK(self, inst, cls=ListVariable): 1832 seqs = self.popn(inst.argval) 1833 items = [] 1834 for seq in seqs: 1835 try: 1836 items.extend(seq.force_unpack_var_sequence(self)) 1837 except NotImplementedError: 1838 unimplemented(f"BUILD_LIST_UNPACK {seq}") 1839 self.push(cls(items, mutable_local=MutableLocal())) 1840 1841 def BUILD_TUPLE_UNPACK(self, inst): 1842 self.BUILD_LIST_UNPACK(inst, cls=TupleVariable) 1843 1844 BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK 1845 1846 def BUILD_MAP(self, inst): 1847 items = self.popn(inst.argval * 2) 1848 d = dict(zip(items[::2], items[1::2])) 1849 self.push(ConstDictVariable(d, mutable_local=MutableLocal())) 1850 1851 def BUILD_MAP_UNPACK(self, inst): 1852 items = self.popn(inst.argval) 1853 # ensure everything is a dict 1854 items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] # type: ignore[arg-type] 1855 result = {} 1856 for x in items: 1857 assert isinstance(x, ConstDictVariable) 1858 result.update(x.items) 1859 self.push( 1860 ConstDictVariable( 1861 result, 1862 mutable_local=MutableLocal(), 1863 ) 1864 ) 1865 1866 BUILD_MAP_UNPACK_WITH_CALL = BUILD_MAP_UNPACK 1867 1868 def BUILD_CONST_KEY_MAP(self, inst): 1869 keys = self.pop() 1870 values = self.popn(inst.argval) 1871 assert isinstance(keys, TupleVariable) 1872 assert keys.is_python_constant() 1873 1874 keys = keys.force_unpack_var_sequence(self) 1875 assert len(keys) == len(values) 1876 1877 self.push( 1878 ConstDictVariable( 1879 dict(zip(keys, values)), 1880 mutable_local=MutableLocal(), 1881 ) 1882 ) 1883 1884 def MAP_ADD(self, inst): 1885 k, v = self.popn(2) 1886 assert inst.argval > 0 1887 obj = self.stack[-inst.arg].realize() 1888 assert isinstance(obj, ConstDictVariable) 1889 obj.call_method(self, "__setitem__", (k, v), {}) # type: ignore[arg-type] 1890 1891 def SET_ADD(self, inst): 1892 v = self.pop() 1893 assert inst.argval > 0 1894 obj = self.stack[-inst.arg] 1895 assert isinstance(obj, SetVariable) 1896 assert obj.mutable_local 1897 return obj.call_method(self, "add", [v], {}) 1898 1899 def SET_UPDATE(self, inst): 1900 v = self.pop() 1901 assert inst.argval > 0 1902 obj = self.stack[-inst.arg] 1903 assert isinstance(obj, SetVariable) 1904 assert obj.mutable_local 1905 obj.call_method(self, "update", [v], {}) 1906 1907 def LIST_APPEND(self, inst): 1908 v = self.pop() 1909 assert inst.argval > 0 1910 obj = self.stack[-inst.arg].realize() 1911 assert isinstance(obj, ListVariable) 1912 assert obj.mutable_local 1913 self.output.side_effects.mutation(obj) 1914 obj.items.append(v) 1915 1916 def MAKE_FUNCTION(self, inst): 1917 flags = inst.arg 1918 old_stack = list(self.stack) 1919 if sys.version_info < (3, 11): 1920 fn_name = self.pop() 1921 code = self.pop() 1922 if sys.version_info >= (3, 11): 1923 # MAKE_FUNCTION behavior actually changed in 3.11, see 1924 # https://github.com/python/cpython/pull/93189/ 1925 assert hasattr(code.value, "co_qualname") # type: ignore[attr-defined] 1926 fn_name = ConstantVariable.create(value=code.value.co_qualname) # type: ignore[attr-defined] 1927 defaults = None 1928 closure = None 1929 annotations = None 1930 kwdefaults = None 1931 1932 if sys.version_info < (3, 13): 1933 # in 3.13, this is handled in SET_FUNCTION_ATTRIBUTE 1934 if flags & 0x08: 1935 closure = self.pop() 1936 if flags & 0x04: 1937 annotations = self.pop() 1938 if flags & 0x02: 1939 kwdefaults = self.pop() 1940 if flags & 0x01: 1941 defaults = self.pop() 1942 1943 self.push( 1944 NestedUserFunctionVariable( 1945 fn_name, 1946 code, 1947 self.f_globals, 1948 defaults, 1949 kwdefaults, 1950 annotations, 1951 closure, 1952 closure_scope=self, 1953 ) 1954 ) 1955 1956 def UNPACK_SEQUENCE(self, inst): 1957 seq = self.pop() 1958 if isinstance(seq, TensorVariable): 1959 val = seq.unpack_var_sequence(self, idxes=range(inst.argval)) # type: ignore[arg-type] 1960 elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable): 1961 # x, y = a.shape 1962 proxy = getattr(seq.obj.as_proxy(), seq.name) 1963 val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)] 1964 elif seq.has_force_unpack_var_sequence(self): 1965 val = seq.force_unpack_var_sequence(self) 1966 else: 1967 unimplemented(f"UNPACK_SEQUENCE {seq}") 1968 if len(val) != inst.argval: 1969 unimplemented("UNPACK_SEQUENCE length mismatch") 1970 for i in reversed(val): 1971 self.push(i) 1972 1973 def UNPACK_EX(self, inst): 1974 assert 0 <= inst.argval <= 0xFFFF 1975 prefix = inst.argval & 0xFF # low byte 1976 suffix = inst.argval >> 8 # high byte 1977 seq = self.pop() 1978 if seq.has_force_unpack_var_sequence(self): 1979 vals = list(seq.force_unpack_var_sequence(self)) 1980 assert len(vals) >= prefix + suffix 1981 vals_prefix = vals[:prefix] 1982 vals_list = vals[prefix : len(vals) - suffix] 1983 vals_suffix = vals[len(vals) - suffix :] 1984 for item in reversed(vals_suffix): 1985 self.push(item) 1986 self.push(TupleVariable(vals_list)) 1987 for item in reversed(vals_prefix): 1988 self.push(item) 1989 else: 1990 unimplemented(f"UNPACK_EX {seq}") 1991 1992 def NOP(self, inst): 1993 pass 1994 1995 def POP_TOP(self, inst): 1996 self.pop() 1997 1998 def ROT_TWO(self, inst): 1999 a = self.pop() 2000 b = self.pop() 2001 self.push(a) 2002 self.push(b) 2003 2004 def ROT_THREE(self, inst): 2005 a = self.pop() 2006 b = self.pop() 2007 c = self.pop() 2008 self.push(a) 2009 self.push(c) 2010 self.push(b) 2011 2012 def ROT_FOUR(self, inst): 2013 a = self.pop() 2014 b = self.pop() 2015 c = self.pop() 2016 d = self.pop() 2017 self.push(a) 2018 self.push(d) 2019 self.push(c) 2020 self.push(b) 2021 2022 def DUP_TOP(self, inst): 2023 a = self.pop() 2024 self.push(a) 2025 self.push(a) 2026 2027 def DUP_TOP_TWO(self, inst): 2028 a = self.pop() 2029 b = self.pop() 2030 self.push(b) 2031 self.push(a) 2032 self.push(b) 2033 self.push(a) 2034 2035 def FORMAT_VALUE(self, inst): 2036 flags = inst.arg 2037 if (flags & 0x04) == 0x04: 2038 fmt_spec = self.pop() 2039 else: 2040 fmt_spec = ConstantVariable.create("") 2041 2042 value = self.pop() 2043 if isinstance(value, SymNodeVariable): 2044 from torch._dynamo.variables.lazy import ( 2045 LazySymNodeFormatString, 2046 LazyVariableTracker, 2047 ) 2048 2049 value = LazyVariableTracker.create( 2050 LazySymNodeFormatString(value, fmt_spec), source=value.source 2051 ) 2052 self.push(value) 2053 return 2054 if (flags & 0x03) == 0x01: 2055 value = BuiltinVariable(str).call_function(self, [value], {}) # type: ignore[arg-type] 2056 elif (flags & 0x03) == 0x02: 2057 value = BuiltinVariable(repr).call_function(self, [value], {}) # type: ignore[arg-type] 2058 elif (flags & 0x03) == 0x03: 2059 value = BuiltinVariable(ascii).call_function(self, [value], {}) # type: ignore[arg-type] 2060 2061 fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}") 2062 2063 self.call_function(BuiltinVariable(str.format), [fmt_var, value], {}) 2064 2065 def BUILD_STRING(self, inst): 2066 format_string_parts: List[str] = [] 2067 args: List[VariableTracker] = [] 2068 kwargs: Dict[str, VariableTracker] = {} 2069 for part in self.popn(inst.arg): 2070 if isinstance(part, ConstantVariable): 2071 format_string_parts.append("{}") 2072 args.append(part) 2073 elif isinstance(part, variables.StringFormatVariable): 2074 format_string_parts.append(part.format_string) 2075 args.extend(part.sym_args) 2076 if set(kwargs.keys()) & set(part.sym_kwargs.keys()): 2077 unimplemented( 2078 f"BUILD_STRING key conflict {kwargs} & {part.sym_kwargs}" 2079 ) 2080 kwargs.update(part.sym_kwargs) 2081 else: 2082 unimplemented(f"BUILD_STRING {part}") 2083 self.push( 2084 variables.StringFormatVariable.create( 2085 "".join(format_string_parts), args, kwargs 2086 ) 2087 ) 2088 2089 def IS_OP(self, inst): 2090 assert inst.argval == 0 or inst.argval == 1 2091 if inst.argval == 0: 2092 new_argval = "is" 2093 else: 2094 new_argval = "is not" 2095 new_inst = create_instruction("COMPARE_OP", argval=new_argval) 2096 self.COMPARE_OP(new_inst) 2097 2098 def CONTAINS_OP(self, inst): 2099 assert inst.argval == 0 or inst.argval == 1 2100 left, right = self.popn(2) 2101 op = inst.argval 2102 self.push(right.call_method(self, "__contains__", [left], {})) 2103 if op == 1: 2104 self.UNARY_NOT(inst) 2105 2106 def LIST_EXTEND(self, inst): 2107 v = self.pop() 2108 assert inst.argval > 0 2109 obj = self.stack[-inst.arg] 2110 assert isinstance(obj, ListVariable) 2111 assert obj.mutable_local 2112 obj.call_method(self, "extend", [v], {}) 2113 2114 def LIST_TO_TUPLE(self, inst): 2115 self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type] 2116 2117 def DICT_MERGE(self, inst): 2118 v = self.pop() 2119 assert inst.argval > 0 2120 obj = self.stack[-inst.arg].realize() 2121 assert isinstance(obj, ConstDictVariable) 2122 assert obj.mutable_local 2123 obj.call_method(self, "update", [v], {}) 2124 2125 DICT_UPDATE = DICT_MERGE 2126 2127 def GEN_START(self, inst): 2128 self.pop() 2129 2130 def GET_LEN(self, inst): 2131 tos = self.stack[-1] 2132 if tos.is_python_constant(): 2133 self.push(ConstantVariable.create(len(tos.as_python_constant()))) 2134 else: 2135 self.push(tos.call_method(self, "__len__", [], {})) 2136 2137 def MATCH_MAPPING(self, inst): 2138 tos = self.stack[-1] 2139 assert isinstance(tos, ConstDictVariable) 2140 if isinstance(tos.items, collections.abc.Mapping): 2141 self.push(ConstantVariable.create(True)) 2142 else: 2143 self.push(ConstantVariable.create(False)) 2144 2145 def MATCH_SEQUENCE(self, inst): 2146 tos = self.stack[-1] 2147 assert tos.is_python_constant() 2148 tos_value = tos.as_python_constant() 2149 if isinstance(tos_value, collections.abc.Sequence) and not isinstance( 2150 tos_value, (str, bytes, bytearray) 2151 ): 2152 self.push(ConstantVariable.create(True)) 2153 else: 2154 self.push(ConstantVariable.create(False)) 2155 2156 def MATCH_KEYS(self, inst): 2157 tos = self.stack[-1] 2158 tos1 = self.stack[-2] 2159 assert isinstance(tos1, ConstDictVariable) 2160 2161 if all(k in tos1 for k in tos): # type: ignore[attr-defined] 2162 self.push(TupleVariable([tos1.getitem_const(self, k) for k in tos])) # type: ignore[attr-defined,arg-type] 2163 if sys.version_info < (3, 11): 2164 self.push(ConstantVariable.create(True)) 2165 else: 2166 self.push(ConstantVariable.create(None)) 2167 if sys.version_info < (3, 11): 2168 self.push(ConstantVariable.create(False)) 2169 2170 def LOAD_ASSERTION_ERROR(self, inst): 2171 self.load_builtin_from_argval("AssertionError") 2172 2173 UNARY_POSITIVE = stack_op(operator.pos) 2174 UNARY_NEGATIVE = stack_op(operator.neg) 2175 UNARY_NOT = stack_op(operator.not_) 2176 UNARY_INVERT = stack_op(operator.invert) 2177 2178 BINARY_POWER = stack_op(operator.pow) 2179 BINARY_MULTIPLY = stack_op(operator.mul) 2180 BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul) 2181 BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv) 2182 BINARY_TRUE_DIVIDE = stack_op(operator.truediv) 2183 BINARY_MODULO = stack_op(operator.mod) 2184 BINARY_REMAINDER = stack_op(operator.mod) 2185 BINARY_ADD = stack_op(operator.add) 2186 BINARY_SUBTRACT = stack_op(operator.sub) 2187 BINARY_SUBSCR = break_graph_if_unsupported(push=1)(stack_op(operator.getitem)) 2188 BINARY_LSHIFT = stack_op(operator.lshift) 2189 BINARY_RSHIFT = stack_op(operator.rshift) 2190 BINARY_AND = stack_op(operator.and_) 2191 BINARY_OR = stack_op(operator.or_) 2192 BINARY_XOR = stack_op(operator.xor) 2193 2194 INPLACE_POWER = stack_op(operator.ipow) 2195 INPLACE_MULTIPLY = stack_op(operator.imul) 2196 INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul) 2197 INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv) 2198 INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv) 2199 INPLACE_MODULO = stack_op(operator.imod) 2200 INPLACE_REMAINDER = stack_op(operator.imod) 2201 INPLACE_ADD = stack_op(operator.iadd) 2202 INPLACE_SUBTRACT = stack_op(operator.isub) 2203 INPLACE_LSHIFT = stack_op(operator.ilshift) 2204 INPLACE_RSHIFT = stack_op(operator.irshift) 2205 INPLACE_AND = stack_op(operator.iand) 2206 INPLACE_XOR = stack_op(operator.ixor) 2207 INPLACE_OR = stack_op(operator.ior) 2208 2209 # 3.11 opcodes 2210 def RESUME(self, inst): 2211 if inst.arg == 0: 2212 self.append_prefix_inst(inst) 2213 self.accept_prefix_inst = False 2214 else: 2215 assert not self.accept_prefix_inst 2216 2217 if sys.version_info >= (3, 11): 2218 2219 def BINARY_OP(self, inst): 2220 return _binary_op_lookup[inst.arg](self, inst) 2221 2222 def PRECALL(self, inst): 2223 pass 2224 2225 def KW_NAMES(self, inst): 2226 kw_names = self.code_options["co_consts"][inst.arg] 2227 assert isinstance(kw_names, tuple) 2228 for name in kw_names: 2229 assert isinstance(name, str) 2230 assert self.kw_names is None 2231 self.kw_names = ConstantVariable.create(value=kw_names) # type: ignore[assignment] 2232 2233 def PUSH_NULL(self, inst): 2234 self.push(NullVariable()) 2235 2236 def _call(self, inst, call_kw=False): 2237 # see https://docs.python.org/3.11/library/dis.html#opcode-CALL 2238 # for convention 2239 if call_kw: 2240 # TOS is kw_names for CALL_KW instruction 2241 assert sys.version_info >= (3, 13) 2242 kw_names = self.pop() 2243 assert isinstance(kw_names, TupleVariable) and kw_names.is_python_constant() 2244 kw_names = kw_names.as_python_constant() 2245 else: 2246 kw_names = self.kw_names.value if self.kw_names else () 2247 2248 contents = self.popn(inst.arg + 2) 2249 if sys.version_info >= (3, 13): 2250 # NULL and callable swapped 2251 fn = contents[0] 2252 args = [] if isinstance(contents[1], NullVariable) else [contents[1]] 2253 else: 2254 if isinstance(contents[0], NullVariable): 2255 fn = contents[1] 2256 args = [] 2257 else: 2258 fn = contents[0] 2259 args = [contents[1]] 2260 2261 if kw_names: 2262 args = args + contents[2 : -len(kw_names)] 2263 kwargs_list = contents[-len(kw_names) :] 2264 kwargs = dict(zip(kw_names, kwargs_list)) 2265 assert len(kwargs) == len(kw_names) 2266 else: 2267 args = args + contents[2:] 2268 kwargs = {} 2269 2270 try: 2271 # if call_function fails, need to set kw_names to None, otherwise 2272 # a subsequent call may have self.kw_names set to an old value 2273 self.call_function(fn, args, kwargs) 2274 finally: 2275 self.kw_names = None 2276 2277 @break_graph_if_unsupported(push=1) 2278 def CALL(self, inst): 2279 self._call(inst) 2280 2281 def COPY(self, inst): 2282 self.push(self.stack[-inst.arg]) 2283 2284 def SWAP(self, inst): 2285 self.stack[-1], self.stack[-inst.arg] = self.stack[-inst.arg], self.stack[-1] 2286 2287 JUMP_BACKWARD = jump 2288 JUMP_BACKWARD_NO_INTERRUPT = jump 2289 2290 POP_JUMP_FORWARD_IF_TRUE = generic_jump(operator.truth, False) 2291 POP_JUMP_BACKWARD_IF_TRUE = generic_jump(operator.truth, False) 2292 POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False) 2293 POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False) 2294 2295 def CACHE(self, inst): 2296 pass 2297 2298 def BEFORE_WITH(self, inst): 2299 self.setup_or_before_with(inst) 2300 2301 def setup_or_before_with(self, inst): 2302 ctx = self.pop() 2303 if not isinstance( 2304 ctx, (ContextWrappingVariable, GenericContextWrappingVariable) 2305 ): 2306 unimplemented(f"{inst.opname} {ctx}") 2307 2308 if isinstance(ctx, GenericContextWrappingVariable): 2309 self.generic_context_manager_depth += 1 2310 2311 # Need this redundant check for mypy 2312 assert isinstance( 2313 ctx, (ContextWrappingVariable, GenericContextWrappingVariable) 2314 ) 2315 2316 exit = WithExitFunctionVariable( 2317 ctx, 2318 inst.target, 2319 ) 2320 2321 if sys.version_info >= (3, 11): 2322 # See create_call_resume_at for block stack details. 2323 # Only push a block if the current instruction's block is a 2324 # with block that is not nested in a try block - that is, the current 2325 # instruction's block target is the same as the top block's target. 2326 if inst.exn_tab_entry and ( 2327 not self.block_stack 2328 or inst.exn_tab_entry.target is not self.block_stack[-1].target 2329 ): 2330 target = None 2331 else: 2332 target = self.next_instruction.exn_tab_entry.target 2333 else: 2334 target = inst.target 2335 2336 if target: 2337 if isinstance(self, InstructionTranslator): 2338 self.block_stack.append( 2339 BlockStackEntry(inst, target, len(self.stack), ctx) 2340 ) 2341 else: 2342 self.block_stack.append(BlockStackEntry(inst, target)) 2343 2344 self.push(exit) 2345 self.push(ctx.enter(self)) 2346 2347 def append_prefix_inst(self, inst): 2348 assert self.accept_prefix_inst 2349 self.prefix_insts.append(inst) 2350 2351 def MAKE_CELL(self, inst): 2352 if sys.version_info >= (3, 12) and not self.accept_prefix_inst: 2353 # In 3.12+, MAKE_CELL is not longer necessarily a prefix instruction. 2354 # It can be generated by inlined comprehensions. 2355 assert isinstance(self.symbolic_locals[inst.argval], NullVariable) 2356 self.symbolic_locals[ 2357 inst.argval 2358 ] = self.output.side_effects.track_cell_new() 2359 else: 2360 self.append_prefix_inst(inst) 2361 2362 def COPY_FREE_VARS(self, inst): 2363 self.append_prefix_inst(inst) 2364 2365 def RETURN_GENERATOR(self, inst): 2366 self.append_prefix_inst(inst) 2367 2368 # 3.12 opcodes 2369 # BINARY/STORE_SLICE opcodes are broken down into 2370 # BUILD_SLICE 2 and BINARY/STORE_SUBSCR 2371 2372 def END_FOR(self, inst): 2373 if sys.version_info >= (3, 13): 2374 self.pop() 2375 else: 2376 self.popn(2) 2377 2378 def LOAD_FAST_CHECK(self, inst): 2379 if isinstance(self.symbolic_locals[inst.argval], NullVariable): 2380 unimplemented("LOAD_FAST_CHECK on uninitialized variable") 2381 self.LOAD_FAST(inst) 2382 2383 def LOAD_FAST_AND_CLEAR(self, inst): 2384 if inst.argval not in self.symbolic_locals: 2385 self.push(NullVariable()) 2386 else: 2387 self.LOAD_FAST(inst) 2388 self.symbolic_locals[inst.argval] = NullVariable() 2389 2390 def LOAD_SUPER_ATTR(self, inst): 2391 self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) 2392 if inst.arg & 1: 2393 self.LOAD_METHOD(inst) 2394 else: 2395 self._load_attr(inst) 2396 2397 def CALL_INTRINSIC_1(self, inst): 2398 if inst.argval == 5: 2399 # INTRINSIC_UNARY_POSITIVE 2400 self.UNARY_POSITIVE(inst) 2401 elif inst.argval == 6: 2402 # INTRINSIC_LIST_TO_TUPLE 2403 self.push(TupleVariable(self.pop().force_unpack_var_sequence(self))) 2404 else: 2405 unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}") 2406 2407 def END_SEND(self, inst): 2408 tos = self.pop() 2409 self.pop() 2410 self.push(tos) 2411 2412 # 3.13 opcodes 2413 # fused instructions LOAD_FAST_LOAD_FAST, STORE_FAST_STORE_FAST, STORE_FAST_LOAD_FAST 2414 # are broken down. 2415 @break_graph_if_unsupported(push=1) 2416 def CALL_KW(self, inst): 2417 self._call(inst, call_kw=True) 2418 2419 def TO_BOOL(self, inst): 2420 # TO_BOOL only precedes a conditional jump or UNARY_NOT (see compile.c in CPython) 2421 # So we can skip this instruction as long as we remember to codegen a TO_BOOL 2422 # before conditional jumps/UNARY_NOT. 2423 assert self.next_instruction.opname in ( 2424 "POP_JUMP_IF_TRUE", 2425 "POP_JUMP_IF_FALSE", 2426 "UNARY_NOT", 2427 ) 2428 2429 def SET_FUNCTION_ATTRIBUTE(self, inst): 2430 flags = inst.arg 2431 fn = self.pop() 2432 assert isinstance(fn, NestedUserFunctionVariable) 2433 attr_names = self.name_stack[-1] 2434 attr = self.pop() 2435 2436 if flags & 0x08: 2437 # 3.13 merged LOAD_CLOSURE into LOAD_FAST, so we won't know if a given LOAD_FAST 2438 # is meant to load a closure variable or not. Our workaround is to maintain a stack 2439 # of LOAD_FAST variable names and tuples (self.name_stack). So if we are indeed 2440 # constructing a closure tuple, we can use self.name_stack to construct the closure 2441 # variables here. 2442 assert isinstance(attr_names, tuple) and all( 2443 isinstance(name, str) for name in attr_names 2444 ) 2445 fn.closure = TupleVariable( 2446 [self._load_closure(name) for name in attr_names] 2447 ) 2448 fn.closure_scope = self 2449 elif flags & 0x04: 2450 fn.annotations = attr 2451 elif flags & 0x02: 2452 fn.kwdefaults = attr 2453 elif flags & 0x01: 2454 fn.defaults = attr 2455 2456 self.push(fn) 2457 2458 def _format_value_313(self, fmt_spec): 2459 value = self.pop() 2460 if isinstance(value, SymNodeVariable): 2461 value = ConstantVariable.create(str(value.sym_num)) 2462 2463 fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}") 2464 2465 self.call_function(BuiltinVariable(str.format), [fmt_var, value], {}) 2466 2467 def FORMAT_SIMPLE(self, inst): 2468 self._format_value_313(ConstantVariable.create("")) 2469 2470 def FORMAT_WITH_SPEC(self, inst): 2471 self._format_value_313(self.pop()) 2472 2473 def is_non_empty_graph(self): 2474 if self.output.count_calls() > 1: 2475 # perf optimization only 2476 self.is_non_empty_graph = lambda: True # type: ignore[method-assign] 2477 return True 2478 return False 2479 2480 def format_frame_summary(self, additional_stack_frames=None): 2481 if additional_stack_frames is None: 2482 additional_stack_frames = [] 2483 return "".join( 2484 traceback.format_list( 2485 [self.frame_summary()] + list(reversed(additional_stack_frames)) 2486 ) 2487 ) 2488 2489 def frame_summary(self): 2490 return traceback.FrameSummary( 2491 getattr(self.f_code, "co_filename", "<unknown>"), 2492 self.lineno, 2493 getattr(self.f_code, "co_name", "<unknown>"), 2494 lookup_line=False, 2495 ) 2496 2497 def is_co_filename_from_nn_modules(self): 2498 filename = getattr(self.f_code, "co_filename", "<unknown>") 2499 nn_modules_pattern = re.compile(r".*torch/nn/modules.*") 2500 return nn_modules_pattern.match(filename) is not None 2501 2502 def store_global_weakref_by_id(self, prefix, value): 2503 global_name = self.output.install_global_by_id(prefix, weakref.ref(value)) 2504 install_guard( 2505 GlobalWeakRefSource(global_name).make_guard(GuardBuilder.WEAKREF_ALIVE) 2506 ) 2507 return global_name 2508 2509 @property 2510 def fake_mode(self): 2511 return self.output.tracing_context.fake_mode 2512 2513 def find_symbolic_locals_name(self, tensor_variable): 2514 for key, value in self.symbolic_locals.items(): 2515 if value is tensor_variable: 2516 return key 2517 return None 2518 2519 @contextlib.contextmanager 2520 def strict_translation_mode(self, check_fn: Callable[[VariableTracker], bool]): 2521 """ 2522 Strict mode is enabled on a per-VariableTracker level depending on the return value of check_fn(node). 2523 """ 2524 prior = self.strict_checks_fn 2525 self.strict_checks_fn = check_fn 2526 try: 2527 yield 2528 finally: 2529 self.strict_checks_fn = prior 2530 2531 def speculate(self) -> SpeculationEntry: 2532 assert self.instruction_pointer is not None 2533 assert self.instruction_pointer > 0 2534 return self.speculation_log.next( 2535 self.f_code.co_filename, 2536 self.lineno, 2537 self.instruction_pointer - 1, 2538 self.instructions[self.instruction_pointer - 1], 2539 ) 2540 2541 def __init__( 2542 self, 2543 output: OutputGraph, 2544 instructions: List[Instruction], 2545 f_locals: Dict[str, Any], 2546 f_globals: Dict[str, Any], 2547 f_builtins: Dict[str, Any], 2548 code_options: Dict[str, Any], 2549 symbolic_locals: Dict[str, VariableTracker], 2550 symbolic_globals: Dict[str, VariableTracker], 2551 symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], 2552 f_code: types.CodeType, 2553 export: bool, 2554 inline_depth: int, 2555 speculation_log: SpeculationLog, 2556 distributed_state: Optional[DistributedState], 2557 ) -> None: 2558 super().__init__() 2559 self.speculation_log = speculation_log 2560 self.distributed_state = distributed_state 2561 2562 # Mutable state checkpointed by copy_graphstate() 2563 self.output = output 2564 self.symbolic_locals = symbolic_locals 2565 self.symbolic_globals = symbolic_globals 2566 self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack 2567 self.stack = [] 2568 # stack of variable names for tracking 3.13 closures 2569 self.name_stack: list[Any] = [] 2570 self.instruction_pointer = 0 2571 self.current_instruction = create_instruction("NOP") 2572 self.block_stack = [] 2573 # states before SETUP_WITH for checkpointing and fallback 2574 self.generic_context_manager_depth = 0 2575 self.lineno = -1 2576 self.kw_names = None 2577 self.accept_prefix_inst = True 2578 self.prefix_insts = [] 2579 self.exn_vt_stack = [] 2580 2581 # Properties of the input/output code 2582 self.instructions: List[Instruction] = instructions 2583 self.indexof: Dict[Instruction, int] = get_indexof(self.instructions) 2584 self.f_locals: Dict[ 2585 str, Any 2586 ] = f_locals # needed for recording accessed locals for replay 2587 self.f_globals: Dict[str, Any] = f_globals 2588 self.f_builtins: Dict[str, Any] = f_builtins 2589 self.code_options: Dict[str, Any] = code_options 2590 self.f_code: types.CodeType = f_code 2591 2592 # Execution record for replaying errors 2593 if config.replay_record_enabled: 2594 self.exec_recorder = ExecutionRecorder( 2595 code=f_code, code_options=code_options 2596 ) 2597 else: 2598 self.exec_recorder = None 2599 # Stack of module being parsed, current nn.module is at the end of ordered dict. 2600 # The first field of tuple is the fully qualified name of current module 2601 # in original hierarchy. The second field is the type of current nn.module 2602 self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {} 2603 # Flag to indicate whether tracing is used for export. 2604 self.export = export 2605 self.one_graph = False 2606 2607 self.current_speculation = None 2608 2609 self.strict_checks_fn = None 2610 2611 if sys.version_info >= (3, 10): 2612 from .resume_execution import ( 2613 CO_ASYNC_GENERATOR, 2614 CO_COROUTINE, 2615 CO_GENERATOR, 2616 CO_ITERABLE_COROUTINE, 2617 ) 2618 2619 if f_code.co_flags & ( 2620 CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR 2621 ): 2622 self.push(BuiltinVariable(None)) 2623 2624 self.inline_depth = inline_depth 2625 self.inconsistent_side_effects = False 2626 self._constants_cache: List[Optional[VariableTracker]] = [None] * len( 2627 f_code.co_consts 2628 ) 2629 linecache.lazycache(f_code.co_filename, f_globals) 2630 2631 2632class InstructionTranslator(InstructionTranslatorBase): 2633 mutated_closure_cell_contents: Set[str] 2634 2635 @staticmethod 2636 def current_tx() -> "InstructionTranslator": 2637 return tls.current_tx 2638 2639 @contextlib.contextmanager 2640 def set_current_tx(self): 2641 prior = getattr(tls, "current_tx", None) 2642 tls.current_tx = self 2643 try: 2644 yield 2645 finally: 2646 tls.current_tx = prior 2647 2648 def __init__( 2649 self, 2650 instructions: List[Instruction], 2651 f_code, 2652 f_locals, 2653 f_globals, 2654 f_builtins, 2655 code_options, 2656 compiler_fn, 2657 one_graph, 2658 export, 2659 export_constraints, 2660 mutated_closure_cell_contents: Set[str], 2661 frame_state, 2662 speculation_log: SpeculationLog, 2663 distributed_state: Optional[DistributedState], 2664 ) -> None: 2665 _step_logger()( 2666 logging.INFO, 2667 f"torchdynamo start tracing {f_code.co_name} {code_options['co_filename']}:{code_options['co_firstlineno']}", 2668 ) 2669 super().__init__( 2670 output=OutputGraph( 2671 code_options, 2672 compiler_fn, 2673 self, 2674 export, 2675 export_constraints, 2676 frame_state, 2677 local_scope=f_locals, 2678 global_scope=f_globals, 2679 f_code=f_code, 2680 ), 2681 instructions=instructions, 2682 f_locals=f_locals, 2683 f_globals=f_globals, 2684 f_builtins=f_builtins, 2685 code_options=code_options, 2686 symbolic_locals={}, # set below 2687 # A global var is inserted only after a STORE_GLOBAL happens to it 2688 symbolic_globals={}, 2689 symbolic_torch_function_mode_stack=collections.deque(), 2690 f_code=f_code, 2691 export=export, 2692 inline_depth=0, 2693 speculation_log=speculation_log, 2694 distributed_state=distributed_state, 2695 ) 2696 2697 self._throw_if_in_functorch() 2698 2699 # as soon as we create the tracing context we should keep it active, so any calls 2700 # into dynamo apis can rely on finding it 2701 with tracing(self.output.tracing_context), self.set_current_tx(): 2702 self.one_graph: bool = one_graph 2703 self.export = export 2704 self.mutated_closure_cell_contents = mutated_closure_cell_contents 2705 if self.export: 2706 assert ( 2707 self.one_graph 2708 ), "Export without one graph - something has gone wrong." 2709 2710 vars = list(code_options["co_varnames"]) 2711 cells_and_freevars = [x for x in self.cell_and_freevars() if x not in vars] 2712 vars.extend(cells_and_freevars) 2713 cells_and_freevars_set = set(cells_and_freevars) 2714 2715 self.symbolic_locals = { 2716 k: variables.LazyVariableTracker.create( 2717 f_locals[k], 2718 source=LocalSource(k, cell_or_freevar=k in cells_and_freevars_set), 2719 ) 2720 for k in vars 2721 if k in f_locals 2722 } 2723 2724 self._init_torch_function_mode_stack() 2725 2726 self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = [] 2727 if export: 2728 # export gets confused if we never realize unused inputs 2729 # in export mode just eagerly realize everything 2730 self.symbolic_locals = variables.LazyVariableTracker.realize_all( 2731 self.symbolic_locals 2732 ) 2733 2734 self._freevars_ids = {} 2735 for name in self.code_options["co_freevars"]: 2736 if name in f_locals: 2737 self._freevars_ids[name] = id(f_locals[name]) 2738 2739 def _throw_if_in_functorch(self): 2740 # Fallback to eager in case of a graph break inside vmap 2741 eager = torch._dynamo.lookup_backend("eager") 2742 compiler_fn = inspect.getattr_static( 2743 self.output.compiler_fn, "compiler_fn", self.output.compiler_fn 2744 ) 2745 ci = torch._C._functorch.peek_interpreter_stack() 2746 forbidden_keys = ( 2747 torch._C._functorch.TransformType.Vmap, 2748 torch._C._functorch.TransformType.Grad, 2749 torch._C._functorch.TransformType.Jvp, 2750 ) 2751 2752 if ci is not None and ci.key() in forbidden_keys and compiler_fn is not eager: 2753 name = ci.key().name.lower() 2754 msg = ( 2755 "If you are reaching here, it means dynamo failed for one of the following reasons:\n" 2756 # Calling a torch.compiled function 2757 f"- Calling torch.func.{name}(compiled_fn) function from eager mode is not supported. " 2758 f"Ensure that torch.func.{name} is also wrapped within a torch.compile function. " 2759 "For more information, see PyTorch issue #128711.\n" 2760 # if it reaches here, it means Dynamo failed to inline a functorch function 2761 f"- torch.func.{name}(fn) requires the function to be inlined by dynamo" 2762 ) 2763 unimplemented(msg) 2764 2765 def _init_torch_function_mode_stack(self): 2766 from .variables.torch_function import TorchFunctionModeStackVariable 2767 2768 TorchFunctionModeStackVariable.reset() 2769 2770 self.symbolic_torch_function_mode_stack: Deque[ 2771 TorchFunctionModeVariable 2772 ] = collections.deque() 2773 # We want to retrieve all modes to properly reconstruct the stack if needed 2774 py_stack = get_torch_function_mode_stack(filter_ignored=False) 2775 2776 if py_stack: 2777 has_device_context = isinstance( 2778 py_stack[0], torch.utils._device.DeviceContext 2779 ) 2780 2781 for i, val in enumerate(py_stack): 2782 self.symbolic_torch_function_mode_stack.append( 2783 variables.LazyVariableTracker.create( 2784 val, source=TorchFunctionModeStackSource(i) 2785 ) 2786 ) 2787 2788 def get_example_value(self, source: Source): 2789 if isinstance(source, LocalSource): 2790 return self.f_locals[source.local_name] 2791 if isinstance(source, GlobalSource): 2792 return self.f_globals[source.global_name] 2793 raise KeyError 2794 2795 def run(self): 2796 super().run() 2797 2798 def match_nested_cell(self, name, cell): 2799 """Match a cell in this method to one in a function we are inlining""" 2800 try: 2801 value = cell.cell_contents 2802 except ValueError: 2803 return None 2804 # TODO(jansel): check the id of the cell rather than the contents 2805 if id(value) != self._freevars_ids.get(name): 2806 return None 2807 return self.symbolic_locals[name] 2808 2809 def should_compile_partial_graph(self): 2810 if sys.version_info >= (3, 11): 2811 # Do not compile if current instruction's block is not the top with block 2812 entry = self.current_instruction.exn_tab_entry 2813 if entry and ( 2814 not self.block_stack or entry.target is not self.block_stack[-1].target 2815 ): 2816 return False 2817 return ( 2818 all(b.can_restore() for b in self.block_stack) 2819 and not self.one_graph 2820 and self.generic_context_manager_depth == 0 2821 ) 2822 2823 def create_call_resume_at(self, inst): 2824 self.instruction_pointer = None 2825 2826 if inst.opname == "RETURN_VALUE": 2827 return [create_instruction("RETURN_VALUE")] 2828 elif inst.opname == "RETURN_CONST": 2829 return [create_instruction("RETURN_CONST", argval=inst.argval)] 2830 2831 reads = livevars_analysis(self.instructions, inst) 2832 all_argnames = tuple( 2833 k 2834 for k in self.symbolic_locals.keys() 2835 if k in reads and k not in self.cell_and_freevars() 2836 ) 2837 # NOTE: do not use isinstance, since it realizes lazy VT's 2838 argnames = tuple( 2839 k 2840 for k in all_argnames 2841 if not type.__instancecheck__(NullVariable, self.symbolic_locals[k]) 2842 ) 2843 argnames_null = tuple( 2844 k 2845 for k in all_argnames 2846 if type.__instancecheck__(NullVariable, self.symbolic_locals[k]) 2847 ) 2848 if sys.version_info < (3, 12): 2849 assert len(argnames_null) == 0, "variables should not be NULL in < 3.12" 2850 2851 cg = PyCodegen(self) 2852 2853 # Handle inactive context variables. 2854 # The resume function assumes that context variables are the class, NOT the object. 2855 # e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled 2856 stack_ctx_vars = [] 2857 for i, var in enumerate(self.stack): 2858 if type.__instancecheck__(ContextWrappingVariable, var): 2859 ctx = cast(ContextWrappingVariable, var) 2860 target_values = ( 2861 () if ctx.target_values is None else tuple(ctx.target_values) 2862 ) 2863 stack_ctx_vars.append((i, target_values)) 2864 # Replace the current stack var with the context class 2865 ctx.reconstruct_type(cg) 2866 cg.extend_output(create_swap(len(self.stack) - i + 1)) 2867 cg.append_output(create_instruction("POP_TOP")) 2868 2869 argnames_ctx_vars = [] 2870 for name in argnames: 2871 if type.__instancecheck__( 2872 ContextWrappingVariable, var := self.symbolic_locals[name] 2873 ): 2874 ctx = cast(ContextWrappingVariable, var) 2875 target_values = ( 2876 () if ctx.target_values is None else tuple(ctx.target_values) 2877 ) 2878 argnames_ctx_vars.append((name, target_values)) 2879 # Replace the local with the context class 2880 ctx.reconstruct_type(cg) 2881 cg.append_output(create_instruction("STORE_FAST", argval=name)) 2882 2883 # Python does not allow null to be an arg to a function, so 2884 # we remove nulls from the stack and restore them in the 2885 # prologue of the resume function 2886 2887 # sorted list of indices of nulls on the stack 2888 null_idxes: List[int] = [] 2889 if sys.version_info >= (3, 11): 2890 # find indices of NullVariables 2891 for i, var in enumerate(self.stack): 2892 if type.__instancecheck__(NullVariable, var): 2893 null_idxes.append(i) 2894 # generate bytecode to pop the nulls 2895 null_cnt = 0 2896 for i, var in enumerate(reversed(self.stack)): 2897 if type.__instancecheck__(NullVariable, var): 2898 for j in range(2, i + 2 - null_cnt): 2899 cg.append_output(create_instruction("SWAP", arg=j)) 2900 cg.extend_output(cg.pop_null()) 2901 null_cnt += 1 2902 2903 # we popped all nulls from the stack at runtime, 2904 # so we should not count NullVariables 2905 stack_len = len(self.stack) - len(null_idxes) 2906 nargs = stack_len + len(argnames) 2907 2908 name = unique_id(f"__resume_at_{inst.offset}") 2909 2910 new_code: types.CodeType = ContinueExecutionCache.lookup( 2911 self.f_code, 2912 self.lineno, 2913 inst.offset, 2914 tuple(b.target.offset for b in self.block_stack), 2915 stack_len, 2916 argnames, 2917 argnames_null, 2918 tuple(b.resume_fn() for b in self.block_stack), 2919 tuple(stack_ctx_vars), 2920 tuple(argnames_ctx_vars), 2921 tuple(null_idxes), 2922 ) 2923 2924 # Add original GraphModule context to the resume function to handle 2925 # the case of a graph break while tracing a GraphModule 2926 orig_graphmodule_maybe = code_context.get_context(self.f_code).get( 2927 "orig_graphmodule", lambda: None 2928 )() 2929 if orig_graphmodule_maybe is not None: 2930 code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref( 2931 orig_graphmodule_maybe 2932 ) 2933 2934 if new_code.co_freevars: 2935 # expose code object for debugging purposes 2936 self.output.install_global_unsafe(name, new_code) 2937 cg.make_function_with_closure(name, new_code, True, stack_len) 2938 else: 2939 # This is safe: we pre-generate a unique name 2940 self.output.install_global_unsafe( 2941 name, types.FunctionType(new_code, self.f_globals, name) 2942 ) 2943 cg.extend_output(cg.load_function_name(name, True, stack_len)) 2944 2945 cg.extend_output([cg.create_load(k) for k in argnames]) 2946 cg.extend_output(create_call_function(nargs, False)) 2947 cg.append_output(create_instruction("RETURN_VALUE")) 2948 return cg.get_instructions() 2949 2950 def symbolic_locals_contain_module_class(self): 2951 for v in self.symbolic_locals.values(): 2952 if isinstance(v, UserDefinedClassVariable) and issubclass( 2953 v.as_python_constant(), torch.nn.Module 2954 ): 2955 return True 2956 return False 2957 2958 def _return(self, inst): 2959 if ( 2960 self.output.count_calls() == 0 2961 and not self.inconsistent_side_effects 2962 and not self.symbolic_locals_contain_module_class() 2963 and not self.export 2964 ): 2965 raise exc.SkipFrame("because no content in function call") 2966 self.instruction_pointer = None 2967 _step_logger()( 2968 logging.INFO, 2969 f"torchdynamo done tracing {self.f_code.co_name} ({inst.opname})", 2970 ) 2971 log.debug("%s triggered compile", inst.opname) 2972 self.output.compile_subgraph( 2973 self, 2974 reason=GraphCompileReason( 2975 "return_value", [self.frame_summary()], graph_break=False 2976 ), 2977 ) 2978 return_inst = ( 2979 create_instruction("RETURN_VALUE") 2980 if inst.opname == "RETURN_VALUE" 2981 else create_instruction("RETURN_CONST", argval=inst.argval) 2982 ) 2983 self.output.add_output_instructions([return_inst]) 2984 raise ReturnValueOp 2985 2986 def RETURN_VALUE(self, inst): 2987 self._return(inst) 2988 2989 def RETURN_CONST(self, inst): 2990 self._return(inst) 2991 2992 2993if sys.version_info >= (3, 11): 2994 _binary_op_lookup = [ 2995 getattr( 2996 InstructionTranslator, 2997 opname[3:] if "INPLACE" in opname else f"BINARY_{opname[3:]}", 2998 ) 2999 for opname, _ in dis._nb_ops # type: ignore[attr-defined] 3000 ] 3001 3002 3003class InliningInstructionTranslator(InstructionTranslatorBase): 3004 """Trace and inline a called method""" 3005 3006 symbolic_result: Optional[TensorVariable] 3007 3008 @classmethod 3009 def inline_call(cls, parent, func, args, kwargs): 3010 with patch.dict(counters, {"unimplemented": counters["inline_call"]}): 3011 return cls.inline_call_(parent, func, args, kwargs) 3012 3013 @staticmethod 3014 def check_inlineable(func): 3015 if func.has_self(): 3016 unimplemented("inline with __self__") 3017 3018 result = trace_rules.check_verbose(func, is_inlined_call=True) 3019 if result.skipped: 3020 from torch._dynamo.variables.misc import produce_trampoline_autograd_apply 3021 3022 # _origin marks this as coming from an internal dynamo known function that is safe to 3023 # trace through. 3024 if hasattr(getattr(func, "fn", None), "_origin") and func.fn._origin in [ 3025 produce_trampoline_autograd_apply, 3026 ]: 3027 # Known sound 3028 return trace_rules.SkipResult( 3029 False, "allowlist in dynamo known function" 3030 ) 3031 fn_qualname = func.fn.__qualname__ if hasattr(func, "fn") else "" 3032 unimplemented( 3033 f"'inline in skipfiles: {fn_qualname} | {func.get_name()} {func.get_filename()}, {result.reason}'" 3034 ) 3035 3036 if isinstance(func, UserFunctionVariable) and inspect.getattr_static( 3037 func.get_function(), "_torchdynamo_disable", False 3038 ): 3039 unimplemented( 3040 f"call torch._dynamo.disable() wrapped function {func.get_function()}" 3041 ) 3042 else: 3043 return result 3044 3045 @staticmethod 3046 def inline_call_( 3047 parent, func: VariableTracker, args: List[VariableTracker], kwargs 3048 ): 3049 if isinstance(func, SkipFunctionVariable): 3050 unimplemented("inline with functions in skip files") 3051 assert isinstance( 3052 func, 3053 (UserFunctionVariable, NestedUserFunctionVariable), 3054 ) 3055 result = InliningInstructionTranslator.check_inlineable(func) 3056 assert result.skipped is False 3057 try: 3058 sub_locals, closure_cells = func.bind_args(parent, args, kwargs) 3059 except TypeError as e: 3060 # Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info 3061 raise ArgsMismatchError( # noqa: B904 3062 "{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format( 3063 reason=str(e), 3064 func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}", 3065 args=[arg.python_type() for arg in args], 3066 kwargs=kwargs, 3067 ), 3068 ) 3069 3070 for v in itertools.chain(sub_locals.values(), closure_cells.values()): 3071 if not isinstance(v, VariableTracker): 3072 unimplemented(f"unconverted arg {v}") 3073 3074 code: types.CodeType = func.get_code() 3075 if code.co_name in ("__setitem__", "__setattr__") and not ( 3076 args 3077 and isinstance( 3078 args[0], 3079 (variables.CustomizedDictVariable, variables.UserDefinedObjectVariable), 3080 ) 3081 ): 3082 unimplemented(f"inline {code.co_name}") 3083 3084 suffix = "" 3085 # TODO: mlazos, add support for enabling multiple artifact logs 3086 # with a single alias 3087 if torch._logging._internal.log_state.is_artifact_enabled("bytecode"): 3088 suffix = f"\n{dis.Bytecode(code).dis()}" 3089 if sys.version_info >= (3, 11): 3090 cur_inst = parent.current_instruction 3091 parent_code = parent.f_code 3092 header = parent.get_line_of_code_header(lineno=cur_inst.positions.lineno) 3093 3094 def get_trace_call_log_str(): 3095 line = get_instruction_source_311(parent_code, cur_inst).rstrip() 3096 return f"TRACE inlined call {code.co_name} from {header}\n{line}" 3097 3098 trace_call_log.debug("%s", LazyString(get_trace_call_log_str)) 3099 log.debug("INLINING %s%s, %s", code, suffix, result.reason) 3100 3101 # Detect inline GraphModule calls in order to propagate node metadata, 3102 # by checking if the first argument (self) is a variable tracking a GraphModule. 3103 if args and isinstance(args[0], NNModuleVariable): 3104 module = parent.output.get_submodule(args[0].module_key) 3105 if isinstance(module, torch.fx.GraphModule): 3106 # The inline call might not actually be a call to `forward`, 3107 # but it is enough to add a context for `forward` in case it is called. 3108 code_context.get_context(module.forward.__code__)[ 3109 "orig_graphmodule" 3110 ] = weakref.ref(module) 3111 3112 tracer: InliningInstructionTranslator 3113 if is_generator(code): 3114 tracer = InliningGeneratorInstructionTranslator( 3115 parent, 3116 code, 3117 sub_locals, 3118 parent.symbolic_globals, 3119 parent.symbolic_torch_function_mode_stack, 3120 closure_cells, 3121 func, 3122 ) 3123 else: 3124 tracer = InliningInstructionTranslator( 3125 parent, 3126 code, 3127 sub_locals, 3128 parent.symbolic_globals, 3129 parent.symbolic_torch_function_mode_stack, 3130 closure_cells, 3131 func, 3132 ) 3133 3134 strict_ctx: Any = contextlib.nullcontext() 3135 if parent.strict_checks_fn: 3136 strict_ctx = tracer.strict_translation_mode(parent.strict_checks_fn) 3137 try: 3138 with strict_ctx: 3139 tracer.run() 3140 except exc.ObservedException as e: 3141 msg = f"Observed exception DURING INLING {code} : {e}" 3142 # TODO(anijain2305) - This works but we should probably have a 3143 # global/central data structure for the exception stack. 3144 parent.exn_vt_stack.extend(tracer.exn_vt_stack) 3145 log.debug(msg) 3146 # bubble up the exception to the parent frame. 3147 raise 3148 except exc.SkipFrame as e: 3149 msg = f"SKIPPED INLINING {code}: {e}" 3150 log.debug(msg) 3151 raise Unsupported(msg) from e 3152 except Exception as e: 3153 log.debug("FAILED INLINING %s", code) 3154 raise 3155 assert tracer.symbolic_result is not None 3156 func.export_freevars(parent, tracer) 3157 3158 if tracer.f_globals is parent.f_globals: 3159 # Merge symbolic_globals back if parent and child are in the same namespace 3160 parent.symbolic_globals.update(tracer.symbolic_globals) 3161 3162 parent.inconsistent_side_effects |= tracer.inconsistent_side_effects 3163 3164 log.debug("DONE INLINING %s", code) 3165 3166 if is_generator(code): 3167 assert isinstance(tracer, InliningGeneratorInstructionTranslator) 3168 assert tracer.symbolic_result.as_python_constant() is None 3169 return ListIteratorVariable( 3170 tracer.generated_items, 3171 mutable_local=MutableLocal(), 3172 ) 3173 else: 3174 return tracer.symbolic_result 3175 3176 def __init__( 3177 self, 3178 parent: InstructionTranslatorBase, 3179 code: types.CodeType, 3180 symbolic_locals: Dict[str, VariableTracker], 3181 symbolic_globals: Dict[str, VariableTracker], 3182 symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], 3183 closure_cells: Dict[str, VariableTracker], 3184 funcvar: BaseUserFunctionVariable, 3185 ) -> None: 3186 f_globals = funcvar.get_globals() # type: ignore[attr-defined] 3187 f_builtins = f_globals["__builtins__"] 3188 if not isinstance(f_builtins, dict): 3189 f_builtins = f_builtins.__dict__ 3190 instructions = cleaned_instructions(code) 3191 propagate_line_nums(instructions) 3192 super().__init__( 3193 output=parent.output, 3194 f_locals={}, 3195 f_globals=f_globals, 3196 f_builtins=f_builtins, 3197 symbolic_locals=symbolic_locals, 3198 symbolic_globals=symbolic_globals, 3199 symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack, 3200 instructions=instructions, 3201 code_options={k: getattr(code, k) for k in get_code_keys()}, 3202 f_code=code, 3203 export=parent.export, 3204 inline_depth=parent.inline_depth + 1, 3205 speculation_log=parent.speculation_log, 3206 distributed_state=parent.distributed_state, 3207 ) 3208 self.parent = parent 3209 self.symbolic_result = None 3210 self.closure_cells = closure_cells 3211 self.nn_module_stack = parent.nn_module_stack.copy() 3212 self.one_graph = parent.one_graph 3213 3214 @property 3215 def fake_mode(self): 3216 return self.parent.fake_mode 3217 3218 def run_ctx_mgr(self): 3219 return TracingContext.current_frame(self.parent.frame_summary()) 3220 3221 def STORE_DEREF(self, inst): 3222 if inst.argval in self.closure_cells: 3223 cell = self.closure_cells[inst.argval] 3224 val = self.pop() 3225 if isinstance(cell, ClosureVariable): 3226 if not self.output.is_root_tracer(): 3227 unimplemented( 3228 "HigherOrderOperator: Mutating a variable not in the current scope (ClosureVariable)" 3229 ) 3230 self.output.root_tx.symbolic_locals[cell.name] = val 3231 else: 3232 self.output.side_effects.store_cell(cell, val) 3233 else: 3234 maybe_cell = self.symbolic_locals.get(inst.argval) 3235 if isinstance( 3236 maybe_cell, 3237 variables.NewCellVariable, 3238 ): 3239 self.output.side_effects.store_cell( 3240 self.symbolic_locals[inst.argval], self.pop() 3241 ) 3242 else: 3243 if ( 3244 maybe_cell is not None 3245 and maybe_cell.source.name() 3246 not in self.output.root_tx.mutated_closure_cell_contents 3247 ): 3248 # Why is the source name here unique? 3249 # mutated_closure_cell_contents is a per-frame 3250 # concept, and sources identify, e.g., particular 3251 # locals from the frame. If you had two locals, 3252 # they'll get different source names, and therefore 3253 # differ here. 3254 self.output.root_tx.mutated_closure_cell_contents.add( 3255 maybe_cell.source.name() 3256 ) 3257 raise exc.UnspecializeRestartAnalysis 3258 unimplemented("write to __closure__ while inlining") 3259 3260 def LOAD_DEREF(self, inst): 3261 if inst.argval in self.closure_cells: 3262 cell = self.closure_cells[inst.argval] 3263 if isinstance(cell, ClosureVariable): 3264 self.push(self.output.root_tx.symbolic_locals[cell.name]) 3265 else: 3266 self.push(self.output.side_effects.load_cell(cell)) 3267 else: 3268 maybe_sym_local = self.symbolic_locals.get(inst.argval, None) 3269 if isinstance(maybe_sym_local, variables.NewCellVariable): 3270 self.push(self.output.side_effects.load_cell(maybe_sym_local)) 3271 else: 3272 super().LOAD_DEREF(inst) 3273 3274 def _load_closure(self, name): 3275 assert name in self.cell_and_freevars() 3276 if name in self.closure_cells: 3277 return self.closure_cells[name] 3278 else: 3279 return InlinedClosureVariable(name=name) 3280 3281 def check_replace_is_safe(self, oldvar): 3282 if not is_side_effect_safe(oldvar.mutable_local): 3283 unimplemented( 3284 "HigherOrderOperator: Mutating a variable not in the current scope (replace_all)" 3285 ) 3286 3287 def should_compile_partial_graph(self): 3288 return False # inlining functions is all-or-nothing 3289 3290 def create_call_resume_at(self, offset): 3291 unimplemented("cant resume while inlining") 3292 3293 def RETURN_VALUE(self, inst): 3294 self.symbolic_result = self.pop() # type: ignore[assignment] 3295 self.instruction_pointer = None 3296 raise ReturnValueOp 3297 3298 def RETURN_CONST(self, inst): 3299 self.symbolic_result = self._load_const(inst) 3300 self.instruction_pointer = None 3301 raise ReturnValueOp 3302 3303 def get_globals_source_and_value(self, name): 3304 if "__name__" in self.f_globals: 3305 module_name = self.f_globals["__name__"] 3306 module_source = self.import_source(module_name) 3307 if "torch_package" in module_name: 3308 fglobals_value = torch.package.package_importer._package_imported_modules[module_name] # type: ignore[assignment] 3309 else: 3310 fglobals_value = importlib.import_module(module_name) # type: ignore[assignment] 3311 fglobals_vt = VariableBuilder(self, module_source)(fglobals_value) 3312 global_source = AttrSource(module_source, name) 3313 else: 3314 globals_name = self.output.install_global_by_id( 3315 "___unnamed_scope", self.f_globals 3316 ) 3317 globals_source = GlobalSource(globals_name) 3318 fglobals_value = self.f_globals # type: ignore[assignment] 3319 fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value) 3320 global_source = GetItemSource(globals_source, name) # type: ignore[assignment] 3321 return fglobals_value, fglobals_vt, global_source 3322 3323 def _load_global(self, inst): 3324 if self.output.global_scope is self.f_globals: 3325 super()._load_global(inst) 3326 else: 3327 name = inst.argval 3328 3329 _, fglobals_vt, global_source = self.get_globals_source_and_value(name) 3330 if self.output.side_effects.has_pending_mutation_of_attr(fglobals_vt, name): 3331 self.push(self.output.side_effects.load_attr(fglobals_vt, name)) 3332 else: 3333 try: 3334 value = self.f_globals[name] 3335 except KeyError: 3336 return self.load_builtin(inst) 3337 3338 self.push(VariableBuilder(self, global_source)(value)) 3339 3340 def STORE_GLOBAL(self, inst): 3341 if self.f_globals is self.parent.f_globals: 3342 super().STORE_GLOBAL(inst) 3343 else: 3344 value = self.pop() 3345 if isinstance(value, RemovableHandleVariable): 3346 unimplemented("Storing handles in globals - NYI") 3347 name = inst.argval 3348 fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name) 3349 self.output.side_effects.store_attr(fglobals_vt, name, value) 3350 3351 3352class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): 3353 generated_items: List[VariableTracker] 3354 3355 def __init__(self, *args, **kwargs) -> None: 3356 super().__init__(*args, **kwargs) 3357 self.generated_items = [] 3358 3359 def YIELD_VALUE(self, inst: Instruction): 3360 self.generated_items.append(self.pop()) 3361 if len(self.generated_items) > MAX_ITERATOR_LIMIT: 3362 unimplemented( 3363 "Too many yield values in generator. Maybe you are inlining an infinite generator. " 3364 f"If not, please report a bug at {PT2_ISSUE_TRACKER_URL}", 3365 ) 3366 self.push(ConstantVariable.create(None)) 3367 3368 def GET_YIELD_FROM_ITER(self, inst): 3369 tos = self.stack[-1] 3370 if not isinstance(tos, ListIteratorVariable): 3371 self.pop() 3372 res = BuiltinVariable(iter).call_function(self, [tos], {}) # type: ignore[arg-type] 3373 self.push(res) 3374 3375 def YIELD_FROM(self, inst): 3376 assert len(self.stack) >= 2 3377 val = self.pop() 3378 tos = self.stack[-1] 3379 if not (isinstance(val, ConstantVariable) and val.value is None): 3380 # invoke send 3381 # Unreachable code - if you hit this, you are implementing generator support and have 3382 # lifted the `unimplemented("generator")` in frame conversion. This codepath handles 3383 # subgenerator and lines up with this line in Python 3.10 3384 # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L2599 3385 unimplemented("Unreachable sub-generator code") 3386 3387 try: 3388 val = tos.next_variable(self) 3389 except (StopIteration, exc.ObservedUserStopIteration) as ex: 3390 if isinstance(ex, exc.ObservedUserStopIteration): 3391 exc.handle_observed_exception(self) 3392 3393 # The iterator is exhausted. Stop the loop and return. 3394 self.pop() 3395 self.push(ConstantVariable.create(ex.value)) 3396 else: 3397 self.push(val) 3398 # Add the value to yield into generated_items and replace the top of the stack with None 3399 self.YIELD_VALUE(inst) 3400 3401 # Repeat the YIELD_FROM instruction in the next eval loop 3402 assert ( 3403 isinstance(self.instruction_pointer, int) 3404 and self.instruction_pointer > 0 3405 ) 3406 self.instruction_pointer -= 1 3407 3408 def SEND(self, inst): 3409 assert len(self.stack) >= 2 3410 val = self.pop() 3411 tos = self.stack[-1] 3412 if isinstance(tos, ListIteratorVariable) or ( 3413 isinstance(tos, UserDefinedObjectVariable) 3414 and isinstance(tos.value, collections.abc.Iterator) 3415 ): 3416 if isinstance(val, ConstantVariable) and val.value is None: 3417 try: 3418 val = tos.next_variable(self) 3419 except (StopIteration, exc.ObservedUserStopIteration) as ex: 3420 # To implement SEND, we have to look at the implementation 3421 # when the iterator returns StopIteration. This translates to this code 3422 # 3.11: https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2613-L2619 3423 # 3.12: https://github.com/python/cpython/blob/3.12/Python/bytecodes.c#L863-L866 3424 # The implementation is different in 3.11 and 3.12. In 3.12, we rely 3425 # on END_SEND to clean up. In 3.11, SEND does the cleanup as well. 3426 if sys.version_info < (3, 12): 3427 self.pop() # Python 3.12 uses new opcode END_SEND 3428 self.push(ConstantVariable.create(ex.value)) 3429 self.jump(inst) 3430 else: 3431 self.push(val) 3432 else: 3433 # invoke send 3434 # Unreachable code - if you hit this, you are implementing generator support and have 3435 # lifted the `unimplemented("generator")` in frame conversion. This codepath handles 3436 # subgenerator and lines up with this line in Python 3.11 3437 # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597 3438 unimplemented("Unreachable sub-generator code") 3439 else: 3440 unimplemented(f"SEND {typestr(tos)}") 3441