xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/user_defined.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import collections
4import contextlib
5import dataclasses
6import enum
7import functools
8import inspect
9import itertools
10import random
11import sys
12import types
13import warnings
14from typing import Dict, Generic, List, TYPE_CHECKING
15
16import torch._dynamo.config
17import torch.nn
18from torch._guards import TracingContext
19from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
20
21from .. import polyfills, variables
22from ..bytecode_transformation import create_call_function
23from ..create_parameter_op import do_not_convert_to_tracable_parameter
24from ..exc import (
25    handle_observed_exception,
26    ObservedAttributeError,
27    raise_observed_exception,
28    unimplemented,
29)
30from ..guards import GuardBuilder, install_guard
31from ..source import (
32    AttrSource,
33    GetItemSource,
34    ODictGetItemSource,
35    RandomValueSource,
36    UnspecializedParamBufferSource,
37    WeakRefCallSource,
38)
39from ..utils import (
40    build_checkpoint_variable,
41    check_constant_args,
42    get_custom_getattr,
43    has_torch_function,
44    is_frozen_dataclass,
45    is_namedtuple_cls,
46    is_utils_checkpoint,
47    is_wrapper_or_member_descriptor,
48    istype,
49    namedtuple_fields,
50    object_has_getattribute,
51    proxy_args_kwargs,
52    tensortype_to_dtype,
53    unpatched_nn_module_getattr,
54)
55from .base import MutableLocal, VariableTracker
56from .dicts import DefaultDictVariable
57
58
59try:
60    import numpy as np
61except ModuleNotFoundError:
62    np = None
63
64try:
65    from torch.utils._cxx_pytree import PyTreeSpec
66except ImportError:
67    PyTreeSpec = type(None)
68
69
70if TYPE_CHECKING:
71    from torch._dynamo.symbolic_convert import InstructionTranslator
72
73
74def is_standard_setattr(val):
75    return val in (object.__setattr__,)
76
77
78def is_forbidden_context_manager(ctx):
79    f_ctxs = []
80
81    try:
82        from _pytest.python_api import RaisesContext
83        from _pytest.recwarn import WarningsChecker
84
85        # TODO mlazos: Temporary to get this stack to pass
86        # remove in subsequent PR
87        from torch.overrides import BaseTorchFunctionMode
88
89        f_ctxs.append(BaseTorchFunctionMode)
90        f_ctxs.append(RaisesContext)
91        f_ctxs.append(WarningsChecker)
92    except ImportError:
93        pass
94
95    try:
96        from torch.testing._internal.jit_utils import (
97            _AssertRaisesRegexWithHighlightContext,
98        )
99
100        f_ctxs.append(_AssertRaisesRegexWithHighlightContext)
101    except ImportError:
102        pass
103
104    return ctx in f_ctxs
105
106
107class UserDefinedVariable(VariableTracker):
108    pass
109
110
111class UserDefinedClassVariable(UserDefinedVariable):
112    def __init__(self, value, **kwargs) -> None:
113        super().__init__(**kwargs)
114        self.value = value
115
116    def as_python_constant(self):
117        return self.value
118
119    def as_proxy(self):
120        return self.value
121
122    def __str__(self) -> str:
123        return f"UserDefinedClassVariable({self.value})"
124
125    @staticmethod
126    @functools.lru_cache(None)
127    def _constant_fold_classes():
128        return {
129            torch.device,
130            torch.finfo,
131            torch.iinfo,
132            torch.Size,
133        }
134
135    @staticmethod
136    @functools.lru_cache(None)
137    def _in_graph_classes():
138        _in_graph_class_list = {
139            torch.Tensor,
140            torch.cuda.Stream,
141            torch.cuda.Event,
142        }
143        if hasattr(torch, "hpu"):
144            _in_graph_class_list.update(
145                {
146                    torch.hpu.Stream,
147                    torch.hpu.Event,
148                }
149            )
150
151        return set(tensortype_to_dtype.keys()) | _in_graph_class_list
152
153    def can_constant_fold_through(self):
154        return self.value in self._constant_fold_classes()
155
156    def has_key_in_generic_dict(self, tx: "InstructionTranslator", key):
157        if tx.output.side_effects.has_pending_mutation_of_attr(self, key):
158            mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True)
159            return not isinstance(mutated_attr, variables.DeletedVariable)
160
161        return key in self.value.__dict__
162
163    def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
164        from . import ConstantVariable, EnumVariable
165        from .builder import SourcelessBuilder, VariableBuilder
166
167        source = AttrSource(self.source, name) if self.source is not None else None
168
169        if name == "__name__":
170            return ConstantVariable.create(self.value.__name__)
171        elif name == "__qualname__":
172            return ConstantVariable.create(self.value.__qualname__)
173        elif name == "__dict__":
174            options = {"source": source}
175            return variables.GetAttrVariable(self, name, **options)
176
177        # Special handling of collections.OrderedDict.fromkeys()
178        # Wrap it as GetAttrVariable(collections.OrderedDict, "fromkeys") to make it consistent with
179        # collections.defaultdict, and both will be handled at UserDefinedClassVariable.call_method().
180        # Otherwise, it would be wrapped as UserDefinedObjectVariable(collections.OrderedDict.fromkeys),
181        # and we need duplicate code to handle both cases.
182        if (
183            self.value in {collections.OrderedDict, collections.defaultdict}
184            and name == "fromkeys"
185        ):
186            return super().var_getattr(tx, name)
187
188        try:
189            obj = inspect.getattr_static(self.value, name)
190        except AttributeError:
191            obj = None
192
193        if isinstance(obj, staticmethod):
194            func = obj.__get__(self.value)
195            if source is not None:
196                return VariableBuilder(tx, source)(func)
197            else:
198                return SourcelessBuilder.create(tx, func)
199        elif isinstance(obj, classmethod):
200            return variables.UserMethodVariable(obj.__func__, self, source=source)
201        elif isinstance(obj, types.ClassMethodDescriptorType):
202            # e.g.: inspect.getattr_static(dict, "fromkeys")
203            #       inspect.getattr_static(itertools.chain, "from_iterable")
204            func = obj.__get__(None, self.value)
205            if source is not None:
206                return VariableBuilder(tx, source)(func)
207            else:
208                return SourcelessBuilder.create(tx, func)
209        elif source:
210            # __mro__ is a member in < 3.12, an attribute in >= 3.12
211            if inspect.ismemberdescriptor(obj) or (
212                sys.version_info >= (3, 12) and name == "__mro__"
213            ):
214                return VariableBuilder(tx, source)(obj.__get__(self.value))
215
216        if ConstantVariable.is_literal(obj):
217            return ConstantVariable.create(obj)
218        elif isinstance(obj, enum.Enum):
219            return EnumVariable(obj)
220        elif name in getattr(self.value, "__dict__", {}) or (
221            self.value.__module__.startswith("torch.")
222            or self.value.__module__ == "torch"
223        ):
224            if source:
225                return VariableBuilder(tx, source)(obj)
226
227        if (
228            source
229            and not inspect.ismethoddescriptor(obj)
230            and not is_wrapper_or_member_descriptor(obj)
231        ):
232            return VariableBuilder(tx, source)(obj)
233        return super().var_getattr(tx, name)
234
235    def _call_cross_entropy_loss(self, tx: "InstructionTranslator", args, kwargs):
236        """
237        functional: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean',
238        label_smoothing=0.0
239
240        non functional ctor: weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean',
241        label_smoothing=0.0
242
243        non functional loss call: input, target, optional_output
244        """
245        from . import ConstantVariable
246
247        def normalize_args(
248            weight=ConstantVariable.create(None),
249            size_average=ConstantVariable.create(None),
250            ignore_index=ConstantVariable.create(-100),
251            reduce=ConstantVariable.create(None),
252            reduction=ConstantVariable.create("mean"),
253            label_smoothing=ConstantVariable.create(0.0),
254        ):
255            return (
256                weight,
257                size_average,
258                ignore_index,
259                reduce,
260                reduction,
261                label_smoothing,
262            )
263
264        (
265            weight,
266            size_average,
267            ignore_index,
268            reduce_arg,
269            reduction,
270            label_smoothing,
271        ) = normalize_args(*args, **kwargs)
272
273        def fake_cross_entropy_loss(input, target):
274            from .builder import wrap_fx_proxy
275
276            return wrap_fx_proxy(
277                tx=tx,
278                proxy=tx.output.create_proxy(
279                    "call_function",
280                    torch.nn.functional.cross_entropy,
281                    *proxy_args_kwargs(
282                        [
283                            input,
284                            target,
285                            weight,
286                            size_average,
287                            ignore_index,
288                            reduce_arg,
289                            reduction,
290                            label_smoothing,
291                        ],
292                        {},
293                    ),
294                ),
295            )
296
297        return variables.LambdaVariable(fake_cross_entropy_loss)
298
299    def call_method(
300        self,
301        tx,
302        name,
303        args: "List[VariableTracker]",
304        kwargs: "Dict[str, VariableTracker]",
305    ) -> "VariableTracker":
306        if (
307            name == "__subclasses__"
308            and len(args) == 0
309            and not kwargs
310            and "__subclasses__" not in self.value.__dict__
311        ):
312            options = {"mutable_local": MutableLocal()}
313            subs_as_vars: List[VariableTracker] = []
314            for sub in self.value.__subclasses__():
315                source = AttrSource(tx.import_source(sub.__module__), sub.__name__)
316                subs_as_vars.append(
317                    variables.UserDefinedClassVariable(sub, source=source)
318                )
319
320            return variables.ListVariable(subs_as_vars, **options)
321        elif (
322            self.value in {collections.OrderedDict, collections.defaultdict}
323            and name == "fromkeys"
324        ):
325            from .builtin import BuiltinVariable
326
327            return BuiltinVariable.call_custom_dict_fromkeys(
328                tx, self.value, *args, **kwargs
329            )
330        elif name == "__eq__" and len(args) == 1 and hasattr(args[0], "value"):
331            return variables.ConstantVariable(self.value == args[0].value)
332        elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"):
333            return variables.ConstantVariable(self.value != args[0].value)
334
335        return super().call_method(tx, name, args, kwargs)
336
337    def call_function(
338        self,
339        tx: "InstructionTranslator",
340        args: "List[VariableTracker]",
341        kwargs: "Dict[str, VariableTracker]",
342    ) -> "VariableTracker":
343        from ..side_effects import SideEffects
344        from .builder import SourcelessBuilder, wrap_fx_proxy
345        from .builtin import BuiltinVariable
346
347        constant_args = check_constant_args(args, kwargs)
348
349        if self.can_constant_fold_through() and constant_args:
350            # constant fold
351            return variables.ConstantVariable.create(
352                self.as_python_constant()(
353                    *[x.as_python_constant() for x in args],
354                    **{k: v.as_python_constant() for k, v in kwargs.items()},
355                ),
356            )
357        elif self.value is torch.nn.CrossEntropyLoss:
358            return self._call_cross_entropy_loss(tx, args, kwargs)
359        elif self.value is contextlib.nullcontext:
360            # import here to avoid circular dependency
361            from .ctx_manager import NullContextVariable
362
363            return NullContextVariable()
364        elif self.value is collections.OrderedDict:
365            return BuiltinVariable.call_custom_dict(
366                tx, collections.OrderedDict, *args, **kwargs
367            )
368        elif (
369            self.value is collections.defaultdict
370            and len(args) <= 1
371            and DefaultDictVariable.is_supported_arg(args[0])
372        ):
373            return DefaultDictVariable(
374                {},
375                collections.defaultdict,
376                args[0],
377                mutable_local=MutableLocal(),
378            )
379        elif self.value is collections.deque and not kwargs:
380            if len(args) == 0:
381                items = []
382            elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
383                items = args[0].force_unpack_var_sequence(tx)
384            else:
385                unimplemented("deque() with more than 1 arg not supported")
386            return variables.lists.DequeVariable(items, mutable_local=MutableLocal())
387        elif self.value is functools.partial:
388            if not args:
389                unimplemented("functools.partial malformed")
390            # The first arg, a callable (the ctor below will assert on types)
391            fn = args[0]
392            rest_args = args[1:]
393            # guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the
394            # args and keywords
395            return variables.functions.FunctoolsPartialVariable(
396                fn, args=rest_args, keywords=kwargs
397            )
398        elif self.value is warnings.catch_warnings and not args:
399            return variables.CatchWarningsCtxManagerVariable.create(tx, kwargs)
400        elif self.value is torch.cuda.device and not kwargs and len(args) == 1:
401            assert args[0].is_python_constant()
402            return variables.CUDADeviceVariable.create(tx, args[0].as_python_constant())
403        elif (
404            issubclass(type(self.value), type)
405            and hasattr(
406                self.value, "__enter__"
407            )  # TODO(voz): These can invoke user code!
408            and hasattr(
409                self.value, "__exit__"
410            )  # TODO(voz): These can invoke user code!
411            and self.is_standard_new()
412            and SideEffects.cls_supports_mutation_side_effects(self.value)
413            and self.source
414            and not is_forbidden_context_manager(self.value)
415        ):
416            # import here to avoid an unfortunate circular dependency.
417            from .ctx_manager import GenericContextWrappingVariable
418
419            cm_obj = tx.output.side_effects.track_object_new(
420                self.source, self.value, GenericContextWrappingVariable, {}
421            )
422            cm_obj.call_method(tx, "__init__", args, kwargs)
423            return cm_obj
424
425        elif is_namedtuple_cls(self.value):
426            fields = namedtuple_fields(self.value)
427            # check if this a quasi-namedtuple or a real one
428            if self.value.__module__ == "torch.return_types":
429                # create pseudo-defaults from values of the quasi-namedtuple
430                field_defaults = dict(zip(fields, args[0].items))
431            else:
432                field_defaults = self.value._field_defaults
433
434            items = list(args)
435            items.extend([None] * (len(fields) - len(items)))
436
437            var_tracker_kwargs = {}
438            for field_name, var_tracker in zip(fields, items):
439                if var_tracker is None:
440                    if field_name in kwargs:
441                        field_var = kwargs[field_name]
442                    else:
443                        assert field_name in field_defaults
444                        field_var = SourcelessBuilder.create(
445                            tx, field_defaults[field_name]
446                        )
447                    var_tracker_kwargs[field_name] = field_var
448
449            for name, value in var_tracker_kwargs.items():
450                assert name in fields
451                items[fields.index(name)] = value
452
453            assert all(x is not None for x in items)
454            return variables.NamedTupleVariable(items, self.value)
455        elif is_frozen_dataclass(self.value) and self.is_standard_new():
456            from .builder import SourcelessBuilder
457
458            fields = dataclasses.fields(self.value)
459            items = list(args)
460            items.extend([None] * (len(fields) - len(items)))
461
462            default_kwargs = {}
463            for field, var_tracker in zip(fields, items):
464                if var_tracker is None:
465                    if field.name in kwargs:
466                        var_tracker = kwargs[field.name]
467                    else:
468                        if not field.init:
469                            continue
470
471                        if field.default is not dataclasses.MISSING:
472                            var_tracker = SourcelessBuilder.create(tx, field.default)
473                        elif field.default_factory is not dataclasses.MISSING:
474                            factory_fn = SourcelessBuilder.create(
475                                tx, field.default_factory
476                            )
477                            var_tracker = factory_fn.call_function(tx, [], {})
478                        else:
479                            # if we are subclass, the constructor could possibly
480                            # be missing args
481                            continue
482
483                    default_kwargs[field.name] = var_tracker
484            kwargs.update(default_kwargs)
485
486            var = tx.output.side_effects.track_object_new_from_user_defined_class(self)
487            var.call_method(tx, "__init__", args, kwargs)
488            return var
489        elif (
490            self.is_standard_new()
491            and SideEffects.cls_supports_mutation_side_effects(self.value)
492            and self.source
493        ):
494            var = tx.output.side_effects.track_object_new_from_user_defined_class(self)
495            with do_not_convert_to_tracable_parameter():
496                var.call_method(tx, "__init__", args, kwargs)
497                return var
498        elif variables.CustomizedDictVariable.is_matching_cls(self.value):
499            options = {"mutable_local": MutableLocal()}
500            return variables.CustomizedDictVariable.create(
501                self.value, args, kwargs, options
502            )
503        elif (
504            variables.RestrictedListSubclassVariable.is_matching_cls(self.value)
505            and self.source
506        ):
507            return variables.RestrictedListSubclassVariable(
508                variables.BuiltinVariable(list).call_function(tx, args, kwargs).items,
509                user_cls=self.value,
510                user_cls_source=self.source,
511                mutable_local=MutableLocal(),
512            )
513        elif (
514            self.value in self._in_graph_classes()
515            or is_traceable_wrapper_subclass_type(self.value)
516        ):
517            # torch.LongTensor cannot accept a list of FakeTensors.
518            # So we stack the list of FakeTensors instead.
519            if (
520                np
521                and self.value in tensortype_to_dtype
522                and len(args) == 1
523                and isinstance(args[0], variables.ListVariable)
524                and len(args[0].items) > 1
525                and all(isinstance(x, variables.TensorVariable) for x in args[0].items)
526            ):
527                # Stack FakeTensor
528                stacked = wrap_fx_proxy(
529                    tx=tx,
530                    proxy=tx.output.create_proxy(
531                        "call_function",
532                        torch.stack,
533                        *proxy_args_kwargs(args, kwargs),
534                    ),
535                )
536                args = [stacked]
537
538            tensor_variable = wrap_fx_proxy(
539                tx=tx,
540                proxy=tx.output.create_proxy(
541                    "call_function",
542                    self.value,
543                    *proxy_args_kwargs(args, kwargs),
544                ),
545            )
546
547            return tensor_variable
548        elif issubclass(self.value, enum.Enum) and len(args) == 1 and not kwargs:
549            options = {"mutable_local": MutableLocal()}
550            return variables.EnumVariable.create(self.value, args[0], options)
551        elif self.value is random.Random:
552            if len(args) == 1 and isinstance(args[0], variables.ConstantVariable):
553                seed = args[0].value
554            else:
555                seed = None
556            random_object = random.Random(seed)
557            return RandomVariable(random_object)
558        elif (
559            not self.is_standard_new()
560            and SideEffects.cls_supports_mutation_side_effects(self.value)
561            and self.source
562        ):
563            return tx.inline_user_function_return(
564                SourcelessBuilder.create(
565                    tx, polyfills.instantiate_user_defined_class_object
566                ),
567                [self, *args],
568                kwargs,
569            )
570
571        return super().call_function(tx, args, kwargs)
572
573    def is_standard_new(self):
574        """Check for __new__ being overridden"""
575        new_fn = inspect.getattr_static(self.value, "__new__", None)
576        if isinstance(new_fn, staticmethod):
577            new_fn = new_fn.__func__
578        return new_fn in (object.__new__, Generic.__new__)
579
580    def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
581        if self.source:
582            source = AttrSource(self.source, name)
583            install_guard(source.make_guard(GuardBuilder.HASATTR))
584            return variables.ConstantVariable(hasattr(self.value, name))
585        return super().call_hasattr(tx, name)
586
587    def const_getattr(self, tx: "InstructionTranslator", name):
588        if name == "__name__":
589            return self.value.__name__
590        return super().const_getattr(tx, name)
591
592
593class NO_SUCH_SUBOBJ:
594    pass
595
596
597def call_random_fn(tx, fn, args, kwargs):
598    from .builder import VariableBuilder
599
600    args = [x.as_python_constant() for x in args]
601    kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
602    random_call_index = len(tx.output.random_calls)
603    example_value = fn(*args, **kwargs)
604    source = RandomValueSource(random_call_index)
605    tx.output.random_calls.append((fn, args, kwargs))
606    # TODO: arguably, this should route to wrap_symint/wrap_symfloat
607    # (currently hypothetical), but I'm not going to poke my hand in
608    # this nest for now
609    return VariableBuilder(tx, source).wrap_unspecialized_primitive(example_value)
610
611
612class UserDefinedObjectVariable(UserDefinedVariable):
613    """
614    Mostly objects of defined type.  Catch-all for something where we only know the type.
615    """
616
617    _nonvar_fields = {"value", "value_type", *UserDefinedVariable._nonvar_fields}
618
619    def __init__(self, value, value_type=None, cls_source=None, **kwargs) -> None:
620        super().__init__(**kwargs)
621        self.value = value
622        self.value_type = value_type or type(value)
623        assert type(value) is self.value_type
624        # This is used with __new__, when the new object is sourceless but the user class can be sourceful.
625        self.cls_source = cls_source
626
627    def __str__(self) -> str:
628        inner = self.value_type.__name__
629        if inner in [
630            "builtin_function_or_method",
631            "getset_descriptor",
632            "method_descriptor",
633            "method",
634        ]:
635            inner = str(getattr(self.value, "__name__", None))
636        return f"{self.__class__.__name__}({inner})"
637
638    def __repr__(self) -> str:
639        return f"{self.__class__.__name__}({self.value_type.__name__})"
640
641    def python_type(self):
642        return self.value_type
643
644    def guard_as_python_constant(self):
645        if self.source:
646            install_guard(self.source.make_guard(GuardBuilder.ID_MATCH))
647            return self.value
648        return super().guard_as_python_constant()
649
650    def torch_function_check(self):
651        assert has_torch_function(
652            self
653        ), f"calling torch function on object without __torch_function__ {self}"
654
655    def get_torch_fn(self, tx):
656        self.torch_function_check()
657        from .torch_function import build_torch_function_fn
658
659        return build_torch_function_fn(tx, self.value, self.source)
660
661    def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
662        self.torch_function_check()
663
664        from .torch_function import _get_subclass_type_var, call_torch_function
665
666        return call_torch_function(
667            tx,
668            _get_subclass_type_var(tx, self),
669            self.get_torch_fn(tx),
670            fn,
671            types,
672            args,
673            kwargs,
674        )
675
676    @staticmethod
677    @functools.lru_cache(None)
678    def _supported_random_functions():
679        fns = {
680            random.random,
681            random.randint,
682            random.randrange,
683            random.uniform,
684        }
685        return fns
686
687    def _maybe_get_baseclass_method(self, name):
688        if name not in getattr(self.value, "__dict__", {}):
689            try:
690                return inspect.getattr_static(type(self.value), name)
691            except AttributeError:
692                pass
693        return None
694
695    def call_method(
696        self,
697        tx,
698        name,
699        args: "List[VariableTracker]",
700        kwargs: "Dict[str, VariableTracker]",
701    ) -> "VariableTracker":
702        from . import (
703            BuiltinVariable,
704            ConstantVariable,
705            TupleVariable,
706            UserMethodVariable,
707        )
708
709        method = self._maybe_get_baseclass_method(name)
710        if method is not None:
711            if method is object.__init__:
712                return ConstantVariable.create(None)
713
714            if is_standard_setattr(method):
715                return self.method_setattr_standard(tx, *args, **kwargs)
716
717            # [NOTE] OrderedDict, dict subtypes must always have source
718            # We cannot instantiate such subtypes in-graph due to builtin __new__
719            if method is collections.OrderedDict.keys:
720                # subclass of OrderedDict
721                assert not (args or kwargs)
722                assert self.source  # OrderedDict, dict subtypes must always have source
723                keys = list(self.value.keys())
724                assert all(map(ConstantVariable.is_literal, keys))
725                install_guard(self.source.make_guard(GuardBuilder.DICT_CONST_KEYS))
726                tx.output.guard_on_key_order.add(self.source.name())
727                return TupleVariable([ConstantVariable.create(k) for k in keys])
728
729            if (
730                method in (collections.OrderedDict.__contains__, dict.__contains__)
731                and len(args) == 1
732                and isinstance(args[0], (ConstantVariable, BuiltinVariable))
733                and inspect.getattr_static(type(self.value), "keys")
734                in (collections.OrderedDict.keys, dict.keys)
735            ):
736                assert not kwargs
737                assert self.source  # OrderedDict, dict subtypes must always have source
738
739                # TODO(anijain2305) - Why do we need to guard on all keys?
740                install_guard(self.source.make_guard(GuardBuilder.DICT_CONST_KEYS))
741                return ConstantVariable.create(
742                    args[0].as_python_constant() in self.value
743                )
744
745            if method is collections.OrderedDict.items and isinstance(
746                self.value, collections.OrderedDict
747            ):
748                assert self.source  # OrderedDict, dict subtypes must always have source
749                assert not (args or kwargs)
750                items = []
751                keys = self.call_method(tx, "keys", [], {})
752                for key in keys.force_unpack_var_sequence(tx):
753                    items.append(
754                        TupleVariable(
755                            [key, self.odict_getitem(tx, key)],
756                        )
757                    )
758                tx.output.guard_on_key_order.add(self.source.name())
759                return TupleVariable(items)
760
761            if method is collections.OrderedDict.__getitem__ and len(args) == 1:
762                assert not kwargs
763                assert self.source  # OrderedDict, dict subtypes must always have source
764                return self.odict_getitem(tx, args[0])
765
766            if (
767                method in (object.__ne__, object.__eq__)
768                and len(args) == 1
769                and not kwargs
770                and hasattr(args[0], "value")
771            ):
772                return ConstantVariable(
773                    (self.value is args[0].value) is (method is object.__eq__)
774                )
775
776            # check for methods implemented in C++
777            if isinstance(method, types.FunctionType):
778                source = (
779                    None
780                    if self.source is None
781                    else AttrSource(AttrSource(self.source, "__class__"), name)
782                )
783                # TODO(jansel): add a guard to check for monkey patching?
784                from ..mutation_guard import unpatched_nn_module_init
785
786                if method is torch.nn.Module.__init__:
787                    method = unpatched_nn_module_init
788                return UserMethodVariable(method, self, source=source).call_function(
789                    tx, args, kwargs
790                )
791
792            if method is list.__len__ and self.source and not (args or kwargs):
793                install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH))
794                return ConstantVariable(len(self.value))
795
796        return super().call_method(tx, name, args, kwargs)
797
798    def method_setattr_standard(self, tx: "InstructionTranslator", name, value):
799        try:
800            name = name.as_python_constant()
801        except NotImplementedError:
802            unimplemented(f"non-const setattr name: {name}")
803        if not tx.output.side_effects.is_attribute_mutation(self):
804            unimplemented(f"setattr({self}, {name}, ...)")
805
806        tx.output.side_effects.store_attr(self, name, value)
807        return variables.ConstantVariable(None)
808
809    def needs_slow_setattr(self):
810        return not is_standard_setattr(
811            inspect.getattr_static(self.value, "__setattr__", None)
812        )
813
814    def unpack_var_sequence(self, tx):
815        if (
816            self.source
817            and self._maybe_get_baseclass_method("__iter__") is list.__iter__
818            and self._maybe_get_baseclass_method("__len__") is list.__len__
819            and self._maybe_get_baseclass_method("__getitem__") is list.__getitem__
820        ):
821            install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH))
822            return [
823                variables.LazyVariableTracker.create(
824                    self.value[k],
825                    source=GetItemSource(self.source, k),
826                )
827                for k in range(len(self.value))
828            ]
829        return super().unpack_var_sequence(tx)
830
831    def next_variable(self, tx):
832        return self.call_method(tx, "__next__", [], {})
833
834    def is_supported_random(self):
835        try:
836            return self.value in self._supported_random_functions()
837        except TypeError:
838            # TypeError: unhashable type
839            return False
840
841    def call_function(
842        self,
843        tx: "InstructionTranslator",
844        args: "List[VariableTracker]",
845        kwargs: "Dict[str, VariableTracker]",
846    ) -> "VariableTracker":
847        from .. import trace_rules
848        from .builder import VariableBuilder
849
850        if (
851            self.is_supported_random()
852            and all(k.is_python_constant() for k in args)
853            and all(v.is_python_constant() for v in kwargs.values())
854        ):
855            return call_random_fn(tx, self.value, args, kwargs)
856        elif istype(self.value, types.MethodType):
857            func = self.value.__func__
858            obj = self.value.__self__
859            if (
860                func is torch.utils._contextlib._DecoratorContextManager.clone
861                and variables.TorchCtxManagerClassVariable.is_matching_cls(
862                    obj.__class__
863                )
864                and not (args or kwargs)
865            ):
866                return variables.TorchCtxManagerClassVariable(
867                    obj.__class__
868                ).call_function(tx, args, kwargs)
869
870            if (
871                func is torch.autograd.grad_mode.inference_mode.clone
872                and obj.__class__ is torch.autograd.grad_mode.inference_mode
873            ):
874                # simulate the inference_mode.clone implementation
875                var = variables.ConstantVariable(obj.mode)
876                return variables.TorchCtxManagerClassVariable(
877                    obj.__class__
878                ).call_function(tx, [var], kwargs)
879
880            if self.source is None:
881                unimplemented(
882                    "Sourceless UserDefinedObjectVariable method not supported"
883                )
884            func_src = AttrSource(self.source, "__func__")
885            func_var = VariableBuilder(tx, func_src)(func)
886            obj_src = AttrSource(self.source, "__self__")
887            obj_var = VariableBuilder(tx, obj_src)(obj)
888            return func_var.call_function(tx, [obj_var] + args, kwargs)
889        elif (
890            istype(self.value, functools.partial)
891            and trace_rules.lookup(self.value.func)
892            == variables.TorchInGraphFunctionVariable
893            and all(
894                variables.ConstantVariable.is_literal(v)
895                for v in itertools.chain(self.value.args, self.value.keywords.values())
896            )
897        ):
898            if self.source:
899                install_guard(
900                    AttrSource(self.source, "func").make_guard(GuardBuilder.ID_MATCH),
901                    AttrSource(self.source, "args").make_guard(
902                        GuardBuilder.CONSTANT_MATCH
903                    ),
904                    AttrSource(self.source, "keywords").make_guard(
905                        GuardBuilder.CONSTANT_MATCH
906                    ),
907                )
908
909            partial_args = [
910                variables.ConstantVariable.create(v) for v in self.value.args
911            ]
912            partial_args.extend(args)
913            partial_kwargs = {
914                k: variables.ConstantVariable.create(v)
915                for k, v in self.value.keywords.items()
916            }
917            partial_kwargs.update(kwargs)
918            if is_utils_checkpoint(self.value.func):
919                return build_checkpoint_variable().call_function(
920                    tx, partial_args, partial_kwargs
921                )
922            return variables.TorchInGraphFunctionVariable(
923                self.value.func
924            ).call_function(tx, partial_args, partial_kwargs)
925        elif callable(self.value):
926            if self.source:
927                install_guard(self.source.make_guard(GuardBuilder.FUNCTION_MATCH))
928            return self.call_method(tx, "__call__", args, kwargs)
929
930        return super().call_function(tx, args, kwargs)
931
932    def _check_for_getattribute(self):
933        if object_has_getattribute(self.value):
934            unimplemented("UserDefinedObjectVariable with custom __getattribute__")
935
936    def _check_for_getattr(self):
937        return get_custom_getattr(self.value)
938
939    def _is_c_defined_property(self, subobj):
940        if not isinstance(subobj, property):
941            return False
942
943        # pybind def_readwrite is implemented via PyCFunction. At the python level, it is visible as a property whose
944        # fget is an instancemethod wrapper - https://docs.python.org/3/c-api/method.html#c.PyInstanceMethod_Check
945
946        # If we have a PyCFunction, we make an assumption that there is no side effect.
947        return isinstance(
948            subobj.fget, types.BuiltinFunctionType
949        ) or torch._C._dynamo.utils.is_instancemethod(subobj.fget)
950
951    def _getattr_static(self, name):
952        subobj = inspect.getattr_static(self.value, name, NO_SUCH_SUBOBJ)
953        import _collections
954
955        # In some cases, we have to do dynamic lookup because getattr_static is not enough. For example, threading.local
956        # has side-effect free __getattribute__ and the attribute is not visible without a dynamic lookup.
957        if (
958            subobj is NO_SUCH_SUBOBJ  # e.g., threading.local
959            or isinstance(
960                subobj, _collections._tuplegetter
961            )  # namedtuple fields are represented by _tuplegetter
962            or (
963                inspect.ismemberdescriptor(subobj) and name in self.value.__slots__
964            )  # handle memberdecriptor and slots
965            or self._is_c_defined_property(subobj)
966        ):
967            # Call __getattribute__, we have already checked that this is not overridden and side-effect free. We don't
968            # want to call getattr because it can be user-overridden.
969            subobj = self.value.__getattribute__(name)
970
971        return subobj
972
973    def has_key_in_generic_dict(self, tx: "InstructionTranslator", key):
974        self._check_for_getattribute()
975        if tx.output.side_effects.has_pending_mutation_of_attr(self, key):
976            mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True)
977            return not isinstance(mutated_attr, variables.DeletedVariable)
978
979        return key in self.value.__dict__
980
981    def is_supported_nn_module_method(self, method):
982        return torch._dynamo.config.inline_inbuilt_nn_modules and method in (
983            torch.nn.Module.parameters,
984        )
985
986    def var_getattr(self, tx: "InstructionTranslator", name):
987        from .. import trace_rules
988        from . import ConstantVariable
989        from .builder import SourcelessBuilder, VariableBuilder
990
991        source = AttrSource(self.source, name) if self.source else None
992        self._check_for_getattribute()
993
994        if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
995            result = tx.output.side_effects.load_attr(self, name, deleted_ok=True)
996            if isinstance(result, variables.DeletedVariable):
997                raise_observed_exception(AttributeError, tx, self)
998            return result
999
1000        if name == "__dict__":
1001            options = {"source": source}
1002            return variables.GetAttrVariable(self, name, **options)
1003
1004        # TODO(anijain2305) - Investigate if we need specialization for more
1005        # dunder attrs. inspect.getattr_static does not return correct value for
1006        # them.
1007        if name == "__class__":
1008            cls_source = source
1009            if cls_source is None:
1010                cls_source = self.cls_source
1011            options = {"source": cls_source}
1012            return UserDefinedClassVariable(type(self.value), **options)
1013
1014        try:
1015            subobj = self._getattr_static(name)
1016        except AttributeError:
1017            subobj = NO_SUCH_SUBOBJ
1018            getattr_fn = self._check_for_getattr()
1019            if isinstance(getattr_fn, types.FunctionType):
1020                # Dynamo is going to trace the __getattr__ function with
1021                # args=name. Set the source accordingly.
1022                if getattr_fn is unpatched_nn_module_getattr and isinstance(
1023                    self, variables.UnspecializedNNModuleVariable
1024                ):
1025                    # Manually trace out the nn module __getattr__ to avoid large compilation latency.
1026                    out = self.manually_trace_nn_module_getattr(tx, name)
1027                else:
1028                    new_source = None
1029                    if self.source:
1030                        new_source = AttrSource(self.source, "__getattr__")
1031                    out = variables.UserMethodVariable(
1032                        getattr_fn, self, source=new_source
1033                    ).call_function(tx, [ConstantVariable.create(name)], {})
1034
1035                if self.source and getattr_fn is torch.nn.Module.__getattr__:
1036                    if isinstance(
1037                        out,
1038                        (
1039                            variables.UnspecializedNNModuleVariable,
1040                            variables.NNModuleVariable,
1041                        ),
1042                    ):
1043                        # nn_module_stack source is BC surface area. Ensure that
1044                        # mod._modules["linear"] is reflected as mod.linear for
1045                        # nn_module_stack.
1046                        out.set_nn_module_stack_source(
1047                            AttrSource(self.get_nn_module_stack_source(), name)
1048                        )
1049                return out
1050
1051            elif getattr_fn is not None:
1052                unimplemented("UserDefined with non-function __getattr__")
1053
1054        if isinstance(subobj, property):
1055            if self.source:
1056                # Read the class attribute to reach the property
1057                source = AttrSource(AttrSource(self.source, "__class__"), name)
1058                # Get the getter function
1059                source = AttrSource(source, "fget")
1060            return variables.UserMethodVariable(
1061                subobj.fget, self, source=source
1062            ).call_function(tx, [], {})
1063        elif isinstance(subobj, staticmethod):
1064            func = subobj.__get__(self.value)
1065            if source is not None:
1066                return trace_rules.lookup(func).create_with_source(func, source=source)
1067            else:
1068                return trace_rules.lookup(func)(func)
1069        elif isinstance(subobj, classmethod):
1070            return variables.UserMethodVariable(
1071                subobj.__func__, self.var_getattr(tx, "__class__"), source=source
1072            )
1073        elif isinstance(subobj, types.ClassMethodDescriptorType):
1074            # e.g.: inspect.getattr_static({}, "fromkeys")
1075            func = subobj.__get__(self.value, None)
1076            if source is not None:
1077                return VariableBuilder(tx, source)(func)
1078            else:
1079                return SourcelessBuilder.create(tx, func)
1080        elif inspect.ismethoddescriptor(subobj) and not is_wrapper_or_member_descriptor(
1081            subobj.__get__
1082        ):
1083            # Attribute has a __get__ method. Create a user defined object vt
1084            # for the subobj, and then trace the __get__ method.
1085            descriptor_var = UserDefinedObjectVariable(subobj, source=source)
1086
1087            get_source = self.source
1088            if self.source:
1089                get_source = AttrSource(self.source, "__get__")
1090
1091            # The arguments of the __get__ function are (self, instance, owner)
1092            # self - descriptor_var
1093            # instance - instance of the class, represented by self here
1094            # owner - class object
1095            owner_var = UserDefinedClassVariable(type(self.value))
1096            return variables.UserMethodVariable(
1097                subobj.__get__.__func__, descriptor_var, source=get_source
1098            ).call_function(tx, [descriptor_var, self, owner_var], {})
1099        elif isinstance(subobj, types.FunctionType) or (
1100            isinstance(subobj, types.MethodType)
1101            and isinstance(self.value, torch.nn.Module)
1102        ):
1103            if self.is_supported_nn_module_method(subobj):
1104                return variables.GetAttrVariable(self, name, source=source)
1105
1106            # Since we get subobj via self._getattr_static, which may not trigger dynamic lookup.
1107            # Static lookup can't tell us it's a method or function correctly,
1108            # so we trigger dynamic lookup here to get the correct type.
1109            dynamic_subobj = getattr(self.value, name)
1110
1111            while dynamic_subobj is subobj and hasattr(subobj, "_torchdynamo_inline"):
1112                subobj = subobj._torchdynamo_inline
1113                dynamic_subobj = subobj
1114                source = AttrSource(source, "_torchdynamo_inline") if source else None
1115
1116            if isinstance(subobj, types.MethodType):
1117                if dynamic_subobj.__self__ is not self.value:
1118                    unimplemented("__self__ mismatch for bound method")
1119                func = subobj.__func__
1120            else:
1121                assert isinstance(subobj, types.FunctionType)
1122                func = subobj
1123
1124            if inspect.ismethod(dynamic_subobj):
1125                return variables.UserMethodVariable(func, self, source=source)
1126            elif inspect.isfunction(dynamic_subobj):
1127                if is_utils_checkpoint(func):
1128                    return build_checkpoint_variable(source=source)
1129                elif source is not None:
1130                    return trace_rules.lookup(func).create_with_source(
1131                        func, source=source
1132                    )
1133                else:
1134                    return trace_rules.lookup(func)(func)
1135
1136        if (
1137            # wrap the source only if inline_inbuilt_nn_modules is set or fsdp modules. This is a temporary solution to
1138            # keep Dynamo behavior compatible with no inlining, as there will be some delay to turn on the flag in
1139            # fbcode.
1140            (
1141                torch._dynamo.config.inline_inbuilt_nn_modules
1142                or isinstance(self, variables.FSDPManagedNNModuleVariable)
1143            )
1144            and source
1145            and isinstance(self, variables.UnspecializedNNModuleVariable)
1146            # export has some awkwardness around specialized and unspecialized modules. Skip wrapping source for export
1147            # usecase for now.
1148            and not tx.output.export
1149        ):
1150            # Recalculate source for params/buffers
1151            if name in ("_buffers", "_parameters"):
1152                source = UnspecializedParamBufferSource(self.source, name)
1153            source = self._wrap_source(source)
1154
1155        if subobj is not NO_SUCH_SUBOBJ:
1156            if is_wrapper_or_member_descriptor(subobj):
1157                options = {"source": source}
1158                return variables.GetAttrVariable(self, name, **options)
1159            if source:
1160                return variables.LazyVariableTracker.create(subobj, source)
1161            else:
1162                # Check if the subobj is accessible from the class itself. If the class source is known, we can create a
1163                # sourceful variable tracker.
1164                if self.cls_source is not None:
1165                    subobj_from_class = inspect.getattr_static(
1166                        self.value.__class__, name, NO_SUCH_SUBOBJ
1167                    )
1168                    if subobj_from_class is subobj:
1169                        src_from_class = AttrSource(self.cls_source, name)
1170                        return variables.LazyVariableTracker.create(
1171                            subobj_from_class, src_from_class
1172                        )
1173
1174                return SourcelessBuilder.create(tx, subobj)
1175
1176        # Earlier we were returning GetAttrVariable but its incorrect. In absence of attr, Python raises AttributeError.
1177        raise_observed_exception(AttributeError, tx, self)
1178
1179    def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
1180        if self._check_for_getattribute():
1181            unimplemented("hasattr with custom __getattribute__")
1182
1183        if self.source:
1184            install_guard(
1185                AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
1186            )
1187
1188        try:
1189            var_vt = self.var_getattr(tx, name)
1190            return variables.ConstantVariable.create(
1191                not isinstance(var_vt, variables.DeletedVariable)
1192            )
1193        except ObservedAttributeError:
1194            handle_observed_exception(tx)
1195            return variables.ConstantVariable.create(False)
1196
1197    def odict_getitem(self, tx: "InstructionTranslator", key):
1198        from .builder import VariableBuilder
1199        from .dicts import is_hashable
1200
1201        # TODO this should probably be merged with the dict handling
1202
1203        index = (
1204            key.source
1205            if is_hashable(key) and key.source is not None
1206            else key.as_python_constant()
1207        )
1208
1209        return VariableBuilder(
1210            tx,
1211            ODictGetItemSource(self.source, index),
1212        )(collections.OrderedDict.__getitem__(self.value, key.as_python_constant()))
1213
1214
1215class FrozenDataClassVariable(UserDefinedObjectVariable):
1216    @staticmethod
1217    def create(tx, value, source):
1218        from dataclasses import fields
1219
1220        assert is_frozen_dataclass(value)
1221
1222        from .builder import VariableBuilder
1223
1224        field_map = {}
1225        for field in fields(value):
1226            if hasattr(value, field.name):
1227                field_map[field.name] = VariableBuilder(
1228                    tx, AttrSource(source, field.name)
1229                )(getattr(value, field.name))
1230
1231        return FrozenDataClassVariable(value, fields=field_map, source=source)
1232
1233    def __init__(self, value, fields=None, **kwargs) -> None:
1234        super().__init__(value, **kwargs)
1235        if fields is None:
1236            fields = {}
1237        self.fields = fields
1238
1239    def as_proxy(self):
1240        from dataclasses import fields
1241
1242        args = []
1243        kwargs = {}
1244        for field in fields(self.value):
1245            proxy = self.fields[field.name].as_proxy()
1246            if hasattr(field, "kw_only") and field.kw_only:
1247                kwargs[field.name] = proxy
1248            else:
1249                args.append(proxy)
1250
1251        return self.python_type()(*args, **kwargs)
1252
1253    # NB: This is called during __init__ for a frozen dataclass
1254    # use this to accumulate the most up-to-date field values
1255    def method_setattr_standard(self, tx: "InstructionTranslator", name, value):
1256        self.fields[name.as_python_constant()] = value
1257        return super().method_setattr_standard(tx, name, value)
1258
1259    def __repr__(self) -> str:
1260        return f"{self.__class__.__name__}({self.value_type.__name__})"
1261
1262
1263class SourcelessGraphModuleVariable(UserDefinedObjectVariable):
1264    def __init__(
1265        self,
1266        value,
1267        **kwargs,
1268    ) -> None:
1269        super().__init__(value, **kwargs)
1270
1271    def call_method(
1272        self,
1273        tx,
1274        name,
1275        args: "List[VariableTracker]",
1276        kwargs: "Dict[str, VariableTracker]",
1277    ) -> "VariableTracker":
1278        fn_variable = variables.UserFunctionVariable(self.value.forward.__func__)
1279        args = [self] + args
1280        return tx.inline_user_function_return(
1281            fn_variable,
1282            args,
1283            kwargs,
1284        )
1285
1286
1287class WeakRefVariable(UserDefinedObjectVariable):
1288    _nonvar_fields = UserDefinedObjectVariable._nonvar_fields
1289
1290    def __init__(self, value, **kwargs) -> None:
1291        super().__init__(value, **kwargs)
1292
1293    def call_function(
1294        self,
1295        tx: "InstructionTranslator",
1296        args: "List[VariableTracker]",
1297        kwargs: "Dict[str, VariableTracker]",
1298    ) -> "VariableTracker":
1299        call_source = None
1300        referent = self.value()
1301
1302        if self.source:
1303            from .builder import VariableBuilder
1304
1305            call_source = WeakRefCallSource(self.source)
1306            return VariableBuilder(tx, call_source)(referent)
1307        else:
1308            from .builder import SourcelessBuilder
1309
1310            return SourcelessBuilder.create(tx, referent)
1311
1312
1313class KeyedJaggedTensorVariable(UserDefinedObjectVariable):
1314    @staticmethod
1315    def is_matching_object(obj):
1316        mod = sys.modules.get("torchrec.sparse.jagged_tensor")
1317        return mod is not None and type(obj) is mod.KeyedJaggedTensor
1318
1319    def __init__(self, value, **kwargs) -> None:
1320        from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
1321
1322        assert type(value) is KeyedJaggedTensor
1323        super().__init__(value, **kwargs)
1324
1325    def var_getattr(self, tx: "InstructionTranslator", name):
1326        if (
1327            torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt
1328            and self.source is not None
1329            and name in ("_length_per_key", "_offset_per_key")
1330        ):
1331            with TracingContext.patch(force_unspec_int_unbacked_size_like=True):
1332                return super().var_getattr(tx, name)
1333        return super().var_getattr(tx, name)
1334
1335
1336class RemovableHandleClass:
1337    # Dummy class to pass to python_type of RemovableHandleVariable
1338    # Useful for isinstance check on hooks
1339    pass
1340
1341
1342class RemovableHandleVariable(VariableTracker):
1343    REMOVED = -1
1344
1345    def __init__(
1346        self,
1347        mutable_local=None,
1348        # index of the registration in the side_effects owned register_hook/handle list, used during removal.
1349        idx=None,
1350        **kwargs,
1351    ) -> None:
1352        super().__init__(**kwargs)
1353        self.mutable_local = mutable_local
1354        self.idx = idx
1355
1356    def call_method(self, tx: "InstructionTranslator", method_name, args, kwargs):
1357        if method_name == "remove":
1358            if self.idx != self.REMOVED:
1359                tx.output.side_effects.remove_hook(self.idx)
1360                self.idx = self.REMOVED
1361            return variables.ConstantVariable.create(None)
1362        super().call_method(tx, method_name, args, kwargs)
1363
1364    def reconstruct(self, codegen):
1365        if self.idx == self.REMOVED:
1366            # Hook has already been removed, return a dummy handle
1367            codegen.add_push_null(
1368                lambda: codegen.load_import_from(
1369                    "torch._dynamo.utils", "invalid_removeable_handle"
1370                )
1371            )
1372            codegen.extend_output(create_call_function(0, False))
1373            return
1374        # unreachable due to codegen.add_cache() when the hook is installed
1375        super().reconstruct(codegen)
1376
1377    def python_type(self):
1378        return RemovableHandleClass
1379
1380
1381class MutableMappingVariable(UserDefinedObjectVariable):
1382    _nonvar_fields = UserDefinedObjectVariable._nonvar_fields
1383
1384    def __init__(self, value, **kwargs):
1385        super().__init__(value, **kwargs)
1386
1387    def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
1388        if name == "get" and type(self.value).get is collections.abc.Mapping.get:
1389            return variables.UserMethodVariable(polyfills.mapping_get, self)
1390        else:
1391            return super().var_getattr(tx, name)
1392
1393
1394class RandomVariable(UserDefinedObjectVariable):
1395    pass
1396