xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/builder.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import abc
4import collections
5import contextlib
6import dataclasses
7import enum
8import functools
9import inspect
10import itertools
11import logging
12import math
13import operator
14import random
15import re
16import sys
17import types
18import warnings
19import weakref
20from typing import (
21    Any,
22    Callable,
23    Dict,
24    FrozenSet,
25    List,
26    MutableMapping,
27    NamedTuple,
28    Optional,
29    Set,
30    TYPE_CHECKING,
31    Union,
32)
33
34import torch
35from torch import SymInt
36from torch._guards import GuardSource, TracingContext
37from torch._higher_order_ops.torchbind import call_torchbind
38from torch._ops import HigherOrderOperator
39from torch._streambase import _EventBase, _StreamBase
40from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
41from torch._subclasses.meta_utils import is_sparse_any, safe_grad
42from torch._utils_internal import justknobs_check
43from torch.fx.experimental._backward_state import BackwardState
44from torch.fx.experimental.symbolic_shapes import (
45    _constrain_range_for_size,
46    DimDynamic,
47    RelaxedUnspecConstraint,
48    StatefulSymbolicContext,
49    SubclassSymbolicContext,
50    SymbolicContext,
51)
52from torch.fx.immutable_collections import immutable_dict, immutable_list
53from torch.utils._python_dispatch import is_traceable_wrapper_subclass
54from torch.utils._sympy.value_ranges import ValueRanges
55from torch.utils.weak import TensorWeakRef
56
57from .. import config, mutation_guard, replay_record, trace_rules
58from ..device_interface import get_registered_device_interfaces
59from ..exc import InternalTorchDynamoError, unimplemented
60from ..guards import GuardBuilder, install_guard, make_dupe_guard
61from ..side_effects import SideEffects
62from ..source import (
63    AttrProxySource,
64    AttrSource,
65    CallMethodItemSource,
66    ConstantSource,
67    ConstDictKeySource,
68    ConvertIntSource,
69    FloatTensorSource,
70    GetItemSource,
71    GradSource,
72    is_cell_contents,
73    is_constant_source,
74    is_from_defaults,
75    is_from_optimizer_source,
76    LocalSource,
77    NumpyTensorSource,
78    OptimizerSource,
79    RandomValueSource,
80    Source,
81    SubclassAttrListSource,
82    TupleIteratorGetItemSource,
83)
84from ..trace_rules import (
85    is_callable_allowed,
86    is_numpy,
87    is_numpy_dtype,
88    is_numpy_type_info,
89)
90from ..utils import (
91    _extract_tensor_dict,
92    build_checkpoint_variable,
93    clone_input,
94    common_constant_types,
95    get_fake_value,
96    get_locals_to_steal,
97    get_static_address_type,
98    is_frozen_dataclass,
99    is_function_or_wrapper,
100    is_lru_cache_wrapped_function,
101    is_namedtuple,
102    is_parameter_freezing,
103    is_typing,
104    is_utils_checkpoint,
105    is_wrapper_or_member_descriptor,
106    istype,
107    odict_values,
108    proxy_args_kwargs,
109    set_example_value,
110    tensor_always_has_static_shape,
111    tuple_iterator,
112    tuple_iterator_getitem,
113    tuple_iterator_len,
114    unwrap_with_attr_name_if_wrapper,
115    wrap_fake_exception,
116)
117from .base import MutableLocal, typestr, VariableTracker, VariableTrackerMeta
118from .constant import ConstantVariable, EnumVariable
119from .ctx_manager import (
120    AutocastModeVariable,
121    EventVariable,
122    NullContextVariable,
123    PreserveVersionContextVariable,
124    StreamContextVariable,
125    StreamVariable,
126)
127from .dicts import (
128    ConstDictVariable,
129    CustomizedDictVariable,
130    DefaultDictVariable,
131    HFPretrainedConfigVariable,
132    PythonSysModulesVariable,
133    SetVariable,
134)
135from .distributed import (
136    DeviceMeshVariable,
137    PlacementClassVariable,
138    PlacementVariable,
139    ProcessGroupVariable,
140    WorldMetaClassVariable,
141)
142from .functions import (
143    CollectiveFunctionRewriteVariable,
144    FunctoolsPartialVariable,
145    TritonKernelVariable,
146    UserFunctionVariable,
147    UserMethodVariable,
148    WrapperUserFunctionVariable,
149)
150from .higher_order_ops import TorchHigherOrderOperatorVariable
151from .iter import ItertoolsVariable
152from .lazy import LazyVariableTracker
153from .lists import (
154    BaseListVariable,
155    ListVariable,
156    NamedTupleVariable,
157    RangeVariable,
158    RestrictedListSubclassVariable,
159    SizeVariable,
160    SliceVariable,
161    TupleIteratorVariable,
162    TupleVariable,
163)
164from .misc import (
165    AutogradEngineVariable,
166    AutogradFunctionContextVariable,
167    AutogradFunctionVariable,
168    ComptimeVariable,
169    DebuggingVariable,
170    DelayGraphBreakVariable,
171    GetAttrVariable,
172    GetSetDescriptorVariable,
173    InspectSignatureVariable,
174    LambdaVariable,
175    LoggingLoggerVariable,
176    MethodWrapperVariable,
177    NumpyDTypeVariable,
178    NumpyTypeInfoVariable,
179    NumpyVariable,
180    PythonModuleVariable,
181    RandomClassVariable,
182    RandomVariable,
183    RegexPatternVariable,
184    SavedTensorBox,
185    TorchVersionVariable,
186    TypingVariable,
187)
188from .nn_module import (
189    FSDPManagedNNModuleVariable,
190    UnspecializedBuiltinNNModuleVariable,
191    UnspecializedNNModuleVariable,
192)
193from .optimizer import OptimizerVariable
194from .script_object import TorchScriptObjectVariable
195from .sdpa import SDPAParamsVariable
196from .tensor import (
197    NumpyNdarrayVariable,
198    SymNodeVariable,
199    TensorSubclassVariable,
200    TensorVariable,
201    UnspecializedPythonVariable,
202)
203from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
204from .torch_function import (
205    build_torch_function_fn,
206    TensorWithTFOverrideVariable,
207    TorchFunctionModeVariable,
208)
209from .user_defined import (
210    FrozenDataClassVariable,
211    KeyedJaggedTensorVariable,
212    MutableMappingVariable,
213    SourcelessGraphModuleVariable,
214    UserDefinedClassVariable,
215    UserDefinedObjectVariable,
216    WeakRefVariable,
217)
218
219
220try:
221    import numpy as np
222except ModuleNotFoundError:
223    np = None
224
225
226if TYPE_CHECKING:
227    from torch._dynamo.symbolic_convert import InstructionTranslator
228
229
230log = logging.getLogger(__name__)
231static_inputs_log = torch._logging.getArtifactLogger(
232    __name__, "cudagraph_static_inputs"
233)
234
235
236DimList = List
237
238
239def safe_has_grad(t):
240    with warnings.catch_warnings():
241        warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
242        return hasattr(t, "grad")
243
244
245class _missing:
246    pass
247
248
249@dataclasses.dataclass
250class GraphArg:
251    source: Source
252    # TODO: storing a SymInt here but not a FakeTensor is a pretty strange
253    # thing to do.  Probably should have example (which stores an int) and
254    # fake_example
255    _example: Union[TensorWeakRef, torch.SymInt]
256    # When True, this indicates that this GraphArg is a Python quantity (e.g.,
257    # a float or int) which we pass to the FX graph as a Tensor.  This
258    # controls how we codegen calls into the Dynamo graph: we will call
259    # torch.as_tensor on the quantity before passing it in.
260    #
261    # Note that we typically do not pass dynamic integers as tensors, because
262    # they will most frequently just be used for size computation.  But this
263    # is a policy decision that we can change our mind on; in particular, when
264    # an int comes from a random number generator (e.g., random.randint), we
265    # DO pass it as a tensor.
266    #
267    # It's also worth noting that our current tracing rules for
268    # pass_arg_as_tensor as subtly broken: we just pun the variable as a
269    # 0d scalar Tensor and pray that the semantics are the same.  Which they
270    # often are, but not necessarily.  ezyang(May 2024) plans to fix this
271    # soon.
272    pass_arg_as_tensor: bool
273    fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor]
274    # UnspecializedPythonVariable often masquerades as a tensor.
275    # We MUST NOT generate shape guard code
276    # that actually tries to access tensor properties on these values.
277    # is_tensor lets us tell if this graph arg actually is a tensor
278    # or not.
279    is_tensor: bool = True
280    # Sometimes, the Tensor we pass to example is freshly allocated (smh).
281    # Then we cannot only keep a weak reference to it.  This lets you
282    # stash a strong reference too.
283    example_strong_ref: Optional[torch.Tensor] = None
284
285    @property
286    def example(self):
287        if isinstance(self._example, TensorWeakRef):
288            r = self._example()
289            assert r is not None
290            return r
291        else:
292            return self._example
293
294    def __post_init__(self):
295        if isinstance(self._example, torch.Tensor):
296            self._example = TensorWeakRef(self._example)
297            assert is_fake(self.fake_tensor)
298
299    def reconstruct(self, codegen):
300        self.source.reconstruct(codegen)
301
302    def erase(self):
303        self._example = None
304        self.example_strong_ref = None
305
306    def __eq__(self, other):
307        return self.source.name() == other.source.name()
308
309
310class BackwardStateGraphArg(GraphArg):
311    def __init__(self) -> None:
312        super().__init__(
313            source=None,
314            _example=BackwardState(),
315            pass_arg_as_tensor=False,
316            fake_tensor=None,
317            is_tensor=False,
318        )
319
320    def reconstruct(self, codegen):
321        assert codegen.tx.output.backward_state_var
322        codegen.add_push_null(
323            lambda: codegen.load_import_from(BackwardState.__module__, "BackwardState")
324        )
325        codegen.call_function(0, False)
326        codegen.dup_top()
327        codegen.store(codegen.tx.output.backward_state_var)
328
329
330@dataclasses.dataclass
331class FrameStateSizeEntry:
332    scalar: Optional[int]
333    size: Optional[List[int]]
334    stride: Optional[List[int]]
335
336
337# All class-based iterators in itertools
338# NOTE: use id() because some objects are not hashable, it will raise error during lookup
339ITERTOOLS_TYPE_IDS: FrozenSet[int] = frozenset(
340    id(member)
341    for name, member in vars(itertools).items()
342    if not name.startswith("_") and inspect.isclass(member)
343)
344# Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py
345ITERTOOLS_POLYFILLED_TYPE_IDS: Set[int] = set()
346
347
348class VariableBuilder:
349    """Wrap a python value in a VariableTracker() instance"""
350
351    def __init__(
352        self,
353        tx,
354        source: Source,
355    ) -> None:
356        assert (
357            source is not None
358        ), "Consider SourcelessBuilder for ephemeral objects, usually objects created locally."
359        assert TracingContext.try_get() is not None, "Expected active TracingContext"
360        super().__init__()
361        self.tx = tx
362        self.source = source
363        self.name = source.name()
364
365    def __call__(self, value):
366        if value in self.tx.output.side_effects:
367            side_effect_result = self.tx.output.side_effects[value]
368            dup_guard = make_dupe_guard(self.source, side_effect_result.source)
369            if dup_guard:
370                self.install_guards(dup_guard)
371            return side_effect_result
372
373        cached_vt = self.tx.output.variable_tracker_cache.lookup(value, self.source)
374        if cached_vt:
375            return cached_vt
376
377        vt = self._wrap(value)
378        vt.source = self.source
379        if (
380            self._can_lift_attrs_to_inputs(vt)
381            and value not in self.tx.output.side_effects
382            and not is_wrapper_or_member_descriptor(value)
383        ):
384            vt = self.tx.output.side_effects.track_object_existing(value, vt)
385
386        self.tx.output.variable_tracker_cache.add(value, self.source, vt)
387        return vt
388
389    def _can_lift_attrs_to_inputs(self, vt):
390        return type(vt) in {
391            TensorVariable,
392            TensorWithTFOverrideVariable,
393            UserDefinedObjectVariable,
394            NumpyNdarrayVariable,
395        }
396
397    @staticmethod
398    @functools.lru_cache(None)
399    def _common_constants():
400        return {
401            # We zero-one specialize shapes, so specialize these constants
402            # too
403            0,
404            1,
405            # NB: There used to be more constants here, but honestly it was
406            # pretty confusing.  Note we specialize floats by default, and
407            # DON'T specialize ints by default.  This all only matters with
408            # dynamic_shapes
409        }
410
411    def get_source(self):
412        return self.source
413
414    def install_guards(self, *guards):
415        source = self.get_source()
416        if (
417            isinstance(source, ConstantSource)
418            or source.guard_source() == GuardSource.CONSTANT
419        ):
420            return None
421        install_guard(*[source.make_guard(guard) for guard in guards], skip=1)
422        return {}
423
424    def set_source_and_track_mutable(self, value, var):
425        assert isinstance(var, VariableTracker)
426        var.source = self.source
427        return self.tx.output.side_effects.track_mutable(value, var)
428
429    @classmethod
430    @functools.lru_cache(None)
431    def _type_dispatch(cls):
432        # NB: Careful not to close over self to avoid ref cycle from lru_cache
433        entries = [
434            (
435                (
436                    torch.Tensor,
437                    torch.nn.Parameter,
438                    torch._subclasses.FakeTensor,
439                    torch._subclasses.functional_tensor.FunctionalTensor,
440                ),
441                cls.wrap_tensor,
442            ),
443            (
444                (tuple, list, odict_values, collections.deque, torch.Size),
445                cls.wrap_listlike,
446            ),
447            (tuple_iterator, cls.wrap_tuple_iterator),
448            ((slice, range), cls.wrap_slice_range),
449            (tuple(common_constant_types), cls.wrap_literal),
450            (re.Pattern, cls.wrap_regex_pattern),
451            (weakref.ReferenceType, cls.wrap_weakref),
452            (torch.utils.hooks.RemovableHandle, cls.wrap_removable_handle),
453            (torch.jit.ScriptFunction, cls.wrap_jit_function),
454        ]
455
456        if config.trace_numpy and np:
457            entries.append((np.ndarray, cls.wrap_numpy_ndarray))
458
459        result = {}
460        for ts, fn in entries:
461            for t in ts if isinstance(ts, tuple) else (ts,):
462                assert t not in result
463                result[t] = fn
464
465        return result
466
467    def wrap_regex_pattern(self, value: re.Pattern):
468        # TODO(jansel): something like a REPR_MATCH might be more robust here
469        self.install_guards(GuardBuilder.ID_MATCH)
470        return RegexPatternVariable(value)
471
472    def wrap_weakref(self, value: weakref.ReferenceType):
473        self.install_guards(GuardBuilder.TYPE_MATCH)
474        return WeakRefVariable(value, source=self.source)
475
476    def wrap_removable_handle(self, value):
477        # This means that the removable handle was created in some other frame.
478        # Our current infra requires the hook to be registered and removed in
479        # the same frame. So graph break.
480        # Related test - PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py -k TestAutograd.test_hooks
481        unimplemented("unregistered hook removable handle")
482
483    def wrap_jit_function(self, value):
484        self.install_guards(GuardBuilder.TYPE_MATCH)
485        return WrapperUserFunctionVariable(
486            value, "_torchdynamo_inline", source=self.source
487        )
488
489    @classmethod
490    @functools.lru_cache(None)
491    def _id_dispatch(
492        cls,
493    ) -> Dict[int, Callable[["VariableBuilder", Any], VariableTracker]]:
494        from ..comptime import comptime
495
496        entries = [
497            (
498                inspect.signature,
499                lambda self, value: LambdaVariable(
500                    InspectSignatureVariable.create,
501                    source=self.source,
502                    **self.install_guards(GuardBuilder.CLOSURE_MATCH),
503                ),
504            ),
505            (comptime, lambda self, value: ComptimeVariable()),
506            (
507                dataclasses.fields,
508                lambda self, value: LambdaVariable(
509                    _dataclasses_fields_lambda,
510                    source=self.source,
511                    **self.install_guards(GuardBuilder.FUNCTION_MATCH),
512                ),
513            ),
514            (torch.__version__, lambda self, value: TorchVersionVariable()),
515        ]
516
517        result = {}
518        for ts, fn in entries:
519            for t in ts if isinstance(ts, (tuple, list)) else (ts,):
520                assert t not in result
521                result[id(t)] = fn
522
523        return result
524
525    def _wrap(self, value):
526        # import here to avoid circular dependencies
527        from torch.utils._triton import has_triton
528
529        if has_triton():
530            from triton.runtime.autotuner import Autotuner
531            from triton.runtime.jit import JITFunction
532        else:
533
534            class JITFunction:
535                pass
536
537            class Autotuner:
538                pass
539
540        # Handle exact type() match
541        type_dispatch = self._type_dispatch().get(type(value))
542        if type_dispatch is not None:
543            return type_dispatch(self, value)
544
545        # Handle exact id() match
546        id_dispatch = self._id_dispatch().get(id(value))
547        if id_dispatch is not None:
548            return id_dispatch(self, value)
549
550        # Note - There are some nested values where types mismatch!
551        # We want to get those out and wrap those.
552        if is_function_or_wrapper(value):
553            value = inspect.getattr_static(value, "_torchdynamo_inline", value)
554
555        # Everything else (NB: order matters!)
556        if is_traceable_wrapper_subclass(value) or istype(
557            value, config.traceable_tensor_subclasses
558        ):
559            return self.wrap_tensor(value)
560        elif is_namedtuple(value):
561            return self.wrap_listlike(value)
562
563        elif value is torch.utils._pytree.SUPPORTED_NODES:
564            # For SUPPORTED_NODES, we guard on the dictionary version (PEP509)
565            # under the assumption that the values themselves don't change.
566            self.install_guards(GuardBuilder.DICT_VERSION)
567
568            # The keys on the SUPPORTED_NODES can be arbitrary, so save on the
569            # key order.
570            self.tx.output.guard_on_key_order.add(self.source.name())
571            result = {
572                ConstantVariable.create(k): UserDefinedObjectVariable(
573                    v,
574                    source=GetItemSource(
575                        self.get_source(), ConstDictKeySource(self.get_source(), i)
576                    ),
577                )
578                for i, (k, v) in enumerate(value.items())
579            }
580            return ConstDictVariable(result, type(value))
581        elif value is sys.modules:
582            self.install_guards(GuardBuilder.FUNCTION_MATCH)
583            return PythonSysModulesVariable(source=self.source)
584        elif CustomizedDictVariable.is_matching_cls_hf(type(value)):
585            self.install_guards(GuardBuilder.TYPE_MATCH)
586            result = CustomizedDictVariable.wrap(self, value)
587            result.source = self.source
588            return self.tx.output.side_effects.track_object_existing(value, result)
589        elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
590            self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
591
592            # Optimisation for the common case strings, ints, etc
593            all_const = all(ConstantVariable.is_literal(k) for k in value.keys())
594            if all_const:
595                # TODO(anijain2305) - Do we have to guard on all the keys? Can
596                # keys be guarded lazily, similar to values?
597                self.install_guards(GuardBuilder.DICT_CONST_KEYS)
598            else:
599                # Guard on the key order
600                # This is not ideal, i.e., there is no need to guard on the key
601                # order. But we guard on the key order because of the complexity
602                #
603                # 1) For non-constant objects, we can't save the key in the
604                # guard context because it can be memory heavy. We can add
605                # weakrefs but this complicates the accesses.
606                #
607                # 2) For non-constant objects, we also have to guard on the keys
608                # (like TENSOR_MATCH on tensor). We might also have guards on
609                # the attributes of the keys (like tensor.grad). To make this
610                # work in tree strucutre is complicated.
611                #
612                # So, instead we guard on the key order. While guarding on key
613                # order, we just save the indices and use it to access keys and
614                # values. Indices are cheap to save.
615                self.tx.output.guard_on_key_order.add(self.source.name())
616
617            # We need all the keys to be hashable. We do this within the
618            # _HashableTracker class in dicts.py
619            def build_key_value(i, k, v):
620                if all_const:
621                    key = ConstantVariable.create(k)
622                    source_key = k
623                else:
624                    source_key = ConstDictKeySource(self.get_source(), i)
625                    key = LazyVariableTracker.create(k, source_key)
626
627                source_value = GetItemSource(self.get_source(), source_key)
628                value = LazyVariableTracker.create(v, source_value)
629
630                return key, value
631
632            result = dict(
633                build_key_value(i, k, v) for i, (k, v) in enumerate(value.items())
634            )
635
636            if istype(value, collections.defaultdict):
637                factory_source = AttrSource(self.source, "default_factory")
638                result = DefaultDictVariable(
639                    result,
640                    type(value),
641                    default_factory=VariableBuilder(self.tx, factory_source)(
642                        value.default_factory
643                    ),
644                    source=self.source,
645                )
646            else:
647                result = ConstDictVariable(result, type(value), source=self.source)
648
649            return self.set_source_and_track_mutable(value, result)
650        elif isinstance(value, torch.nn.Module):
651            return self.wrap_module(value)
652        elif ConstantVariable.is_literal(value):  # non-atomic literals
653            return self.wrap_literal(value)
654        elif isinstance(value, torch.overrides.TorchFunctionMode):
655            var = TorchFunctionModeVariable(value, source=self.source)
656            self.tx.output.side_effects.track_object_existing(value, var)
657            return var
658        elif istype(value, frozenset) and (
659            ConstantVariable.is_literal(x) for x in value
660        ):
661            # For frozenset, we can guard by object ID instead of value
662            # equality, this allows us to handle non-literal values
663            self.install_guards(GuardBuilder.ID_MATCH)
664            return ConstantVariable.create(value=value, source=self.source)
665        elif isinstance(value, enum.Enum):
666            self.install_guards(GuardBuilder.ID_MATCH)
667            return EnumVariable(value=value, source=self.source)
668        elif DebuggingVariable.is_reorderable_logging_function(value):
669            # Put this above builtin_callable so that print() can be handled
670            # along with other builtin debugging functions
671            self.install_guards(GuardBuilder.BUILTIN_MATCH)
672            return DebuggingVariable(value, source=self.source)
673        elif isinstance(value, logging.Logger):
674            self.install_guards(GuardBuilder.FUNCTION_MATCH)
675            return LoggingLoggerVariable(value, source=self.source)
676        elif is_utils_checkpoint(value):
677            return build_checkpoint_variable(source=self.source)
678        elif isinstance(value, functools.partial):
679            func_src = AttrSource(self.get_source(), "func")
680            func_obj = VariableBuilder(self.tx, func_src)(value.func)
681
682            args = []
683            args_source = AttrSource(self.get_source(), "args")
684            for i, arg in enumerate(value.args):
685                args.append(
686                    VariableBuilder(self.tx, GetItemSource(args_source, i))(arg)
687                )
688
689            keywords = {}
690            keywords_source = AttrSource(self.get_source(), "keywords")
691            for k, v in value.keywords.items():
692                if not ConstantVariable.is_literal(k):
693                    unimplemented("functools.partial with non-literal keyword")
694                keywords[k] = VariableBuilder(
695                    self.tx, GetItemSource(keywords_source, k)
696                )(v)
697
698            install_guard(
699                self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
700                keywords_source.make_guard(GuardBuilder.DICT_KEYS),
701                args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH),
702            )
703            return FunctoolsPartialVariable(func_obj, args, keywords)
704        elif is_typing(value):
705            # typing.List, typing.Mapping, etc.
706            self.install_guards(GuardBuilder.ID_MATCH)
707            return TypingVariable(
708                value,
709                source=self.source,
710            )
711        elif np is not None and isinstance(value, np.generic):
712            # numpy array scalars: convert to 0D arrays
713            return self.wrap_numpy_ndarray(np.asarray(value))
714        elif is_numpy(value):
715            assert np
716            self.install_guards(
717                GuardBuilder.FUNCTION_MATCH
718                if callable(value)
719                else GuardBuilder.TYPE_MATCH
720            )
721            return NumpyVariable(value, source=self.source)
722        elif is_numpy_dtype(value):
723            self.install_guards(GuardBuilder.ID_MATCH)
724            return NumpyDTypeVariable(value, source=self.source)
725        elif is_numpy_type_info(value):
726            if isinstance(value, np.iinfo):
727                self.install_guards(GuardBuilder.TYPE_MATCH)
728                dt_source = AttrSource(self.source, "dtype")
729                install_guard(dt_source.make_guard(GuardBuilder.ID_MATCH))
730            else:
731                self.install_guards(GuardBuilder.ID_MATCH)
732            return NumpyTypeInfoVariable(value, source=self.source)
733        # NB: These can't be put in type_dispatch, they have to run later
734        elif CollectiveFunctionRewriteVariable.can_rewrite(value):
735            self.install_guards(GuardBuilder.FUNCTION_MATCH)
736            return CollectiveFunctionRewriteVariable.create(
737                self.tx,
738                value,
739                source=self.source,
740            )
741        elif istype(value, torch.autograd.function.FunctionMeta):
742            self.install_guards(GuardBuilder.FUNCTION_MATCH)
743            return AutogradFunctionVariable(
744                value,
745                source=self.source,
746            )
747        elif isinstance(value, torch.autograd.function.FunctionCtx):
748            actual_saved_tensors = None
749            try:
750                actual_saved_tensors = value.saved_tensors
751            except RuntimeError:
752                pass
753
754            saved_tensors = []
755            guards = [self.source.make_guard(GuardBuilder.TYPE_MATCH)]
756            if isinstance(actual_saved_tensors, tuple):
757                saved_tensors_source = AttrSource(self.source, "saved_tensors")
758                guards.append(
759                    saved_tensors_source.make_guard(GuardBuilder.SEQUENCE_LENGTH)
760                )
761                for i, v in enumerate(actual_saved_tensors):
762                    saved_tensors.append(
763                        VariableBuilder(
764                            self.tx, GetItemSource(saved_tensors_source, i)
765                        )(v)
766                    )
767            install_guard(*guards)
768
769            return self.tx.output.side_effects.track_object_existing(
770                value,
771                AutogradFunctionContextVariable(
772                    value,
773                    source=self.source,
774                    saved_tensors=SavedTensorBox(saved_tensors),
775                ),
776            )
777        elif (
778            isinstance(value, types.MethodType)
779            and istype(
780                getattr(value, "__self__", None), torch.autograd.function.FunctionMeta
781            )
782            and getattr(value, "__name__", "") == "apply"
783            and value == getattr(value.__self__, "apply", None)
784        ):
785            # handle aliased autograd function `apply` calls
786            self.install_guards(GuardBuilder.FUNCTION_MATCH)
787            return GetAttrVariable(
788                AutogradFunctionVariable(
789                    value.__self__, source=AttrSource(self.source, member="__self__")
790                ),
791                "apply",
792            )
793        elif isinstance(value, torch._C._ImperativeEngine):
794            self.install_guards(GuardBuilder.ID_MATCH)
795            return AutogradEngineVariable(value, source=self.source)
796        elif (
797            value
798            is torch._dynamo.external_utils.FakeCompiledAutogradEngine._exec_final_callbacks_stub
799        ):
800            self.install_guards(GuardBuilder.FUNCTION_MATCH)
801            return LambdaVariable(
802                lambda: UserFunctionVariable(
803                    torch._dynamo.external_utils.FakeCompiledAutogradEngine.exec_final_callbacks,
804                ).call_function(
805                    self.tx,
806                    (self.tx.output.side_effects.get_ca_final_callbacks_var(),),
807                    {},
808                )
809            )
810        elif callable(value) and trace_rules.lookup_callable(value) is not None:
811            if is_callable_allowed(value):
812                self.tx.output.has_user_defined_allowed_in_graph = True
813            return trace_rules.lookup_callable(value).create_with_source(
814                value, source=self.source
815            )
816        elif np and isinstance(value, np.number):
817            return self.wrap_unspecialized_primitive(value)
818        elif HFPretrainedConfigVariable.is_matching_object(value):
819            self.install_guards(GuardBuilder.TYPE_MATCH)
820            return HFPretrainedConfigVariable(value)
821        elif isinstance(value, HigherOrderOperator):
822            self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH)
823            return TorchHigherOrderOperatorVariable.make(value, source=self.source)
824        elif isinstance(value, torch.cuda.StreamContext):
825            self.install_guards(GuardBuilder.ID_MATCH)
826            stream_source = AttrSource(self.source, "stream")
827            stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
828            return StreamContextVariable.create(self.tx, stream_var)
829        elif isinstance(value, _StreamBase):
830            self.install_guards(GuardBuilder.ID_MATCH)
831            stream_proxy = self.tx.output.create_proxy(
832                "call_function",
833                torch.cuda.Stream,
834                (),
835                {
836                    "stream_id": value.stream_id,
837                    "device_index": value.device_index,
838                    "device_type": value.device_type,
839                },
840            )
841            set_example_value(stream_proxy.node, value)
842            return StreamVariable(
843                stream_proxy,
844                value,
845                value.device,
846                source=self.source,
847            )
848        elif isinstance(value, (torch._C._SDPAParams)):
849            self.install_guards(GuardBuilder.TYPE_MATCH)
850            return SDPAParamsVariable.create(self.tx, value, self.source)
851        elif isinstance(value, _EventBase):
852            self.install_guards(GuardBuilder.ID_MATCH)
853            torch._dynamo.utils.store_user_object_weakref(value)
854            event_proxy = self.tx.output.create_proxy(
855                "call_function",
856                torch._dynamo.utils.get_user_object_from_id,
857                (id(value),),
858                {},
859            )
860            set_example_value(event_proxy.node, value)
861            return EventVariable(
862                event_proxy,
863                value,
864                source=self.source,
865            )
866        elif (
867            isinstance(value, torch._C._TensorMeta)
868            and value in config.traceable_tensor_subclasses
869        ):
870            return TensorSubclassVariable(value, source=self.source)
871        elif (
872            istype(value, contextlib.nullcontext)
873            and inspect.getattr_static(value, "enter_result", None) is None
874        ):
875            self.install_guards(GuardBuilder.TYPE_MATCH)
876            return NullContextVariable(source=self.source)
877        elif KeyedJaggedTensorVariable.is_matching_object(value):
878            self.install_guards(GuardBuilder.TYPE_MATCH)
879            result = KeyedJaggedTensorVariable(value, source=self.source)
880            # TODO: this doing it manually is bad
881            return self.tx.output.side_effects.track_object_existing(value, result)
882        elif isinstance(value, torch.optim.Optimizer):
883            self.install_guards(GuardBuilder.ID_MATCH)
884            self.source = OptimizerSource(self.source)
885            return OptimizerVariable(value, source=self.source)
886        elif WorldMetaClassVariable.is_group_member_type(value):
887            return WorldMetaClassVariable(value, source=self.source)
888        elif ProcessGroupVariable.is_process_group(value):
889            self.install_guards(GuardBuilder.ID_MATCH)
890            return ProcessGroupVariable(value, source=self.source)
891        elif DeviceMeshVariable.is_device_mesh(value):
892            # TODO: see if we need to add custom guard instead of a simple ID_MATCH
893            self.install_guards(GuardBuilder.EQUALS_MATCH)
894            return DeviceMeshVariable(value, source=self.source)
895        elif PlacementClassVariable.is_placement_type(value):
896            # TODO: see if we need to add custom guard instead of a simple ID_MATCH
897            self.install_guards(GuardBuilder.ID_MATCH)
898            return PlacementClassVariable(value, source=self.source)
899        elif PlacementVariable.is_placement(value):
900            # TODO: see if we need to add custom guard instead of a simple ID_MATCH
901            self.install_guards(GuardBuilder.EQUALS_MATCH)
902            return PlacementVariable(
903                value,
904                source=self.source,
905            )
906        elif (
907            id(value) in ITERTOOLS_TYPE_IDS
908            and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS
909        ):
910            self.install_guards(GuardBuilder.FUNCTION_MATCH)
911            return ItertoolsVariable(value, source=self.source)
912        elif isinstance(value, torch.SymBool):
913            # Note: the idea here is to re-use the infra we've built for SymInt by simulating the
914            # user provided SymBool with a SymInt in dynamo.
915
916            # Concretely,
917            # 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source).
918            # so that guards on the SymInts can be effectively applied on the original SymBool in user program.
919            # 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program
920            # depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly.
921
922            new_source = ConvertIntSource(self.source)
923            if value.node.has_hint():
924                value_hint = value.node.require_hint()
925
926                new_symint = (
927                    self.tx.output.shape_env.create_unspecified_symint_and_symbol(
928                        int(value_hint),
929                        new_source,
930                        dynamic_dim=DimDynamic.DYNAMIC,
931                    )
932                )
933            else:
934                # We need to create an unbacked symint to replace the unbacked symbool.
935                new_symint = self.tx.output.shape_env.create_unbacked_symint()
936
937            sym_node_proxy = self.tx.output.root_tracer.create_graph_input(
938                re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
939                type(new_symint),
940                source=new_source,
941            )
942
943            sym_node_proxy.node.meta["grapharg"] = GraphArg(
944                new_source,
945                new_symint,
946                False,
947                None,
948                is_tensor=False,
949                example_strong_ref=new_symint,
950            )
951            # We bind the new_symint to graph input.
952            set_example_value(sym_node_proxy.node, new_symint)
953            self.tx.output.bound_symbols.add(new_symint.node.expr)
954            self.tx.output.tracked_fakes.append(
955                TrackedFake(new_symint, new_source, None)
956            )
957            return SymNodeVariable(
958                sym_node_proxy,
959                new_symint == 1,
960            )
961        elif isinstance(value, (JITFunction, Autotuner)):
962            self.install_guards(GuardBuilder.ID_MATCH)
963            return TritonKernelVariable(
964                value,
965                None,  # No kernel idx provided
966                None,  # No grid provided
967                source=self.source,
968            )
969        elif isinstance(value, torch.amp.autocast_mode.autocast):
970            self.install_guards(GuardBuilder.ID_MATCH)
971            return AutocastModeVariable(
972                target_values=[
973                    value.device,
974                    value.fast_dtype,
975                    value._enabled,
976                    value._cache_enabled,
977                ],
978                source=self.source,
979            )
980        elif TorchCtxManagerClassVariable.is_matching_cls(value):
981            self.install_guards(GuardBuilder.FUNCTION_MATCH)
982            return TorchCtxManagerClassVariable(value, source=self.source)
983        elif inspect.getattr_static(value, "__script_if_tracing_wrapper", False):
984            self.install_guards(GuardBuilder.TYPE_MATCH)
985            return WrapperUserFunctionVariable(
986                value, "__original_fn", source=self.source
987            )
988        elif is_lru_cache_wrapped_function(value):
989            self.install_guards(GuardBuilder.TYPE_MATCH)
990            return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source)
991        elif is_function_or_wrapper(value):
992            value, attr_name = unwrap_with_attr_name_if_wrapper(value)
993            # For these wrappers, Dynamo points to the wrapped function,
994            # so source needs to be updated as well.
995            if attr_name is not None:
996                self.source = AttrSource(self.source, attr_name)
997            return trace_rules.lookup(value).create_with_source(
998                value, source=self.source
999            )
1000        elif value is random.Random:
1001            self.install_guards(GuardBuilder.ID_MATCH)
1002            return RandomClassVariable(source=self.source)
1003        elif istype(value, random.Random) and RandomVariable.is_supported_random_obj(
1004            value
1005        ):
1006            self.install_guards(GuardBuilder.TYPE_MATCH)
1007            result = RandomVariable(value, source=self.source)
1008            self.tx.output.side_effects.track_mutable(value, result)
1009            return result
1010        # Don't use istype, since some python modules are not subclasses of types.ModuleType directly.
1011        # E.g, type(torch.ops) -> <class 'torch._ops._Ops'>,
1012        # type(torch.backends.cudnn) -> <class 'torch.backends.cudnn.CudnnModule'>
1013        elif isinstance(value, (types.ModuleType, replay_record.DummyModule)):
1014            self.install_guards(GuardBuilder.FUNCTION_MATCH)
1015            result = PythonModuleVariable(
1016                value,
1017                source=self.source,
1018            )
1019            self.tx.output.side_effects.track_object_existing(value, result)
1020            return result
1021        elif isinstance(value, types.MethodType) and isinstance(
1022            value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec)
1023        ):
1024            # don't let MethodTypes fall through to UserDefinedObject,
1025            # which doesn't support 'CALL_FUNCTION'
1026
1027            # TODO(whc): Why do we limit this to methods on NNModules?
1028            # I don't have a good reason for this, but it preserves the existing behavior
1029            # for MBartForConditionalGeneration, which generates many graph breaks and OOMs otherwise.
1030            # I suspect we probably want to relax this check and dig deeper there.
1031
1032            # In order to construct a MethodVariable in Dynamo, we start with an actual method obj from python,
1033            # but need to separately wrap its underlying `__func__` and its `self` argument.  We wrap `self` here
1034            # and then `__func__` gets wrapped inside UserMethodVariable.
1035            self_obj = VariableBuilder(
1036                self.tx, source=AttrSource(self.source, "__self__")
1037            )(value.__self__)
1038            assert self_obj and isinstance(
1039                self_obj, VariableTracker
1040            ), "Failed to produce a valid self obj"
1041            self.install_guards(GuardBuilder.FUNCTION_MATCH)
1042            return UserMethodVariable(
1043                value.__func__,
1044                self_obj,
1045                source=self.source,
1046            )
1047        elif isinstance(value, types.GetSetDescriptorType):
1048            # GetSet descriptors are C functions attached to an attribute lookup
1049            # using PyGetSetDef. Python, on attribute lookup, can decide to
1050            # create a new object on the fly, and therefore the `id` of the
1051            # descriptors is not guaranteed to be same for different attribute
1052            # accesses. Since these are unlikely to change during the program
1053            # execution, we can skip guarding on them.
1054            return GetSetDescriptorVariable(value)
1055        elif isinstance(value, types.MethodWrapperType):
1056            # Method-wrappers are written in C, and they are not guaranteed to
1057            # return the same object on attribute lookup. Therefore, we cannot
1058            # insert a FUNCTION_MATCH guard here. method-wrappers are very
1059            # unlikely to change, so its ok to skip the guard here.
1060            return MethodWrapperVariable(value)
1061        elif issubclass(type(value), type):
1062            if value in (
1063                torch.utils.hooks.BackwardHook,
1064                torch.nn.Parameter,
1065                torch.nn.Buffer,
1066            ):
1067                # TODO(jansel): combine this case with the one above
1068                return trace_rules.lookup(value).create_with_source(
1069                    value, source=self.source
1070                )
1071            if value is torch.autograd._unsafe_preserve_version_counter:
1072                self.install_guards(GuardBuilder.FUNCTION_MATCH)
1073                return PreserveVersionContextVariable.constructor(self.tx)
1074            # This is a userdefined class, so install an ID_MATCH even if its a
1075            # global variable.
1076            self.install_guards(GuardBuilder.ID_MATCH)
1077            return UserDefinedClassVariable(
1078                value,
1079                source=self.source,
1080            )
1081        elif RestrictedListSubclassVariable.is_matching_cls(type(value)):
1082            self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
1083            return self.set_source_and_track_mutable(
1084                value,
1085                RestrictedListSubclassVariable(
1086                    [
1087                        LazyVariableTracker.create(
1088                            value=value[i], source=GetItemSource(self.source, i)
1089                        )
1090                        for i in range(len(value))
1091                    ],
1092                    user_cls=type(value),
1093                    user_cls_source=AttrSource(self.source, "__class__"),
1094                ),
1095            )
1096        elif TorchScriptObjectVariable.is_matching_cls(type(value)):
1097            from ..source import (
1098                FlattenScriptObjectSource,
1099                ScriptObjectQualifiedNameSource,
1100            )
1101
1102            if torch._library.fake_class_registry.tracing_with_real(value):
1103                proxy = self.tx.output.root_tracer.create_graph_input(
1104                    re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
1105                    type(value),
1106                    source=self.source,
1107                )
1108
1109                # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default
1110                # seting example to be real value because these example values will be used
1111                # as example_inputs for user compiler.
1112                proxy.node.meta["grapharg"] = GraphArg(
1113                    self.source, value, False, None, False, value
1114                )
1115                return TorchScriptObjectVariable.create(
1116                    proxy,
1117                    value,
1118                    source=self.source,
1119                )
1120
1121            # This exists to allow a smoother transition.
1122            # The implications are:
1123            # The script objects won't be tracked as proxies.
1124            # Methods on these objects won't show up in the graph.
1125            # The original script object might be mutated.
1126            if not hasattr(value, "__obj_flatten__"):
1127                return self.wrap_user_defined(value)
1128
1129            # Install the guards on the fully qualified name of the script object
1130            LazyVariableTracker.realize_all(
1131                VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))(
1132                    value._type().qualified_name()  # type: ignore[attr-defined]
1133                )
1134            )
1135            # Install the guards on the content of the script object by setting the source
1136            # to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
1137            LazyVariableTracker.realize_all(
1138                VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
1139                    value.__obj_flatten__()
1140                )
1141            )
1142
1143            fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj(
1144                self.tx.output.fake_mode, value
1145            )
1146
1147            proxy = self.tx.output.root_tracer.create_graph_input(
1148                re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
1149                type(value),
1150                source=self.source,
1151            )
1152
1153            # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default
1154            # seting example to be real value because these example values will be used
1155            # as example_inputs for user compiler.
1156            proxy.node.meta["grapharg"] = GraphArg(
1157                self.source, value, False, None, False, fake_script_obj
1158            )
1159            return TorchScriptObjectVariable.create(
1160                proxy,
1161                fake_script_obj,
1162                source=self.source,
1163            )
1164        elif issubclass(type(value), MutableMapping):
1165            self.install_guards(GuardBuilder.TYPE_MATCH)
1166            return MutableMappingVariable(value, source=self.source)
1167        elif is_frozen_dataclass(value):
1168            self.install_guards(GuardBuilder.TYPE_MATCH)
1169            result = FrozenDataClassVariable.create(self.tx, value, source=self.source)
1170            return self.tx.output.side_effects.track_object_existing(value, result)
1171        else:
1172            return self.wrap_user_defined(value)
1173
1174    def wrap_user_defined(self, value: Any):
1175        self.install_guards(GuardBuilder.TYPE_MATCH)
1176        result = UserDefinedObjectVariable(value, source=self.source)
1177        if not SideEffects.cls_supports_mutation_side_effects(type(value)):
1178            # don't allow STORE_ATTR mutation with custom __setattr__
1179            return result
1180        return self.tx.output.side_effects.track_object_existing(value, result)
1181
1182    def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
1183        if config.specialize_int and type(value) is torch.Size:
1184            self.install_guards(GuardBuilder.CONSTANT_MATCH)
1185            return ConstantVariable.create(value=value)
1186
1187        # One can index a tensor with a list/tuple. Therefore, we need to
1188        # have a stricter match.
1189        self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
1190
1191        for item in value:
1192            if item is value:
1193                unimplemented("list elements are pointing to the list itself")
1194
1195        # Tuples are immutable objects, so we should mark its items static. This
1196        # avoids wrapping of tuple items as symints. This helps for nn module
1197        # attributes like conv2d strides, dilations.
1198        if (
1199            istype(value, tuple)
1200            and all(ConstantVariable.is_literal(item) for item in value)
1201            and self.source.guard_source().is_unspecialized_nn_module()
1202        ):
1203            self.install_guards(GuardBuilder.CONSTANT_MATCH)
1204            return TupleVariable([ConstantVariable.create(item) for item in value])
1205
1206        output = [
1207            LazyVariableTracker.create(
1208                item,
1209                source=GetItemSource(self.get_source(), i),
1210            )
1211            for i, item in enumerate(value)
1212        ]
1213
1214        maybe_gm = self.tx.output.local_scope.get("self")
1215        if isinstance(
1216            self.source, LocalSource
1217        ) and self.source.local_name in get_locals_to_steal(maybe_gm):
1218            # The input tensor list to dynamo from compiled autograd may contain activations
1219            # which are freed as they are used in inductor. Dynamo's default behavior is to
1220            # lift all tensors to the graph inputs, but this will cause dynamo to hold an
1221            # extra reference to the activation tensors and increase peak memory usage.
1222            # To allow freeing ASAP, we keep the list as graph argument to the dynamo output
1223            # graph, and unpack it locally.
1224            # e.g. instead of `def forward(self, L_inputs_0_, L_inputs_1_, ...):`, we have
1225            # `def forward(self, L_inputs_):`
1226            source = self.source
1227            assert isinstance(value, list)
1228            tensor_list_proxy = self.tx.output.root_tracer.create_graph_input(
1229                re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source
1230            )
1231            tensor_list_proxy.node.meta["steal_arg"] = True
1232
1233            list_variable = wrap_fx_proxy_cls(
1234                target_cls=TensorVariable,
1235                tx=self.tx,
1236                proxy=tensor_list_proxy,
1237                example_value=value,
1238                subclass_type=None,
1239                source=source,
1240            )
1241
1242            guards = []
1243            for i, tensor_variable in enumerate(list_variable.items):
1244                source_i = GetItemSource(base=source, index=i, index_is_slice=False)
1245                # access unpacked tensor from this list instead of from a lifted arg
1246                self.tx.output.input_source_to_var[source_i] = tensor_variable
1247                tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(
1248                    value[i]
1249                )
1250
1251                guard = functools.partial(
1252                    GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i])
1253                )
1254                guards.append(source_i.make_guard(guard))
1255
1256            install_guard(*guards, skip=1)
1257
1258            grapharg = GraphArg(
1259                source,
1260                value,
1261                pass_arg_as_tensor=False,
1262                fake_tensor=None,
1263                is_tensor=False,
1264            )
1265            tensor_list_proxy.node.meta["grapharg"] = grapharg
1266
1267        result = BaseListVariable.cls_for_instance(value)(
1268            output, mutable_local=MutableLocal()
1269        )
1270        if istype(value, list):
1271            return self.set_source_and_track_mutable(value, result)
1272        return result
1273
1274    def wrap_tuple_iterator(self, value: tuple_iterator):
1275        self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
1276        output = [
1277            VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))(
1278                tuple_iterator_getitem(value, i)
1279            )
1280            for i in range(tuple_iterator_len(value))
1281        ]
1282        result = TupleIteratorVariable(
1283            output, mutable_local=MutableLocal(), source=self.source
1284        )
1285
1286        return self.set_source_and_track_mutable(value, result)
1287
1288    def wrap_slice_range(self, value: Union[slice, range]):
1289        items = [
1290            VariableBuilder(self.tx, AttrSource(self.get_source(), k))(
1291                getattr(value, k)
1292            )
1293            for k in ("start", "stop", "step")
1294        ]
1295        self.install_guards(GuardBuilder.TYPE_MATCH)
1296        if isinstance(value, slice):
1297            return SliceVariable(items, source=self.source)
1298        else:
1299            return RangeVariable(items, source=self.source)
1300
1301    def mark_static_input(self, value: torch.Tensor, guard: bool):
1302        from ..decorators import mark_static_address
1303
1304        static_inputs_log.debug(
1305            "Marking static input %s, id: %s)", self.source.name(), id(value)
1306        )
1307        mark_static_address(value, guard=guard)
1308
1309        # Check if we've seen this tensor before and update graph metadata if needed
1310        # As long as this runs before AOT this is sound
1311        if value in self.tx.output.side_effects:
1312            var = self.tx.output.side_effects[value]
1313            var.proxy.node.meta["tensor_dict"][
1314                "_dynamo_static_input_type"
1315            ] = value._dynamo_static_input_type
1316
1317    def wrap_module(self, value: torch.nn.Module):
1318        from ..eval_frame import OptimizedModule
1319
1320        if len(value.__dict__) == 0:
1321            unimplemented(f"uninitialized nn.Module: {typestr(value)}")
1322        if istype(value, OptimizedModule):
1323            # Check if the optimized module was disabled
1324            if inspect.getattr_static(value.forward, "_torchdynamo_disable", False):
1325                # This bytecode is mostly of kind LOAD_ATTR or LOAD_METHOD. If
1326                # we graph break here, Dynamo does not know how to create
1327                # continuation functions for such bytecodes. So, we delay the
1328                # graph break to CALL_FUNCTION.
1329                return DelayGraphBreakVariable(source=self.source)
1330
1331            self.install_guards(GuardBuilder.TYPE_MATCH)
1332            self.source = AttrSource(self.source, "_orig_mod")
1333            return self.wrap_module(value._orig_mod)
1334
1335        if (
1336            isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
1337            and not config.allow_rnn
1338        ):
1339            unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs")
1340
1341        if getattr(value, "_is_fsdp_managed_module", False):
1342            # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
1343            # in fully_sharded_data_parallel.py for more information
1344
1345            # we can't do this assert inside FSDP constructor,
1346            # since we don't know yet whether dynamo will be used
1347            assert getattr(
1348                value, "_fsdp_use_orig_params", False
1349            ), "Dynamo only supports FSDP with use_orig_params=True"
1350
1351            # Note on FSDP guarding
1352            # Eager FSDP already assumes (requires, but without enforcement)
1353            # that users don't mutate their model parameters/structure after
1354            # FSDP wrapping, because FSDP wouldn't notice or update its
1355            # FlatParams.
1356            #
1357            # Therefore, torch.compile can skip guarding on params or submodule
1358            # structure of fsdp_managed modules, by using FSDPNNModuleSource as
1359            # the guard source.  This behavior is gated on
1360            # config.skip_fsdp_guards.
1361            self.install_guards(GuardBuilder.TYPE_MATCH)
1362            result = FSDPManagedNNModuleVariable(value, source=self.get_source())
1363            if not SideEffects.cls_supports_mutation_side_effects(type(value)):
1364                # don't allow STORE_ATTR mutation with custom __setattr__
1365                return result
1366            return self.tx.output.side_effects.track_object_existing(value, result)
1367        elif mutation_guard.is_dynamic_nn_module(value, self.tx.export):
1368            # created dynamically, don't specialize on it
1369
1370            # Note [Tracing a torch.compiled function]
1371            # when make_fx tracing a compiled function, we need
1372            if isinstance(value, torch.fx.experimental.proxy_tensor._AttrProxy):
1373                value = value.get_base()
1374                self.source = AttrProxySource(self.source)
1375
1376            self.install_guards(GuardBuilder.TYPE_MATCH)
1377            if torch._dynamo.config.inline_inbuilt_nn_modules:
1378                freezing = is_parameter_freezing()
1379                for p in value.parameters():
1380                    self.mark_static_input(p, guard=freezing)
1381
1382                for b in value.buffers():
1383                    self.mark_static_input(b, guard=freezing)
1384
1385                if freezing:
1386                    # we need to add the module to tracing context
1387                    # in order to allow its params to get invalidated
1388                    # this will get cleaned up once compile ends
1389                    self.tx.output.nn_modules[self.name] = value
1390
1391            if value.__module__.startswith(("torch.nn.", "torch.ao.")) or getattr(
1392                value.__class__, "_dynamo_marked_static", False
1393            ):
1394                result = UnspecializedBuiltinNNModuleVariable(value, source=self.source)
1395            else:
1396                result = UnspecializedNNModuleVariable(value, source=self.source)
1397
1398            if not SideEffects.cls_supports_mutation_side_effects(type(value)):
1399                # don't allow STORE_ATTR mutation with custom __setattr__
1400                return result
1401            return self.tx.output.side_effects.track_object_existing(value, result)
1402        elif issubclass(
1403            value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
1404        ):
1405            self.install_guards(GuardBuilder.TYPE_MATCH)
1406            return UnspecializedNNModuleVariable(value, source=self.get_source())
1407        else:
1408            return self.tx.output.register_attr_or_module(
1409                value,
1410                self.name,
1411                source=self.get_source(),
1412                # Guards are added inside register_attr_or_module
1413            )
1414
1415    def wrap_literal(self, value):
1416        if not config.specialize_int and type(value) is int:
1417            # unspecializing int by default, but still
1418            # specialize for the following conditions
1419            if not TracingContext.get().force_unspec_int_unbacked_size_like and (
1420                # Assume integers from global variables want to be specialized
1421                not self.source.guard_source().is_local()
1422                # Assume that integers that came from NN modules want to be
1423                # specialized (as we don't expect users to be changing the
1424                # NN modules on the fly)
1425                or self.source.guard_source().is_specialized_nn_module()
1426                or self.source.guard_source().is_unspecialized_builtin_nn_module()
1427                or is_from_defaults(self.source)
1428                or is_cell_contents(self.source)
1429                # TODO: Delete this condition when rollout is done.  NB: this
1430                # condition never evaluates True in open source
1431                or (
1432                    not justknobs_check(
1433                        "pytorch/dynamo:enable_unspecialize_zero_one_plain_int"
1434                    )
1435                    and value in self._common_constants()
1436                )
1437            ):
1438                self.install_guards(GuardBuilder.CONSTANT_MATCH)
1439                return ConstantVariable.create(value=value, source=self.source)
1440            else:
1441                return self.wrap_symint(value)
1442        elif not config.specialize_float and type(value) is float:
1443            return self.wrap_symfloat(value)
1444        else:
1445            self.install_guards(GuardBuilder.CONSTANT_MATCH)
1446            result = ConstantVariable.create(value=value, source=self.source)
1447            if isinstance(value, (list, set)):
1448                return self.set_source_and_track_mutable(value, result)
1449            return result
1450
1451    def assert_not_wrapped_by_this_graph(self, value: torch.Tensor):
1452        if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
1453            raise InternalTorchDynamoError(
1454                "Cannot wrap a Tensor that has already been",
1455                "wrapped by this instance of Dynamo",
1456            )
1457
1458    def wrap_tensor(self, value: torch.Tensor):
1459        source = self.get_source()
1460
1461        # We cannot already be tracking the tensor, which implies
1462        # it would have already been wrapped
1463        assert value not in self.tx.output.side_effects
1464
1465        is_static_input = get_static_address_type(value) is not None
1466
1467        if (
1468            config.inline_inbuilt_nn_modules
1469            and not is_static_input
1470            and (
1471                isinstance(value, torch.nn.Parameter)
1472                # mark tensor attributes of nn modules static. This is done to keep inline_inbuilt_nn_modules behavior
1473                # compatible with previous behavior.
1474                or (source and source.guard_source().is_unspecialized_nn_module())
1475            )
1476        ):
1477            self.mark_static_input(value, guard=is_parameter_freezing())
1478            is_static_input = True
1479
1480        make_graph_attribute = is_static_input and (
1481            not config.inline_inbuilt_nn_modules or is_parameter_freezing()
1482        )
1483
1484        if (
1485            source.guard_source().is_specialized_nn_module() or make_graph_attribute
1486        ) and not source.guard_source().is_fsdp_module():
1487            self.assert_not_wrapped_by_this_graph(value)
1488            return self.tx.output.register_attr_or_module(
1489                value, self.name, source=source
1490            )
1491
1492        if is_constant_source(source):
1493            self.assert_not_wrapped_by_this_graph(value)
1494            return self.tx.output.register_attr_or_module(
1495                value,
1496                re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
1497                source=source,
1498                # Guards are added inside register_attr_or_module
1499            )
1500
1501        if type(value) in config.traceable_tensor_subclasses:
1502            # Ordinarily, we would fakeify a tensor so that it can get dynamic
1503            # shapes and be computed on without triggering actual operations.
1504            # However, how can we fakeify a tensor subclass?  Ordinary
1505            # inheritance (nor multiple inheritance) won't work work.
1506            #
1507            # Instead, our plan is to *manually simulate* the tensor subclass
1508            # inheriting from a fake tensor with dynamo.  This means our
1509            # data representation for a tensor subclass will be a fake tensor
1510            # + tensor subclass type + any extra data the subclass may have
1511            # been storing on the tensor.  Because all Python accesses are
1512            # mediated through TensorWithTFOverrideVariable, we can ensure
1513            # that we dispatch differently, e.g., according to
1514            # __torch_function__
1515            #
1516            # To simplify things for now, the __dict__ tracking bits haven't
1517            # been implemented yet, but they can be added into this design at
1518            # a later point in time.
1519            subclass_type = type(value)
1520        else:
1521            assert type(value) in (
1522                torch.Tensor,
1523                torch.nn.Parameter,
1524                torch._subclasses.fake_tensor.FakeTensor,
1525                torch._subclasses.functional_tensor.FunctionalTensor,
1526            ) or is_traceable_wrapper_subclass(value), type(value)
1527            subclass_type = None
1528
1529        # NB: this just says we accessed a tensor from the same source again
1530        # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice).
1531        # This is distinct from two distinct sources mapping to the same
1532        # Tensor (per id())!  No guard is necessary here.  See below for the
1533        # other case.
1534        is_duplicate_tensor = source in self.tx.output.input_source_to_var
1535        if is_duplicate_tensor:
1536            return self.tx.output.input_source_to_var[source]
1537
1538        if get_static_address_type(value) == "guarded":
1539            self.install_guards(GuardBuilder.ID_MATCH)
1540
1541        # By this point, we should have deduplicated all tensors
1542        self.assert_not_wrapped_by_this_graph(value)
1543
1544        # tx.output has multiple tracers if we're introspecting HigherOrderOperator.
1545        # When we've discovered an untracked tensor, then we actually need
1546        # to get Dynamo to track the tensor (which is what this function does)
1547        # and put it as a graph input on the root tracer. Later on,
1548        # if the input is actually used in the body of the HigherOrderOperator,
1549        # then the relevant SubgraphTracer will lift it to being an input of
1550        # the subgraph.
1551        # See NOTE [HigherOrderOperator tracing design] for more details.
1552
1553        tensor_proxy = self.tx.output.root_tracer.create_graph_input(
1554            re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source
1555        )
1556        options = {}
1557        if type(value) in config.traceable_tensor_subclasses:
1558            options["torch_function_fn"] = build_torch_function_fn(
1559                self.tx, value, self.source
1560            )
1561            self.install_guards(GuardBuilder.TYPE_MATCH)
1562
1563        if (
1564            isinstance(value, torch.Tensor)
1565            and value.is_nested
1566            and not isinstance(value, torch.nested._internal.nested_tensor.NestedTensor)
1567        ):
1568            unimplemented("torch.compile does not support strided NestedTensor")
1569
1570        # TODO(pearu,sparse-team) - Add the corresponding SPARSE_TENSOR_MATCH guards
1571        if (
1572            isinstance(value, torch.Tensor)
1573            and is_sparse_any(value)
1574            and (not self.tx.export or not config.capture_sparse_compute)
1575        ):
1576            # A hot fix for sparse tensors + torch.compile. Support for
1577            # export + sparsity is being added but we need to create
1578            # SPARSE_TENSOR_GUARDS for guards to work propertly.
1579            unimplemented("torch.compile does not support sparse Tensors")
1580
1581        if (
1582            safe_has_grad(value)
1583            and safe_grad(value) is not None
1584            and value.dtype != safe_grad(value).dtype
1585        ):
1586            unimplemented(
1587                "Inconsistent dtype between tensor and its gradient. "
1588                "This can happen in FSDP and crashes meta tensor creation. "
1589                "This is potentially a workaround. Fixing it correctly "
1590                "requires some design around FSDP + torch.compile."
1591            )
1592
1593        tensor_variable = wrap_fx_proxy(
1594            tx=self.tx,
1595            proxy=tensor_proxy,
1596            example_value=value,
1597            subclass_type=subclass_type,
1598            source=source,
1599            **options,
1600        )
1601
1602        guard_type = GuardBuilder.TENSOR_MATCH
1603
1604        if isinstance(source, GradSource) and is_from_optimizer_source(source):
1605            guard_type = GuardBuilder.NOT_NONE_MATCH
1606
1607        self.install_guards(
1608            functools.partial(
1609                guard_type,
1610                value=(
1611                    value
1612                    if isinstance(source, NumpyTensorSource)
1613                    else TensorWeakRef(value)
1614                ),
1615            )
1616        )
1617
1618        # We install TYPE_MATCH guards for traceable wrapper subclass object,
1619        # and recursively install corresponding guard for each inner attribute.
1620        if is_traceable_wrapper_subclass(value):
1621            self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
1622            self.install_guards(GuardBuilder.TYPE_MATCH)
1623            install_guard(
1624                SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH)
1625            )
1626
1627            attrs, _ = value.__tensor_flatten__()
1628            for attr in attrs:
1629                inner_value = getattr(value, attr)
1630                inner_source = AttrSource(self.source, attr)
1631                LazyVariableTracker.realize_all(
1632                    VariableBuilder(self.tx, inner_source)(inner_value)
1633                )
1634
1635        self.tx.output.input_source_to_var[source] = tensor_variable
1636        assert "tensor_dict" not in tensor_proxy.node.meta
1637        tensor_proxy.node.meta["tensor_dict"] = _extract_tensor_dict(value)
1638
1639        # Note: this information is conveyed via subclass_type now
1640        fake_tensor_value = tensor_variable.proxy.node.meta["example_value"]
1641        if maybe_get_fake_mode(fake_tensor_value) is not self.tx.fake_mode:
1642            raise InternalTorchDynamoError("Wrapped Tensor must be this graph's fake")
1643
1644        grapharg = GraphArg(source, value, False, fake_tensor_value)
1645        tensor_proxy.node.meta["grapharg"] = grapharg
1646        self.tx.output.add_symbol_bindings(grapharg)
1647        return tensor_variable
1648
1649    def wrap_numpy_ndarray(self, value):
1650        assert np is not None
1651        assert isinstance(value, np.ndarray)
1652
1653        source = NumpyTensorSource(self.get_source())
1654
1655        from torch._numpy import _util
1656
1657        readonly = not value.flags.writeable
1658        if readonly:
1659            try:
1660                value.flags.writeable = True
1661            except ValueError:
1662                # One can not easily make nditer elements writable,
1663                # but warning is not the end of the world
1664                assert isinstance(value.base, np.nditer)
1665
1666        try:
1667            tensor_value = _util._try_convert_to_tensor(value)
1668            if readonly:
1669                from torch._prims_common import clone_preserve_strides
1670
1671                tensor_value = clone_preserve_strides(tensor_value)
1672        except NotImplementedError as e:
1673            # failed to convert to tensor, graph break
1674            unimplemented(str(e))
1675
1676        # We do this because we want the full behavior of guarding the numpy ndarray as if it were
1677        # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
1678        # that there's not another great way to do this atm.
1679        # This creates the right graphargs, as well as registration for guards in tensor names and shape env.
1680        LazyVariableTracker.realize_all(VariableBuilder(self.tx, source)(tensor_value))
1681        proxy = self.tx.output.root_tracer.create_graph_input(
1682            re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(tensor_value), source=source
1683        )
1684        options = {"source": source}
1685        numpy_ndarray_variable = wrap_fx_proxy_cls(
1686            target_cls=NumpyNdarrayVariable,
1687            tx=self.tx,
1688            proxy=proxy,
1689            example_value=tensor_value,
1690            **options,
1691        )
1692
1693        self.tx.output.input_source_to_var[source] = numpy_ndarray_variable
1694        example_value = numpy_ndarray_variable.proxy.node.meta["example_value"]
1695
1696        # pass_arg_as_tensor should be true because we are wrapping a np.ndarray as argument input, and it needs to be
1697        # converted to a tensor.
1698        grapharg = GraphArg(
1699            source,
1700            tensor_value,
1701            pass_arg_as_tensor=True,
1702            fake_tensor=example_value,
1703            is_tensor=True,
1704            example_strong_ref=tensor_value,
1705        )
1706        proxy.node.meta["grapharg"] = grapharg
1707
1708        return numpy_ndarray_variable
1709
1710    def wrap_symint(self, value):
1711        assert type(value) is int
1712
1713        if self.name in self.tx.output.unspec_variable_map:
1714            return self.tx.output.unspec_variable_map[self.name]
1715
1716        shape_env = self.tx.output.shape_env
1717        if TracingContext.get().force_unspec_int_unbacked_size_like:
1718            wrapped_value = shape_env.create_unbacked_symint()
1719            _constrain_range_for_size(wrapped_value)
1720            self.tx.output.bound_symbols.add(wrapped_value.node.expr)
1721            self.tx.output.tracked_fakes.append(
1722                TrackedFake(wrapped_value, self.source, None)
1723            )
1724
1725        # NB: We do not do float.  For motivation, see
1726        # https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit
1727        # but the general idea is that we generate kernels that can
1728        # take unspecialized floats and use them in sizevar computation
1729        elif not is_constant_source(self.get_source()):
1730            if torch._dynamo.config.specialize_int:
1731                # If specialize_int is False, also return
1732                # a constant (but this should have been handled
1733                # in the caller, TBH)
1734                self.install_guards(GuardBuilder.CONSTANT_MATCH)
1735                return ConstantVariable.create(value=value, source=self.source)
1736
1737            name = self.source.name()
1738
1739            def update_frame_state(value):
1740                if name not in self.tx.output.frame_state:
1741                    # Note - this essentially means that if this name gets reused as a tensor,
1742                    # it will start fully dynamic. That should always be a safe option, and not awfully inefficient.
1743                    # Alternatively, if we want to improve pef here, we can add a third state of unset, but I am not
1744                    # sure that is necessary for now.
1745                    frame_state_entry = FrameStateSizeEntry(
1746                        scalar=value, size=None, stride=None
1747                    )
1748                else:
1749                    frame_state_entry = self.tx.output.frame_state[name]
1750                    if frame_state_entry.scalar != value:
1751                        log.debug(
1752                            "automatic dynamic int %s val %s != %s",
1753                            name,
1754                            value,
1755                            frame_state_entry.scalar,
1756                        )
1757                        if self.source.guard_source().is_unspecialized_nn_module():
1758                            log.info(
1759                                "%s",
1760                                (
1761                                    f"{name} is converted to a symbolic integer. It is an attribute of a "
1762                                    "user defined nn module class. If you wish to keep it static, you can "
1763                                    "mark the nn module class as `torch._dynamo.mark_static`."
1764                                ),
1765                            )
1766                        frame_state_entry.scalar = None
1767                self.tx.output.frame_state[name] = frame_state_entry
1768
1769            if (st := self.tx.distributed_state) is None:
1770                update_frame_state(value)
1771                frame_state_entry = self.tx.output.frame_state[name]
1772            elif st.all_states is None:
1773                # Preflight, always pretend as if it's static
1774                frame_state_entry = FrameStateSizeEntry(
1775                    size=None, scalar=value, stride=None
1776                )
1777                st.local_state.input_sizes[name] = value
1778            else:
1779                # Apply the updates
1780                for sub_state in st.all_states:
1781                    update_frame_state(sub_state.input_sizes[name])
1782                frame_state_entry = self.tx.output.frame_state[name]
1783
1784            # TODO: This should be dynamic, as we in general do not
1785            # know if bare integers are actually going to be sizevars
1786            # and it is inappropriate to eagerly duck size them with
1787            # real sizevars
1788            if (
1789                config.automatic_dynamic_shapes and frame_state_entry.scalar is None
1790            ) or not config.assume_static_by_default:
1791                dynamic_dim = DimDynamic.DYNAMIC
1792            else:  # assume_static_by_default
1793                # TODO: dynamic_dim = DimDynamic.STATIC should work but
1794                # for some reason it doesn't
1795                self.install_guards(GuardBuilder.CONSTANT_MATCH)
1796                return ConstantVariable.create(value=value)
1797
1798            wrapped_value = shape_env.create_unspecified_symint_and_symbol(
1799                value,
1800                source=self.source,
1801                dynamic_dim=dynamic_dim,
1802            )
1803            self.tx.output.bound_symbols.add(wrapped_value.node.expr)
1804
1805            self.tx.output.tracked_fakes.append(
1806                TrackedFake(wrapped_value, self.source, None)
1807            )
1808        else:
1809            assert is_constant_source(self.get_source())
1810            # TODO: Do I actually need guard for constant source?
1811            self.install_guards(GuardBuilder.CONSTANT_MATCH)
1812            return ConstantVariable.create(value=value, source=self.source)
1813
1814        assert not isinstance(self.get_source(), RandomValueSource)
1815        install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
1816
1817        options = {"source": self.get_source()}
1818
1819        proxy = self.tx.output.root_tracer.create_graph_input(
1820            re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
1821            type(wrapped_value),
1822            source=self.get_source(),
1823        )
1824
1825        set_example_value(proxy.node, wrapped_value)
1826        unspec_var = SymNodeVariable(proxy, wrapped_value, **options)
1827        self.tx.output.unspec_variable_map[self.name] = unspec_var
1828
1829        if not is_constant_source(self.get_source()):
1830            if self.tx.export and not isinstance(self.get_source(), LocalSource):
1831                raise AssertionError(
1832                    f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
1833                )
1834
1835            example_value = unspec_var.proxy.node.meta["example_value"]
1836
1837            proxy.node.meta["grapharg"] = GraphArg(
1838                self.get_source(),
1839                wrapped_value,
1840                pass_arg_as_tensor=False,
1841                fake_tensor=None,
1842                is_tensor=False,
1843                example_strong_ref=wrapped_value,
1844            )
1845
1846        return unspec_var
1847
1848    def wrap_symfloat(self, value):
1849        # SymFloat wrapping is special.  We first wrap it in the same way we
1850        # do an unspecialized primitive, and then we item() it into a
1851        # SymFloat.  Removal of the item() call is left to a later FX pass,
1852        # mostly because that pass is more easily done after we have lowered
1853        # to ATen ops.  (Dynamo doesn't do decomposition right now).
1854
1855        if self.name in self.tx.output.unspec_variable_map:
1856            return self.tx.output.unspec_variable_map[self.name]
1857
1858        # NB: we specialize on nan input, because our guard modeling in
1859        # ShapeEnv cannot deal with nan
1860        if (
1861            torch._dynamo.config.specialize_float
1862            or is_constant_source(self.get_source())
1863            or math.isnan(value)
1864        ):
1865            self.install_guards(GuardBuilder.CONSTANT_MATCH)
1866            return ConstantVariable.create(value=value, source=self.source)
1867
1868        # NB: At the point we've gotten here, we don't assume static by
1869        # default.  Since we have a guard mechanism, there isn't really any
1870        # downside to trying to be dynamic for float all the time.  Unlike
1871        # ints, this won't make codegen perf worse.  Modest cost to compile
1872        # time.
1873
1874        wrapped_value = torch.tensor(value, dtype=torch.float64)
1875        # TODO: Switch RandomValueSource over to use this, this is more
1876        # accurate
1877        assert not isinstance(self.get_source(), RandomValueSource)
1878        install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
1879
1880        # The FloatTensorSource here is just for pedantic correctness: if you
1881        # guard against an UnspecializedPythonVariable, you need to guard
1882        # against the tensor-ified version of the local, otherwise it's not a
1883        # Tensor.  However, we never let the UnspecializedPythonVariable escape
1884        # here, so there should never actually be any guards against this
1885        # source.
1886        options = {"source": FloatTensorSource(self.get_source()), "raw_value": value}
1887
1888        # TODO: Maybe the tensor-ification should be built into the source,
1889        # rather than by special pattern match
1890        proxy = self.tx.output.root_tracer.create_graph_input(
1891            re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
1892            type(wrapped_value),
1893            source=self.get_source(),
1894        )
1895
1896        unspec_var = wrap_fx_proxy_cls(
1897            UnspecializedPythonVariable,
1898            tx=self.tx,
1899            proxy=proxy,
1900            example_value=wrapped_value,
1901            **options,
1902        )
1903        assert isinstance(unspec_var, UnspecializedPythonVariable)
1904        self.tx.output.unspec_variable_map[self.name] = unspec_var
1905
1906        if self.tx.export and not isinstance(self.get_source(), LocalSource):
1907            raise AssertionError(
1908                f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
1909            )
1910        fake_tensor_value = None
1911        example_value = unspec_var.proxy.node.meta["example_value"]
1912        assert is_fake(example_value)
1913
1914        fake_tensor_value = example_value
1915        assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
1916            f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
1917            "({self.tx.fake_mode}) from InstructionTranslator"
1918        )
1919
1920        # There's something a bit incoherent about pass_arg_as_tensor,
1921        # specifically regarding sources.
1922        #
1923        # Specifically, suppose we have "x: float" local argument.  We
1924        # eventually end up with an UnspecializedPythonVariable denoting
1925        # torch.as_tensor(x)... but it's source is still L['x'] (which if you
1926        # accessed it directly is a float!)  So you gotta be careful when
1927        # setting up your guards, because it's still going to be a float at
1928        # this point, the conversion happens only precisely at the point we're
1929        # actually calling the FX graph.  This happens to be what we want for
1930        # shape guard generation, but it's kind of unintuitive.
1931        proxy.node.meta["grapharg"] = GraphArg(
1932            self.get_source(),
1933            wrapped_value,
1934            pass_arg_as_tensor=True,
1935            fake_tensor=fake_tensor_value,
1936            is_tensor=False,
1937            example_strong_ref=wrapped_value,
1938        )
1939
1940        # Directly do item to bypass capture_scalar_outputs
1941        r = wrap_fx_proxy(
1942            self.tx,
1943            self.tx.output.create_proxy(
1944                "call_method",
1945                "item",
1946                *proxy_args_kwargs([unspec_var], {}),
1947            ),
1948        )
1949        self.tx.output.tracked_fakes.append(TrackedFake(r.sym_num, self.source, None))
1950
1951        return r
1952
1953    def wrap_unspecialized_primitive(self, value):
1954        if self.name in self.tx.output.unspec_variable_map:
1955            return self.tx.output.unspec_variable_map[self.name]
1956
1957        wrapped_value = torch.tensor(value)
1958        if not isinstance(self.get_source(), RandomValueSource):
1959            install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
1960
1961        options = {"source": self.get_source()}
1962        options.update({"raw_value": value})
1963
1964        proxy = self.tx.output.root_tracer.create_graph_input(
1965            re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
1966            type(wrapped_value),
1967            source=self.get_source(),
1968        )
1969
1970        unspec_var = wrap_fx_proxy_cls(
1971            UnspecializedPythonVariable,
1972            tx=self.tx,
1973            proxy=proxy,
1974            example_value=wrapped_value,
1975            **options,
1976        )
1977        self.tx.output.unspec_variable_map[self.name] = unspec_var
1978        if not is_constant_source(self.get_source()):
1979            if self.tx.export and not isinstance(self.get_source(), LocalSource):
1980                raise AssertionError(
1981                    f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
1982                )
1983            fake_tensor_value = None
1984            if isinstance(unspec_var, ConstantVariable):
1985                # TODO: when can this happen?
1986                example_value = unspec_var.value
1987            else:
1988                example_value = unspec_var.proxy.node.meta["example_value"]
1989            assert is_fake(example_value)
1990
1991            fake_tensor_value = example_value
1992            assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
1993                f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
1994                "({self.tx.fake_mode}) from InstructionTranslator"
1995            )
1996
1997            proxy.node.meta["grapharg"] = GraphArg(
1998                self.get_source(),
1999                wrapped_value,
2000                pass_arg_as_tensor=True,
2001                fake_tensor=fake_tensor_value,
2002                is_tensor=False,
2003                example_strong_ref=wrapped_value,
2004            )
2005        return unspec_var
2006
2007
2008def _dataclasses_fields_lambda(obj):
2009    if isinstance(obj, UserDefinedObjectVariable):
2010        value = obj.value
2011    elif isinstance(obj, CustomizedDictVariable):
2012        value = obj.user_cls
2013    else:
2014        unimplemented(f"Dataclass fields handling fails for type {obj}")
2015    items = []
2016    for field in dataclasses.fields(value):
2017        source = None
2018        if obj.source:
2019            source = GetItemSource(
2020                AttrSource(obj.source, "__dataclass_fields__"), field.name
2021            )
2022        items.append(UserDefinedObjectVariable(field, source=source))
2023    return TupleVariable(items)
2024
2025
2026def wrap_fx_proxy(
2027    tx, proxy, example_value=None, subclass_type=None, **options
2028) -> VariableTracker:
2029    kwargs = {
2030        "tx": tx,
2031        "proxy": proxy,
2032        "example_value": example_value,
2033        "subclass_type": subclass_type,
2034        **options,
2035    }
2036    if subclass_type is None:
2037        return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
2038    else:
2039        result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)
2040        result.install_global(tx)
2041        return result
2042
2043
2044# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable
2045# Should be compositional instead
2046#
2047# This is a horribly complicated function that does too many things, to
2048# explain what it does, let's first talk about the classic usage wrap_fx_proxy
2049# for a TensorVariable.  There are two primary modes of use:
2050#
2051#   1. Wrapping a pre-existing Tensor.  In this case, example_value is set
2052#      to the pre-existing Tensor.  (Note that this example_value will NOT
2053#      be the final example_value we put into node.meta['example_value'],
2054#      instead it is converted into a fake tensor using
2055#      wrap_to_fake_tensor_and_record and registered as a graph input.)
2056#
2057#   2. "Wrapping" the result of some Tensor operation Dynamo traced over. In
2058#      this case, example_value is None (and we are going to figure it out
2059#      ourselves using FakeTensors, via get_fake_value, which will run
2060#      the operation represented by the (singular!) FX node referenced by
2061#      the passed in proxy.)
2062#
2063# The expectation is you end up with a Tensor output, and everything is
2064# straightforwardly traced into the graph.
2065#
2066# In all cases, the returned `TensorVariable` subclass will have an `example_value`
2067# and that `example_value` must be a `FakeTensor` produced by the currently running
2068# instance of Dynamo.
2069#
2070# Upon closer inspection, you may notice that there are a slurry of non-Tensor
2071# output cases.  What gives?  Well, we sometimes trace operations into the
2072# graph that don't involve tensors.
2073#
2074#   * Some operators return tuples; we need to recursively handle their
2075#     contents
2076#
2077#   * Some operators have side effects that will affect subsequent AOTAutograd
2078#     tracing but don't otherwise return anything.
2079#
2080#   * Some operators return symbolic ints/floats/bools which can go in the
2081#     graph and be traced (but only if they're actually symbolic!  If they're
2082#     static you don't want to put them in the graph, which means you
2083#     shouldn't call this function.)
2084#
2085# The common theme is that you only use this function WHEN YOU ARE TRACING
2086# SOMETHING INTO THE GRAPH.  This is sort of obvious, because you can't call
2087# this function without a proxy.
2088def wrap_fx_proxy_cls(
2089    target_cls, tx, proxy, example_value=None, subclass_type=None, **options
2090):
2091    from ..symbolic_convert import InstructionTranslatorBase
2092
2093    assert isinstance(tx, InstructionTranslatorBase)
2094    if "guards" in options and options["guards"] is not None:
2095        tx.output.guards.update(options["guards"])
2096
2097    assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}"
2098
2099    initial_example_value = example_value
2100
2101    def _clone_input(value):
2102        if isinstance(value, torch.Tensor):
2103            # tensor subclasses will not be converted to FakeTensors and need to be cloned
2104            if not (
2105                isinstance(value, FakeTensor)
2106                or (
2107                    # Is functional tensor fakeified by this instance of Dynamo
2108                    torch._is_functional_tensor(value)
2109                    and maybe_get_fake_mode(value) is tx.fake_mode
2110                )
2111                or value.is_nested
2112            ):
2113                # NB: ensure strides are preserved
2114                value = clone_input(value)
2115
2116        return value
2117
2118    # See NOTE: [Deferring tensor pack/unpack hooks until runtime]
2119    with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
2120        # with preserve_rng_state():
2121        if example_value is None:
2122            # only allow_non_graph_fake in this instance because we handle the non-fake
2123            # cases properly below.
2124            example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
2125
2126        # Handle recursive calls here
2127        elif maybe_get_fake_mode(example_value) is tx.fake_mode:
2128            pass
2129
2130        elif isinstance(example_value, torch.Tensor):
2131            if tx.export:
2132                # The legacy behavior for real value cache with subclasses was
2133                # to perform a clone WITHOUT preserving the subclass.  It's
2134                # not entirely clear this is what you actually want though.
2135                with torch._C.DisableTorchFunctionSubclass():
2136                    proxy.tracer.real_value_cache[proxy.node] = _clone_input(
2137                        example_value
2138                    )
2139            # NB: If we're ignoring subclass, then the expectation is you will
2140            # take the returned TensorVariable and wrap it into a more
2141            # accurate TensorVariable that is able to track subclass-ness;
2142            # otherwise this is wrong!
2143            kwargs = {
2144                "is_tensor": target_cls
2145                in (TensorVariable, TensorWithTFOverrideVariable),
2146            }
2147            assert "source" in options and options["source"] is not None
2148            kwargs["source"] = options["source"]
2149            example_value = wrap_to_fake_tensor_and_record(
2150                example_value, tx=tx, **kwargs
2151            )
2152        if (
2153            isinstance(example_value, torch.Tensor)
2154            and example_value.device.type != "meta"
2155            and (maybe_get_fake_mode(example_value) is not tx.fake_mode)
2156        ):
2157            raise InternalTorchDynamoError(
2158                "`example_value` needs to be a `FakeTensor`"
2159                f"wrapped by this instance of Dynamo. Found: {example_value}"
2160            )
2161
2162    if isinstance(example_value, torch.Tensor):
2163        is_parameter = isinstance(example_value, torch.nn.Parameter)
2164        is_buffer = isinstance(example_value, torch.nn.Buffer)
2165
2166        # NB: In most (all?) cases, this does not actually do a clone.
2167        # (WARNING: this means that if we mutate metadata on the fake
2168        # tensor, the stored example value will update too!)
2169        example_value = _clone_input(example_value)
2170        set_example_value(proxy.node, example_value)
2171        specialized_props = target_cls.specialize(example_value)
2172        # TODO: not sure about this fake mode test
2173        if (
2174            isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
2175            and example_value.fake_mode is tx.fake_mode
2176        ):
2177            tensor_type = subclass_type if subclass_type else torch.Tensor
2178            specialized_props["class_type"] = (
2179                torch.nn.Parameter
2180                if is_parameter
2181                else torch.nn.Buffer
2182                if is_buffer
2183                else tensor_type
2184            )
2185
2186        options.update(specialized_props)
2187        return target_cls(proxy, **options)
2188    elif (
2189        hasattr(proxy.node.target, "__name__")
2190        and proxy.node.target.__name__ == "set_state"
2191        and isinstance(proxy.node.target.__self__, torch._C.Generator)
2192        or proxy.node.target == torch.random.set_rng_state
2193    ):
2194        return TorchInGraphFunctionVariable(proxy.node.target)
2195    elif (
2196        proxy.node.target == torch._C._DisableFuncTorch
2197        or proxy.node.target == torch.cuda._is_in_bad_fork
2198    ):
2199        return UserDefinedObjectVariable(example_value)
2200    elif istype(example_value, torch.Size) and all(
2201        isinstance(x, int) for x in example_value
2202    ):
2203        sizes = [ConstantVariable.create(x) for x in example_value]
2204        return SizeVariable(sizes, **options)
2205    elif isinstance(example_value, (tuple, list)):
2206        set_example_value(proxy.node, example_value)
2207        unpacked = []
2208        for i, val in enumerate(example_value):
2209            if val is None:
2210                # nn.MultiheadAttention() can return None, see issue #175
2211                unpacked.append(
2212                    ConstantVariable.create(None, **options),
2213                )
2214            else:
2215                proxy_i = proxy.tracer.create_proxy(
2216                    kind="call_function",
2217                    target=operator.getitem,
2218                    args=(proxy, i),
2219                    kwargs={},
2220                )
2221
2222                if "source" in options:
2223                    source = options["source"]
2224                    options_i = options.copy()
2225                    options_i["source"] = GetItemSource(
2226                        base=source, index=i, index_is_slice=False
2227                    )
2228                else:
2229                    # use the same options object as parent
2230                    options_i = options
2231
2232                # WARNING: this assumes the same target_cls as this tuple/list call
2233                unpacked.append(
2234                    wrap_fx_proxy_cls(
2235                        target_cls=target_cls,
2236                        tx=tx,
2237                        proxy=proxy_i,
2238                        example_value=val,
2239                        **options_i,
2240                    )
2241                )
2242        if isinstance(example_value, torch.Size):
2243            # NB: Keep the old proxy around.  See SizeVariable for an
2244            # explanation why
2245            return SizeVariable(unpacked, proxy, **options)
2246        elif istype(example_value, tuple):
2247            return TupleVariable(unpacked, **options)
2248        elif istype(example_value, (list, immutable_list)):
2249            return ListVariable(unpacked, mutable_local=MutableLocal(), **options)
2250        else:
2251            assert example_value.__class__.__module__ == "torch.return_types" or hasattr(
2252                example_value, "_fields"
2253            ), f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}"
2254            return NamedTupleVariable(unpacked, example_value.__class__, **options)
2255    elif example_value is None or proxy.node.target is torch.manual_seed:
2256        return ConstantVariable.create(None, **options)
2257    elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
2258        set_example_value(proxy.node, example_value)
2259        return SymNodeVariable(proxy, example_value, **options)
2260    elif (
2261        inspect.isclass(proxy.node.target)
2262        and issubclass(proxy.node.target, _StreamBase)
2263    ) or proxy.node.target in [
2264        device_interface.current_stream
2265        for _, device_interface in get_registered_device_interfaces()
2266    ]:
2267        set_example_value(proxy.node, example_value)
2268        return StreamVariable(proxy, example_value, example_value.device, **options)
2269    elif (
2270        inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase)
2271    ) or proxy.node.target in [
2272        device_interface.Event
2273        for _, device_interface in get_registered_device_interfaces()
2274    ]:
2275        set_example_value(proxy.node, example_value)
2276        return EventVariable(proxy, example_value, **options)
2277    elif proxy.node.target == "query" and proxy.node.op == "call_method":
2278        set_example_value(proxy.node, example_value)
2279        return ConstantVariable(example_value, **options)
2280    elif (
2281        example_value is not None
2282        and isinstance(example_value, _EventBase)
2283        and proxy.node.target == "record_event"
2284        and proxy.node.op == "call_method"
2285    ):
2286        set_example_value(proxy.node, example_value)
2287        return EventVariable(proxy, example_value, **options)
2288    elif isinstance(example_value, int) and (
2289        proxy.node.target
2290        in [
2291            torch.sym_int,
2292            getattr,
2293            operator.getitem,
2294            torch._utils._element_size,
2295            torch.seed,
2296            operator.mod,
2297            torch._functorch.vmap._validate_and_get_batch_size,
2298            # some mac builds are missing torch.distributed.get_rank()
2299            getattr(torch.distributed, "get_rank", _missing),
2300            getattr(torch.distributed, "get_world_size", _missing),
2301            # This always wants to be in the graph, even if the constraint
2302            # results in a constant int
2303            torch._constrain_as_size,
2304        ]
2305        or (
2306            # TODO: this is a little sus, because we didn't check what the self is
2307            proxy.node.op == "call_method"
2308            and proxy.node.target in ["bit_length"]
2309        )
2310    ):
2311        set_example_value(proxy.node, example_value)
2312        return ConstantVariable.create(example_value, **options)
2313    elif isinstance(example_value, torch.backends.cuda.SDPAParams):
2314        from .sdpa import SDPAParamsVariable
2315
2316        set_example_value(proxy.node, example_value)
2317        return SDPAParamsVariable(proxy, **options)
2318    elif isinstance(example_value, bool) and proxy.node.target in [
2319        torch._C._are_functorch_transforms_active,
2320        torch.backends.cuda.is_flash_attention_available,
2321        torch.backends.cuda.can_use_flash_attention,
2322        torch.backends.cuda.can_use_efficient_attention,
2323    ]:
2324        set_example_value(proxy.node, example_value)
2325        return ConstantVariable.create(example_value, **options)
2326    elif (
2327        isinstance(example_value, (int, float, bool))
2328        and proxy.node.target is call_torchbind
2329    ):
2330        set_example_value(proxy.node, example_value)
2331        return ConstantVariable.create(example_value, **options)
2332    else:
2333        unimplemented(
2334            "torch.* op returned non-Tensor "
2335            + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}",
2336            case_name="unsupported_operator",
2337        )
2338
2339
2340# Tracks the sources of all fake tensors we wrap in Dynamo.
2341# Used by shape guard computation.
2342@dataclasses.dataclass
2343class TrackedFake:
2344    fake: Union[FakeTensor, SymInt]
2345    source: Source
2346    # Is None when fake is SymInt
2347    symbolic_context: Optional[SymbolicContext]
2348
2349    def __hash__(self) -> int:
2350        return hash((self.fake, self.source.name()))
2351
2352    def __eq__(self, other: object) -> bool:
2353        if isinstance(other, TrackedFake):
2354            return self.fake is other.fake and self.source.name() == other.source.name()
2355        return False
2356
2357
2358# Performs automatic dynamic dim determination.
2359# Returns a SymbolicContext
2360def _automatic_dynamic(
2361    e, tx, source, static_shapes, outer_only=False
2362) -> SymbolicContext:
2363    # strided NT not supported
2364    if e.is_nested and not isinstance(
2365        e, torch.nested._internal.nested_tensor.NestedTensor
2366    ):
2367        unimplemented("torch.compile does not support strided NestedTensor")
2368
2369    name = source.name()
2370    prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None)
2371    shape_env_to_source_to_symbol_cache = (
2372        prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None
2373    )
2374
2375    # Get base context if the tensor is a view
2376    view_base_context: Optional[SymbolicContext] = None
2377    if e._is_view():
2378        base_source = AttrSource(source, "_base")
2379        view_base_context = _automatic_dynamic(e._base, tx, base_source, static_shapes)
2380
2381    if is_traceable_wrapper_subclass(e) and not outer_only:
2382        # Get symbolic context for outer tensor
2383        outer_context = _automatic_dynamic(
2384            e, tx, source, static_shapes, outer_only=True
2385        )
2386
2387        # Get symbolic contexts for inner tensors
2388        inner_contexts = {}  # mapping from attr -> symbolic context
2389        attrs, _ = type(e).__tensor_flatten__(e)
2390        for attr in attrs:
2391            inner_tensor = getattr(e, attr)
2392            inner_source = AttrSource(source, attr)
2393            inner_contexts[attr] = _automatic_dynamic(
2394                inner_tensor, tx, inner_source, static_shapes
2395            )
2396
2397        return SubclassSymbolicContext(
2398            dynamic_sizes=outer_context.dynamic_sizes,
2399            dynamic_strides=outer_context.dynamic_strides,
2400            constraint_sizes=outer_context.constraint_sizes,
2401            constraint_strides=outer_context.constraint_strides,
2402            view_base_context=view_base_context,
2403            tensor_source=outer_context.tensor_source,
2404            shape_env_to_source_to_symbol_cache=outer_context.shape_env_to_source_to_symbol_cache,
2405            inner_contexts=inner_contexts,
2406        )
2407
2408    if static_shapes:
2409        return StatefulSymbolicContext(
2410            dynamic_sizes=[DimDynamic.STATIC] * e.dim(),
2411            dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(),
2412            constraint_sizes=[None] * e.dim(),
2413            constraint_strides=[None] * e.dim(),
2414            view_base_context=view_base_context,
2415            tensor_source=source,
2416            shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
2417        )
2418
2419    # We preserve the dynamism of inputs. For example, when users call
2420    # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes.
2421    from torch.fx.experimental.symbolic_shapes import is_nested_int
2422
2423    if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()):
2424        return StatefulSymbolicContext(
2425            dynamic_sizes=[
2426                DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC
2427                for s in e.size()
2428            ],
2429            dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(),
2430            constraint_sizes=[None] * e.dim(),
2431            constraint_strides=[None] * e.dim(),
2432            view_base_context=view_base_context,
2433            tensor_source=source,
2434            shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
2435        )
2436
2437    # Prep for automatic dynamic
2438    def update_frame_state(size, stride):
2439        # Intentionally shadow e from parent scope so it is not accidentally
2440        # called
2441        e = None
2442        frame_state_entry = None
2443        if name not in tx.output.frame_state:
2444            # If there is no entry for this source, add the tensor to frame state with its current static size.
2445            # E.g., {} -> {"x": [2, 4]}
2446            frame_state_entry = FrameStateSizeEntry(None, None, None)
2447            frame_state_entry.size = list(size)
2448            frame_state_entry.stride = list(stride)
2449        else:
2450            frame_state_entry = tx.output.frame_state[name]
2451            if frame_state_entry.size is not None:
2452                if len(size) != len(frame_state_entry.size):
2453                    # If there is already an entry, and the dim mismatches, replace the frame state entry with None.
2454                    # E.g. {"x": [2, 3, 4]} -> {"x": None}
2455                    log.debug(
2456                        "automatic dynamic %s dim %s != %s",
2457                        name,
2458                        len(size),
2459                        frame_state_entry.size,
2460                    )
2461                    frame_state_entry.size = None
2462                    frame_state_entry.stride = None
2463                else:
2464                    # If there is already an entry, and the dim matches, for every size/stride in the frame state which
2465                    # disagrees with the current static size/stride, replace it with None.
2466                    # E.g., {"x": [2, 3]} -> {"x": [2, # None]}
2467
2468                    has_size_changed = False
2469                    for i, dim in enumerate(frame_state_entry.size):
2470                        if dim is not None and size[i] != dim:
2471                            log.debug(
2472                                "automatic dynamic %s size(%s) %s != %s",
2473                                name,
2474                                i,
2475                                size[i],
2476                                dim,
2477                            )
2478                            frame_state_entry.size[i] = None
2479                        has_size_changed = (
2480                            has_size_changed or frame_state_entry.size[i] is None
2481                        )
2482
2483                    # We want to trigger automatic dynamism when strides change, but we have to think whether stride should
2484                    # be INFER_STRIDE or DYNAMIC.
2485                    #
2486                    # Case 1: if strides change because of size changes, we might not want to allocate a new symbol for
2487                    # stride. Lets say we have a tensor (10, 20) and we mark the dim=1 dynamic for size. Resulting size will
2488                    # be (10, s0) and stride can be either (s0, 1) or (s1, 1). In most cases, (s0, 1) is preferred because
2489                    # users are not changing both size and stride.
2490                    #
2491                    # Case 2: But for another case, lets suppose the size remains same between the two invocations but stride
2492                    # change. In this case, we definitely want to mark the changing stride to be DYNAMIC.
2493
2494                    # Here, we use a hueristic to simplify determination of dynamic stride. For case 1, we will always
2495                    # assume that stride will be inferred (INFER_STRIDE). This might be suboptimal, where user is doing something
2496                    # arbitrary size and stride resizing, and we fail to trigger dynamism, but we have not seen any cases
2497                    # yet. For case 2, we will mark the changing dimensions DYNAMIC.
2498                    if not has_size_changed:
2499                        for i, dim in enumerate(frame_state_entry.stride):
2500                            if dim is not None and stride[i] != dim:
2501                                log.debug(
2502                                    "automatic dynamic %s stride(%s) %s != %s",
2503                                    name,
2504                                    i,
2505                                    stride[i],
2506                                    dim,
2507                                )
2508                                frame_state_entry.stride[i] = None
2509        tx.output.frame_state[name] = frame_state_entry
2510
2511    if (st := tx.distributed_state) is None:
2512        stride = e.stride() if not is_sparse_any(e) else ()
2513        update_frame_state(e.size(), stride)
2514        frame_state_entry = tx.output.frame_state[name]
2515    elif st.all_states is None:
2516        # Preflight, always pretend as if it's static
2517        frame_state_entry = FrameStateSizeEntry(
2518            size=e.size(), scalar=None, stride=e.stride()
2519        )
2520        st.local_state.input_sizes[name] = list(e.size())
2521        st.local_state.input_strides[name] = list(e.stride())
2522    else:
2523        # Apply the updates
2524        for sub_state in st.all_states:
2525            # Not all inputs are necessarily present on all ranks
2526            if name in sub_state.input_sizes and name in sub_state.input_strides:
2527                update_frame_state(
2528                    sub_state.input_sizes[name], sub_state.input_strides[name]
2529                )
2530        frame_state_entry = tx.output.frame_state[name]
2531
2532    # TODO: index export_constraints ahead of time so we don't have to
2533    # do a linear scan every time here
2534    t_id = id(e)
2535    dim2constraint = {}
2536
2537    def update_dim2constraint(dim, constraint_range, name):
2538        if dim in dim2constraint:
2539            from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
2540
2541            old_constraint_range, old_name = dim2constraint[dim]
2542            new_constraint_range = StrictMinMaxConstraint(
2543                vr=constraint_range.vr & old_constraint_range.vr,
2544                warn_only=False,
2545            )
2546            # It is possible for (non-None) old_name and name to be different
2547            # but this will only happen the corresponding Dims can be derived equal.
2548            new_name = old_name or name
2549            dim2constraint[dim] = new_constraint_range, new_name
2550        else:
2551            dim2constraint[dim] = constraint_range, name
2552
2553    if tx.output.export_constraints:
2554        for constraint in tx.output.export_constraints:
2555            if constraint.t_id == t_id:
2556                update_dim2constraint(
2557                    constraint.dim, constraint.constraint_range, constraint.name
2558                )
2559
2560    dynamic_sizes = []
2561    dynamic_strides = []
2562    constraint_sizes = []
2563    constraint_strides = []
2564    for i in range(e.dim()):
2565        # NB: mark dynamic has precedence over static
2566        marked_unbacked = i in getattr(e, "_dynamo_unbacked_indices", set())
2567        marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set())
2568        marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
2569        marked_static = i in getattr(e, "_dynamo_static_indices", set())
2570
2571        # NB: both static and dynamic have precedence over
2572        automatic_dynamic_size = config.automatic_dynamic_shapes and (
2573            frame_state_entry.size is None or frame_state_entry.size[i] is None
2574        )
2575
2576        # if size is None, no need to make stride dynamic
2577        automatic_dynamic_stride = config.automatic_dynamic_shapes and (
2578            frame_state_entry.size is not None
2579            and (
2580                frame_state_entry.stride is None or frame_state_entry.stride[i] is None
2581            )
2582        )
2583
2584        automatic_dynamic = automatic_dynamic_size or automatic_dynamic_stride
2585
2586        # Reflect the user directive in the frame_state
2587        # For dynamic, apply None always
2588        if frame_state_entry.size and marked_dynamic:
2589            log.debug("automatic dynamic %s marked dynamic", name)
2590            frame_state_entry.size[i] = None
2591            frame_state_entry.stride[i] = None
2592
2593        # We will process constraints first, as they will imply that we
2594        # have a dynamic dimension
2595        # Precedence: export constraints > eager constraints
2596        constraint = dim2constraint.get(i)
2597        if constraint is None:
2598            constraint_size = None
2599            constraint_stride = None
2600            if marked_dynamic and not config.allow_ignore_mark_dynamic:
2601                # constraint_stride is deliberaly kept None because no easy way to provide value ranges for mark dynamic
2602                constraint_stride = None
2603                if hasattr(e, "_dynamo_dynamic_range"):
2604                    dim_range = [
2605                        dr for dr in e._dynamo_dynamic_range if dr.dim == i
2606                    ].pop()
2607                    if dim_range.min is None and dim_range.max is None:
2608                        constraint_size = RelaxedUnspecConstraint(warn_only=False)
2609                    else:
2610                        from torch.fx.experimental.symbolic_shapes import (
2611                            StrictMinMaxConstraint,
2612                        )
2613
2614                        constraint_size = StrictMinMaxConstraint(
2615                            vr=ValueRanges(lower=dim_range.min, upper=dim_range.max),
2616                            warn_only=False,
2617                        )
2618                else:
2619                    constraint_size = RelaxedUnspecConstraint(warn_only=False)
2620            elif not marked_static and automatic_dynamic:
2621                if automatic_dynamic_size:
2622                    constraint_size = RelaxedUnspecConstraint(warn_only=True)
2623                if automatic_dynamic_stride:
2624                    constraint_stride = RelaxedUnspecConstraint(warn_only=True)
2625            else:
2626                constraint_size = None
2627                constraint_stride = None
2628        else:
2629            constraint_size, name_ = constraint
2630            constraint_stride = None
2631            dim_name = f"{name}.size()[{i}]"
2632            tx.output.shape_env.source_name_to_debug_name[dim_name] = name_
2633        constraint_sizes.append(constraint_size)
2634        constraint_strides.append(constraint_stride)
2635
2636        if marked_unbacked:
2637            dynamic_size = DimDynamic.SIZE_LIKE_UNBACKED
2638        elif (
2639            constraint_size is not None
2640            or marked_dynamic
2641            or marked_weak_dynamic
2642            or is_nested_int(e.size()[i])
2643        ):
2644            # NB: We could assert static_shapes is False here, but it
2645            # seems better to allow the user to override symbolic_context in this
2646            # case
2647            dynamic_size = DimDynamic.DYNAMIC
2648        elif static_shapes or config.assume_static_by_default or marked_static:
2649            dynamic_size = DimDynamic.STATIC
2650        else:
2651            dynamic_size = DimDynamic.DUCK
2652
2653        if constraint_stride is not None:
2654            dynamic_stride = DimDynamic.DYNAMIC
2655        else:
2656            dynamic_stride = DimDynamic.INFER_STRIDE
2657
2658        dynamic_sizes.append(dynamic_size)
2659        dynamic_strides.append(dynamic_stride)
2660
2661    return StatefulSymbolicContext(
2662        dynamic_sizes=dynamic_sizes,
2663        dynamic_strides=dynamic_strides,
2664        constraint_sizes=constraint_sizes,
2665        constraint_strides=constraint_strides,
2666        view_base_context=view_base_context,
2667        tensor_source=source,
2668        shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
2669    )
2670
2671
2672# See note [Tensor Fakification and Symbol Caching]
2673def wrap_to_fake_tensor_and_record(
2674    e, tx, *, source: Optional[Source], is_tensor: bool, parent_context=None
2675):
2676    if (
2677        type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor)
2678        or isinstance(e, torch.Tensor)
2679        or is_traceable_wrapper_subclass(e)
2680    ):
2681        assert source is not None
2682        static_shapes, reason = tensor_always_has_static_shape(
2683            e,
2684            is_tensor,
2685            tensor_source=source,
2686        )
2687
2688        if not parent_context:
2689            symbolic_context = _automatic_dynamic(e, tx, source, static_shapes)
2690        else:
2691            # Parent contexts are passed in when we are recursively creating
2692            # fake tensors for subclasses. A better design would be not to create a
2693            # parent/child relationship, but to recursively call _automatic_dynamic
2694            # as we recursively call wrap_to_fake_tensor_and_record. This runs
2695            # into bugs around how meta_utils knows and works to create fake tensors
2696            # with tensor subclasses. Ideally, dynamo would drive both the recursive
2697            # wrap_to_fake_tensor_and_record and _automatic_dynamic policy creation.
2698            assert isinstance(source, AttrSource)
2699            inner_context_name = source.member
2700            symbolic_context = parent_context.inner_contexts[inner_context_name]
2701
2702        log.debug(
2703            "wrap_to_fake %s %s %s %s",
2704            source.name(),
2705            tuple(e.shape),
2706            symbolic_context,
2707            type(e),
2708        )
2709        fake_e = wrap_fake_exception(
2710            lambda: tx.fake_mode.from_tensor(
2711                e,
2712                source=source,
2713                symbolic_context=symbolic_context,
2714            )
2715        )
2716        if (
2717            source is not None
2718            and isinstance(fake_e, FakeTensor)
2719            and (sym_val := fake_e.item_memo) is not None
2720        ):
2721            tx.output.tracked_fakes.append(
2722                TrackedFake(sym_val, CallMethodItemSource(source), symbolic_context)
2723            )
2724
2725        if is_traceable_wrapper_subclass(fake_e):
2726            attrs, _ = fake_e.__tensor_flatten__()
2727            for attr in attrs:
2728                fake_inner = getattr(fake_e, attr)
2729                inner = getattr(e, attr)
2730                inner_source = AttrSource(source, attr)
2731                wrap_to_fake_tensor_and_record(
2732                    inner,
2733                    tx,
2734                    source=inner_source,
2735                    is_tensor=isinstance(fake_inner, torch.Tensor),
2736                    parent_context=symbolic_context,
2737                )
2738
2739        tx.output.tracing_context.tensor_to_context[e] = symbolic_context
2740        if is_sparse_any(fake_e):
2741            # TODO: for TensorGuards, this eventually may need more
2742            #       fields for the size/stride of any other constituents
2743            values = fake_e._values() if fake_e.is_sparse else fake_e.values()
2744            tx.output.input_source_to_sizes_strides[source] = {
2745                "size": fake_e.size(),
2746                # TODO: revise this, but for now this stride instead of ()
2747                #       avoids SegFault with PYTORCH_TEST_WITH_DYNAMO=1
2748                "stride": (1,) * fake_e.ndim,
2749                "values_size": values.size(),
2750                "values_stride": values.stride(),
2751            }
2752        else:
2753            tx.output.input_source_to_sizes_strides[source] = {
2754                "size": fake_e.size(),
2755                "stride": fake_e.stride(),
2756            }
2757
2758        if (
2759            is_tensor
2760            and not (static_shapes and source.is_specialized_nn_module())
2761            and not is_constant_source(source)
2762        ):
2763            tx.output.tracked_fakes.append(
2764                TrackedFake(fake_e, source, symbolic_context)
2765            )
2766            tx.output.tracked_fakes_id_to_source[id(e)].append(source)
2767
2768        return fake_e
2769    else:
2770        return e
2771
2772
2773class SourcelessBuilder:
2774    """
2775    Like builder, but stateless and does not require a source. Useful for simple type->VT objects, or objects
2776    that are being created/evaporated during inlining (ex: consider a locally made list of tensors we then iterate over
2777    .), such a list should not show up as an artifact from inputs, nor in reconstruction, nor in the graph. However,
2778    there may be reasons to represent it as a ListVariable internally.
2779
2780    NOTE - Objects produced here are born UNGUARDED due to the nature of sources!
2781
2782    NOTE - This class is very new! It will have some rough edges, but it was created to stem the bleeding of giant
2783    if/else type->VariableTracker trees that were cropping up all over dynamo.
2784    """
2785
2786    def __init__(self) -> None:
2787        raise AssertionError("Use SourcelessBuilder.create()")
2788
2789    @staticmethod
2790    def create(tx: "InstructionTranslator", value) -> VariableTracker:
2791        value_type = type(value)
2792        fast_handler = SourcelessBuilder._type_handlers.get(value_type)
2793        if fast_handler:
2794            return fast_handler(tx, value)
2795
2796        if isinstance(value, VariableTracker):
2797            # This is always valid to call, and useful for recursive calls.
2798            return value
2799        elif isinstance(value, dataclasses._HAS_DEFAULT_FACTORY_CLASS):
2800            return UserDefinedObjectVariable(value)
2801        elif ConstantVariable.is_literal(value):
2802            return ConstantVariable.create(value)
2803        elif callable(value) and trace_rules.lookup_callable(value) is not None:
2804            if is_callable_allowed(value):
2805                tx.output.has_user_defined_allowed_in_graph = True
2806            return trace_rules.lookup_callable(value)(value)
2807        elif is_function_or_wrapper(value):
2808            return trace_rules.lookup(value)(value)
2809        elif isinstance(value, enum.Enum):
2810            return EnumVariable(value)
2811        elif isinstance(value, (type, abc.ABCMeta)):
2812            return UserDefinedClassVariable(value)
2813        elif isinstance(value, types.MethodWrapperType):
2814            return MethodWrapperVariable(value)
2815        elif isinstance(value, torch.fx.graph_module.GraphModule):
2816            return SourcelessGraphModuleVariable(value)
2817        elif isinstance(
2818            value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec)
2819        ):
2820            return UserDefinedObjectVariable(value)
2821        elif PlacementVariable.is_placement(value):
2822            return PlacementVariable(value)
2823        elif DeviceMeshVariable.is_device_mesh(value):
2824            return DeviceMeshVariable(value)
2825        elif isinstance(value, re.Pattern):
2826            return RegexPatternVariable(value)
2827        elif isinstance(value, torch._dynamo.variables.lazy.LazySymNodeFormatString):
2828            return ConstantVariable.create(str(value))
2829        unimplemented(
2830            f"Unexpected type in sourceless builder {value_type.__module__}.{value_type.__qualname__}"
2831        )
2832
2833    @staticmethod
2834    def wrap_constant_literal(value):
2835        assert ConstantVariable.is_literal(value)
2836        return ConstantVariable.create(value=value)
2837
2838    @staticmethod
2839    def make_type_handlers():
2840        create = SourcelessBuilder.create
2841        handlers = {}
2842        for t in common_constant_types:
2843            handlers[t] = lambda tx, value: ConstantVariable(value)
2844        handlers[set] = lambda tx, value: SetVariable(
2845            [create(tx, x) for x in value], mutable_local=MutableLocal()
2846        )
2847        handlers[dict] = lambda tx, value: ConstDictVariable(
2848            {create(tx, k): create(tx, v) for k, v in value.items()},
2849            type(value),
2850            mutable_local=MutableLocal(),
2851        )
2852        handlers[list] = lambda tx, value: ListVariable(
2853            [create(tx, x) for x in value], mutable_local=MutableLocal()
2854        )
2855        handlers[tuple] = lambda tx, value: TupleVariable(
2856            [create(tx, x) for x in value]
2857        )
2858        handlers[torch.Size] = lambda tx, value: SizeVariable(
2859            [create(tx, x) for x in value]
2860        )
2861        handlers[collections.OrderedDict] = handlers[dict]
2862        handlers[immutable_dict] = handlers[dict]
2863        handlers[immutable_list] = handlers[list]
2864        handlers[random.Random] = lambda tx, value: RandomClassVariable()
2865        handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value)
2866
2867        handlers[
2868            torch.distributions.constraints._Real
2869        ] = lambda tx, value: UserDefinedObjectVariable(
2870            value, mutable_local=MutableLocal()
2871        )
2872        handlers[
2873            torch.distributions.constraints._Interval
2874        ] = lambda tx, value: UserDefinedObjectVariable(
2875            value, mutable_local=MutableLocal()
2876        )
2877        handlers[
2878            torch.distributions.constraints.Constraint
2879        ] = lambda tx, value: UserDefinedObjectVariable(
2880            value, mutable_local=MutableLocal()
2881        )
2882
2883        def passthrough(tx: "InstructionTranslator", value):
2884            return value
2885
2886        for cls in VariableTrackerMeta.all_subclasses:
2887            handlers[cls] = passthrough
2888        return handlers
2889
2890
2891SourcelessBuilder._type_handlers = SourcelessBuilder.make_type_handlers()
2892