1# mypy: allow-untyped-defs 2import collections 3import contextlib 4import copy 5import dataclasses 6import functools 7import itertools 8import json 9import logging 10import operator 11import re 12import sys 13import traceback 14import weakref 15from dataclasses import dataclass 16from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union 17 18import sympy 19 20import torch._guards 21import torch._logging 22import torch.distributed as dist 23import torch.nn 24import torch.utils._pytree as pytree 25from torch import fx 26from torch._guards import GlobalContextCheckpointState, Source, TracingContext 27from torch._utils_internal import signpost_event 28from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined] 29from torch.fx.experimental._backward_state import BackwardState 30from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv 31from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts 32from torch.utils._python_dispatch import is_traceable_wrapper_subclass 33 34from . import config, exc, logging as torchdynamo_logging, variables 35from .backends.registry import CompiledFn, CompilerFn 36from .bytecode_transformation import ( 37 create_call_function, 38 create_instruction, 39 Instruction, 40 unique_id, 41) 42from .code_context import code_context 43from .codegen import PyCodegen 44from .current_scope_id import enter_new_scope 45from .exc import ( 46 BackendCompilerFailed, 47 exceptions_allowed_to_be_fallback, 48 SkipFrame, 49 unimplemented, 50 unimplemented_with_warning, 51) 52from .guards import GuardBuilder, install_guard 53from .mutation_guard import is_dynamic_nn_module 54from .side_effects import AttributeMutationExisting, SideEffects 55from .source import ( 56 AttrSource, 57 BackwardStateSource, 58 ConstantSource, 59 GetItemSource, 60 GlobalStateSource, 61 is_constant_source, 62 is_from_local_source, 63 LocalSource, 64 ParamBufferSource, 65 ShapeEnvSource, 66 SyntheticLocalSource, 67 TensorProperty, 68 TensorPropertySource, 69) 70from .utils import ( 71 _extract_tensor_dict, 72 checkpoint_params, 73 CleanupHook, 74 clone_inputs, 75 count_calls, 76 counters, 77 dynamo_timed, 78 get_instruction_source_311, 79 get_locals_to_steal, 80 get_static_address_type, 81 get_torch_function_mode_stack, 82 graph_break_reasons, 83 increment_op_count, 84 lazy_format_graph_code, 85 LazyString, 86 nn_module_proxy, 87 same, 88 set_example_value, 89) 90from .variables.base import VariableTracker 91from .variables.builder import ( 92 BackwardStateGraphArg, 93 GraphArg, 94 TrackedFake, 95 VariableBuilder, 96 wrap_fx_proxy, 97) 98from .variables.lists import BaseListVariable 99from .variables.misc import NullVariable 100from .variables.nn_module import NNModuleVariable 101from .variables.tensor import ( 102 NumpyNdarrayVariable, 103 SymNodeVariable, 104 TensorVariable, 105 UnspecializedPythonVariable, 106) 107from .variables.torch_function import TensorWithTFOverrideVariable 108 109 110if TYPE_CHECKING: 111 from torch._dynamo.symbolic_convert import InstructionTranslatorBase 112 113 114log = logging.getLogger(__name__) 115graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph") 116graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") 117graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes") 118trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") 119 120 121@dataclass(frozen=True) 122class VariableTrackerCacheKey: 123 vt_id: int 124 # Two different source can point to the same object. However, Dynamo handles 125 # globals and local source differently when it comes to guards and possibly 126 # some other parts as well. So, cache also relies on the source. 127 source: Source 128 129 130class VariableTrackerCache: 131 def __init__(self): 132 self.cache = {} 133 134 def lookup(self, value, source): 135 key = VariableTrackerCacheKey(id(value), source) 136 if key not in self.cache: 137 return None 138 return self.cache[key] 139 140 def add(self, value, source, vt): 141 key = VariableTrackerCacheKey(id(value), source) 142 self.cache[key] = vt 143 144 def clone(self): 145 # Needed for copy and restore graph state 146 new_cache = VariableTrackerCache() 147 new_cache.cache.update(self.cache) 148 return new_cache 149 150 def clear(self): 151 self.cache.clear() 152 153 154@functools.lru_cache(None) 155def _step_logger(): 156 return torchdynamo_logging.get_step_logger(log) 157 158 159@dataclass 160class GraphCompileReason: 161 """Stores why a given output graph was compiled; i.e. what caused the graph break.""" 162 163 reason: str 164 user_stack: List[traceback.FrameSummary] 165 166 # Indicates if this was a graph compile reason due to graph break. 167 graph_break: bool = True 168 169 def __post_init__(self): 170 if self.graph_break: 171 graph_break_reasons.append(self) 172 173 174def _get_gen_rand_values_fn(random_calls): 175 def _gen_rand_values(): 176 return [fn(*args, **kwargs) for fn, args, kwargs in random_calls] 177 178 return _gen_rand_values 179 180 181class FakeRootModule(torch.nn.Module): 182 """Trick the constructor of fx.GraphModule""" 183 184 def __init__(self, nn_modules: Dict[str, torch.nn.Module]): 185 super().__init__() 186 for k, v in nn_modules.items(): 187 setattr(self, k, v) 188 189 def __repr__(self): 190 return "FakeRootModule(...)" 191 192 193class WrapperBackend: 194 def __init__(self, backend: CompilerFn): 195 self.backend: CompilerFn = backend 196 197 def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): 198 self.restore = checkpoint_params(gm) 199 self.gm = gm 200 copy_gm = copy.deepcopy(self.gm) 201 self.candidate = self.backend(copy_gm, example_inputs) 202 203 if self.candidate is None or self.candidate is self.gm.forward: 204 return self.gm.forward 205 206 if not config.verify_correctness: 207 return self.candidate 208 209 # if verify_correctness=True 210 try: 211 correct = self.gm.forward(*clone_inputs(example_inputs)) 212 result = self.candidate(*clone_inputs(example_inputs)) 213 214 # TODO: replace `same` function with the one in testing 215 if same(correct, result): 216 return self.candidate 217 218 raise RuntimeError(f"incorrect results of backend {self}") 219 return self.gm.forward 220 221 except Exception: 222 log.exception("error in verify_correctness") 223 raise 224 finally: 225 self.restore() 226 227 228Scope = Dict[str, object] 229 230 231class OutputGraph: 232 """ 233 Wrapper class to hold outputs of InstructionTranslator. Mainly the 234 generated fx.Graph. 235 236 OutputGraph is 1:1 with a frame being processed. Each frame is associated 237 with some root InstructionTranslator. When user code calls a function, 238 we construct a InliningInstructionTranslator that continues to write into 239 the root InstructionTranslator's OutputGraph. 240 """ 241 242 def __init__( 243 self, 244 code_options: Dict[str, Any], 245 compiler_fn: Optional[CompilerFn], 246 root_tx, 247 export: bool, 248 export_constraints, 249 frame_state, 250 local_scope: Scope, 251 global_scope: Scope, 252 f_code, 253 ): 254 super().__init__() 255 self.tracers = [SubgraphTracer(self, export_root=export)] 256 # Map from graph input's `Source` to its `VariableTracker` to 257 # de-duplicate graph inputs by source and reuse the tracker 258 self.input_source_to_var: Dict[Source, VariableTracker] = {} 259 self.export = export 260 self.export_constraints = export_constraints 261 self.frame_state = frame_state 262 # Map from graph input's `Source` to sizes / strides metadata 263 self.input_source_to_sizes_strides: Dict[Source, Dict[str, Any]] = {} 264 self.cleanup_hooks: List[Callable[[], Any]] = [] 265 # compile_id is an id number for the current torch.compile 266 self.compile_id: int = next(_compile_id_counter) 267 # Set of globals installed via install_global* APIs 268 self.installed_globals: Set[str] = set() 269 270 # TODO: maybe should just pass the entire f_code in here? Not 271 # sure... 272 self.co_fields = { 273 "co_name": f_code.co_name, 274 "co_filename": f_code.co_filename, 275 "co_firstlineno": f_code.co_firstlineno, 276 } 277 278 # tracked_fakes says where any tensor that was wrapped to fake came 279 # from. It is similar to GraphArg, in that all GraphArgs will get 280 # will get added to TrackedFakes, but TrackedFakes also contains 281 # GraphArgs that got pruned, and things like Tensor attributes which 282 # aren't explicit graph inputs. Used by shape guard 283 self.tracked_fakes: List[TrackedFake] = [] 284 285 # List of symbols for which we have exact bindings in the arguments 286 # already 287 self.bound_symbols: Set[sympy.Symbol] = set() 288 289 shape_env = ShapeEnv( 290 # Reference Cycle! 291 # Share a reference to the list of TrackedFake. 292 # 293 # ShapeEnv needs this in order to be able to reproduce the call 294 # to produce_guards at an arbitrary time point. That is because 295 # TrackedFake instances may have its metadata changed throughout 296 # the program execution. 297 tracked_fakes=self.tracked_fakes, 298 allow_scalar_outputs=config.capture_scalar_outputs, 299 allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, 300 prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards, 301 allow_complex_guards_as_runtime_asserts=config.allow_complex_guards_as_runtime_asserts, 302 co_fields=self.co_fields, 303 ) 304 305 # In export mode, we force the shape_env to strictly disallow any constraining 306 # of the user marked dynamic dims 307 import torch._functorch.config as _config 308 309 with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): 310 fake_mode = torch._subclasses.FakeTensorMode( 311 shape_env=shape_env, 312 # TODO (tmanlaibaatar) Remove this once we always lift params and buffers 313 allow_non_fake_inputs=True if self.export else False, 314 export=self.export, 315 ) 316 self.tracing_context: TracingContext = TracingContext(fake_mode) 317 self.init_ambient_guards() 318 319 # Map each tensor id to a list of sources. This is necessary because 320 # tensor ids cannot be recovered from tracked fakes (in general). 321 # We use this map to interpret (i.e., check for violations of) constraints, 322 # specifically equality constraints, which have shared tensor ids in them. 323 # This map should also be generally useful, e.g., for (de)serialization. 324 self.tracked_fakes_id_to_source: Dict[ 325 int, List[Source] 326 ] = collections.defaultdict(list) 327 # Stores the full fqn of a param or buffer to the relevant source. 328 self.param_name_to_source: Optional[Dict[str, Source]] = {} 329 self.side_effects = SideEffects() 330 # Cached variable trackers. This makes symbolic analysis of LOAD_GLOBAL 331 # and LOAD_ATTR for same python objects free. 332 self.variable_tracker_cache = VariableTrackerCache() 333 self.unique_var_id = itertools.count() 334 self.code_options = dict(code_options) 335 self.output_instructions: List[Instruction] = [] 336 # used to track nodes that are added between calls of copy_graphstate 337 # and restore_graphstate 338 self.timestamp = 0 339 340 # A list of register_finalizer_fns to apply to the output graph module 341 self.register_finalizer_fns: List[Callable[[fx.GraphModule], None]] = [] 342 343 # Not checkpointed 344 self.compiler_fn: Optional[CompilerFn] = compiler_fn 345 self.global_scope = global_scope 346 self.local_scope = local_scope 347 self.root_tx = root_tx 348 349 # Given a source, what are the user stacks of all locations that 350 # accessed it? 351 # 352 # For efficiency, we only populate this: 353 # - During export, and 354 # - If the source could potentially lead to a spurious export input 355 # 356 # Feel free to populate this more frequently if other use-cases arise, 357 # but be aware that we have to generate full stacks for each 358 # recording! 359 self.source_to_user_stacks: Dict[Source, List[traceback.StackSummary]] = {} 360 361 self._current_tx: List[InstructionTranslatorBase] = [] 362 self.cleanups: List[CleanupHook] = [] 363 self.should_exit = False 364 self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {} 365 366 # Note this returns true iff TF Mode and TF Subclasses are enabled 367 self.torch_function_enabled = torch._C._is_torch_function_enabled() 368 # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty 369 self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() 370 # This records the initial torch function mode stack for guarding 371 self.torch_function_mode_stack = get_torch_function_mode_stack() 372 373 # Tracks if the output graph has a user defined allowed function in the 374 # graph. This is used later to determine if we should fallback to eager 375 # for certain exceptions. THe idea is that if the user has applied 376 # allow_in_graph, they would like to see the error instead of falling 377 # back for backend errors. 378 self.has_user_defined_allowed_in_graph = False 379 380 # Tracks a list of called ops that were not tagged with "pt2_compliant_tag". 381 # This information is useful for logging. 382 self.non_compliant_ops: Set[torch._ops.OpOverload] = set({}) 383 384 # Tracks a list of called custom ops that were tagged with "pt2_compliant_tag". 385 # This information is useful for logging. 386 self.compliant_custom_ops: Set[torch._ops.OpOverload] = set({}) 387 388 # We save the global torch state here to be restored in case of graph 389 # breaks. The relevant issue is seen here 390 # https://github.com/pytorch/pytorch/pull/100570#issuecomment-1543427086 391 # where inlining of a function changes the global state (because of the 392 # presence of torch.no_grad) and there is a graph break. 393 self.save_global_state() 394 395 # Tracks the original FQNs of the constant tensors from the original graph, 396 # i.e. buffers and parameters. 397 self.dynamo_flat_name_to_original_fqn: Dict[str, str] = {} 398 399 # All calls to random() are replaced with a single call to __gen_rand_values 400 # functions that returns a tuple of random values for each original call. 401 # random_calls tracks calls to random() and random_values_var stores the name of 402 # the variable that stores __gen_rand_values results. 403 self.random_calls: List[ 404 Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]] 405 ] = [] 406 self.random_values_var = None 407 408 # Bytecode to insert right before we call the graph 409 self.pregraph_bytecode: List[Instruction] = [] 410 411 # Use to pass values to backward hooks when using compiled autograd 412 self.backward_state: Dict[str, VariableTracker] = {} 413 self.backward_state_proxy: Optional[torch.fx.Proxy] = None 414 self.backward_state_var: Optional[str] = None 415 416 self.name_of_builtins_dict_key_in_fglobals: str = ( 417 self.install_builtins_dict_in_fglobals() 418 ) 419 420 self.guard_on_key_order: Set[str] = set() 421 422 def install_builtins_dict_in_fglobals(self): 423 # f_globals["__builtins__"] can be a dict or a module. This is an 424 # implemenation detail - 425 # https://docs.python.org/3/library/builtins.html. 426 427 # This makes guarding on any builtin messy because the guard check_fn 428 # has to check if the __builtins__ is a module or dict, and then access 429 # by either using getattr or getitem respectively. 430 431 # To solve this problem, we insert a new entry in f_globals which points 432 # to the builtins __dict__ and then we guard any builtin on this dict. 433 # To avoid any collision with the pre-existing keys, we use the 434 # install_global to give us a unique dict key. 435 436 f_builtins = self.global_scope["__builtins__"] 437 if not isinstance(f_builtins, dict): 438 f_builtins = f_builtins.__dict__ 439 return self.install_global("__builtins_dict__", f_builtins) 440 441 def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"): 442 name = f"{prefix}{len(self.backward_state)}" 443 assert name not in self.backward_state 444 self.backward_state[name] = hook 445 return name, self.get_backward_state_proxy() 446 447 def get_backward_state_proxy(self): 448 if self.backward_state_proxy is None: 449 if self.export: 450 unimplemented("backward_state does not support export") 451 self.backward_state_proxy = self.root_tracer.create_graph_input( 452 "dynamo_backward_state", BackwardState, source=BackwardStateSource() 453 ) 454 self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg() 455 set_example_value(self.backward_state_proxy.node, BackwardState()) 456 self.backward_state_var = self.new_var() 457 return self.backward_state_proxy 458 459 # This gets its own helper function so guards DEBUG logs are more informative 460 def init_ambient_guards(self): 461 # Register a SHAPE_ENV guard to make sure we setup shape guards 462 # that show up in ShapeEnv 463 self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV)) 464 465 self.guards.add( 466 GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS) 467 ) 468 469 self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE)) 470 471 self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE)) 472 473 self.guards.add( 474 GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE) 475 ) 476 477 ci = torch._C._functorch.peek_interpreter_stack() 478 if ci is not None: 479 self.guards.add( 480 GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH) 481 ) 482 483 def synthetic_graph_input(self, fn, args): 484 """ 485 call fn(*args) before the graph runs and turn the result into a fake input. 486 """ 487 example_value = fn(*args) 488 varname = self.new_var() 489 cg = PyCodegen(self.root_tx) 490 cg.add_push_null( 491 lambda: cg.load_import_from( 492 fn.__module__, 493 fn.__name__, 494 ) 495 ) 496 cg.foreach(map(variables.ConstantVariable.create, args)) 497 cg.call_function(len(args), False) 498 cg.store(varname) 499 self.pregraph_bytecode.extend(cg.get_instructions()) 500 source = SyntheticLocalSource(varname) 501 result = VariableBuilder(self.root_tx, source)(example_value) 502 TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source( 503 source 504 ) 505 return result 506 507 def add_cleanup_hook(self, fn: Callable[[], Any]): 508 self.cleanup_hooks.append(fn) 509 510 def call_cleanup_hooks(self): 511 for hook in reversed(self.cleanup_hooks): 512 hook() 513 self.cleanup_hooks.clear() 514 515 @property 516 def root_tracer(self): 517 return self.tracers[0] 518 519 @property 520 def current_tracer(self): 521 return self.tracers[-1] 522 523 def is_root_tracer(self): 524 # Helper to tell if we are inside the higher order operator tracing. 525 return len(self.tracers) == 1 526 527 @property 528 def graph(self): 529 return self.current_tracer.graph 530 531 # TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer. 532 @graph.setter 533 def graph(self, value): 534 self.current_tracer.graph = value 535 536 @property 537 def input_name_to_proxy(self): 538 return self.current_tracer.input_name_to_proxy 539 540 @property 541 def real_value_cache(self): 542 return self.current_tracer.real_value_cache 543 544 # If you are here, and you're looking for create_graph_input, 545 # to avoid ambiguity, please call one of the following: 546 # - self.current_tracer.create_graph_input 547 # - self.root_tracer.create_graph_input 548 # See NOTE [HigherOrderOperator tracing design] for more context. 549 550 def create_proxy(self, *args, **kwargs): 551 return self.current_tracer.create_proxy(*args, **kwargs) 552 553 def create_node(self, *args, **kwargs): 554 return self.current_tracer.create_node(*args, **kwargs) 555 556 def remove_node(self, *args, **kwargs): 557 return self.current_tracer.remove_node(*args, **kwargs) 558 559 @contextlib.contextmanager 560 def subtracer(self, source_target, prior_tracer): 561 new_scope_ctx = enter_new_scope() 562 try: 563 if prior_tracer: 564 # Lineage MUST stay preserved 565 assert prior_tracer.parent is self.current_tracer 566 new_scope_ctx.__enter__() 567 tracer = ( 568 prior_tracer 569 if prior_tracer 570 else SubgraphTracer( 571 self, parent=self.current_tracer, source_target=source_target 572 ) 573 ) 574 self.tracers.append(tracer) 575 yield tracer 576 finally: 577 new_scope_ctx.__exit__(None, None, None) 578 self.tracers.pop() 579 580 @property 581 def output(self): 582 return self 583 584 @property 585 def fake_mode(self): 586 return self.tracing_context.fake_mode 587 588 @property 589 def shape_env(self): 590 return self.tracing_context.fake_mode.shape_env 591 592 @property 593 def guards(self) -> torch._guards.GuardsSet: 594 return self.tracing_context.guards_context.dynamo_guards 595 596 @property 597 def nn_modules(self) -> Dict[str, Any]: 598 return self.tracing_context.module_context.nn_modules 599 600 def save_global_state(self, out=None): 601 """ 602 Saves to out if it is provided. Else saves to the tracing context's global_state. 603 """ 604 global_state = ( 605 out if out is not None else self.tracing_context.global_context.global_state 606 ) 607 608 # TODO - Consider having a torch level API for torch_function_state. As 609 # of now, we create a ref cycle by passing the 610 # output.set_torch_function_state to 611 # output.tracing_context.global_context.global_state. In the interim, 612 # the problem can be solved by manually set 613 # output.tracing_context.global_context.global_state to None at cleanup. 614 global_state["torch_function_enabled"] = ( 615 self.set_torch_function_state, 616 self.torch_function_enabled, 617 ) 618 global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled()) 619 620 global_state["autocast_enabled"] = ( 621 functools.partial(torch.set_autocast_enabled, "cuda"), 622 torch.is_autocast_enabled("cuda"), 623 ) 624 global_state["autocast_cpu_enabled"] = ( 625 functools.partial(torch.set_autocast_enabled, "cpu"), 626 torch.is_autocast_enabled("cpu"), 627 ) 628 global_state["autocast_gpu_dtype"] = ( 629 functools.partial(torch.set_autocast_dtype, "cuda"), 630 torch.get_autocast_dtype("cuda"), 631 ) 632 global_state["autocast_cpu_dtype"] = ( 633 functools.partial(torch.set_autocast_dtype, "cpu"), 634 torch.get_autocast_dtype("cpu"), 635 ) 636 global_state["autocast_cache_enabled"] = ( 637 torch.set_autocast_cache_enabled, 638 torch.is_autocast_cache_enabled(), 639 ) 640 641 def push_tx(self, tx): 642 self._current_tx.append(tx) 643 644 def pop_tx(self): 645 return self._current_tx.pop() 646 647 @property 648 def current_tx(self): 649 return self.root_tx if not self._current_tx else self._current_tx[-1] 650 651 def add_symbol_bindings(self, arg: GraphArg): 652 # Insert implicit size vars as necessary. With dynamic shapes, we 653 # maintain the invariant that every sizevar gets a direct SymInt input 654 # into the graph. This means downstream graph transforms can assume 655 # every size variable is explicitly bound and accessible, instead of 656 # having to pull it out implicitly from tensors. 657 658 if self.export: 659 return 660 661 assert arg.fake_tensor is not None 662 663 def bind_symint(s, prop): 664 if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)): 665 return 666 s0 = s.node.expr 667 if s0 in self.bound_symbols: 668 return 669 self.bound_symbols.add(s0) 670 log.debug("bind_symint %s %s", s, prop.name()) 671 # TODO: don't readd symint if we already have it in graph 672 # (this is harmless because we do remove the unused ones later) 673 proxy = self.root_tracer.create_graph_input( 674 str(s0), 675 torch.SymInt, 676 before=True, 677 source=prop, 678 ) 679 set_example_value(proxy.node, s) 680 proxy.node.meta["grapharg"] = GraphArg( 681 prop, 682 s, 683 pass_arg_as_tensor=False, 684 fake_tensor=None, 685 is_tensor=False, 686 ) 687 688 def handle_tensor(t, src): 689 for i, s in enumerate(t.size()): 690 bind_symint(s, TensorPropertySource(src, TensorProperty.SIZE, i)) 691 if t.layout is torch.strided: 692 for i, s in enumerate(t.stride()): 693 bind_symint(s, TensorPropertySource(src, TensorProperty.STRIDE, i)) 694 bind_symint( 695 t.storage_offset(), 696 TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), 697 ) 698 elif t.layout is torch.sparse_coo: 699 handle_tensor(t._indices(), src) 700 handle_tensor(t._values(), src) 701 elif t.layout in {torch.sparse_csr, torch.sparse_bsr}: 702 handle_tensor(t.crow_indices(), src) 703 handle_tensor(t.col_indices(), src) 704 elif t.layout in {torch.sparse_csc, torch.sparse_bsc}: 705 handle_tensor(t.ccol_indices(), src) 706 handle_tensor(t.row_indices(), src) 707 if is_traceable_wrapper_subclass(t): 708 attrs, ctx = t.__tensor_flatten__() 709 for attr in attrs: 710 inner_t = getattr(t, attr) 711 handle_tensor(inner_t, AttrSource(src, attr)) 712 713 handle_tensor(arg.fake_tensor, arg.source) 714 715 def count_calls(self): 716 return count_calls(self.graph) 717 718 def is_empty_graph(self): 719 return len(list(self.graph.nodes)) == 0 720 721 def get_submodule(self, keys): 722 assert keys 723 obj: Union[torch.nn.Module, Dict[str, torch.nn.Module]] = self.nn_modules 724 for k in keys.split("."): 725 if isinstance(obj, dict): 726 obj = obj[k] 727 else: 728 obj = getattr(obj, k) 729 return obj 730 731 def new_var(self, name="tmp"): 732 existing = set(self.code_options["co_varnames"]) 733 # In common case, this will be O(1) 734 while True: 735 var = f"{name}_{next(self.unique_var_id)}" 736 if var not in existing: 737 self.code_options["co_varnames"] += (var,) 738 return var 739 740 def update_co_names(self, name): 741 """Ensure self.code_options.co_names contains name""" 742 if name not in self.code_options["co_names"]: 743 self.code_options["co_names"] += (name,) 744 745 @staticmethod 746 def module_key_name(*names): 747 # create a new unique name 748 name = "_".join(map(str, names)) 749 # Strip the guard lookup L/G access 750 name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name) 751 # e.g. replace abc.xyz[123].qkv with abc.xyz_123.qkv 752 name = re.sub(r"\[(\d+)\]", r"_\g<1>", name) 753 # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv 754 name = re.sub(r"[^a-zA-Z0-9]", "_", name) 755 756 if not name or not name[0].isalpha(): 757 name = "sub" + name 758 759 return name 760 761 def register_attr_or_module( 762 self, 763 target: Union[torch.nn.Module, torch.Tensor, Any], 764 *names, 765 **options, 766 ): 767 if is_dynamic_nn_module(target, self.root_tx.export): 768 # Instead of returning UnspecializedNNModuleVariable, call 769 # VariableBuilder so that it is tracked for mutation. 770 return VariableBuilder(self.current_tx, **options)(target) 771 772 options = dict(options) 773 assert "source" in options 774 source = options["source"] 775 assert not isinstance(source, ParamBufferSource) 776 777 if isinstance(target, torch.Tensor): 778 tracer = self.current_tracer 779 if not self.is_root_tracer(): 780 # For higher order ops, we don't want to insert the get_attr in 781 # innermost graph. Instead, we want to raise the params/buffers 782 # as inputs to the higher-order graph, and register them as 783 # get_attrs in the root tracer. 784 785 # Note that Dynamo will still call lift_tracked_freevar_to_input 786 # when these inputs are encountered for the inner graph. The 787 # only difference is what happens at the root tracer for 788 # nn.Parameters vs free inputs. The free inputs are registered 789 # as placeholders in the root graph, whereas the nn.Parameters 790 # are registered as get_attr nodes in the root graph. 791 tracer = self.root_tracer 792 793 def wrap_name(module_key): 794 assert self.param_name_to_source is not None 795 self.param_name_to_source[module_key] = source 796 797 # Check if the attr has already been registered. This can happen 798 # when two different sources point to the same tensor. 799 if target in self.root_tx.output.side_effects: 800 return self.root_tx.output.side_effects[target] 801 802 if get_static_address_type(target) == "guarded": 803 install_guard(source.make_guard(GuardBuilder.ID_MATCH)) 804 elif not is_constant_source(source): 805 install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH)) 806 807 vt = wrap_fx_proxy( 808 self.root_tx, 809 tracer.create_proxy("get_attr", module_key, (), {}), 810 example_value=target, 811 **options, 812 ) 813 814 # Track the object so to avoid duplicate registration in case of 815 # different sources pointing to the same tensor object. 816 vt = self.root_tx.output.side_effects.track_object_existing(target, vt) 817 818 assert "tensor_dict" not in vt.proxy.node.meta 819 vt.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(target) 820 821 return vt 822 823 elif isinstance(target, torch.nn.Module): 824 assert isinstance(target, torch.nn.Module) 825 826 if source: 827 install_guard(source.make_guard(GuardBuilder.NN_MODULE)) 828 829 def wrap_name(module_key): 830 return NNModuleVariable(type(target), module_key, target, **options) 831 832 else: 833 # This is Dynamo created graph module, e.g., graph module coming 834 # from higher order ops. NNModuleVariable tracker can't be 835 # sourceless, so let's return a unspecializedNNModule variable 836 # tracker. 837 def wrap_name(module_key): 838 return variables.UnspecializedNNModuleVariable(target, **options) 839 840 elif isinstance(target, (torch.SymInt, torch.SymFloat)): 841 # HACKY CODE REGION BEGIN 842 # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS 843 # This ultimately gets written to self.nn_modules, which is unfortunate 844 # Attrs that are tenors and symints and such need to be migrated to have their 845 # own storage 846 # alas, this is like this for now 847 848 def wrap_name(module_key): 849 return SymNodeVariable.create( 850 self, 851 self.create_proxy("get_attr", module_key, (), {}), 852 sym_num=target, 853 **options, 854 ) 855 856 # HACKY CODE REGION END 857 else: 858 859 def wrap_name(module_key): 860 self.output.update_co_names(module_key) 861 self.global_scope[module_key] = target 862 return VariableBuilder(self, ConstantSource(source_name=module_key))( 863 target 864 ) 865 866 for k, v in self.nn_modules.items(): 867 if v is target: 868 # it already exists 869 return wrap_name(k) 870 871 name = OutputGraph.module_key_name(*names) 872 873 base = name 874 for i in itertools.count(): 875 if name not in self.nn_modules: 876 self.nn_modules[name] = target 877 if isinstance(target, torch.nn.Module): 878 879 def register_leaf_name(leaf_name): 880 assert self.param_name_to_source is not None 881 new_source = ParamBufferSource(source, leaf_name) 882 new_name = f"{name}.{leaf_name}" 883 self.param_name_to_source[new_name] = new_source 884 if isinstance(source, LocalSource): 885 self.dynamo_flat_name_to_original_fqn[ 886 OutputGraph.module_key_name(new_source.name()) 887 ] = leaf_name 888 889 # annoying, but there are cases when we do not have parameters 890 # see test_nn_moduledict_contains 891 if hasattr(target, "_parameters"): 892 for leaf_name, _ in target.named_parameters(): 893 register_leaf_name(leaf_name) 894 if hasattr(target, "_buffers"): 895 for leaf_name, _ in target.named_buffers(): 896 register_leaf_name(leaf_name) 897 898 return wrap_name(name) 899 name = f"{base}_{i}" 900 901 raise AssertionError("unreachable") 902 903 def handle_aliases_for_stolen_lists(self, tx): 904 # If list inputs are stolen, but still needed after the function call, create aliases to keep them alive 905 maybe_gm = self.local_scope.get("self") 906 stolen_list_names = get_locals_to_steal(maybe_gm) 907 if not stolen_list_names: 908 return [] 909 910 alias_insts = [] 911 needs_alias: Dict[ 912 str, List[Union[VariableTracker, AttributeMutationExisting]] 913 ] = {} 914 915 queue = [ 916 *tx.stack, 917 *tx.symbolic_locals.values(), 918 *self.side_effects.store_attr_mutations.keys(), 919 ] 920 921 while queue: 922 x = queue.pop() 923 if isinstance(x, BaseListVariable): 924 assert isinstance(x.items, List) 925 queue += x.items 926 continue 927 928 if not ( 929 isinstance(x, (VariableTracker, AttributeMutationExisting)) 930 and isinstance(x.source, GetItemSource) 931 and isinstance(x.source.base, LocalSource) 932 and x.source.base.local_name in stolen_list_names 933 ): 934 continue 935 936 stolen_name = x.source.base.local_name 937 if stolen_name not in needs_alias: 938 needs_alias[stolen_name] = [] 939 needs_alias[stolen_name].append(x) 940 941 visited = {} 942 for arg in self.graphargs: 943 if not ( 944 isinstance(arg._example, list) 945 and isinstance(arg.source, LocalSource) 946 and arg.source.local_name in needs_alias 947 ): 948 continue 949 950 # arg is a list that will be cleared by the compiled function 951 list_name = arg.source.local_name 952 assert list_name in self.code_options["co_varnames"] 953 for x in needs_alias[list_name]: 954 list_idx = x.source.index 955 if list_idx not in visited: 956 alias_name = self.new_var( 957 f"{list_name}_ref" 958 ) # self.new_var already adds unique id suffix 959 960 visited[list_idx] = alias_name 961 # bytecode of `alias_name = list_name[list_idx]` 962 alias_insts.extend( 963 [ 964 create_instruction("LOAD_FAST", argval=list_name), 965 create_instruction("LOAD_CONST", argval=list_idx), 966 create_instruction("BINARY_SUBSCR"), 967 create_instruction("STORE_FAST", argval=alias_name), 968 ] 969 ) 970 971 # operate on alias, handled by suffix codegen 972 x.source = LocalSource(visited[list_idx]) 973 974 return alias_insts 975 976 def compile_subgraph( 977 self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None 978 ): 979 """ 980 Generate a subgraph to continue execution on user code. 981 Automatically restore live variables. 982 """ 983 assert reason is not None 984 985 from .decorators import disable 986 987 self.partial_convert = partial_convert 988 self.compile_subgraph_reason = reason 989 self.should_exit = True 990 991 log.debug("COMPILING GRAPH due to %s", reason) 992 993 if not all(block.can_restore() for block in tx.block_stack): 994 unimplemented("compile_subgraph with block_depth != 0") 995 996 prefix_insts: List[Instruction] = [] 997 if sys.version_info >= (3, 11): 998 # prefix instructions (Python 3.11+) 999 for inst in tx.prefix_insts: 1000 if inst.opname == "MAKE_CELL": 1001 prefix_insts.append( 1002 create_instruction("MAKE_CELL", argval=inst.argval) 1003 ) 1004 elif inst.opname == "COPY_FREE_VARS": 1005 prefix_insts.append( 1006 create_instruction( 1007 "COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"]) 1008 ) 1009 ) 1010 else: 1011 prefix_insts.append(copy.copy(inst)) 1012 assert not ( 1013 self.pregraph_bytecode and self.export 1014 ), "export does not support pregraph_bytecode" 1015 prefix_insts.extend(self.pregraph_bytecode) 1016 prefix_insts.extend(self.handle_aliases_for_stolen_lists(tx)) 1017 1018 def append_prefix_insts(): 1019 self.add_output_instructions(prefix_insts) 1020 prefix_insts.clear() 1021 1022 for block in reversed(tx.block_stack): 1023 block.exit(tx) 1024 1025 self.cleanup_graph() 1026 tx.prune_dead_locals() 1027 stack_values = list(tx.stack) 1028 1029 # realize any unrealized tensor VTs in case they 1030 # need to be added to self.nn_modules as attributes 1031 for value in stack_values: 1032 value.realize() 1033 1034 # Use nn.Module "proxies" in the constructed GraphModule so that 1035 # the resulting GM does not hold additional strong references to the original modules. 1036 # This prevents a strong ref cycle where Dynamo created code holds on to references 1037 # to modules that also have Dynamo code cache invalidation checks. 1038 # When cache invalidation runs, the generated GM will be invalidated, which also deletes 1039 # the proxies. 1040 nn_modules_proxies = { 1041 name: nn_module_proxy(mod) for name, mod in self.nn_modules.items() 1042 } 1043 root = FakeRootModule(nn_modules_proxies) 1044 # Add all the local vars to the "stack" so restore at the end 1045 restore_vars = [] 1046 val_to_names: Dict[VariableTracker, List[str]] = {} 1047 if stack_values: 1048 val_to_names[stack_values[-1]] = [] 1049 # NB: Typically (i.e., for graph compile from RETURN_VALUE), 1050 # symbolic_locals will be empty at this point, as prune_dead_locals 1051 # will clear out all of symbolic_locals because RETURN_VALUE is the 1052 # last instruction and no more locals are used. The fanciness here 1053 # is only needed for partial graphs. 1054 for k, v in tx.symbolic_locals.items(): 1055 # Note! this explicitly uses .local_name for matching 1056 # Failure to do so will cause spurious registrations in val_to_names. 1057 # This will in turn result in spurious variables showing up in the graph. 1058 # This was very tricky to debug. For an example, dump the graph at call_user_compiler 1059 # while running test_subgraphs.py 1060 if isinstance(v.source, LocalSource) and v.source.local_name == k: 1061 continue # no need to restore initial state 1062 # Do not load variable if it is NULL. 1063 if sys.version_info >= (3, 12): 1064 # Continuation function will load the NULL for v. 1065 if type.__instancecheck__(NullVariable, v): 1066 continue 1067 else: 1068 # A variable should never be NULL in < 3.12 1069 assert not type.__instancecheck__(NullVariable, v) 1070 if v not in val_to_names: 1071 val_to_names[v] = [] 1072 val_to_names[v].append(k) 1073 for v in val_to_names.keys(): 1074 restore_vars.extend(val_to_names[v]) 1075 stack_values.extend([v] * len(val_to_names[v])) 1076 1077 # to handle random calls 1078 if len(self.random_calls) > 0: 1079 append_prefix_insts() 1080 random_calls_instructions = [] 1081 self.random_values_var = self.new_var("random_values") 1082 rand_fn = disable(_get_gen_rand_values_fn(self.random_calls)) 1083 rand_fn_name = self.install_global("__gen_rand_values", rand_fn) 1084 codegen = PyCodegen(tx, root) 1085 random_calls_instructions.extend( 1086 codegen.load_function_name(rand_fn_name, True) 1087 ) 1088 random_calls_instructions.extend(create_call_function(0, False)) 1089 random_calls_instructions.append( 1090 codegen.create_store(tx.output.random_values_var), 1091 ) 1092 self.add_output_instructions(random_calls_instructions) 1093 1094 if ( 1095 stack_values 1096 and all( 1097 not isinstance( 1098 v, 1099 ( 1100 UnspecializedPythonVariable, 1101 NumpyNdarrayVariable, 1102 TensorWithTFOverrideVariable, 1103 ), 1104 ) 1105 and not (isinstance(v, SymNodeVariable) and v.python_type() is float) 1106 for v in stack_values 1107 ) 1108 and all(isinstance(x, TensorVariable) for x in stack_values) 1109 and len(set(stack_values)) == len(stack_values) 1110 and self.side_effects.is_empty() 1111 and not len(tx.debug_locals) != 0 1112 and not self.backward_state 1113 ): 1114 append_prefix_insts() 1115 # optimization to generate better code in a common case 1116 self.add_output_instructions( 1117 self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root) 1118 + [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))] 1119 ) 1120 # restore all the live local vars 1121 self.add_output_instructions( 1122 [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)] 1123 ) 1124 else: 1125 graph_output_var = self.new_var("graph_out") 1126 pass1 = PyCodegen(tx, root, graph_output_var) 1127 self.codegen_suffix(tx, stack_values, pass1) 1128 1129 # one more time now that we have established tempvars 1130 pass2 = PyCodegen( 1131 tx, 1132 root, 1133 graph_output_var, 1134 tempvars={val: None for val, count in pass1.uses.items() if count > 1}, 1135 ) 1136 self.codegen_suffix(tx, stack_values, pass2) 1137 1138 stored_graph_output_var = False 1139 output = [] 1140 if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0: 1141 output.extend( 1142 self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root) 1143 ) 1144 1145 if len(pass2.graph_outputs) != 0: 1146 output.append(pass2.create_store(graph_output_var)) 1147 stored_graph_output_var = True 1148 else: 1149 output.append(create_instruction("POP_TOP")) 1150 else: 1151 # NB: Important to run compiler collective even when there is 1152 # a graph break 1153 self.run_compiler_collective(tx) 1154 append_prefix_insts() 1155 self.add_output_instructions(output + pass2.get_instructions()) 1156 1157 # restore all the live local vars 1158 self.add_output_instructions( 1159 [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)] 1160 ) 1161 1162 if stored_graph_output_var: 1163 self.add_output_instructions( 1164 [PyCodegen(tx).create_delete(graph_output_var)] 1165 ) 1166 1167 def codegen_suffix(self, tx, stack_values, cg): 1168 if self.backward_state: 1169 assert not self.export 1170 for name, val in self.backward_state.items(): 1171 cg(val) 1172 cg.append_output(cg.create_load(self.backward_state_var)) 1173 cg.store_attr(name) 1174 self.side_effects.codegen_hooks(cg) 1175 self.side_effects.codegen_save_tempvars(cg) 1176 1177 # Return variables used for logging at the end 1178 for debug_var, args in tx.debug_locals: 1179 cg.add_push_null(lambda: cg(debug_var)) 1180 for arg in args: 1181 cg(arg) 1182 cg.extend_output(create_call_function(len(args), False)) 1183 cg.extend_output([create_instruction("POP_TOP")]) 1184 1185 cg.restore_stack(stack_values, value_from_source=not tx.export) 1186 self.side_effects.codegen_update_mutated(cg) 1187 1188 def cleanup_graph(self): 1189 """ 1190 Remove "creation_timestamp" from node meta 1191 1192 Remove this pattern from the graph: 1193 torch._C._set_grad_enabled(False) 1194 torch._C._set_grad_enabled(True) 1195 """ 1196 assert self.should_exit 1197 nodes = list(self.graph.nodes) 1198 for node in nodes: 1199 node.meta.pop("creation_timestamp", None) 1200 1201 grad_enabled = torch.is_grad_enabled() 1202 for node1, node2 in zip(nodes, nodes[1:]): 1203 if ( 1204 node1.target is torch._C._set_grad_enabled 1205 and tuple(node1.args) == (not grad_enabled,) 1206 and not node1._erased 1207 ): 1208 grad_enabled = node1.args[0] 1209 if ( 1210 node2.target is torch._C._set_grad_enabled 1211 and tuple(node2.args) == (not grad_enabled,) 1212 and not node2._erased 1213 ): 1214 grad_enabled = node2.args[0] 1215 self.graph.erase_node(node1) 1216 self.graph.erase_node(node2) 1217 1218 def get_graph_sizes_structured(self): 1219 ret = {} 1220 for node in self.graph.nodes: 1221 example_value = node.meta.get("example_value", None) 1222 if isinstance(example_value, torch._subclasses.FakeTensor): 1223 size = example_value.size() 1224 ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size] 1225 return ret 1226 1227 def get_graph_sizes(self, name: str): 1228 graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n" 1229 graph_sizes_str += f"===== {name} =====\n" 1230 for node in self.graph.nodes: 1231 example_value = node.meta.get("example_value", None) 1232 if isinstance(example_value, torch._subclasses.FakeTensor): 1233 size = example_value.size() 1234 graph_sizes_str += f"{node.name}: {tuple(size)}\n" 1235 concrete_size = [] 1236 has_symint = False 1237 for sz in size: 1238 if isinstance(sz, int): 1239 concrete_size.append(sz) 1240 elif isinstance(sz, torch.SymInt): 1241 has_symint = True 1242 concrete_size.append(sz.node.hint) 1243 else: 1244 break 1245 else: 1246 if has_symint: 1247 graph_sizes_str += ( 1248 f"{node.name} (concrete): {tuple(concrete_size)}\n" 1249 ) 1250 return graph_sizes_str 1251 1252 @contextlib.contextmanager 1253 def restore_global_state(self): 1254 """ 1255 Momentarily restores the global state to what it was prior to tracing the current output 1256 """ 1257 prior_global_state = self.tracing_context.global_context.copy_graphstate() 1258 current_global_state: Dict[str, Tuple[Any, bool]] = {} 1259 self.save_global_state(out=current_global_state) 1260 try: 1261 # Set to state prior to tracing the graph 1262 self.tracing_context.global_context.restore_graphstate(prior_global_state) 1263 yield 1264 finally: 1265 # Reset to state at the current time (e.g. before calling the user compiler) 1266 self.tracing_context.global_context.restore_graphstate( 1267 GlobalContextCheckpointState(current_global_state) 1268 ) 1269 1270 def run_compiler_collective(self, tx): 1271 if (ds := tx.distributed_state) is not None and ds.all_states is None: 1272 compile_pg = ds.compile_pg 1273 log.info("compiler_collective %s", ds.local_state) 1274 torch._logging.trace_structured( 1275 "artifact", 1276 metadata_fn=lambda: { 1277 "name": "compiler_collective", 1278 "encoding": "json", 1279 }, 1280 payload_fn=lambda: json.dumps( 1281 dataclasses.asdict(ds.local_state), 1282 ), 1283 ) 1284 with torch.cuda.device(compile_pg.rank() % torch.cuda.device_count()): 1285 all_states = [None] * compile_pg.size() 1286 dist.all_gather_object(all_states, ds.local_state, group=compile_pg) 1287 ds.all_states = all_states 1288 # Clear speculation log, because are tracing may diverge due to 1289 # this information from the compiler collective 1290 tx.speculation_log.clear() 1291 raise exc.CompileCollectiveRestartAnalysis 1292 1293 def compile_and_call_fx_graph(self, tx, rv, root): 1294 """ 1295 Generate code from self.graph and return the Instruction()s to 1296 call that generated code. 1297 """ 1298 with torch._guards.TracingContext.clear_frame(): 1299 from .decorators import disable 1300 1301 assert self.should_exit 1302 1303 self.run_compiler_collective(tx) 1304 1305 name = unique_id("__compiled_fn") 1306 1307 assert isinstance(rv, list) 1308 assert isinstance(root, FakeRootModule) 1309 output_node = self.create_node( 1310 "output", 1311 "output", 1312 (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),), 1313 {}, 1314 ) 1315 tx.output.current_tracer._maybe_preserve_original_meta(tx, output_node) 1316 if not config.do_not_emit_runtime_asserts: 1317 insert_deferred_runtime_asserts( 1318 fx.GraphModule(root, self.graph), 1319 self.shape_env, 1320 name, 1321 ) 1322 # NB: deferred runtime asserts can keep graphargs live, so make sure 1323 # those are inserted before pruning 1324 self.remove_unused_graphargs() 1325 ncalls = count_calls(self.graph) 1326 counters["stats"]["calls_captured"] += ncalls 1327 1328 # free a bit of memory 1329 self.real_value_cache.clear() 1330 1331 gm = _make_graph_module(root, self.graph) 1332 for register_finalizer in self.register_finalizer_fns: 1333 register_finalizer(gm) 1334 1335 gm.compile_subgraph_reason = self.compile_subgraph_reason 1336 gm.meta[ 1337 "dynamo_flat_name_to_original_fqn" 1338 ] = self.dynamo_flat_name_to_original_fqn.copy() 1339 1340 graph_code_log.debug( 1341 "%s", 1342 lazy_format_graph_code( 1343 name, gm, include_stride=True, include_device=True, colored=True 1344 ), 1345 ) 1346 torch._logging.trace_structured( 1347 "dynamo_output_graph", 1348 lambda: {"sizes": self.get_graph_sizes_structured()}, 1349 payload_fn=lambda: gm.print_readable( 1350 print_output=False, include_stride=True, include_device=True 1351 ), 1352 ) 1353 self.call_cleanup_hooks() 1354 old_fake_mode = self.tracing_context.fake_mode 1355 if not self.export: 1356 import torch._functorch.config as _config 1357 1358 with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): 1359 # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting 1360 backend_fake_mode = torch._subclasses.FakeTensorMode( 1361 shape_env=old_fake_mode.shape_env, 1362 ) 1363 # TODO(voz): Ostensibily, this should be scoped and 1364 # restore back to old_fake_mode, but doing so currently violates 1365 # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode 1366 self.tracing_context.fake_mode = backend_fake_mode 1367 1368 with self.restore_global_state(): 1369 compiled_fn = self.call_user_compiler(gm) 1370 1371 from torch.fx._lazy_graph_module import _LazyGraphModule 1372 1373 if isinstance(compiled_fn, _LazyGraphModule) or ( 1374 isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule) 1375 and compiled_fn.__name__ == "_lazy_forward" # type: ignore[attr-defined] 1376 ): 1377 # Since dynamo will run the forward method for the GraphModule shortly 1378 # anyways, it does not hurt to do the real recompilation here if 1379 # this is a _LazyGraphModule. This makes it easier for dynamo to 1380 # optimize a _LazyGraphModule. 1381 1382 lazy_gm = ( 1383 compiled_fn 1384 if isinstance(compiled_fn, _LazyGraphModule) 1385 else compiled_fn.__self__ # type: ignore[attr-defined] 1386 ) 1387 1388 _LazyGraphModule.force_recompile(lazy_gm) 1389 1390 if not isinstance(compiled_fn, _LazyGraphModule): 1391 # replace compiled_fn with the real forward method 1392 compiled_fn = lazy_gm.forward 1393 1394 compiled_fn = disable(compiled_fn) 1395 1396 counters["stats"]["unique_graphs"] += 1 1397 # This is safe because we pre-process name to be unique 1398 self.install_global_unsafe(name, compiled_fn) 1399 1400 cg = PyCodegen(tx) 1401 cg.make_call_generated_code(name) 1402 return cg.get_instructions() 1403 1404 @property 1405 def placeholders(self) -> List[fx.Node]: 1406 return self.graph.find_nodes(op="placeholder") 1407 1408 @property 1409 def graphargs(self) -> List[GraphArg]: 1410 return [node.meta["grapharg"] for node in self.placeholders] 1411 1412 def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: 1413 with dynamo_timed( 1414 "OutputGraph.call_user_compiler", phase_name="backend_compile" 1415 ): 1416 return self._call_user_compiler(gm) 1417 1418 def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: 1419 assert self.compiler_fn is not None 1420 tot = 0 1421 placeholders = [] 1422 for node in gm.graph.nodes: 1423 if node.op in ("call_function", "call_method", "call_module"): 1424 tot += 1 1425 if node.op == "placeholder": 1426 placeholders.append(node) 1427 increment_op_count(tot) 1428 for pl in placeholders: 1429 arg = pl.meta["grapharg"] 1430 # TODO: Why isn't this stored in meta :think: 1431 pl._dynamo_source = arg.source 1432 1433 gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment] 1434 gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment] 1435 1436 try: 1437 name = ( 1438 self.compiler_fn.__name__ 1439 if hasattr(self.compiler_fn, "__name__") 1440 else "" 1441 ) 1442 _step_logger()(logging.INFO, f"calling compiler function {name}") 1443 compiler_fn = self.compiler_fn 1444 if config.verify_correctness: 1445 compiler_fn = WrapperBackend(compiler_fn) 1446 compiled_fn = compiler_fn(gm, self.example_inputs()) 1447 _step_logger()(logging.INFO, f"done compiler function {name}") 1448 assert callable(compiled_fn), "compiler_fn did not return callable" 1449 except exceptions_allowed_to_be_fallback as e: 1450 if self.has_user_defined_allowed_in_graph: 1451 raise BackendCompilerFailed(self.compiler_fn, e).with_traceback( 1452 e.__traceback__ 1453 ) from None 1454 msg = ( 1455 "Backend compiler failed with a fake tensor exception at \n" 1456 f"{self.root_tx.format_frame_summary()}" 1457 "Adding a graph break." 1458 ) 1459 unimplemented_with_warning(e, self.root_tx.f_code, msg) 1460 except SkipFrame as e: 1461 # The backend compiler has requested that we skip the frame, instead of 1462 # aborting execution. 1463 raise e 1464 except Exception as e: 1465 raise BackendCompilerFailed(self.compiler_fn, e) from e 1466 1467 signpost_event( 1468 "dynamo", 1469 "OutputGraph.call_user_compiler", 1470 { 1471 **self.co_fields, 1472 "op_count": tot, 1473 "node_count": len(gm.graph.nodes), 1474 "input_count": len(placeholders), 1475 }, 1476 ) 1477 1478 return compiled_fn 1479 1480 def example_inputs(self) -> List[torch.Tensor]: 1481 result = [] 1482 for arg in self.graphargs: 1483 result.append(arg.example) 1484 return result 1485 1486 def remove_unused_graphargs(self) -> None: 1487 # NB: It's always OK to drop GraphArg for symbols that ended up being 1488 # specialized. You don't even have to make a guard for it, because 1489 # ShapeEnv produce_guards operates on tracked_fakes, which never gets 1490 # pruned. That being said, you'll get marginally better generated 1491 # guard code if you promote the guard into a Dynamo guard (since that 1492 # allows for the guard to be done using C++ guards.) If we get 1493 # ShapeEnv guards to go into C++ guards, this will stop being a thing 1494 # though! 1495 1496 assert self.should_exit 1497 1498 # Miniature DCE pass, but only for obviously trivial operations 1499 def is_static_true(b_node: fx.node.Argument): 1500 if b_node is True: 1501 return True 1502 if not isinstance(b_node, fx.Node): 1503 return False 1504 b = b_node.meta.get("example_value") 1505 if b is None: 1506 return False 1507 if b is True: 1508 return True 1509 if ( 1510 isinstance(b, torch.SymBool) 1511 and (r := b.node.maybe_as_bool()) is not None 1512 ): 1513 return r 1514 # TODO: We can also technically remove all cases when the input 1515 # doesn't have unbacked inputs, since it's all in the ShapeEnv 1516 return False 1517 1518 def is_symnode_arg(a: fx.node.Argument): 1519 from torch.fx.experimental.sym_node import SymTypes 1520 1521 if isinstance(a, (int, float, bool)): 1522 return True 1523 if isinstance(a, fx.Node): 1524 return isinstance(a.meta.get("example_value"), SymTypes) 1525 return False 1526 1527 # NB: We assume that you cannot do mutations on int/float/bool, 1528 # because they are immutable types, and therefore is always safe to 1529 # DCE. 1530 def is_symnode_compute_node(node): 1531 from torch.fx.experimental.sym_node import SymTypes 1532 1533 if node.op != "call_function": 1534 return False 1535 # TODO: I don't think it's possible to have a bare int/float here? 1536 if not isinstance(node.meta.get("example_value"), SymTypes): 1537 return False 1538 # TODO: This will bail here if you ever end up with a more complicated 1539 # computation function, like sum(list_of_ints), even though it 1540 # should be DCE'able 1541 if not all(is_symnode_arg(a) for a in node.args): 1542 return False 1543 if not all(is_symnode_arg(a) for a in node.kwargs.values()): 1544 return False 1545 return True 1546 1547 from torch.fx.experimental.symbolic_shapes import is_accessor_node 1548 1549 for node in reversed(list(self.graph.nodes)): 1550 if len(list(node.users)) == 0: 1551 if ( 1552 node.op == "get_attr" 1553 or (node.op == "call_function" and node.target is operator.getitem) 1554 or ( 1555 node.op == "call_function" 1556 and node.target is torch._check 1557 and is_static_true(node.args[0]) 1558 ) 1559 or is_symnode_compute_node(node) 1560 or is_accessor_node(node) 1561 ): 1562 self.remove_node(node) 1563 1564 def placeholder_binds_symbol(node): 1565 arg = node.meta["grapharg"] 1566 example = arg.example 1567 if isinstance(example, torch.SymInt) and isinstance( 1568 example.node.expr, sympy.Symbol 1569 ): 1570 return example.node.expr 1571 return None 1572 1573 def remove_unused(node): 1574 log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name()) 1575 # I'm not really sure why you need to delete these from the 1576 # node since the node is going to get removed 1577 del node.meta["grapharg"] 1578 self.remove_node(node) 1579 self.real_value_cache.pop(node, None) 1580 1581 used_symbols: Set[sympy.Symbol] = set() 1582 1583 def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]): 1584 used_symbols |= free_symbols(fake) 1585 1586 recheck_placeholders = [] 1587 for node in self.placeholders: 1588 binds_symbol = placeholder_binds_symbol(node) is not None 1589 # Don't delete symbol bindings yet 1590 if binds_symbol: 1591 if not node.users: 1592 recheck_placeholders.append(node) 1593 else: 1594 if not node.users and not isinstance( 1595 node.meta["grapharg"], BackwardStateGraphArg 1596 ): 1597 remove_unused(node) 1598 else: 1599 # Register the free symbols as uses 1600 arg = node.meta["grapharg"] 1601 if isinstance(arg, BackwardStateGraphArg): 1602 continue 1603 if isinstance(node.meta["grapharg"].example, torch.ScriptObject): 1604 real_script_obj = node.meta["grapharg"].example 1605 fake_script_obj = node.meta["grapharg"].example_strong_ref 1606 if not torch._library.fake_class_registry.tracing_with_real( 1607 real_script_obj 1608 ): 1609 flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined] 1610 for attr in flat_dict.keys(): 1611 fake_attr_val = getattr( 1612 fake_script_obj.wrapped_obj, attr 1613 ) 1614 pytree.tree_map_only( 1615 (torch.SymInt, torch.Tensor), 1616 lambda t: update_used_symbols(used_symbols, t), 1617 fake_attr_val, 1618 ) 1619 continue 1620 fake = ( 1621 arg.fake_tensor if arg.fake_tensor is not None else arg.example 1622 ) 1623 update_used_symbols(used_symbols, fake) 1624 1625 # After removing unused graphargs, prune unused binds_symbol 1626 for node in recheck_placeholders: 1627 symbol = placeholder_binds_symbol(node) 1628 if symbol is not None: 1629 if symbol not in used_symbols: 1630 remove_unused(node) 1631 else: 1632 # Make sure we delete later occurrences of the same symbol 1633 used_symbols.remove(symbol) 1634 1635 def add_output_instructions(self, prefix: List[Instruction]) -> None: 1636 """ 1637 We call this on the creation of a new compiled subgraph that is inserted 1638 before user code. 1639 """ 1640 self.output_instructions.extend(prefix) 1641 self.should_exit = True 1642 1643 def install_global_unsafe(self, name, value) -> None: 1644 """ 1645 WARNING: prefer the safer `install_global_by_id/install_global`. 1646 torch.compile instances should be independent of each other; 1647 one footgun is to have one instance depend on the existence of 1648 a global installed by another instance. This can happen if we mangle 1649 a global the same way across both instances. 1650 """ 1651 assert name not in self.installed_globals 1652 self.installed_globals.add(name) 1653 self.cleanups.append(CleanupHook.create(self.global_scope, name, value)) 1654 1655 def install_global_by_id(self, prefix, value) -> str: 1656 """ 1657 Installs a global if it hasn't been installed already. 1658 This is determined by (prefix, id(value)) pair. 1659 1660 Returns the name of the newly installed global. 1661 """ 1662 # NB: need self.compile_id to distinguish this global 1663 # from another global created in a different torch.compile instance 1664 name = f"{prefix}_{id(value)}_c{self.compile_id}" 1665 if name in self.installed_globals: 1666 return name 1667 self.install_global_unsafe(name, value) 1668 return name 1669 1670 def install_global(self, prefix, value) -> str: 1671 """ 1672 Installs a global, generating a unique name for it. 1673 1674 Returns the name of the newly installed global. 1675 """ 1676 # NB: unique_id is unique, even across torch.compile instances 1677 name = unique_id(prefix) 1678 self.install_global_unsafe(name, value) 1679 return name 1680 1681 def cleanup(self) -> None: 1682 # There is a reference cycle between tracer and OutputGraph, causing 1683 # some of the tensor objects to be held alive for longer than necessary. 1684 self.root_tx = None 1685 self.nn_modules.clear() 1686 self.param_name_to_source = None 1687 1688 for node in self.graph.nodes: 1689 if "grapharg" in node.meta: 1690 del node.meta["grapharg"] 1691 self.real_value_cache.clear() 1692 self.input_name_to_proxy.clear() 1693 self.side_effects.clear() 1694 self.variable_tracker_cache.clear() 1695 self.register_finalizer_fns.clear() 1696 self.dynamo_flat_name_to_original_fqn.clear() 1697 self.tracing_context.clear() 1698 1699 def set_torch_function_state(self, enabled: bool) -> None: 1700 self.torch_function_enabled = enabled 1701 1702 def add_graph_finalizer( 1703 self, register_finalizer: Callable[[fx.GraphModule], None] 1704 ) -> None: 1705 self.register_finalizer_fns.append(register_finalizer) 1706 1707 def example_value_from_input_node(self, node: torch.fx.Node): 1708 """Extract the non-fake example tensor""" 1709 if node.op == "placeholder": 1710 return node.meta["grapharg"].example 1711 assert node.op == "get_attr" 1712 return self.nn_modules[node.target] # type: ignore[index] 1713 1714 1715err_epilogue = ( 1716 "With the current config, we will graph break " 1717 "(and fall back to eager-mode PyTorch) on all ops " 1718 "that have do not have the 'pt2_compliant_tag'. " 1719 "Please see the following doc for how to mark this op as PT2 compliant " 1720 "https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html" 1721) 1722 1723 1724def check_pt2_compliant_op(output_graph, kind, target, args, kwargs): 1725 if kind != "call_function": 1726 return 1727 1728 def encountered_compliant_op(target): 1729 if target.namespace in {"prim", "prims", "aten"}: 1730 return 1731 output_graph.compliant_custom_ops.add(target) 1732 1733 def encountered_non_compliant_op(target, msg): 1734 output_graph.non_compliant_ops.add(target) 1735 if config.only_allow_pt2_compliant_ops: 1736 unimplemented(msg + " " + err_epilogue) 1737 1738 if isinstance(target, torch._ops.OpOverload): 1739 if torch.Tag.pt2_compliant_tag in target.tags: 1740 encountered_compliant_op(target) 1741 return 1742 encountered_non_compliant_op( 1743 target, 1744 f"Encountered the torch.ops.OpOverload {target} " 1745 f"that is not PT2 compliant.", 1746 ) 1747 return 1748 1749 if isinstance(target, torch._ops.OpOverloadPacket): 1750 overloads = tuple(target.overloads()) 1751 # Optimization: Overload resolution is expensive. 1752 # If there's only one overload, we know what it will resolve to. 1753 if len(overloads) == 1: 1754 op = getattr(target, overloads[0]) 1755 if torch.Tag.pt2_compliant_tag in op.tags: 1756 encountered_compliant_op(op) 1757 return 1758 encountered_non_compliant_op( 1759 op, 1760 f"Encountered the non-overloaded " 1761 f"torch.ops.OpOverloadPacket {target} " 1762 f"that is not PT2 compliant. ", 1763 ) 1764 return 1765 1766 args, kwargs = torch._dynamo.utils.get_fake_values_from_nodes( 1767 output_graph.current_tx, (args, kwargs), False 1768 ) 1769 try: 1770 overload = torch._C._jit_resolve_packet( 1771 target._qualified_op_name, *args, **kwargs 1772 ) 1773 except RuntimeError as e: 1774 unimplemented(str(e)) 1775 1776 op = getattr(target, overload) 1777 if torch.Tag.pt2_compliant_tag in op.tags: 1778 encountered_compliant_op(op) 1779 else: 1780 encountered_non_compliant_op( 1781 op, 1782 f"Encountered the torch.ops.OpOverloadPacket {target} " 1783 f"which resolves to the overload ({overload}) that is " 1784 f"not PT2 compliant.", 1785 ) 1786 1787 1788_compile_id_counter = itertools.count() 1789 1790 1791class SubgraphTracer(fx.Tracer): 1792 """ 1793 Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer 1794 and the separation of responsibilities is that SubgraphTracer is 1795 responsible for building the graph while OutputGraph is responsible for 1796 compiling and executing the graph. 1797 """ 1798 1799 def __init__( 1800 self, output_graph, parent=None, export_root=False, source_target=None 1801 ): 1802 super().__init__() 1803 self.output_graph = weakref.proxy(output_graph) 1804 self.graph = torch.fx.Graph() 1805 1806 # The export is only ever set for the ROOT tracer. It controls 1807 # whether or not certain inputs are allowed to be added or not. 1808 # Look at call sites of create_graph_input to see how it is used. 1809 if export_root: 1810 assert parent is None 1811 self.export_root = export_root 1812 # Map from graph input name to its placeholder proxy object, where the 1813 # map's keys give all current placeholder node names and can be used to 1814 # create unique node names 1815 self.input_name_to_proxy: Dict[str, fx.Proxy] = {} 1816 # Node => computed real value (see utils.get_real_value) 1817 self.real_value_cache: Dict[fx.Node, torch.Tensor] = {} 1818 1819 # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design] 1820 self.parent = parent 1821 # A dict mapping previously free variables (Proxy objects) 1822 # to new Proxy objects that wrap inputs to this subgraph. 1823 # 1824 # This dict serves two purposes: 1825 # - Proxies are associated with VariableTrackers. If we see 1826 # the same VariableTracker twice (and it is a free variable), 1827 # then we want to use the same Proxy in the current subgraph to 1828 # record the tracing. 1829 # - If we are tracing a HigherOrderOperator's body_fn, then we 1830 # need to keep track of what free variables were lifted so we can 1831 # rewrite the HigherOrderOperator call using the traced body_fn. 1832 # Dicts maintain the order of args for the HigherOrderOperator call. 1833 self.lifted_freevars = {} 1834 self.prev_inst = None 1835 1836 self._cur_code = None 1837 self._orig_gm_meta = None 1838 self._orig_gm_lineno_map = None 1839 self._orig_gm_firstlineno = None 1840 # Each SubgraphTracer is associated with a source target, which indicates 1841 # which operator this subgraph is attached to. We compute a source_fn_stack 1842 # based on the source target. For the root tracer, it's set to []. 1843 # This is useful for debugging and transforming the exported graph. 1844 if self.parent is None: 1845 self.source_fn_stack = [] 1846 else: 1847 self.source_fn_stack = self.parent.source_fn_stack + [ 1848 (self.graph._target_to_str(source_target), source_target) 1849 ] 1850 1851 # preserve original meta if it is available 1852 def _maybe_preserve_original_meta(self, tx, node): 1853 if ( 1854 self._orig_gm_meta 1855 and self._orig_gm_lineno_map 1856 and self._orig_gm_firstlineno 1857 ): 1858 lineno = tx.current_instruction.starts_line 1859 node_idx = None 1860 if lineno is not None: 1861 node_idx = self._orig_gm_lineno_map.get( 1862 lineno - self._orig_gm_firstlineno, None 1863 ) 1864 if node_idx is not None: 1865 meta = self._orig_gm_meta[node_idx] 1866 for field in fx.proxy._COPY_META_FIELDS: 1867 if field in meta: 1868 node.meta[field] = meta[field] 1869 if "stack_trace" in meta: 1870 node.meta["stack_trace"] = meta["stack_trace"] 1871 1872 def create_proxy( 1873 self, 1874 kind, 1875 target, 1876 args, 1877 kwargs, 1878 name=None, 1879 type_expr=None, 1880 proxy_factory_fn=None, 1881 ): 1882 # NOTE: [Nested SubgraphTracer and free_variable handling] 1883 # -------------------------------------------------------- 1884 # Read NOTE [HigherOrderOperator tracing design] first. 1885 # 1886 # Let's say we're in the middle of introspecting the body of a possibly 1887 # nested HigherOrderOperator, and we see a free variable. 1888 # 1889 # There are two cases: 1890 # 1. We see a free variable that is already tracked by Dynamo. 1891 # 2. We see a free variable that has not been tracked by Dynamo 1892 # 1893 # In case 1, we call `maybe_lift_tracked_freevar_to_input` (below) 1894 # which will lift the freevar to be an input of this subgraph 1895 # and also recursively lift it to be an input on the parent(s). 1896 # 1897 # In case 2, before the call to `create_proxy`, the InstructionTranslator 1898 # will see the freevar when it gets loaded by Python bytecode. 1899 # E.g. for Python 3.11 the bytecodes that may do this are LOAD_DEREF or 1900 # LOAD_GLOBAL. 1901 # There, the InstructionTranslator asks Dynamo to begin tracking the 1902 # freevar by building a new Variable. 1903 # Building a new Variable automatically lifts the freevar to be an 1904 # input of the root SubgraphTracer. 1905 # 1906 # The implications for the code below are: 1907 # - We will always be in Case 1 when we get to this code. 1908 # - Any "free variable" we encounter here is guaranteed to already be 1909 # bound, that is, it is either a graph input of the root graph, or 1910 # some local variable of the root graph or a subgraph. 1911 # - The additional work we need to do here is *only* that we need to 1912 # lift this free variable into inputs (recursively) of each nested 1913 # higher-order-op subgraph until we hit the subgraph where the free 1914 # variable is bound 1915 if self.parent is not None: 1916 flat_args, tree_spec = pytree.tree_flatten((args, kwargs)) 1917 new_flat_args = [] 1918 for arg in flat_args: 1919 maybe_new_arg = self.maybe_lift_tracked_freevar_to_input(arg) 1920 new_flat_args.append(maybe_new_arg) 1921 1922 args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec) 1923 1924 rv = super().create_proxy( 1925 kind, target, args, kwargs, name, type_expr, proxy_factory_fn 1926 ) 1927 1928 # append stack trace to fx node 1929 tx = self.output_graph.current_tx 1930 1931 # log detailed location of line of code in 3.11 1932 if sys.version_info >= (3, 11) and kind in ( 1933 "call_function", 1934 "call_method", 1935 "call_module", 1936 ): 1937 cur_inst = tx.current_instruction 1938 if ( 1939 cur_inst is not self.prev_inst 1940 and cur_inst.positions is not None 1941 and cur_inst.positions.lineno is not None 1942 ): 1943 tx_code = tx.f_code 1944 header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno) 1945 1946 def get_trace_call_log_str(): 1947 line = get_instruction_source_311(tx_code, cur_inst).rstrip() 1948 return f"TRACE FX call {rv.node.name} from {header}\n{line}" 1949 1950 trace_call_log.debug("%s", LazyString(get_trace_call_log_str)) 1951 self.prev_inst = cur_inst 1952 1953 # update reference to original meta if we're tracing a new code object 1954 is_retracing = False 1955 if tx.f_code is not self._cur_code: 1956 orig_graphmodule_maybe = code_context.get_context(tx.f_code).get( 1957 "orig_graphmodule", lambda: None 1958 )() 1959 if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule): 1960 is_retracing = True 1961 self._orig_gm_meta = [ 1962 nd.meta for nd in orig_graphmodule_maybe.graph.nodes 1963 ] 1964 self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map 1965 self._orig_gm_firstlineno = ( 1966 orig_graphmodule_maybe.forward.__code__.co_firstlineno 1967 ) 1968 else: 1969 self._orig_gm_meta = None 1970 self._orig_gm_lineno_map = None 1971 self._orig_gm_firstlineno = None 1972 nn_module_stack = tx.nn_module_stack 1973 if nn_module_stack: 1974 rv.node.meta["nn_module_stack"] = nn_module_stack.copy() 1975 1976 if kind in {"call_function", "call_method"}: 1977 rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ 1978 (rv.node.name, target) 1979 ] 1980 elif kind == "call_module": 1981 if self.parent is not None: 1982 unimplemented("Invoking an nn.Module inside HigherOrderOperator") 1983 # For modules we store the class 1984 rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ 1985 ( 1986 rv.node.name, 1987 rv.node.meta["nn_module_stack"][target][1], 1988 ) 1989 ] 1990 1991 self._maybe_preserve_original_meta(tx, rv.node) 1992 1993 if not is_retracing: 1994 if "nn_module_stack" not in rv.node.meta: 1995 nn_module_stack = tx.nn_module_stack 1996 if nn_module_stack: 1997 rv.node.meta["nn_module_stack"] = nn_module_stack.copy() 1998 1999 if "source_fn_stack" not in rv.node.meta: 2000 if kind in {"call_function", "call_method"}: 2001 rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ 2002 (rv.node.name, target) 2003 ] 2004 elif kind == "call_module": 2005 if self.parent is not None: 2006 unimplemented( 2007 "Invoking an nn.Module inside HigherOrderOperator" 2008 ) 2009 # For modules we store the class 2010 rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ 2011 ( 2012 rv.node.name, 2013 rv.node.meta["nn_module_stack"][target][1], 2014 ) 2015 ] 2016 2017 if "stack_trace" not in rv.node.meta: 2018 frame_summaries: List[traceback.FrameSummary] = [] 2019 while tx: 2020 # Avoid frame summaries from inside the torch/nn/modules. This ensures that we keep the stack trace of 2021 # the user code. 2022 if not tx.is_co_filename_from_nn_modules(): 2023 frame_summaries.append(tx.frame_summary()) 2024 tx = getattr(tx, "parent", None) 2025 # Reverse the frame_summaries, such that the innermost frame is at the last 2026 frame_summaries.reverse() 2027 2028 # official from_list stub doesn't have new-style type 2029 msgs = traceback.StackSummary.from_list(frame_summaries).format() 2030 rv.node.stack_trace = "".join(msgs) 2031 2032 return rv 2033 2034 def create_node( 2035 self, op, target, args=None, kwargs=None, name=None, type_expr=None 2036 ): 2037 check_pt2_compliant_op(self.output_graph, op, target, args, kwargs) 2038 if self.parent is not None: 2039 flat_args = pytree.arg_tree_leaves(*args, **kwargs) 2040 for arg in flat_args: 2041 if not isinstance(arg, torch.fx.Node): 2042 continue 2043 assert ( 2044 arg.graph == self.graph 2045 ), "create_node using arg not from this SubgraphTracer" 2046 2047 node = super().create_node(op, target, args, kwargs, name, type_expr) 2048 node.meta["creation_timestamp"] = self.output_graph.timestamp 2049 return node 2050 2051 # Note: we did not override erase_node since 2052 # we call self.graph.erase_node elsewhere 2053 def remove_node(self, node): 2054 if len(node.users) > 0: 2055 user_graph_nodes: List[torch.fx.Node] = [] 2056 for user in node.users.keys(): 2057 # For the case where user.graph == self.graph, that is a real bug and will raise 2058 # properly. 2059 if user.graph != self.graph: 2060 # This is a nested graph, which needs to be deleted. 2061 # If we do not do this, we will raise on attempting to remove this. 2062 # As we only get here during restoration cleanup, this is sound. 2063 user_graph_nodes.extend(reversed(list(user.graph.nodes))) 2064 for other_graph_node in user_graph_nodes: 2065 other_graph_node.graph.erase_node(other_graph_node) 2066 self.graph.erase_node(node) 2067 self.input_name_to_proxy.pop(node.name, None) 2068 2069 # when before=True, we will insert this input before the most recent 2070 # inserted proxy. This is a hack to get around an ordering problem, 2071 # where we first insert a tensor argument, and then insert bindings 2072 # for SymInts that may occur in the tensor argument. 2073 # Remove this if https://github.com/pytorch/pytorch/issues/99007 gets 2074 # fixed. 2075 def create_graph_input(self, name, type_expr=None, before=False, source=None): 2076 log.debug( 2077 "create_graph_input %s %s", 2078 name, 2079 source.name() if source is not None else "(none)", 2080 ) 2081 if source is None: 2082 assert ( 2083 self.parent is not None 2084 ), "you are required to provide a source for inputs on the root tracer" 2085 2086 # In eager, we are generally OK with adding graph inputs whenever we 2087 # want, because we take care of writing the bytecode that knows how 2088 # to source all the inputs. 2089 # 2090 # In export, this is bad, because you want a self-contained export 2091 # object which only depends on the inputs you explicitly passed to it. 2092 # So we are a bit more strict about what sources can become inputs 2093 # in export 2094 if self.export_root: 2095 if not is_from_local_source(source, allow_cell_or_freevar=False): 2096 self.output_graph.source_to_user_stacks.setdefault(source, []).append( 2097 TracingContext.extract_stack() 2098 ) 2099 2100 # unique 2101 if name in self.input_name_to_proxy: 2102 for i in itertools.count(): 2103 candidate_name = f"{name}_{i}" 2104 if candidate_name not in self.input_name_to_proxy: 2105 name = candidate_name 2106 break 2107 2108 if self.input_name_to_proxy: 2109 prev_name = next(reversed(self.input_name_to_proxy)) 2110 node = self.input_name_to_proxy[prev_name].node 2111 if before: 2112 ctx = self.graph.inserting_before(node) 2113 else: 2114 ctx = self.graph.inserting_after(node) 2115 else: 2116 ctx = self.graph.inserting_before(None) 2117 with ctx: 2118 proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr) 2119 if self.input_name_to_proxy and before: 2120 k, v = self.input_name_to_proxy.popitem() 2121 self.input_name_to_proxy[name] = proxy 2122 self.input_name_to_proxy[k] = v 2123 else: 2124 self.input_name_to_proxy[name] = proxy 2125 return proxy 2126 2127 # See NOTE: [Nested SubgraphTracer and free_variable handling] for more details 2128 def lift_tracked_freevar_to_input(self, proxy): 2129 # You're doing something wrong if we are the root SubgraphTracer because 2130 # Dynamo adds tensors to graph inputs before creating a proxy for them. 2131 assert ( 2132 self.parent is not None 2133 ), "lift_tracked_freevar_to_input should not be called on root SubgraphTracer" 2134 # Proxys are associated with VariableTracker. 2135 # It is possible that we've already lifted the Proxy to be an input. 2136 # If that is the case, just return the already lifted Proxy. 2137 if proxy in self.lifted_freevars: 2138 return self.lifted_freevars[proxy] 2139 new_proxy = self.create_graph_input(proxy.node.name) 2140 set_example_value(new_proxy.node, proxy.node.meta["example_value"]) 2141 self.lifted_freevars[proxy] = new_proxy 2142 if self.parent is not None and proxy.tracer != self.parent: 2143 self.parent.lift_tracked_freevar_to_input(proxy) 2144 return new_proxy 2145 2146 def maybe_lift_tracked_freevar_to_input(self, arg): 2147 """ 2148 If arg is a free variable, then lift it to be an input. 2149 Returns the new lifted arg (if arg was a freevar), else the 2150 original arg. 2151 """ 2152 if not isinstance(arg, torch.fx.Proxy): 2153 return arg 2154 elif arg.tracer == self: 2155 return arg 2156 return self.lift_tracked_freevar_to_input(arg) 2157 2158 2159# NOTE: [HigherOrderOperator tracing design] 2160# Ignoring HigherOrderOperators for a moment, 2161# OutputGraph represents the graph being built by Dynamo that may be compiled 2162# and executed. It holds a root SubgraphTracer where the FX graph is built. 2163# 2164# HigherOrderOperators are operators that take functions as their arguments. 2165# When Dynamo encounters a HigherOrderOperator, then it attempts to introspect 2166# the function passed to it (call this the "body function"), capture it into a 2167# GraphModule, and rewrite the call to the HigherOrderOperator to use the 2168# GraphModule. 2169# 2170# The way we handle the capture of body functions is through having 2171# (possibly nested) SubgraphTracers, one per body function. 2172# 2173# Mechanically, we do the introspection by: 2174# - Creating a new SubgraphTracer via OutputGraph.subtracer 2175# - Executing the body function. 2176# This constructs the graph of the body function in the new SubgraphTracer 2177# while modifying the state of the OutputGraph. For example: 2178# - the OutputGraph can receive new GraphArgs (if we discover any new 2179# untracked Tensors) 2180# - side effects from the body function get accumulated into 2181# OutputGraph.side_effects 2182# - guards produced by the body function get accumulated into OutputGraph.guards 2183# 2184# The traced function has some special properties that make it easier for us 2185# to transform later down the line: 2186# - we lift all free variables to being inputs. 2187# 2188# If the introspection fails (due to the existence of graph breaks), then 2189# we roll back the current OutputGraph state and graph break on the 2190# HigherOrderOperator. 2191