xref: /aosp_15_r20/external/pytorch/torch/_dynamo/output_graph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import contextlib
4import copy
5import dataclasses
6import functools
7import itertools
8import json
9import logging
10import operator
11import re
12import sys
13import traceback
14import weakref
15from dataclasses import dataclass
16from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
17
18import sympy
19
20import torch._guards
21import torch._logging
22import torch.distributed as dist
23import torch.nn
24import torch.utils._pytree as pytree
25from torch import fx
26from torch._guards import GlobalContextCheckpointState, Source, TracingContext
27from torch._utils_internal import signpost_event
28from torch.fx._lazy_graph_module import _make_graph_module  # type: ignore[attr-defined]
29from torch.fx.experimental._backward_state import BackwardState
30from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv
31from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
32from torch.utils._python_dispatch import is_traceable_wrapper_subclass
33
34from . import config, exc, logging as torchdynamo_logging, variables
35from .backends.registry import CompiledFn, CompilerFn
36from .bytecode_transformation import (
37    create_call_function,
38    create_instruction,
39    Instruction,
40    unique_id,
41)
42from .code_context import code_context
43from .codegen import PyCodegen
44from .current_scope_id import enter_new_scope
45from .exc import (
46    BackendCompilerFailed,
47    exceptions_allowed_to_be_fallback,
48    SkipFrame,
49    unimplemented,
50    unimplemented_with_warning,
51)
52from .guards import GuardBuilder, install_guard
53from .mutation_guard import is_dynamic_nn_module
54from .side_effects import AttributeMutationExisting, SideEffects
55from .source import (
56    AttrSource,
57    BackwardStateSource,
58    ConstantSource,
59    GetItemSource,
60    GlobalStateSource,
61    is_constant_source,
62    is_from_local_source,
63    LocalSource,
64    ParamBufferSource,
65    ShapeEnvSource,
66    SyntheticLocalSource,
67    TensorProperty,
68    TensorPropertySource,
69)
70from .utils import (
71    _extract_tensor_dict,
72    checkpoint_params,
73    CleanupHook,
74    clone_inputs,
75    count_calls,
76    counters,
77    dynamo_timed,
78    get_instruction_source_311,
79    get_locals_to_steal,
80    get_static_address_type,
81    get_torch_function_mode_stack,
82    graph_break_reasons,
83    increment_op_count,
84    lazy_format_graph_code,
85    LazyString,
86    nn_module_proxy,
87    same,
88    set_example_value,
89)
90from .variables.base import VariableTracker
91from .variables.builder import (
92    BackwardStateGraphArg,
93    GraphArg,
94    TrackedFake,
95    VariableBuilder,
96    wrap_fx_proxy,
97)
98from .variables.lists import BaseListVariable
99from .variables.misc import NullVariable
100from .variables.nn_module import NNModuleVariable
101from .variables.tensor import (
102    NumpyNdarrayVariable,
103    SymNodeVariable,
104    TensorVariable,
105    UnspecializedPythonVariable,
106)
107from .variables.torch_function import TensorWithTFOverrideVariable
108
109
110if TYPE_CHECKING:
111    from torch._dynamo.symbolic_convert import InstructionTranslatorBase
112
113
114log = logging.getLogger(__name__)
115graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
116graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
117graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
118trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
119
120
121@dataclass(frozen=True)
122class VariableTrackerCacheKey:
123    vt_id: int
124    # Two different source can point to the same object. However, Dynamo handles
125    # globals and local source differently when it comes to guards and possibly
126    # some other parts as well. So, cache also relies on the source.
127    source: Source
128
129
130class VariableTrackerCache:
131    def __init__(self):
132        self.cache = {}
133
134    def lookup(self, value, source):
135        key = VariableTrackerCacheKey(id(value), source)
136        if key not in self.cache:
137            return None
138        return self.cache[key]
139
140    def add(self, value, source, vt):
141        key = VariableTrackerCacheKey(id(value), source)
142        self.cache[key] = vt
143
144    def clone(self):
145        # Needed for copy and restore graph state
146        new_cache = VariableTrackerCache()
147        new_cache.cache.update(self.cache)
148        return new_cache
149
150    def clear(self):
151        self.cache.clear()
152
153
154@functools.lru_cache(None)
155def _step_logger():
156    return torchdynamo_logging.get_step_logger(log)
157
158
159@dataclass
160class GraphCompileReason:
161    """Stores why a given output graph was compiled; i.e. what caused the graph break."""
162
163    reason: str
164    user_stack: List[traceback.FrameSummary]
165
166    # Indicates if this was a graph compile reason due to graph break.
167    graph_break: bool = True
168
169    def __post_init__(self):
170        if self.graph_break:
171            graph_break_reasons.append(self)
172
173
174def _get_gen_rand_values_fn(random_calls):
175    def _gen_rand_values():
176        return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
177
178    return _gen_rand_values
179
180
181class FakeRootModule(torch.nn.Module):
182    """Trick the constructor of fx.GraphModule"""
183
184    def __init__(self, nn_modules: Dict[str, torch.nn.Module]):
185        super().__init__()
186        for k, v in nn_modules.items():
187            setattr(self, k, v)
188
189    def __repr__(self):
190        return "FakeRootModule(...)"
191
192
193class WrapperBackend:
194    def __init__(self, backend: CompilerFn):
195        self.backend: CompilerFn = backend
196
197    def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
198        self.restore = checkpoint_params(gm)
199        self.gm = gm
200        copy_gm = copy.deepcopy(self.gm)
201        self.candidate = self.backend(copy_gm, example_inputs)
202
203        if self.candidate is None or self.candidate is self.gm.forward:
204            return self.gm.forward
205
206        if not config.verify_correctness:
207            return self.candidate
208
209        # if verify_correctness=True
210        try:
211            correct = self.gm.forward(*clone_inputs(example_inputs))
212            result = self.candidate(*clone_inputs(example_inputs))
213
214            # TODO: replace `same` function with the one in testing
215            if same(correct, result):
216                return self.candidate
217
218            raise RuntimeError(f"incorrect results of backend {self}")
219            return self.gm.forward
220
221        except Exception:
222            log.exception("error in verify_correctness")
223            raise
224        finally:
225            self.restore()
226
227
228Scope = Dict[str, object]
229
230
231class OutputGraph:
232    """
233    Wrapper class to hold outputs of InstructionTranslator.  Mainly the
234    generated fx.Graph.
235
236    OutputGraph is 1:1 with a frame being processed. Each frame is associated
237    with some root InstructionTranslator. When user code calls a function,
238    we construct a InliningInstructionTranslator that continues to write into
239    the root InstructionTranslator's OutputGraph.
240    """
241
242    def __init__(
243        self,
244        code_options: Dict[str, Any],
245        compiler_fn: Optional[CompilerFn],
246        root_tx,
247        export: bool,
248        export_constraints,
249        frame_state,
250        local_scope: Scope,
251        global_scope: Scope,
252        f_code,
253    ):
254        super().__init__()
255        self.tracers = [SubgraphTracer(self, export_root=export)]
256        # Map from graph input's `Source` to its `VariableTracker` to
257        # de-duplicate graph inputs by source and reuse the tracker
258        self.input_source_to_var: Dict[Source, VariableTracker] = {}
259        self.export = export
260        self.export_constraints = export_constraints
261        self.frame_state = frame_state
262        # Map from graph input's `Source` to sizes / strides metadata
263        self.input_source_to_sizes_strides: Dict[Source, Dict[str, Any]] = {}
264        self.cleanup_hooks: List[Callable[[], Any]] = []
265        # compile_id is an id number for the current torch.compile
266        self.compile_id: int = next(_compile_id_counter)
267        # Set of globals installed via install_global* APIs
268        self.installed_globals: Set[str] = set()
269
270        # TODO: maybe should just pass the entire f_code in here?  Not
271        # sure...
272        self.co_fields = {
273            "co_name": f_code.co_name,
274            "co_filename": f_code.co_filename,
275            "co_firstlineno": f_code.co_firstlineno,
276        }
277
278        # tracked_fakes says where any tensor that was wrapped to fake came
279        # from.  It is similar to GraphArg, in that all GraphArgs will get
280        # will get added to TrackedFakes, but TrackedFakes also contains
281        # GraphArgs that got pruned, and things like Tensor attributes which
282        # aren't explicit graph inputs.  Used by shape guard
283        self.tracked_fakes: List[TrackedFake] = []
284
285        # List of symbols for which we have exact bindings in the arguments
286        # already
287        self.bound_symbols: Set[sympy.Symbol] = set()
288
289        shape_env = ShapeEnv(
290            # Reference Cycle!
291            # Share a reference to the list of TrackedFake.
292            #
293            # ShapeEnv needs this in order to be able to reproduce the call
294            # to produce_guards at an arbitrary time point. That is because
295            # TrackedFake instances may have its metadata changed throughout
296            # the program execution.
297            tracked_fakes=self.tracked_fakes,
298            allow_scalar_outputs=config.capture_scalar_outputs,
299            allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
300            prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
301            allow_complex_guards_as_runtime_asserts=config.allow_complex_guards_as_runtime_asserts,
302            co_fields=self.co_fields,
303        )
304
305        # In export mode, we force the shape_env to strictly disallow any constraining
306        # of the user marked dynamic dims
307        import torch._functorch.config as _config
308
309        with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
310            fake_mode = torch._subclasses.FakeTensorMode(
311                shape_env=shape_env,
312                # TODO (tmanlaibaatar) Remove this once we always lift params and buffers
313                allow_non_fake_inputs=True if self.export else False,
314                export=self.export,
315            )
316        self.tracing_context: TracingContext = TracingContext(fake_mode)
317        self.init_ambient_guards()
318
319        # Map each tensor id to a list of sources. This is necessary because
320        # tensor ids cannot be recovered from tracked fakes (in general).
321        # We use this map to interpret (i.e., check for violations of) constraints,
322        # specifically equality constraints, which have shared tensor ids in them.
323        # This map should also be generally useful, e.g., for (de)serialization.
324        self.tracked_fakes_id_to_source: Dict[
325            int, List[Source]
326        ] = collections.defaultdict(list)
327        # Stores the full fqn of a param or buffer to the relevant source.
328        self.param_name_to_source: Optional[Dict[str, Source]] = {}
329        self.side_effects = SideEffects()
330        # Cached variable trackers. This makes symbolic analysis of LOAD_GLOBAL
331        # and LOAD_ATTR for same python objects free.
332        self.variable_tracker_cache = VariableTrackerCache()
333        self.unique_var_id = itertools.count()
334        self.code_options = dict(code_options)
335        self.output_instructions: List[Instruction] = []
336        # used to track nodes that are added between calls of copy_graphstate
337        # and restore_graphstate
338        self.timestamp = 0
339
340        # A list of register_finalizer_fns to apply to the output graph module
341        self.register_finalizer_fns: List[Callable[[fx.GraphModule], None]] = []
342
343        # Not checkpointed
344        self.compiler_fn: Optional[CompilerFn] = compiler_fn
345        self.global_scope = global_scope
346        self.local_scope = local_scope
347        self.root_tx = root_tx
348
349        # Given a source, what are the user stacks of all locations that
350        # accessed it?
351        #
352        # For efficiency, we only populate this:
353        #   - During export, and
354        #   - If the source could potentially lead to a spurious export input
355        #
356        # Feel free to populate this more frequently if other use-cases arise,
357        # but be aware that we have to generate full stacks for each
358        # recording!
359        self.source_to_user_stacks: Dict[Source, List[traceback.StackSummary]] = {}
360
361        self._current_tx: List[InstructionTranslatorBase] = []
362        self.cleanups: List[CleanupHook] = []
363        self.should_exit = False
364        self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {}
365
366        # Note this returns true iff TF Mode and TF Subclasses are enabled
367        self.torch_function_enabled = torch._C._is_torch_function_enabled()
368        # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
369        self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
370        # This records the initial torch function mode stack for guarding
371        self.torch_function_mode_stack = get_torch_function_mode_stack()
372
373        # Tracks if the output graph has a user defined allowed function in the
374        # graph. This is used later to determine if we should fallback to eager
375        # for certain exceptions. THe idea is that if the user has applied
376        # allow_in_graph, they would like to see the error instead of falling
377        # back for backend errors.
378        self.has_user_defined_allowed_in_graph = False
379
380        # Tracks a list of called ops that were not tagged with "pt2_compliant_tag".
381        # This information is useful for logging.
382        self.non_compliant_ops: Set[torch._ops.OpOverload] = set({})
383
384        # Tracks a list of called custom ops that were tagged with "pt2_compliant_tag".
385        # This information is useful for logging.
386        self.compliant_custom_ops: Set[torch._ops.OpOverload] = set({})
387
388        # We save the global torch state here to be restored in case of graph
389        # breaks. The relevant issue is seen here
390        # https://github.com/pytorch/pytorch/pull/100570#issuecomment-1543427086
391        # where inlining of a function changes the global state (because of the
392        # presence of torch.no_grad) and there is a graph break.
393        self.save_global_state()
394
395        # Tracks the original FQNs of the constant tensors from the original graph,
396        # i.e. buffers and parameters.
397        self.dynamo_flat_name_to_original_fqn: Dict[str, str] = {}
398
399        # All calls to random() are replaced with a single call to __gen_rand_values
400        # functions that returns a tuple of random values for each original call.
401        # random_calls tracks calls to random() and random_values_var stores the name of
402        # the variable that stores __gen_rand_values results.
403        self.random_calls: List[
404            Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
405        ] = []
406        self.random_values_var = None
407
408        # Bytecode to insert right before we call the graph
409        self.pregraph_bytecode: List[Instruction] = []
410
411        # Use to pass values to backward hooks when using compiled autograd
412        self.backward_state: Dict[str, VariableTracker] = {}
413        self.backward_state_proxy: Optional[torch.fx.Proxy] = None
414        self.backward_state_var: Optional[str] = None
415
416        self.name_of_builtins_dict_key_in_fglobals: str = (
417            self.install_builtins_dict_in_fglobals()
418        )
419
420        self.guard_on_key_order: Set[str] = set()
421
422    def install_builtins_dict_in_fglobals(self):
423        # f_globals["__builtins__"] can be a dict or a module. This is an
424        # implemenation detail -
425        # https://docs.python.org/3/library/builtins.html.
426
427        # This makes guarding on any builtin messy because the guard check_fn
428        # has to check if the __builtins__ is a module or dict, and then access
429        # by either using getattr or getitem respectively.
430
431        # To solve this problem, we insert a new entry in f_globals which points
432        # to the builtins __dict__ and then we guard any builtin on this dict.
433        # To avoid any collision with the pre-existing keys, we use the
434        # install_global to give us a unique dict key.
435
436        f_builtins = self.global_scope["__builtins__"]
437        if not isinstance(f_builtins, dict):
438            f_builtins = f_builtins.__dict__
439        return self.install_global("__builtins_dict__", f_builtins)
440
441    def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"):
442        name = f"{prefix}{len(self.backward_state)}"
443        assert name not in self.backward_state
444        self.backward_state[name] = hook
445        return name, self.get_backward_state_proxy()
446
447    def get_backward_state_proxy(self):
448        if self.backward_state_proxy is None:
449            if self.export:
450                unimplemented("backward_state does not support export")
451            self.backward_state_proxy = self.root_tracer.create_graph_input(
452                "dynamo_backward_state", BackwardState, source=BackwardStateSource()
453            )
454            self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg()
455            set_example_value(self.backward_state_proxy.node, BackwardState())
456            self.backward_state_var = self.new_var()
457        return self.backward_state_proxy
458
459    # This gets its own helper function so guards DEBUG logs are more informative
460    def init_ambient_guards(self):
461        # Register a SHAPE_ENV guard to make sure we setup shape guards
462        # that show up in ShapeEnv
463        self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
464
465        self.guards.add(
466            GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS)
467        )
468
469        self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE))
470
471        self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
472
473        self.guards.add(
474            GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
475        )
476
477        ci = torch._C._functorch.peek_interpreter_stack()
478        if ci is not None:
479            self.guards.add(
480                GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH)
481            )
482
483    def synthetic_graph_input(self, fn, args):
484        """
485        call fn(*args) before the graph runs and turn the result into a fake input.
486        """
487        example_value = fn(*args)
488        varname = self.new_var()
489        cg = PyCodegen(self.root_tx)
490        cg.add_push_null(
491            lambda: cg.load_import_from(
492                fn.__module__,
493                fn.__name__,
494            )
495        )
496        cg.foreach(map(variables.ConstantVariable.create, args))
497        cg.call_function(len(args), False)
498        cg.store(varname)
499        self.pregraph_bytecode.extend(cg.get_instructions())
500        source = SyntheticLocalSource(varname)
501        result = VariableBuilder(self.root_tx, source)(example_value)
502        TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
503            source
504        )
505        return result
506
507    def add_cleanup_hook(self, fn: Callable[[], Any]):
508        self.cleanup_hooks.append(fn)
509
510    def call_cleanup_hooks(self):
511        for hook in reversed(self.cleanup_hooks):
512            hook()
513        self.cleanup_hooks.clear()
514
515    @property
516    def root_tracer(self):
517        return self.tracers[0]
518
519    @property
520    def current_tracer(self):
521        return self.tracers[-1]
522
523    def is_root_tracer(self):
524        # Helper to tell if we are inside the higher order operator tracing.
525        return len(self.tracers) == 1
526
527    @property
528    def graph(self):
529        return self.current_tracer.graph
530
531    # TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer.
532    @graph.setter
533    def graph(self, value):
534        self.current_tracer.graph = value
535
536    @property
537    def input_name_to_proxy(self):
538        return self.current_tracer.input_name_to_proxy
539
540    @property
541    def real_value_cache(self):
542        return self.current_tracer.real_value_cache
543
544    # If you are here, and you're looking for create_graph_input,
545    # to avoid ambiguity, please call one of the following:
546    # - self.current_tracer.create_graph_input
547    # - self.root_tracer.create_graph_input
548    # See NOTE [HigherOrderOperator tracing design] for more context.
549
550    def create_proxy(self, *args, **kwargs):
551        return self.current_tracer.create_proxy(*args, **kwargs)
552
553    def create_node(self, *args, **kwargs):
554        return self.current_tracer.create_node(*args, **kwargs)
555
556    def remove_node(self, *args, **kwargs):
557        return self.current_tracer.remove_node(*args, **kwargs)
558
559    @contextlib.contextmanager
560    def subtracer(self, source_target, prior_tracer):
561        new_scope_ctx = enter_new_scope()
562        try:
563            if prior_tracer:
564                # Lineage MUST stay preserved
565                assert prior_tracer.parent is self.current_tracer
566            new_scope_ctx.__enter__()
567            tracer = (
568                prior_tracer
569                if prior_tracer
570                else SubgraphTracer(
571                    self, parent=self.current_tracer, source_target=source_target
572                )
573            )
574            self.tracers.append(tracer)
575            yield tracer
576        finally:
577            new_scope_ctx.__exit__(None, None, None)
578            self.tracers.pop()
579
580    @property
581    def output(self):
582        return self
583
584    @property
585    def fake_mode(self):
586        return self.tracing_context.fake_mode
587
588    @property
589    def shape_env(self):
590        return self.tracing_context.fake_mode.shape_env
591
592    @property
593    def guards(self) -> torch._guards.GuardsSet:
594        return self.tracing_context.guards_context.dynamo_guards
595
596    @property
597    def nn_modules(self) -> Dict[str, Any]:
598        return self.tracing_context.module_context.nn_modules
599
600    def save_global_state(self, out=None):
601        """
602        Saves to out if it is provided. Else saves to the tracing context's global_state.
603        """
604        global_state = (
605            out if out is not None else self.tracing_context.global_context.global_state
606        )
607
608        # TODO - Consider having a torch level API for torch_function_state. As
609        # of now, we create a ref cycle by passing the
610        # output.set_torch_function_state to
611        # output.tracing_context.global_context.global_state. In the interim,
612        # the problem can be solved by manually set
613        # output.tracing_context.global_context.global_state to None at cleanup.
614        global_state["torch_function_enabled"] = (
615            self.set_torch_function_state,
616            self.torch_function_enabled,
617        )
618        global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())
619
620        global_state["autocast_enabled"] = (
621            functools.partial(torch.set_autocast_enabled, "cuda"),
622            torch.is_autocast_enabled("cuda"),
623        )
624        global_state["autocast_cpu_enabled"] = (
625            functools.partial(torch.set_autocast_enabled, "cpu"),
626            torch.is_autocast_enabled("cpu"),
627        )
628        global_state["autocast_gpu_dtype"] = (
629            functools.partial(torch.set_autocast_dtype, "cuda"),
630            torch.get_autocast_dtype("cuda"),
631        )
632        global_state["autocast_cpu_dtype"] = (
633            functools.partial(torch.set_autocast_dtype, "cpu"),
634            torch.get_autocast_dtype("cpu"),
635        )
636        global_state["autocast_cache_enabled"] = (
637            torch.set_autocast_cache_enabled,
638            torch.is_autocast_cache_enabled(),
639        )
640
641    def push_tx(self, tx):
642        self._current_tx.append(tx)
643
644    def pop_tx(self):
645        return self._current_tx.pop()
646
647    @property
648    def current_tx(self):
649        return self.root_tx if not self._current_tx else self._current_tx[-1]
650
651    def add_symbol_bindings(self, arg: GraphArg):
652        # Insert implicit size vars as necessary.  With dynamic shapes, we
653        # maintain the invariant that every sizevar gets a direct SymInt input
654        # into the graph.  This means downstream graph transforms can assume
655        # every size variable is explicitly bound and accessible, instead of
656        # having to pull it out implicitly from tensors.
657
658        if self.export:
659            return
660
661        assert arg.fake_tensor is not None
662
663        def bind_symint(s, prop):
664            if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)):
665                return
666            s0 = s.node.expr
667            if s0 in self.bound_symbols:
668                return
669            self.bound_symbols.add(s0)
670            log.debug("bind_symint %s %s", s, prop.name())
671            # TODO: don't readd symint if we already have it in graph
672            # (this is harmless because we do remove the unused ones later)
673            proxy = self.root_tracer.create_graph_input(
674                str(s0),
675                torch.SymInt,
676                before=True,
677                source=prop,
678            )
679            set_example_value(proxy.node, s)
680            proxy.node.meta["grapharg"] = GraphArg(
681                prop,
682                s,
683                pass_arg_as_tensor=False,
684                fake_tensor=None,
685                is_tensor=False,
686            )
687
688        def handle_tensor(t, src):
689            for i, s in enumerate(t.size()):
690                bind_symint(s, TensorPropertySource(src, TensorProperty.SIZE, i))
691            if t.layout is torch.strided:
692                for i, s in enumerate(t.stride()):
693                    bind_symint(s, TensorPropertySource(src, TensorProperty.STRIDE, i))
694                bind_symint(
695                    t.storage_offset(),
696                    TensorPropertySource(src, TensorProperty.STORAGE_OFFSET),
697                )
698            elif t.layout is torch.sparse_coo:
699                handle_tensor(t._indices(), src)
700                handle_tensor(t._values(), src)
701            elif t.layout in {torch.sparse_csr, torch.sparse_bsr}:
702                handle_tensor(t.crow_indices(), src)
703                handle_tensor(t.col_indices(), src)
704            elif t.layout in {torch.sparse_csc, torch.sparse_bsc}:
705                handle_tensor(t.ccol_indices(), src)
706                handle_tensor(t.row_indices(), src)
707            if is_traceable_wrapper_subclass(t):
708                attrs, ctx = t.__tensor_flatten__()
709                for attr in attrs:
710                    inner_t = getattr(t, attr)
711                    handle_tensor(inner_t, AttrSource(src, attr))
712
713        handle_tensor(arg.fake_tensor, arg.source)
714
715    def count_calls(self):
716        return count_calls(self.graph)
717
718    def is_empty_graph(self):
719        return len(list(self.graph.nodes)) == 0
720
721    def get_submodule(self, keys):
722        assert keys
723        obj: Union[torch.nn.Module, Dict[str, torch.nn.Module]] = self.nn_modules
724        for k in keys.split("."):
725            if isinstance(obj, dict):
726                obj = obj[k]
727            else:
728                obj = getattr(obj, k)
729        return obj
730
731    def new_var(self, name="tmp"):
732        existing = set(self.code_options["co_varnames"])
733        # In common case, this will be O(1)
734        while True:
735            var = f"{name}_{next(self.unique_var_id)}"
736            if var not in existing:
737                self.code_options["co_varnames"] += (var,)
738                return var
739
740    def update_co_names(self, name):
741        """Ensure self.code_options.co_names contains name"""
742        if name not in self.code_options["co_names"]:
743            self.code_options["co_names"] += (name,)
744
745    @staticmethod
746    def module_key_name(*names):
747        # create a new unique name
748        name = "_".join(map(str, names))
749        # Strip the guard lookup L/G access
750        name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name)
751        # e.g. replace abc.xyz[123].qkv with abc.xyz_123.qkv
752        name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
753        # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
754        name = re.sub(r"[^a-zA-Z0-9]", "_", name)
755
756        if not name or not name[0].isalpha():
757            name = "sub" + name
758
759        return name
760
761    def register_attr_or_module(
762        self,
763        target: Union[torch.nn.Module, torch.Tensor, Any],
764        *names,
765        **options,
766    ):
767        if is_dynamic_nn_module(target, self.root_tx.export):
768            # Instead of returning UnspecializedNNModuleVariable, call
769            # VariableBuilder so that it is tracked for mutation.
770            return VariableBuilder(self.current_tx, **options)(target)
771
772        options = dict(options)
773        assert "source" in options
774        source = options["source"]
775        assert not isinstance(source, ParamBufferSource)
776
777        if isinstance(target, torch.Tensor):
778            tracer = self.current_tracer
779            if not self.is_root_tracer():
780                # For higher order ops, we don't want to insert the get_attr in
781                # innermost graph. Instead, we want to raise the params/buffers
782                # as inputs to the higher-order graph, and register them as
783                # get_attrs in the root tracer.
784
785                # Note that Dynamo will still call lift_tracked_freevar_to_input
786                # when these inputs are encountered for the inner graph. The
787                # only difference is what happens at the root tracer for
788                # nn.Parameters vs free inputs. The free inputs are registered
789                # as placeholders in the root graph, whereas the nn.Parameters
790                # are registered as get_attr nodes in the root graph.
791                tracer = self.root_tracer
792
793            def wrap_name(module_key):
794                assert self.param_name_to_source is not None
795                self.param_name_to_source[module_key] = source
796
797                # Check if the attr has already been registered. This can happen
798                # when two different sources point to the same tensor.
799                if target in self.root_tx.output.side_effects:
800                    return self.root_tx.output.side_effects[target]
801
802                if get_static_address_type(target) == "guarded":
803                    install_guard(source.make_guard(GuardBuilder.ID_MATCH))
804                elif not is_constant_source(source):
805                    install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH))
806
807                vt = wrap_fx_proxy(
808                    self.root_tx,
809                    tracer.create_proxy("get_attr", module_key, (), {}),
810                    example_value=target,
811                    **options,
812                )
813
814                # Track the object so to avoid duplicate registration in case of
815                # different sources pointing to the same tensor object.
816                vt = self.root_tx.output.side_effects.track_object_existing(target, vt)
817
818                assert "tensor_dict" not in vt.proxy.node.meta
819                vt.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(target)
820
821                return vt
822
823        elif isinstance(target, torch.nn.Module):
824            assert isinstance(target, torch.nn.Module)
825
826            if source:
827                install_guard(source.make_guard(GuardBuilder.NN_MODULE))
828
829                def wrap_name(module_key):
830                    return NNModuleVariable(type(target), module_key, target, **options)
831
832            else:
833                # This is Dynamo created graph module, e.g., graph module coming
834                # from higher order ops. NNModuleVariable tracker can't be
835                # sourceless, so let's return a unspecializedNNModule variable
836                # tracker.
837                def wrap_name(module_key):
838                    return variables.UnspecializedNNModuleVariable(target, **options)
839
840        elif isinstance(target, (torch.SymInt, torch.SymFloat)):
841            # HACKY CODE REGION BEGIN
842            # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS
843            # This ultimately gets written to self.nn_modules, which is unfortunate
844            # Attrs that are tenors and symints and such need to be migrated to have their
845            # own storage
846            # alas, this is like this for now
847
848            def wrap_name(module_key):
849                return SymNodeVariable.create(
850                    self,
851                    self.create_proxy("get_attr", module_key, (), {}),
852                    sym_num=target,
853                    **options,
854                )
855
856            # HACKY CODE REGION END
857        else:
858
859            def wrap_name(module_key):
860                self.output.update_co_names(module_key)
861                self.global_scope[module_key] = target
862                return VariableBuilder(self, ConstantSource(source_name=module_key))(
863                    target
864                )
865
866        for k, v in self.nn_modules.items():
867            if v is target:
868                # it already exists
869                return wrap_name(k)
870
871        name = OutputGraph.module_key_name(*names)
872
873        base = name
874        for i in itertools.count():
875            if name not in self.nn_modules:
876                self.nn_modules[name] = target
877                if isinstance(target, torch.nn.Module):
878
879                    def register_leaf_name(leaf_name):
880                        assert self.param_name_to_source is not None
881                        new_source = ParamBufferSource(source, leaf_name)
882                        new_name = f"{name}.{leaf_name}"
883                        self.param_name_to_source[new_name] = new_source
884                        if isinstance(source, LocalSource):
885                            self.dynamo_flat_name_to_original_fqn[
886                                OutputGraph.module_key_name(new_source.name())
887                            ] = leaf_name
888
889                    # annoying, but there are cases when we do not have parameters
890                    # see test_nn_moduledict_contains
891                    if hasattr(target, "_parameters"):
892                        for leaf_name, _ in target.named_parameters():
893                            register_leaf_name(leaf_name)
894                    if hasattr(target, "_buffers"):
895                        for leaf_name, _ in target.named_buffers():
896                            register_leaf_name(leaf_name)
897
898                return wrap_name(name)
899            name = f"{base}_{i}"
900
901        raise AssertionError("unreachable")
902
903    def handle_aliases_for_stolen_lists(self, tx):
904        # If list inputs are stolen, but still needed after the function call, create aliases to keep them alive
905        maybe_gm = self.local_scope.get("self")
906        stolen_list_names = get_locals_to_steal(maybe_gm)
907        if not stolen_list_names:
908            return []
909
910        alias_insts = []
911        needs_alias: Dict[
912            str, List[Union[VariableTracker, AttributeMutationExisting]]
913        ] = {}
914
915        queue = [
916            *tx.stack,
917            *tx.symbolic_locals.values(),
918            *self.side_effects.store_attr_mutations.keys(),
919        ]
920
921        while queue:
922            x = queue.pop()
923            if isinstance(x, BaseListVariable):
924                assert isinstance(x.items, List)
925                queue += x.items
926                continue
927
928            if not (
929                isinstance(x, (VariableTracker, AttributeMutationExisting))
930                and isinstance(x.source, GetItemSource)
931                and isinstance(x.source.base, LocalSource)
932                and x.source.base.local_name in stolen_list_names
933            ):
934                continue
935
936            stolen_name = x.source.base.local_name
937            if stolen_name not in needs_alias:
938                needs_alias[stolen_name] = []
939            needs_alias[stolen_name].append(x)
940
941        visited = {}
942        for arg in self.graphargs:
943            if not (
944                isinstance(arg._example, list)
945                and isinstance(arg.source, LocalSource)
946                and arg.source.local_name in needs_alias
947            ):
948                continue
949
950            # arg is a list that will be cleared by the compiled function
951            list_name = arg.source.local_name
952            assert list_name in self.code_options["co_varnames"]
953            for x in needs_alias[list_name]:
954                list_idx = x.source.index
955                if list_idx not in visited:
956                    alias_name = self.new_var(
957                        f"{list_name}_ref"
958                    )  # self.new_var already adds unique id suffix
959
960                    visited[list_idx] = alias_name
961                    # bytecode of `alias_name = list_name[list_idx]`
962                    alias_insts.extend(
963                        [
964                            create_instruction("LOAD_FAST", argval=list_name),
965                            create_instruction("LOAD_CONST", argval=list_idx),
966                            create_instruction("BINARY_SUBSCR"),
967                            create_instruction("STORE_FAST", argval=alias_name),
968                        ]
969                    )
970
971                # operate on alias, handled by suffix codegen
972                x.source = LocalSource(visited[list_idx])
973
974        return alias_insts
975
976    def compile_subgraph(
977        self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None
978    ):
979        """
980        Generate a subgraph to continue execution on user code.
981        Automatically restore live variables.
982        """
983        assert reason is not None
984
985        from .decorators import disable
986
987        self.partial_convert = partial_convert
988        self.compile_subgraph_reason = reason
989        self.should_exit = True
990
991        log.debug("COMPILING GRAPH due to %s", reason)
992
993        if not all(block.can_restore() for block in tx.block_stack):
994            unimplemented("compile_subgraph with block_depth != 0")
995
996        prefix_insts: List[Instruction] = []
997        if sys.version_info >= (3, 11):
998            # prefix instructions (Python 3.11+)
999            for inst in tx.prefix_insts:
1000                if inst.opname == "MAKE_CELL":
1001                    prefix_insts.append(
1002                        create_instruction("MAKE_CELL", argval=inst.argval)
1003                    )
1004                elif inst.opname == "COPY_FREE_VARS":
1005                    prefix_insts.append(
1006                        create_instruction(
1007                            "COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"])
1008                        )
1009                    )
1010                else:
1011                    prefix_insts.append(copy.copy(inst))
1012        assert not (
1013            self.pregraph_bytecode and self.export
1014        ), "export does not support pregraph_bytecode"
1015        prefix_insts.extend(self.pregraph_bytecode)
1016        prefix_insts.extend(self.handle_aliases_for_stolen_lists(tx))
1017
1018        def append_prefix_insts():
1019            self.add_output_instructions(prefix_insts)
1020            prefix_insts.clear()
1021
1022        for block in reversed(tx.block_stack):
1023            block.exit(tx)
1024
1025        self.cleanup_graph()
1026        tx.prune_dead_locals()
1027        stack_values = list(tx.stack)
1028
1029        # realize any unrealized tensor VTs in case they
1030        # need to be added to self.nn_modules as attributes
1031        for value in stack_values:
1032            value.realize()
1033
1034        # Use nn.Module "proxies" in the constructed GraphModule so that
1035        # the resulting GM does not hold additional strong references to the original modules.
1036        # This prevents a strong ref cycle where Dynamo created code holds on to references
1037        # to modules that also have Dynamo code cache invalidation checks.
1038        # When cache invalidation runs, the generated GM will be invalidated, which also deletes
1039        # the proxies.
1040        nn_modules_proxies = {
1041            name: nn_module_proxy(mod) for name, mod in self.nn_modules.items()
1042        }
1043        root = FakeRootModule(nn_modules_proxies)
1044        # Add all the local vars to the "stack" so restore at the end
1045        restore_vars = []
1046        val_to_names: Dict[VariableTracker, List[str]] = {}
1047        if stack_values:
1048            val_to_names[stack_values[-1]] = []
1049        # NB: Typically (i.e., for graph compile from RETURN_VALUE),
1050        # symbolic_locals will be empty at this point, as prune_dead_locals
1051        # will clear out all of symbolic_locals because RETURN_VALUE is the
1052        # last instruction and no more locals are used.  The fanciness here
1053        # is only needed for partial graphs.
1054        for k, v in tx.symbolic_locals.items():
1055            # Note! this explicitly uses .local_name for matching
1056            # Failure to do so will cause spurious registrations in val_to_names.
1057            # This will in turn result in spurious variables showing up in the graph.
1058            # This was very tricky to debug. For an example, dump the graph at call_user_compiler
1059            # while running test_subgraphs.py
1060            if isinstance(v.source, LocalSource) and v.source.local_name == k:
1061                continue  # no need to restore initial state
1062            # Do not load variable if it is NULL.
1063            if sys.version_info >= (3, 12):
1064                # Continuation function will load the NULL for v.
1065                if type.__instancecheck__(NullVariable, v):
1066                    continue
1067            else:
1068                # A variable should never be NULL in < 3.12
1069                assert not type.__instancecheck__(NullVariable, v)
1070            if v not in val_to_names:
1071                val_to_names[v] = []
1072            val_to_names[v].append(k)
1073        for v in val_to_names.keys():
1074            restore_vars.extend(val_to_names[v])
1075            stack_values.extend([v] * len(val_to_names[v]))
1076
1077        # to handle random calls
1078        if len(self.random_calls) > 0:
1079            append_prefix_insts()
1080            random_calls_instructions = []
1081            self.random_values_var = self.new_var("random_values")
1082            rand_fn = disable(_get_gen_rand_values_fn(self.random_calls))
1083            rand_fn_name = self.install_global("__gen_rand_values", rand_fn)
1084            codegen = PyCodegen(tx, root)
1085            random_calls_instructions.extend(
1086                codegen.load_function_name(rand_fn_name, True)
1087            )
1088            random_calls_instructions.extend(create_call_function(0, False))
1089            random_calls_instructions.append(
1090                codegen.create_store(tx.output.random_values_var),
1091            )
1092            self.add_output_instructions(random_calls_instructions)
1093
1094        if (
1095            stack_values
1096            and all(
1097                not isinstance(
1098                    v,
1099                    (
1100                        UnspecializedPythonVariable,
1101                        NumpyNdarrayVariable,
1102                        TensorWithTFOverrideVariable,
1103                    ),
1104                )
1105                and not (isinstance(v, SymNodeVariable) and v.python_type() is float)
1106                for v in stack_values
1107            )
1108            and all(isinstance(x, TensorVariable) for x in stack_values)
1109            and len(set(stack_values)) == len(stack_values)
1110            and self.side_effects.is_empty()
1111            and not len(tx.debug_locals) != 0
1112            and not self.backward_state
1113        ):
1114            append_prefix_insts()
1115            # optimization to generate better code in a common case
1116            self.add_output_instructions(
1117                self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
1118                + [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))]
1119            )
1120            # restore all the live local vars
1121            self.add_output_instructions(
1122                [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
1123            )
1124        else:
1125            graph_output_var = self.new_var("graph_out")
1126            pass1 = PyCodegen(tx, root, graph_output_var)
1127            self.codegen_suffix(tx, stack_values, pass1)
1128
1129            # one more time now that we have established tempvars
1130            pass2 = PyCodegen(
1131                tx,
1132                root,
1133                graph_output_var,
1134                tempvars={val: None for val, count in pass1.uses.items() if count > 1},
1135            )
1136            self.codegen_suffix(tx, stack_values, pass2)
1137
1138            stored_graph_output_var = False
1139            output = []
1140            if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
1141                output.extend(
1142                    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
1143                )
1144
1145                if len(pass2.graph_outputs) != 0:
1146                    output.append(pass2.create_store(graph_output_var))
1147                    stored_graph_output_var = True
1148                else:
1149                    output.append(create_instruction("POP_TOP"))
1150            else:
1151                # NB: Important to run compiler collective even when there is
1152                # a graph break
1153                self.run_compiler_collective(tx)
1154            append_prefix_insts()
1155            self.add_output_instructions(output + pass2.get_instructions())
1156
1157            # restore all the live local vars
1158            self.add_output_instructions(
1159                [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
1160            )
1161
1162            if stored_graph_output_var:
1163                self.add_output_instructions(
1164                    [PyCodegen(tx).create_delete(graph_output_var)]
1165                )
1166
1167    def codegen_suffix(self, tx, stack_values, cg):
1168        if self.backward_state:
1169            assert not self.export
1170            for name, val in self.backward_state.items():
1171                cg(val)
1172                cg.append_output(cg.create_load(self.backward_state_var))
1173                cg.store_attr(name)
1174        self.side_effects.codegen_hooks(cg)
1175        self.side_effects.codegen_save_tempvars(cg)
1176
1177        # Return variables used for logging at the end
1178        for debug_var, args in tx.debug_locals:
1179            cg.add_push_null(lambda: cg(debug_var))
1180            for arg in args:
1181                cg(arg)
1182            cg.extend_output(create_call_function(len(args), False))
1183            cg.extend_output([create_instruction("POP_TOP")])
1184
1185        cg.restore_stack(stack_values, value_from_source=not tx.export)
1186        self.side_effects.codegen_update_mutated(cg)
1187
1188    def cleanup_graph(self):
1189        """
1190        Remove "creation_timestamp" from node meta
1191
1192        Remove this pattern from the graph:
1193            torch._C._set_grad_enabled(False)
1194            torch._C._set_grad_enabled(True)
1195        """
1196        assert self.should_exit
1197        nodes = list(self.graph.nodes)
1198        for node in nodes:
1199            node.meta.pop("creation_timestamp", None)
1200
1201        grad_enabled = torch.is_grad_enabled()
1202        for node1, node2 in zip(nodes, nodes[1:]):
1203            if (
1204                node1.target is torch._C._set_grad_enabled
1205                and tuple(node1.args) == (not grad_enabled,)
1206                and not node1._erased
1207            ):
1208                grad_enabled = node1.args[0]
1209                if (
1210                    node2.target is torch._C._set_grad_enabled
1211                    and tuple(node2.args) == (not grad_enabled,)
1212                    and not node2._erased
1213                ):
1214                    grad_enabled = node2.args[0]
1215                    self.graph.erase_node(node1)
1216                    self.graph.erase_node(node2)
1217
1218    def get_graph_sizes_structured(self):
1219        ret = {}
1220        for node in self.graph.nodes:
1221            example_value = node.meta.get("example_value", None)
1222            if isinstance(example_value, torch._subclasses.FakeTensor):
1223                size = example_value.size()
1224                ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size]
1225        return ret
1226
1227    def get_graph_sizes(self, name: str):
1228        graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n"
1229        graph_sizes_str += f"===== {name} =====\n"
1230        for node in self.graph.nodes:
1231            example_value = node.meta.get("example_value", None)
1232            if isinstance(example_value, torch._subclasses.FakeTensor):
1233                size = example_value.size()
1234                graph_sizes_str += f"{node.name}: {tuple(size)}\n"
1235                concrete_size = []
1236                has_symint = False
1237                for sz in size:
1238                    if isinstance(sz, int):
1239                        concrete_size.append(sz)
1240                    elif isinstance(sz, torch.SymInt):
1241                        has_symint = True
1242                        concrete_size.append(sz.node.hint)
1243                    else:
1244                        break
1245                else:
1246                    if has_symint:
1247                        graph_sizes_str += (
1248                            f"{node.name} (concrete): {tuple(concrete_size)}\n"
1249                        )
1250        return graph_sizes_str
1251
1252    @contextlib.contextmanager
1253    def restore_global_state(self):
1254        """
1255        Momentarily restores the global state to what it was prior to tracing the current output
1256        """
1257        prior_global_state = self.tracing_context.global_context.copy_graphstate()
1258        current_global_state: Dict[str, Tuple[Any, bool]] = {}
1259        self.save_global_state(out=current_global_state)
1260        try:
1261            # Set to state prior to tracing the graph
1262            self.tracing_context.global_context.restore_graphstate(prior_global_state)
1263            yield
1264        finally:
1265            # Reset to state at the current time (e.g. before calling the user compiler)
1266            self.tracing_context.global_context.restore_graphstate(
1267                GlobalContextCheckpointState(current_global_state)
1268            )
1269
1270    def run_compiler_collective(self, tx):
1271        if (ds := tx.distributed_state) is not None and ds.all_states is None:
1272            compile_pg = ds.compile_pg
1273            log.info("compiler_collective %s", ds.local_state)
1274            torch._logging.trace_structured(
1275                "artifact",
1276                metadata_fn=lambda: {
1277                    "name": "compiler_collective",
1278                    "encoding": "json",
1279                },
1280                payload_fn=lambda: json.dumps(
1281                    dataclasses.asdict(ds.local_state),
1282                ),
1283            )
1284            with torch.cuda.device(compile_pg.rank() % torch.cuda.device_count()):
1285                all_states = [None] * compile_pg.size()
1286                dist.all_gather_object(all_states, ds.local_state, group=compile_pg)
1287                ds.all_states = all_states
1288            # Clear speculation log, because are tracing may diverge due to
1289            # this information from the compiler collective
1290            tx.speculation_log.clear()
1291            raise exc.CompileCollectiveRestartAnalysis
1292
1293    def compile_and_call_fx_graph(self, tx, rv, root):
1294        """
1295        Generate code from self.graph and return the Instruction()s to
1296        call that generated code.
1297        """
1298        with torch._guards.TracingContext.clear_frame():
1299            from .decorators import disable
1300
1301            assert self.should_exit
1302
1303            self.run_compiler_collective(tx)
1304
1305            name = unique_id("__compiled_fn")
1306
1307            assert isinstance(rv, list)
1308            assert isinstance(root, FakeRootModule)
1309            output_node = self.create_node(
1310                "output",
1311                "output",
1312                (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
1313                {},
1314            )
1315            tx.output.current_tracer._maybe_preserve_original_meta(tx, output_node)
1316            if not config.do_not_emit_runtime_asserts:
1317                insert_deferred_runtime_asserts(
1318                    fx.GraphModule(root, self.graph),
1319                    self.shape_env,
1320                    name,
1321                )
1322            # NB: deferred runtime asserts can keep graphargs live, so make sure
1323            # those are inserted before pruning
1324            self.remove_unused_graphargs()
1325            ncalls = count_calls(self.graph)
1326            counters["stats"]["calls_captured"] += ncalls
1327
1328            # free a bit of memory
1329            self.real_value_cache.clear()
1330
1331            gm = _make_graph_module(root, self.graph)
1332            for register_finalizer in self.register_finalizer_fns:
1333                register_finalizer(gm)
1334
1335            gm.compile_subgraph_reason = self.compile_subgraph_reason
1336            gm.meta[
1337                "dynamo_flat_name_to_original_fqn"
1338            ] = self.dynamo_flat_name_to_original_fqn.copy()
1339
1340            graph_code_log.debug(
1341                "%s",
1342                lazy_format_graph_code(
1343                    name, gm, include_stride=True, include_device=True, colored=True
1344                ),
1345            )
1346            torch._logging.trace_structured(
1347                "dynamo_output_graph",
1348                lambda: {"sizes": self.get_graph_sizes_structured()},
1349                payload_fn=lambda: gm.print_readable(
1350                    print_output=False, include_stride=True, include_device=True
1351                ),
1352            )
1353            self.call_cleanup_hooks()
1354            old_fake_mode = self.tracing_context.fake_mode
1355            if not self.export:
1356                import torch._functorch.config as _config
1357
1358                with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
1359                    # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
1360                    backend_fake_mode = torch._subclasses.FakeTensorMode(
1361                        shape_env=old_fake_mode.shape_env,
1362                    )
1363                # TODO(voz): Ostensibily, this should be scoped and
1364                # restore back to old_fake_mode, but doing so currently violates
1365                # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
1366                self.tracing_context.fake_mode = backend_fake_mode
1367
1368            with self.restore_global_state():
1369                compiled_fn = self.call_user_compiler(gm)
1370
1371            from torch.fx._lazy_graph_module import _LazyGraphModule
1372
1373            if isinstance(compiled_fn, _LazyGraphModule) or (
1374                isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule)
1375                and compiled_fn.__name__ == "_lazy_forward"  # type: ignore[attr-defined]
1376            ):
1377                # Since dynamo will run the forward method for the GraphModule shortly
1378                # anyways, it does not hurt to do the real recompilation here if
1379                # this is a _LazyGraphModule. This makes it easier for dynamo to
1380                # optimize a _LazyGraphModule.
1381
1382                lazy_gm = (
1383                    compiled_fn
1384                    if isinstance(compiled_fn, _LazyGraphModule)
1385                    else compiled_fn.__self__  # type: ignore[attr-defined]
1386                )
1387
1388                _LazyGraphModule.force_recompile(lazy_gm)
1389
1390                if not isinstance(compiled_fn, _LazyGraphModule):
1391                    # replace compiled_fn with the real forward method
1392                    compiled_fn = lazy_gm.forward
1393
1394            compiled_fn = disable(compiled_fn)
1395
1396            counters["stats"]["unique_graphs"] += 1
1397            # This is safe because we pre-process name to be unique
1398            self.install_global_unsafe(name, compiled_fn)
1399
1400            cg = PyCodegen(tx)
1401            cg.make_call_generated_code(name)
1402            return cg.get_instructions()
1403
1404    @property
1405    def placeholders(self) -> List[fx.Node]:
1406        return self.graph.find_nodes(op="placeholder")
1407
1408    @property
1409    def graphargs(self) -> List[GraphArg]:
1410        return [node.meta["grapharg"] for node in self.placeholders]
1411
1412    def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
1413        with dynamo_timed(
1414            "OutputGraph.call_user_compiler", phase_name="backend_compile"
1415        ):
1416            return self._call_user_compiler(gm)
1417
1418    def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
1419        assert self.compiler_fn is not None
1420        tot = 0
1421        placeholders = []
1422        for node in gm.graph.nodes:
1423            if node.op in ("call_function", "call_method", "call_module"):
1424                tot += 1
1425            if node.op == "placeholder":
1426                placeholders.append(node)
1427        increment_op_count(tot)
1428        for pl in placeholders:
1429            arg = pl.meta["grapharg"]
1430            # TODO: Why isn't this stored in meta :think:
1431            pl._dynamo_source = arg.source
1432
1433        gm._param_name_to_source = self.param_name_to_source  # type: ignore[assignment]
1434        gm._source_to_user_stacks = self.source_to_user_stacks  # type: ignore[assignment]
1435
1436        try:
1437            name = (
1438                self.compiler_fn.__name__
1439                if hasattr(self.compiler_fn, "__name__")
1440                else ""
1441            )
1442            _step_logger()(logging.INFO, f"calling compiler function {name}")
1443            compiler_fn = self.compiler_fn
1444            if config.verify_correctness:
1445                compiler_fn = WrapperBackend(compiler_fn)
1446            compiled_fn = compiler_fn(gm, self.example_inputs())
1447            _step_logger()(logging.INFO, f"done compiler function {name}")
1448            assert callable(compiled_fn), "compiler_fn did not return callable"
1449        except exceptions_allowed_to_be_fallback as e:
1450            if self.has_user_defined_allowed_in_graph:
1451                raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
1452                    e.__traceback__
1453                ) from None
1454            msg = (
1455                "Backend compiler failed with a fake tensor exception at \n"
1456                f"{self.root_tx.format_frame_summary()}"
1457                "Adding a graph break."
1458            )
1459            unimplemented_with_warning(e, self.root_tx.f_code, msg)
1460        except SkipFrame as e:
1461            # The backend compiler has requested that we skip the frame, instead of
1462            # aborting execution.
1463            raise e
1464        except Exception as e:
1465            raise BackendCompilerFailed(self.compiler_fn, e) from e
1466
1467        signpost_event(
1468            "dynamo",
1469            "OutputGraph.call_user_compiler",
1470            {
1471                **self.co_fields,
1472                "op_count": tot,
1473                "node_count": len(gm.graph.nodes),
1474                "input_count": len(placeholders),
1475            },
1476        )
1477
1478        return compiled_fn
1479
1480    def example_inputs(self) -> List[torch.Tensor]:
1481        result = []
1482        for arg in self.graphargs:
1483            result.append(arg.example)
1484        return result
1485
1486    def remove_unused_graphargs(self) -> None:
1487        # NB: It's always OK to drop GraphArg for symbols that ended up being
1488        # specialized.  You don't even have to make a guard for it, because
1489        # ShapeEnv produce_guards operates on tracked_fakes, which never gets
1490        # pruned.  That being said, you'll get marginally better generated
1491        # guard code if you promote the guard into a Dynamo guard (since that
1492        # allows for the guard to be done using C++ guards.)  If we get
1493        # ShapeEnv guards to go into C++ guards, this will stop being a thing
1494        # though!
1495
1496        assert self.should_exit
1497
1498        # Miniature DCE pass, but only for obviously trivial operations
1499        def is_static_true(b_node: fx.node.Argument):
1500            if b_node is True:
1501                return True
1502            if not isinstance(b_node, fx.Node):
1503                return False
1504            b = b_node.meta.get("example_value")
1505            if b is None:
1506                return False
1507            if b is True:
1508                return True
1509            if (
1510                isinstance(b, torch.SymBool)
1511                and (r := b.node.maybe_as_bool()) is not None
1512            ):
1513                return r
1514            # TODO: We can also technically remove all cases when the input
1515            # doesn't have unbacked inputs, since it's all in the ShapeEnv
1516            return False
1517
1518        def is_symnode_arg(a: fx.node.Argument):
1519            from torch.fx.experimental.sym_node import SymTypes
1520
1521            if isinstance(a, (int, float, bool)):
1522                return True
1523            if isinstance(a, fx.Node):
1524                return isinstance(a.meta.get("example_value"), SymTypes)
1525            return False
1526
1527        # NB: We assume that you cannot do mutations on int/float/bool,
1528        # because they are immutable types, and therefore is always safe to
1529        # DCE.
1530        def is_symnode_compute_node(node):
1531            from torch.fx.experimental.sym_node import SymTypes
1532
1533            if node.op != "call_function":
1534                return False
1535            # TODO: I don't think it's possible to have a bare int/float here?
1536            if not isinstance(node.meta.get("example_value"), SymTypes):
1537                return False
1538            # TODO: This will bail here if you ever end up with a more complicated
1539            # computation function, like sum(list_of_ints), even though it
1540            # should be DCE'able
1541            if not all(is_symnode_arg(a) for a in node.args):
1542                return False
1543            if not all(is_symnode_arg(a) for a in node.kwargs.values()):
1544                return False
1545            return True
1546
1547        from torch.fx.experimental.symbolic_shapes import is_accessor_node
1548
1549        for node in reversed(list(self.graph.nodes)):
1550            if len(list(node.users)) == 0:
1551                if (
1552                    node.op == "get_attr"
1553                    or (node.op == "call_function" and node.target is operator.getitem)
1554                    or (
1555                        node.op == "call_function"
1556                        and node.target is torch._check
1557                        and is_static_true(node.args[0])
1558                    )
1559                    or is_symnode_compute_node(node)
1560                    or is_accessor_node(node)
1561                ):
1562                    self.remove_node(node)
1563
1564        def placeholder_binds_symbol(node):
1565            arg = node.meta["grapharg"]
1566            example = arg.example
1567            if isinstance(example, torch.SymInt) and isinstance(
1568                example.node.expr, sympy.Symbol
1569            ):
1570                return example.node.expr
1571            return None
1572
1573        def remove_unused(node):
1574            log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name())
1575            # I'm not really sure why you need to delete these from the
1576            # node since the node is going to get removed
1577            del node.meta["grapharg"]
1578            self.remove_node(node)
1579            self.real_value_cache.pop(node, None)
1580
1581        used_symbols: Set[sympy.Symbol] = set()
1582
1583        def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]):
1584            used_symbols |= free_symbols(fake)
1585
1586        recheck_placeholders = []
1587        for node in self.placeholders:
1588            binds_symbol = placeholder_binds_symbol(node) is not None
1589            # Don't delete symbol bindings yet
1590            if binds_symbol:
1591                if not node.users:
1592                    recheck_placeholders.append(node)
1593            else:
1594                if not node.users and not isinstance(
1595                    node.meta["grapharg"], BackwardStateGraphArg
1596                ):
1597                    remove_unused(node)
1598                else:
1599                    # Register the free symbols as uses
1600                    arg = node.meta["grapharg"]
1601                    if isinstance(arg, BackwardStateGraphArg):
1602                        continue
1603                    if isinstance(node.meta["grapharg"].example, torch.ScriptObject):
1604                        real_script_obj = node.meta["grapharg"].example
1605                        fake_script_obj = node.meta["grapharg"].example_strong_ref
1606                        if not torch._library.fake_class_registry.tracing_with_real(
1607                            real_script_obj
1608                        ):
1609                            flat_dict = dict(real_script_obj.__obj_flatten__())  # type: ignore[attr-defined]
1610                            for attr in flat_dict.keys():
1611                                fake_attr_val = getattr(
1612                                    fake_script_obj.wrapped_obj, attr
1613                                )
1614                                pytree.tree_map_only(
1615                                    (torch.SymInt, torch.Tensor),
1616                                    lambda t: update_used_symbols(used_symbols, t),
1617                                    fake_attr_val,
1618                                )
1619                        continue
1620                    fake = (
1621                        arg.fake_tensor if arg.fake_tensor is not None else arg.example
1622                    )
1623                    update_used_symbols(used_symbols, fake)
1624
1625        # After removing unused graphargs, prune unused binds_symbol
1626        for node in recheck_placeholders:
1627            symbol = placeholder_binds_symbol(node)
1628            if symbol is not None:
1629                if symbol not in used_symbols:
1630                    remove_unused(node)
1631                else:
1632                    # Make sure we delete later occurrences of the same symbol
1633                    used_symbols.remove(symbol)
1634
1635    def add_output_instructions(self, prefix: List[Instruction]) -> None:
1636        """
1637        We call this on the creation of a new compiled subgraph that is inserted
1638        before user code.
1639        """
1640        self.output_instructions.extend(prefix)
1641        self.should_exit = True
1642
1643    def install_global_unsafe(self, name, value) -> None:
1644        """
1645        WARNING: prefer the safer `install_global_by_id/install_global`.
1646        torch.compile instances should be independent of each other;
1647        one footgun is to have one instance depend on the existence of
1648        a global installed by another instance. This can happen if we mangle
1649        a global the same way across both instances.
1650        """
1651        assert name not in self.installed_globals
1652        self.installed_globals.add(name)
1653        self.cleanups.append(CleanupHook.create(self.global_scope, name, value))
1654
1655    def install_global_by_id(self, prefix, value) -> str:
1656        """
1657        Installs a global if it hasn't been installed already.
1658        This is determined by (prefix, id(value)) pair.
1659
1660        Returns the name of the newly installed global.
1661        """
1662        # NB: need self.compile_id to distinguish this global
1663        # from another global created in a different torch.compile instance
1664        name = f"{prefix}_{id(value)}_c{self.compile_id}"
1665        if name in self.installed_globals:
1666            return name
1667        self.install_global_unsafe(name, value)
1668        return name
1669
1670    def install_global(self, prefix, value) -> str:
1671        """
1672        Installs a global, generating a unique name for it.
1673
1674        Returns the name of the newly installed global.
1675        """
1676        # NB: unique_id is unique, even across torch.compile instances
1677        name = unique_id(prefix)
1678        self.install_global_unsafe(name, value)
1679        return name
1680
1681    def cleanup(self) -> None:
1682        # There is a reference cycle between tracer and OutputGraph, causing
1683        # some of the tensor objects to be held alive for longer than necessary.
1684        self.root_tx = None
1685        self.nn_modules.clear()
1686        self.param_name_to_source = None
1687
1688        for node in self.graph.nodes:
1689            if "grapharg" in node.meta:
1690                del node.meta["grapharg"]
1691        self.real_value_cache.clear()
1692        self.input_name_to_proxy.clear()
1693        self.side_effects.clear()
1694        self.variable_tracker_cache.clear()
1695        self.register_finalizer_fns.clear()
1696        self.dynamo_flat_name_to_original_fqn.clear()
1697        self.tracing_context.clear()
1698
1699    def set_torch_function_state(self, enabled: bool) -> None:
1700        self.torch_function_enabled = enabled
1701
1702    def add_graph_finalizer(
1703        self, register_finalizer: Callable[[fx.GraphModule], None]
1704    ) -> None:
1705        self.register_finalizer_fns.append(register_finalizer)
1706
1707    def example_value_from_input_node(self, node: torch.fx.Node):
1708        """Extract the non-fake example tensor"""
1709        if node.op == "placeholder":
1710            return node.meta["grapharg"].example
1711        assert node.op == "get_attr"
1712        return self.nn_modules[node.target]  # type: ignore[index]
1713
1714
1715err_epilogue = (
1716    "With the current config, we will graph break "
1717    "(and fall back to eager-mode PyTorch) on all ops "
1718    "that have do not have the 'pt2_compliant_tag'. "
1719    "Please see the following doc for how to mark this op as PT2 compliant "
1720    "https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html"
1721)
1722
1723
1724def check_pt2_compliant_op(output_graph, kind, target, args, kwargs):
1725    if kind != "call_function":
1726        return
1727
1728    def encountered_compliant_op(target):
1729        if target.namespace in {"prim", "prims", "aten"}:
1730            return
1731        output_graph.compliant_custom_ops.add(target)
1732
1733    def encountered_non_compliant_op(target, msg):
1734        output_graph.non_compliant_ops.add(target)
1735        if config.only_allow_pt2_compliant_ops:
1736            unimplemented(msg + " " + err_epilogue)
1737
1738    if isinstance(target, torch._ops.OpOverload):
1739        if torch.Tag.pt2_compliant_tag in target.tags:
1740            encountered_compliant_op(target)
1741            return
1742        encountered_non_compliant_op(
1743            target,
1744            f"Encountered the torch.ops.OpOverload {target} "
1745            f"that is not PT2 compliant.",
1746        )
1747        return
1748
1749    if isinstance(target, torch._ops.OpOverloadPacket):
1750        overloads = tuple(target.overloads())
1751        # Optimization: Overload resolution is expensive.
1752        # If there's only one overload, we know what it will resolve to.
1753        if len(overloads) == 1:
1754            op = getattr(target, overloads[0])
1755            if torch.Tag.pt2_compliant_tag in op.tags:
1756                encountered_compliant_op(op)
1757                return
1758            encountered_non_compliant_op(
1759                op,
1760                f"Encountered the non-overloaded "
1761                f"torch.ops.OpOverloadPacket {target} "
1762                f"that is not PT2 compliant. ",
1763            )
1764            return
1765
1766        args, kwargs = torch._dynamo.utils.get_fake_values_from_nodes(
1767            output_graph.current_tx, (args, kwargs), False
1768        )
1769        try:
1770            overload = torch._C._jit_resolve_packet(
1771                target._qualified_op_name, *args, **kwargs
1772            )
1773        except RuntimeError as e:
1774            unimplemented(str(e))
1775
1776        op = getattr(target, overload)
1777        if torch.Tag.pt2_compliant_tag in op.tags:
1778            encountered_compliant_op(op)
1779        else:
1780            encountered_non_compliant_op(
1781                op,
1782                f"Encountered the torch.ops.OpOverloadPacket {target} "
1783                f"which resolves to the overload ({overload}) that is "
1784                f"not PT2 compliant.",
1785            )
1786
1787
1788_compile_id_counter = itertools.count()
1789
1790
1791class SubgraphTracer(fx.Tracer):
1792    """
1793    Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer
1794    and the separation of responsibilities is that SubgraphTracer is
1795    responsible for building the graph while OutputGraph is responsible for
1796    compiling and executing the graph.
1797    """
1798
1799    def __init__(
1800        self, output_graph, parent=None, export_root=False, source_target=None
1801    ):
1802        super().__init__()
1803        self.output_graph = weakref.proxy(output_graph)
1804        self.graph = torch.fx.Graph()
1805
1806        # The export is only ever set for the ROOT tracer.  It controls
1807        # whether or not certain inputs are allowed to be added or not.
1808        # Look at call sites of create_graph_input to see how it is used.
1809        if export_root:
1810            assert parent is None
1811        self.export_root = export_root
1812        # Map from graph input name to its placeholder proxy object, where the
1813        # map's keys give all current placeholder node names and can be used to
1814        # create unique node names
1815        self.input_name_to_proxy: Dict[str, fx.Proxy] = {}
1816        # Node => computed real value (see utils.get_real_value)
1817        self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}
1818
1819        # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design]
1820        self.parent = parent
1821        # A dict mapping previously free variables (Proxy objects)
1822        # to new Proxy objects that wrap inputs to this subgraph.
1823        #
1824        # This dict serves two purposes:
1825        # - Proxies are associated with VariableTrackers. If we see
1826        # the same VariableTracker twice (and it is a free variable),
1827        # then we want to use the same Proxy in the current subgraph to
1828        # record the tracing.
1829        # - If we are tracing a HigherOrderOperator's body_fn, then we
1830        # need to keep track of what free variables were lifted so we can
1831        # rewrite the HigherOrderOperator call using the traced body_fn.
1832        # Dicts maintain the order of args for the HigherOrderOperator call.
1833        self.lifted_freevars = {}
1834        self.prev_inst = None
1835
1836        self._cur_code = None
1837        self._orig_gm_meta = None
1838        self._orig_gm_lineno_map = None
1839        self._orig_gm_firstlineno = None
1840        # Each SubgraphTracer is associated with a source target, which indicates
1841        # which operator this subgraph is attached to. We compute a source_fn_stack
1842        # based on the source target. For the root tracer, it's set to [].
1843        # This is useful for debugging and transforming the exported graph.
1844        if self.parent is None:
1845            self.source_fn_stack = []
1846        else:
1847            self.source_fn_stack = self.parent.source_fn_stack + [
1848                (self.graph._target_to_str(source_target), source_target)
1849            ]
1850
1851    # preserve original meta if it is available
1852    def _maybe_preserve_original_meta(self, tx, node):
1853        if (
1854            self._orig_gm_meta
1855            and self._orig_gm_lineno_map
1856            and self._orig_gm_firstlineno
1857        ):
1858            lineno = tx.current_instruction.starts_line
1859            node_idx = None
1860            if lineno is not None:
1861                node_idx = self._orig_gm_lineno_map.get(
1862                    lineno - self._orig_gm_firstlineno, None
1863                )
1864            if node_idx is not None:
1865                meta = self._orig_gm_meta[node_idx]
1866                for field in fx.proxy._COPY_META_FIELDS:
1867                    if field in meta:
1868                        node.meta[field] = meta[field]
1869                if "stack_trace" in meta:
1870                    node.meta["stack_trace"] = meta["stack_trace"]
1871
1872    def create_proxy(
1873        self,
1874        kind,
1875        target,
1876        args,
1877        kwargs,
1878        name=None,
1879        type_expr=None,
1880        proxy_factory_fn=None,
1881    ):
1882        # NOTE: [Nested SubgraphTracer and free_variable handling]
1883        # --------------------------------------------------------
1884        # Read NOTE [HigherOrderOperator tracing design] first.
1885        #
1886        # Let's say we're in the middle of introspecting the body of a possibly
1887        # nested HigherOrderOperator, and we see a free variable.
1888        #
1889        # There are two cases:
1890        # 1. We see a free variable that is already tracked by Dynamo.
1891        # 2. We see a free variable that has not been tracked by Dynamo
1892        #
1893        # In case 1, we call `maybe_lift_tracked_freevar_to_input` (below)
1894        # which will lift the freevar to be an input of this subgraph
1895        # and also recursively lift it to be an input on the parent(s).
1896        #
1897        # In case 2, before the call to `create_proxy`, the InstructionTranslator
1898        # will see the freevar when it gets loaded by Python bytecode.
1899        # E.g. for Python 3.11 the bytecodes that may do this are LOAD_DEREF or
1900        # LOAD_GLOBAL.
1901        # There, the InstructionTranslator asks Dynamo to begin tracking the
1902        # freevar by building a new Variable.
1903        # Building a new Variable automatically lifts the freevar to be an
1904        # input of the root SubgraphTracer.
1905        #
1906        # The implications for the code below are:
1907        # - We will always be in Case 1 when we get to this code.
1908        # - Any "free variable" we encounter here is guaranteed to already be
1909        #   bound, that is, it is either a graph input of the root graph, or
1910        #   some local variable of the root graph or a subgraph.
1911        # - The additional work we need to do here is *only* that we need to
1912        #   lift this free variable into inputs (recursively) of each nested
1913        #   higher-order-op subgraph until we hit the subgraph where the free
1914        #   variable is bound
1915        if self.parent is not None:
1916            flat_args, tree_spec = pytree.tree_flatten((args, kwargs))
1917            new_flat_args = []
1918            for arg in flat_args:
1919                maybe_new_arg = self.maybe_lift_tracked_freevar_to_input(arg)
1920                new_flat_args.append(maybe_new_arg)
1921
1922            args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec)
1923
1924        rv = super().create_proxy(
1925            kind, target, args, kwargs, name, type_expr, proxy_factory_fn
1926        )
1927
1928        # append stack trace to fx node
1929        tx = self.output_graph.current_tx
1930
1931        # log detailed location of line of code in 3.11
1932        if sys.version_info >= (3, 11) and kind in (
1933            "call_function",
1934            "call_method",
1935            "call_module",
1936        ):
1937            cur_inst = tx.current_instruction
1938            if (
1939                cur_inst is not self.prev_inst
1940                and cur_inst.positions is not None
1941                and cur_inst.positions.lineno is not None
1942            ):
1943                tx_code = tx.f_code
1944                header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno)
1945
1946                def get_trace_call_log_str():
1947                    line = get_instruction_source_311(tx_code, cur_inst).rstrip()
1948                    return f"TRACE FX call {rv.node.name} from {header}\n{line}"
1949
1950                trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
1951                self.prev_inst = cur_inst
1952
1953        # update reference to original meta if we're tracing a new code object
1954        is_retracing = False
1955        if tx.f_code is not self._cur_code:
1956            orig_graphmodule_maybe = code_context.get_context(tx.f_code).get(
1957                "orig_graphmodule", lambda: None
1958            )()
1959            if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule):
1960                is_retracing = True
1961                self._orig_gm_meta = [
1962                    nd.meta for nd in orig_graphmodule_maybe.graph.nodes
1963                ]
1964                self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map
1965                self._orig_gm_firstlineno = (
1966                    orig_graphmodule_maybe.forward.__code__.co_firstlineno
1967                )
1968            else:
1969                self._orig_gm_meta = None
1970                self._orig_gm_lineno_map = None
1971                self._orig_gm_firstlineno = None
1972        nn_module_stack = tx.nn_module_stack
1973        if nn_module_stack:
1974            rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
1975
1976        if kind in {"call_function", "call_method"}:
1977            rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
1978                (rv.node.name, target)
1979            ]
1980        elif kind == "call_module":
1981            if self.parent is not None:
1982                unimplemented("Invoking an nn.Module inside HigherOrderOperator")
1983            # For modules we store the class
1984            rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
1985                (
1986                    rv.node.name,
1987                    rv.node.meta["nn_module_stack"][target][1],
1988                )
1989            ]
1990
1991        self._maybe_preserve_original_meta(tx, rv.node)
1992
1993        if not is_retracing:
1994            if "nn_module_stack" not in rv.node.meta:
1995                nn_module_stack = tx.nn_module_stack
1996                if nn_module_stack:
1997                    rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
1998
1999            if "source_fn_stack" not in rv.node.meta:
2000                if kind in {"call_function", "call_method"}:
2001                    rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
2002                        (rv.node.name, target)
2003                    ]
2004                elif kind == "call_module":
2005                    if self.parent is not None:
2006                        unimplemented(
2007                            "Invoking an nn.Module inside HigherOrderOperator"
2008                        )
2009                    # For modules we store the class
2010                    rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
2011                        (
2012                            rv.node.name,
2013                            rv.node.meta["nn_module_stack"][target][1],
2014                        )
2015                    ]
2016
2017        if "stack_trace" not in rv.node.meta:
2018            frame_summaries: List[traceback.FrameSummary] = []
2019            while tx:
2020                # Avoid frame summaries from inside the torch/nn/modules. This ensures that we keep the stack trace of
2021                # the user code.
2022                if not tx.is_co_filename_from_nn_modules():
2023                    frame_summaries.append(tx.frame_summary())
2024                tx = getattr(tx, "parent", None)
2025            # Reverse the frame_summaries, such that the innermost frame is at the last
2026            frame_summaries.reverse()
2027
2028            # official from_list stub doesn't have new-style type
2029            msgs = traceback.StackSummary.from_list(frame_summaries).format()
2030            rv.node.stack_trace = "".join(msgs)
2031
2032        return rv
2033
2034    def create_node(
2035        self, op, target, args=None, kwargs=None, name=None, type_expr=None
2036    ):
2037        check_pt2_compliant_op(self.output_graph, op, target, args, kwargs)
2038        if self.parent is not None:
2039            flat_args = pytree.arg_tree_leaves(*args, **kwargs)
2040            for arg in flat_args:
2041                if not isinstance(arg, torch.fx.Node):
2042                    continue
2043                assert (
2044                    arg.graph == self.graph
2045                ), "create_node using arg not from this SubgraphTracer"
2046
2047        node = super().create_node(op, target, args, kwargs, name, type_expr)
2048        node.meta["creation_timestamp"] = self.output_graph.timestamp
2049        return node
2050
2051    # Note: we did not override erase_node since
2052    # we call self.graph.erase_node elsewhere
2053    def remove_node(self, node):
2054        if len(node.users) > 0:
2055            user_graph_nodes: List[torch.fx.Node] = []
2056            for user in node.users.keys():
2057                # For the case where user.graph == self.graph, that is a real bug and will raise
2058                # properly.
2059                if user.graph != self.graph:
2060                    # This is a nested graph, which needs to be deleted.
2061                    # If we do not do this, we will raise on attempting to remove this.
2062                    # As we only get here during restoration cleanup, this is sound.
2063                    user_graph_nodes.extend(reversed(list(user.graph.nodes)))
2064            for other_graph_node in user_graph_nodes:
2065                other_graph_node.graph.erase_node(other_graph_node)
2066        self.graph.erase_node(node)
2067        self.input_name_to_proxy.pop(node.name, None)
2068
2069    # when before=True, we will insert this input before the most recent
2070    # inserted proxy.  This is a hack to get around an ordering problem,
2071    # where we first insert a tensor argument, and then insert bindings
2072    # for SymInts that may occur in the tensor argument.
2073    # Remove this if https://github.com/pytorch/pytorch/issues/99007 gets
2074    # fixed.
2075    def create_graph_input(self, name, type_expr=None, before=False, source=None):
2076        log.debug(
2077            "create_graph_input %s %s",
2078            name,
2079            source.name() if source is not None else "(none)",
2080        )
2081        if source is None:
2082            assert (
2083                self.parent is not None
2084            ), "you are required to provide a source for inputs on the root tracer"
2085
2086        # In eager, we are generally OK with adding graph inputs whenever we
2087        # want, because we take care of writing the bytecode that knows how
2088        # to source all the inputs.
2089        #
2090        # In export, this is bad, because you want a self-contained export
2091        # object which only depends on the inputs you explicitly passed to it.
2092        # So we are a bit more strict about what sources can become inputs
2093        # in export
2094        if self.export_root:
2095            if not is_from_local_source(source, allow_cell_or_freevar=False):
2096                self.output_graph.source_to_user_stacks.setdefault(source, []).append(
2097                    TracingContext.extract_stack()
2098                )
2099
2100        # unique
2101        if name in self.input_name_to_proxy:
2102            for i in itertools.count():
2103                candidate_name = f"{name}_{i}"
2104                if candidate_name not in self.input_name_to_proxy:
2105                    name = candidate_name
2106                    break
2107
2108        if self.input_name_to_proxy:
2109            prev_name = next(reversed(self.input_name_to_proxy))
2110            node = self.input_name_to_proxy[prev_name].node
2111            if before:
2112                ctx = self.graph.inserting_before(node)
2113            else:
2114                ctx = self.graph.inserting_after(node)
2115        else:
2116            ctx = self.graph.inserting_before(None)
2117        with ctx:
2118            proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
2119            if self.input_name_to_proxy and before:
2120                k, v = self.input_name_to_proxy.popitem()
2121                self.input_name_to_proxy[name] = proxy
2122                self.input_name_to_proxy[k] = v
2123            else:
2124                self.input_name_to_proxy[name] = proxy
2125            return proxy
2126
2127    # See NOTE: [Nested SubgraphTracer and free_variable handling] for more details
2128    def lift_tracked_freevar_to_input(self, proxy):
2129        # You're doing something wrong if we are the root SubgraphTracer because
2130        # Dynamo adds tensors to graph inputs before creating a proxy for them.
2131        assert (
2132            self.parent is not None
2133        ), "lift_tracked_freevar_to_input should not be called on root SubgraphTracer"
2134        # Proxys are associated with VariableTracker.
2135        # It is possible that we've already lifted the Proxy to be an input.
2136        # If that is the case, just return the already lifted Proxy.
2137        if proxy in self.lifted_freevars:
2138            return self.lifted_freevars[proxy]
2139        new_proxy = self.create_graph_input(proxy.node.name)
2140        set_example_value(new_proxy.node, proxy.node.meta["example_value"])
2141        self.lifted_freevars[proxy] = new_proxy
2142        if self.parent is not None and proxy.tracer != self.parent:
2143            self.parent.lift_tracked_freevar_to_input(proxy)
2144        return new_proxy
2145
2146    def maybe_lift_tracked_freevar_to_input(self, arg):
2147        """
2148        If arg is a free variable, then lift it to be an input.
2149        Returns the new lifted arg (if arg was a freevar), else the
2150        original arg.
2151        """
2152        if not isinstance(arg, torch.fx.Proxy):
2153            return arg
2154        elif arg.tracer == self:
2155            return arg
2156        return self.lift_tracked_freevar_to_input(arg)
2157
2158
2159# NOTE: [HigherOrderOperator tracing design]
2160# Ignoring HigherOrderOperators for a moment,
2161# OutputGraph represents the graph being built by Dynamo that may be compiled
2162# and executed. It holds a root SubgraphTracer where the FX graph is built.
2163#
2164# HigherOrderOperators are operators that take functions as their arguments.
2165# When Dynamo encounters a HigherOrderOperator, then it attempts to introspect
2166# the function passed to it (call this the "body function"), capture it into a
2167# GraphModule, and rewrite the call to the HigherOrderOperator to use the
2168# GraphModule.
2169#
2170# The way we handle the capture of body functions is through having
2171# (possibly nested) SubgraphTracers, one per body function.
2172#
2173# Mechanically, we do the introspection by:
2174# - Creating a new SubgraphTracer via OutputGraph.subtracer
2175# - Executing the body function.
2176# This constructs the graph of the body function in the new SubgraphTracer
2177# while modifying the state of the OutputGraph. For example:
2178# - the OutputGraph can receive new GraphArgs (if we discover any new
2179#   untracked Tensors)
2180# - side effects from the body function get accumulated into
2181#   OutputGraph.side_effects
2182# - guards produced by the body function get accumulated into OutputGraph.guards
2183#
2184# The traced function has some special properties that make it easier for us
2185# to transform later down the line:
2186# - we lift all free variables to being inputs.
2187#
2188# If the introspection fails (due to the existence of graph breaks), then
2189# we roll back the current OutputGraph state and graph break on the
2190# HigherOrderOperator.
2191