xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/misc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2import collections
3import dataclasses
4import functools
5import inspect
6import itertools
7import random
8import re
9import sys
10import types
11from typing import Dict, List, Optional, TYPE_CHECKING
12
13import torch._C
14import torch._numpy as tnp
15import torch.utils._pytree as pytree
16
17from .. import config, variables
18from ..bytecode_transformation import create_call_function, create_instruction
19from ..create_parameter_op import do_not_convert_to_tracable_parameter
20from ..exc import unimplemented
21from ..guards import GuardBuilder, install_guard
22from ..mutation_guard import unpatched_nn_module_init
23from ..source import (
24    AttrSource,
25    DefaultsSource,
26    GetItemSource,
27    ODictGetItemSource,
28    TypeSource,
29)
30from ..utils import (
31    check_unspec_or_constant_args,
32    identity,
33    is_tensor_base_attr_getter,
34    proxy_args_kwargs,
35    set_example_value,
36)
37from .base import VariableTracker
38from .functions import (
39    NestedUserFunctionVariable,
40    UserFunctionVariable,
41    UserMethodVariable,
42    wrap_bound_arg,
43)
44from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable
45
46
47if TYPE_CHECKING:
48    from torch._dynamo.symbolic_convert import InstructionTranslator
49
50
51class NO_SUCH_SUBOBJ:
52    pass
53
54
55class SuperVariable(VariableTracker):
56    _nonvar_fields = {
57        "specialized",
58        *VariableTracker._nonvar_fields,
59    }
60
61    def __init__(self, typevar, objvar=None, specialized=False, **kwargs) -> None:
62        super().__init__(**kwargs)
63        # typevar is the fist argument to super(). In the case where no argument
64        # is provided to super(), it is the __class__ object where
65        # the super() function is being called
66        self.typevar = typevar
67        # objvar here must be an instance or subtype of typevar.
68        # In the case where super() is called without arguments, it is the first argument
69        # to the current function where super() is called from (self for regular method,
70        # cls for a classmethod)
71        self.objvar = objvar
72        self.specialized = specialized  # directly get attr from self.typevar if true
73
74    def reconstruct(self, codegen):
75        codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super)))
76        codegen(self.typevar)
77        if self.objvar is not None:
78            codegen(self.objvar)
79            codegen.extend_output(create_call_function(2, False))
80        else:
81            codegen.extend_output(create_call_function(1, False))
82
83    def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name):
84        assert self.objvar, "1-arg super not implemented"
85        if self.specialized:
86            return getattr(self.typevar.as_python_constant(), name)
87        search_type = self.typevar.as_python_constant()
88
89        # The rest of this function does two things:
90        #   - Walk the mro to find where the attribute comes from to be
91        #     able to provide accurate source
92        #   - Call the getattr to get the object
93
94        # Find the class object, where the function lives.
95        # When objvar is "self", use type(self), when objvar is "cls", use it as-is
96        type_to_use = self.objvar.python_type()
97        type_to_use_source = (
98            TypeSource(self.objvar.source) if self.objvar.source else None
99        )
100        if issubclass(type_to_use, type):
101            type_to_use = self.objvar.value
102            type_to_use_source = self.objvar.source
103
104        source = None
105        resolved_class = None
106        resolved_attr = None
107        search_mro = type_to_use.__mro__
108
109        try:
110            start_index = search_mro.index(search_type) + 1
111        except ValueError:
112            # Corner case where the typevar is not in the mro of the objvar
113            # https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8843-L8844
114            return getattr(super(search_type, type_to_use), name), None
115        # Implemented based on https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8812
116        # super has its getattro implementation. The key point is that instead of calling getattr, it checks the
117        # attribute in the class __dict__
118        for index in range(start_index, len(search_mro)):
119            # Dont call getattr, just check the __dict__ of the class
120            if resolved_getattr := search_mro[index].__dict__.get(name, NO_SUCH_SUBOBJ):
121                if resolved_getattr is not NO_SUCH_SUBOBJ:
122                    # Equivalent of something like type(L['self']).__mro__[1].attr_name
123                    if type_to_use_source:
124                        source = AttrSource(
125                            GetItemSource(
126                                AttrSource(type_to_use_source, "__mro__"), index
127                            ),
128                            name,
129                        )
130                    return resolved_getattr, source
131
132        unimplemented("Unable to resolve super getattr")
133
134    def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
135        # Check if getattr is a constant. If not, delay the actual work by
136        # wrapping the result in GetAttrVariable. Mostly super is called with a
137        # method, so most of the work is delayed to call_function.
138        #
139        # We could have just implemented a const_getattr. However, super is
140        # special when it comes to finding sources. Compared to other VTs, super
141        # requires the attr name to walk the mro and find the actual source (and
142        # not just AttrSource).
143        value, source = self._resolved_getattr_and_source(self, name)
144        if not variables.ConstantVariable.is_literal(value):
145            return GetAttrVariable(self, name)
146        if source:
147            install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
148            return variables.ConstantVariable.create(value, source=source)
149        return variables.ConstantVariable.create(value)
150
151    def call_method(
152        self,
153        tx,
154        name,
155        args: "List[VariableTracker]",
156        kwargs: "Dict[str, VariableTracker]",
157    ) -> "VariableTracker":
158        inner_fn, source = self._resolved_getattr_and_source(self, name)
159        if inner_fn is object.__init__:
160            return LambdaVariable(identity)
161        elif inner_fn is torch.nn.Module.__init__:
162            objvar = self.objvar
163            from ..side_effects import AttributeMutationNew
164
165            if (
166                isinstance(objvar, variables.UserDefinedObjectVariable)
167                and isinstance(objvar.mutable_local, AttributeMutationNew)
168                and not (args or kwargs)
169            ):
170                with do_not_convert_to_tracable_parameter():
171                    return variables.UserFunctionVariable(
172                        unpatched_nn_module_init, source=source
173                    ).call_function(tx, [self.objvar] + args, kwargs)
174            else:
175                unimplemented("super() nn.Module.__init__")
176        elif self.objvar.source and inner_fn is object.__new__:
177            return tx.output.side_effects.track_object_new_from_user_defined_class(
178                self.objvar
179            )
180        elif isinstance(inner_fn, staticmethod) and isinstance(
181            inner_fn.__func__, types.FunctionType
182        ):
183            return variables.UserFunctionVariable(
184                inner_fn.__func__, source=source
185            ).call_function(tx, args, kwargs)
186        elif isinstance(inner_fn, classmethod) and isinstance(
187            inner_fn.__func__, types.FunctionType
188        ):
189            return variables.UserMethodVariable(
190                inner_fn.__func__, self.objvar, source=source
191            ).call_function(tx, args, kwargs)
192        elif isinstance(inner_fn, types.FunctionType):
193            return variables.UserFunctionVariable(
194                inner_fn, source=source
195            ).call_function(tx, [self.objvar] + args, kwargs)
196        elif isinstance(inner_fn, types.MethodType):
197            return variables.UserMethodVariable(
198                inner_fn.__func__, self.objvar, source=source
199            ).call_function(tx, args, kwargs)
200        elif (
201            inner_fn is collections.OrderedDict.__getitem__
202            and isinstance(self.objvar, variables.UserDefinedObjectVariable)
203            and self.objvar.source
204            and len(args) == 1
205            and len(kwargs) == 0
206            and args[0].is_python_constant()
207        ):
208            from .builder import VariableBuilder
209
210            key = args[0].as_python_constant()
211            return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))(
212                collections.OrderedDict.__getitem__(self.objvar.value, key)
213            )
214        elif inner_fn in (
215            collections.OrderedDict.__setitem__,
216            object.__setattr__,
217        ) and isinstance(self.objvar, variables.CustomizedDictVariable):
218            assert not kwargs and len(args) == 2
219            return super(variables.CustomizedDictVariable, self.objvar).call_method(
220                tx, "__setitem__", args, kwargs
221            )
222        elif inner_fn is collections.OrderedDict.__getitem__ and isinstance(
223            self.objvar, variables.CustomizedDictVariable
224        ):
225            return super(variables.CustomizedDictVariable, self.objvar).call_method(
226                tx, "__getitem__", args, kwargs
227            )
228        elif is_standard_setattr(inner_fn) and isinstance(
229            self.objvar, UserDefinedObjectVariable
230        ):
231            return self.objvar.method_setattr_standard(tx, *args, **kwargs)
232        elif inner_fn is object.__delattr__:
233            attr = args[0]
234            try:
235                attr = attr.as_python_constant()
236            except NotImplementedError:
237                unimplemented(f"non-const delattr attr: {attr}")
238            if not tx.output.side_effects.is_attribute_mutation(self.objvar):
239                unimplemented(f"delattr({self.objvar}, {attr}, ...)")
240
241            tx.output.side_effects.store_attr(
242                self.objvar, attr, variables.DeletedVariable()
243            )
244            return variables.ConstantVariable(None)
245
246        unimplemented(f"non-function or method super: {inner_fn}")
247
248
249class ExceptionVariable(VariableTracker):
250    def __init__(self, exc_type, args, **kwargs) -> None:
251        super().__init__(**kwargs)
252        self.exc_type = exc_type
253        self.args = args
254
255    def reconstruct(self, codegen):
256        codegen.add_push_null(
257            lambda: codegen.load_import_from("builtins", self.exc_type.__name__)
258        )
259        codegen.foreach(self.args)
260        codegen.call_function(len(self.args), False)
261
262
263class UnknownVariable(VariableTracker):
264    """
265    It could be anything!
266    """
267
268
269class DelayGraphBreakVariable(UnknownVariable):
270    """
271    Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION.
272    """
273
274
275class ComptimeVariable(VariableTracker):
276    """
277    This variable is special, it lets you execute arbitrary code at
278    Dynamo compile time
279    """
280
281    def reconstruct(self, codegen):
282        raise NotImplementedError("comptime is special form")
283
284    def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
285        from ..comptime import comptime
286
287        # To support the comptime.print_graph convenience accessors
288        from .functions import UserFunctionVariable
289
290        return UserFunctionVariable(
291            getattr(comptime, name), source=AttrSource(self.source, name)
292        )
293
294    def call_function(
295        self,
296        tx: "InstructionTranslator",
297        args: "List[VariableTracker]",
298        kwargs: "Dict[str, VariableTracker]",
299    ) -> "VariableTracker":
300        from ..comptime import ComptimeContext
301
302        # TODO: support an expression form as well
303
304        assert not kwargs
305        # Second argument is runtime lambda, ignored
306        assert len(args) <= 2
307        fn = args[0]
308        if isinstance(fn, UserFunctionVariable):
309            fn.get_function()(ComptimeContext(tx))
310        elif isinstance(fn, NestedUserFunctionVariable):
311            # We have to manually bind the freevars ourselves
312            code = fn.get_code()
313            assert not fn.closure, (
314                "comptime function must not have free variables, "
315                f"but these variables were free: {code.co_freevars}"
316            )
317            func = types.FunctionType(
318                code,
319                fn.f_globals,
320                fn.fn_name.as_python_constant(),
321                tuple(fn.defaults.items) if fn.defaults else None,
322                # We could automatically promote free variables into
323                # ComptimeVar but this is confusing if you access
324                # a free variable that we actually DO have the runtime
325                # value for
326                # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items)
327                (),
328            )
329            func(ComptimeContext(tx))
330        else:
331            raise RuntimeError(f"unsupported argument to comptime: {type(fn)}")
332
333        return variables.ConstantVariable.create(None)
334
335
336class ClosureVariable(UnknownVariable):
337    _nonvar_fields = {
338        "name",
339        *UnknownVariable._nonvar_fields,
340    }
341
342    def __init__(self, name, **kwargs) -> None:
343        super().__init__(**kwargs)
344        self.name = name
345
346    def reconstruct(self, codegen):
347        codegen.append_output(codegen.create_load_closure(self.name))
348
349
350# closure variable created by an inlined function
351class InlinedClosureVariable(UnknownVariable):
352    _nonvar_fields = {
353        "name",
354        *UnknownVariable._nonvar_fields,
355    }
356
357    def __init__(self, name, **kwargs) -> None:
358        super().__init__(**kwargs)
359        self.name = name
360
361    def reconstruct(self, codegen):
362        codegen.append_output(codegen.create_load_closure(self.name))
363
364
365class NewCellVariable(VariableTracker):
366    def __init__(self, **kwargs) -> None:
367        super().__init__(**kwargs)
368
369
370class NewGlobalVariable(VariableTracker):
371    def __init__(self, **kwargs) -> None:
372        super().__init__(**kwargs)
373
374
375class InspectSignatureVariable(VariableTracker):
376    """represents inspect.signature(...)"""
377
378    _nonvar_fields = {
379        "signature",
380        "parameters",
381        *VariableTracker._nonvar_fields,
382    }
383
384    @staticmethod
385    def create(callable, **kwargs):
386        if kwargs:
387            unimplemented(f"inspect.signature with {kwargs}")
388        return InspectSignatureVariable(
389            callable, mutable_local=variables.base.MutableLocal()
390        )
391
392    def __init__(self, inspected: VariableTracker, **kwargs) -> None:
393        super().__init__(**kwargs)
394        self.inspected = inspected
395
396        if isinstance(self.inspected, UserMethodVariable):
397            self.fn = self.inspected.get_function()
398            self.signature = inspect.signature(self.fn)
399            self.parameters = list(self.signature.parameters.items())[1:]
400        elif isinstance(self.inspected, UserFunctionVariable):
401            self.fn = self.inspected.get_function()
402            self.signature = inspect.signature(self.fn)
403            self.parameters = list(self.signature.parameters.items())
404        else:
405            self.fn = self.inspected.as_python_constant()
406            self.signature = inspect.signature(self.fn)
407            self.parameters = list(self.signature.parameters.items())
408
409    def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
410        if name == "parameters":
411            return variables.ConstDictVariable(
412                {
413                    variables.ConstantVariable.create(
414                        param[0]
415                    ): InspectParameterVariable(param[1])
416                    for param in self.parameters
417                },
418                user_cls=dict,
419            )
420        return super().var_getattr(tx, name)
421
422    def call_method(
423        self,
424        tx,
425        name,
426        args: "List[VariableTracker]",
427        kwargs: "Dict[str, VariableTracker]",
428    ) -> "VariableTracker":
429        if name == "bind":
430            if not hasattr(self.fn, "__kwdefaults__"):
431                unimplemented(
432                    f"inspect.signature.bind with {self.fn} without __kwdefaults__"
433                )
434            obj = self.signature.bind(*args, **kwargs)
435
436            # wrap function defaults in VTs
437            defaults = {}
438            if self.fn.__kwdefaults__:
439                wrap = functools.partial(wrap_bound_arg, tx=tx)
440                kwdefaults_sources = {
441                    k: None
442                    if self.source is None
443                    else DefaultsSource(self.source, k, is_kw=True)
444                    for k in self.fn.__kwdefaults__
445                }
446                defaults = {
447                    k: wrap(val=v, source=kwdefaults_sources[k])
448                    for k, v in self.fn.__kwdefaults__.items()
449                }
450
451            return InspectBoundArgumentsVariable(
452                obj,
453                defaults,
454                self,
455            )
456        return super().call_method(tx, name, args, kwargs)
457
458    def reconstruct(self, codegen):
459        codegen.add_push_null(
460            lambda: codegen.extend_output(
461                [
462                    codegen.create_load_python_module(inspect),
463                    codegen.create_load_attr("signature"),
464                ]
465            )
466        )
467        codegen(self.inspected)
468        codegen.extend_output(create_call_function(1, False))
469
470
471class InspectParameterVariable(VariableTracker):
472    """represents inspect.Parameter(...)"""
473
474    def __init__(self, value, **kwargs) -> None:
475        super().__init__(**kwargs)
476        self.value = value
477
478    def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
479        from .builder import SourcelessBuilder, VariableBuilder
480
481        try:
482            attr_value = getattr(self.value, name)
483            if self.source:
484                attr_source = AttrSource(self.source, name)
485                return VariableBuilder(tx, attr_source)(attr_value)
486            else:
487                return SourcelessBuilder.create(tx, attr_value)
488        except AttributeError:
489            unimplemented(f"getattr({self.value}, {name})")
490
491
492class InspectBoundArgumentsVariable(VariableTracker):
493    """represents inspect.signature(...).bind(...)"""
494
495    _nonvar_fields = {
496        "bound_arguments",
497        "packed_vars",
498        *VariableTracker._nonvar_fields,
499    }
500
501    # NOTE: we keep track of changes to arguments via bound_arguments_var,
502    # but we still keep a copy of the inspect.BoundArguments object in order
503    # to get the correct args/kwargs.
504    def __init__(
505        self,
506        bound_arguments: inspect.BoundArguments,
507        defaults: Dict[str, VariableTracker],
508        signature: InspectSignatureVariable,
509        **kwargs,
510    ):
511        super().__init__(**kwargs)
512        self.bound_arguments = bound_arguments
513        self.defaults = defaults
514        # used to convert from VT to tuple/dict when updating bound_arguments
515        self.packed_vars = set()
516
517        arguments_dict = {}
518        for key, val in bound_arguments.arguments.items():
519            key_var = variables.ConstantVariable(key)
520            # convert val to VT
521            if isinstance(val, tuple):
522                arguments_dict[key_var] = variables.TupleVariable(list(val))
523                self.packed_vars.add(key)
524            elif isinstance(val, dict):
525                self.packed_vars.add(key)
526                arguments_dict[key_var] = variables.ConstDictVariable(
527                    {variables.ConstantVariable(k): v for k, v in val.items()}
528                )
529            elif isinstance(val, VariableTracker):
530                arguments_dict[key_var] = val
531            else:
532                unimplemented(
533                    "inspect.signature(...).bind(...).arguments contains non-variable/tuple/dict"
534                )
535
536        self.bound_arguments_var = variables.ConstDictVariable(
537            arguments_dict,
538            type(bound_arguments.arguments),
539            mutable_local=variables.base.MutableLocal(),
540        )
541        self.signature = signature
542
543    def _update_bound_arguments(self):
544        for key, val in self.bound_arguments_var.items.items():
545            true_val = val
546            if key.underlying_value in self.packed_vars:
547                if isinstance(val, variables.TupleVariable):
548                    true_val = tuple(val.items)
549                elif isinstance(val, variables.ConstDictVariable):
550                    true_val = {k.underlying_value: v for k, v in val.items.items()}
551                else:
552                    unimplemented(
553                        "inspect.signature(...).bind(...) cannot update bound arguments"
554                    )
555            self.bound_arguments.arguments[key.underlying_value] = true_val
556
557    def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
558        if name == "arguments":
559            return self.bound_arguments_var
560        elif name == "args":
561            self._update_bound_arguments()
562            return variables.TupleVariable(list(self.bound_arguments.args))
563        elif name == "kwargs":
564            self._update_bound_arguments()
565            kw = {
566                variables.ConstantVariable(key): val
567                for key, val in self.bound_arguments.kwargs.items()
568            }
569            return variables.ConstDictVariable(kw)
570        elif name == "signature":
571            return self.signature
572        return super().var_getattr(tx, name)
573
574    def call_method(
575        self,
576        tx,
577        name,
578        args: "List[VariableTracker]",
579        kwargs: "Dict[str, VariableTracker]",
580    ) -> "VariableTracker":
581        if name == "apply_defaults":
582            # mimic calling apply_defaults
583            for key, val in self.defaults.items():
584                key_var = variables.ConstantVariable(key)
585                if key_var not in self.bound_arguments_var:
586                    self.bound_arguments_var.call_method(
587                        tx, "__setitem__", [key_var, val], {}
588                    )
589
590            # actually apply the changes
591            self._update_bound_arguments()
592
593            return variables.ConstantVariable(None)
594        return super().call_method(tx, name, args, kwargs)
595
596    def reconstruct(self, codegen):
597        # reconstruct inspect.signature(...).bind(*bound_arguments.args, **bound_arguments.kwargs)
598        # NOTE the reconstructed inspect.signature(...) object might not be the same object
599        # as the Signature object that originally created the BoundArguments object.
600        self._update_bound_arguments()
601
602        def gen_fn():
603            codegen(self.signature)
604            codegen.append_output(codegen.create_load_attr("bind"))
605
606        codegen.add_push_null(gen_fn, call_function_ex=True)
607
608        codegen.foreach(self.bound_arguments.args)
609        codegen.append_output(
610            create_instruction("BUILD_TUPLE", arg=len(self.bound_arguments.args))
611        )
612
613        for key, val in self.bound_arguments.kwargs.items():
614            codegen.append_output(codegen.create_load_const(key))
615            codegen(val)
616        codegen.extend_output(
617            [
618                create_instruction("BUILD_MAP", arg=len(self.bound_arguments.kwargs)),
619                create_instruction("CALL_FUNCTION_EX", arg=1),
620            ]
621        )
622
623
624def produce_trampoline_autograd_apply(fn_cls):
625    def trampoline_autograd_apply(*args, **kwargs):
626        return fn_cls.apply(*args, **kwargs)
627
628    trampoline_autograd_apply._origin = produce_trampoline_autograd_apply
629    return trampoline_autograd_apply
630
631
632class AutogradFunctionVariable(VariableTracker):
633    """represents a torch.autograd.Function subclass"""
634
635    _nonvar_fields = {
636        "fn_cls",
637        *VariableTracker._nonvar_fields,
638    }
639
640    def __init__(self, fn_cls, **kwargs) -> None:
641        super().__init__(**kwargs)
642        self.fn_cls = fn_cls
643
644    def call_apply(self, tx: "InstructionTranslator", args, kwargs):
645        requires_grad = False
646
647        def visit(node):
648            nonlocal requires_grad
649            if isinstance(node, variables.TensorVariable):
650                if node.requires_grad is not False:
651                    requires_grad = True
652            if isinstance(node, variables.NNModuleVariable):
653                if node.is_training(tx):
654                    requires_grad = True
655
656        VariableTracker.visit(visit, (args, kwargs))
657
658        if (
659            requires_grad
660            and torch.is_grad_enabled()
661            and config.capture_autograd_function
662        ):
663            from torch._functorch.autograd_function import (
664                autograd_function_forward_rewritten,
665            )
666            from torch.autograd.function import _is_setup_context_defined
667
668            forward_fn = self.fn_cls.forward
669
670            is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context)
671            if is_setup_ctx_defined:
672                # If setup_context is defined, we generate a new forward function which includes
673                # the original forward and setup_context function, and trace the new forward function.
674                forward_fn = autograd_function_forward_rewritten(
675                    self.fn_cls.forward, self.fn_cls.setup_context
676                )
677
678            vjp_fn = self.fn_cls.vjp  # type: ignore[attr-defined]
679            if vjp_fn is not torch.autograd.Function.vjp:
680                unimplemented("NYI - User defind vjp")
681
682            jvp_fn = self.fn_cls.jvp  # type: ignore[attr-defined]
683            if jvp_fn is not torch.autograd.Function.jvp:
684                unimplemented("NYI - User defind jvp")
685
686            from .higher_order_ops import AutogradFunctionApplyVariable
687
688            source = self.source
689            if source is None:
690                source = AttrSource(
691                    tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
692                )
693
694            val = AutogradFunctionApplyVariable(
695                forward_fn,
696                self.fn_cls.backward,
697                source,
698                source=AttrSource(source, member="apply"),
699            ).call_function(tx, args, kwargs)
700            # Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping
701            # the forward function, as we don't want to generate guards for new_forward.__closure__
702            # if forward is rewritten by autograd_function_forward_rewritten.
703            # But we still need to generate correct guards for the original forward and setup_context
704            # functions, so we have to add guards manually.
705            if self.source:
706                fwd_src = AttrSource(self.source, "forward")
707                install_guard(fwd_src.make_guard(GuardBuilder.FUNCTION_MATCH))
708                if is_setup_ctx_defined:
709                    setup_ctx_src = AttrSource(self.source, "setup_context")
710                    install_guard(setup_ctx_src.make_guard(GuardBuilder.FUNCTION_MATCH))
711
712            return val
713
714        if self.source:
715            source = AttrSource(self.source, "forward")
716        else:
717            source = None
718
719        fn = self.fn_cls.forward
720        ctx = AutogradFunctionContextVariable.create(tx, args, kwargs)
721        args = [ctx, *args]
722        if isinstance(fn, types.FunctionType):
723            return variables.UserFunctionVariable(fn, source=source).call_function(
724                tx, args, kwargs
725            )
726        elif isinstance(fn, types.MethodType):
727            return variables.UserMethodVariable(
728                fn.__func__,
729                variables.UserDefinedClassVariable(self.fn_cls),
730                source=source,
731            ).call_function(tx, args, kwargs)
732        else:
733            unimplemented(
734                f"non-function or method in subclass of torch.autograd.Function: {fn}"
735            )
736
737    def call_backward(self, tx: "InstructionTranslator", args, kwargs):
738        fn = self.fn_cls.backward
739        self.source = AttrSource(self.source, "backward")
740        assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction
741        assert isinstance(fn, types.FunctionType)
742
743        return variables.UserFunctionVariable(fn, source=self.source).call_function(
744            tx, args, kwargs
745        )
746
747    def call_function(self, tx: "InstructionTranslator", args, kwargs):
748        return AutogradFunctionVariable(self.fn_cls)
749
750    def call_method(
751        self,
752        tx,
753        name,
754        args: "List[VariableTracker]",
755        kwargs: "Dict[str, VariableTracker]",
756    ):
757        from ..trace_rules import is_callable_allowed
758        from .builder import wrap_fx_proxy
759
760        if name == "apply":
761            if is_callable_allowed(self.fn_cls):
762                trampoline_autograd_apply = produce_trampoline_autograd_apply(
763                    self.fn_cls
764                )
765                return wrap_fx_proxy(
766                    tx=tx,
767                    proxy=tx.output.create_proxy(
768                        "call_function",
769                        trampoline_autograd_apply,
770                        *proxy_args_kwargs(args, kwargs),
771                    ),
772                )
773            else:
774                return self.call_apply(tx, args, kwargs)
775
776        elif name == "backward":
777            return self.call_backward(tx, args, kwargs)
778        else:
779            from .. import trace_rules
780
781            source = AttrSource(self.source, name) if self.source is not None else None
782            try:
783                obj = inspect.getattr_static(self.fn_cls, name)
784            except AttributeError:
785                obj = None
786
787            if isinstance(obj, staticmethod):
788                func = obj.__get__(self.fn_cls)
789                if source is not None:
790                    return (
791                        trace_rules.lookup(func)
792                        .create_with_source(func, source=source)
793                        .call_function(tx, args, kwargs)
794                    )
795                else:
796                    return trace_rules.lookup(func)(func).call_function(
797                        tx, args, kwargs
798                    )
799            elif isinstance(obj, classmethod):
800                return variables.UserMethodVariable(
801                    obj.__func__, self, source=source
802                ).call_function(tx, args, kwargs)
803            else:
804                unimplemented(f"Unsupported method: {name}")
805
806
807@dataclasses.dataclass
808class SavedTensorBox:
809    tensors: List[VariableTracker] = dataclasses.field(default_factory=list)
810
811
812class AutogradFunctionContextVariable(UserDefinedObjectVariable):
813    """
814    Tracks an autograd.Function() context using mutation tracking in side_effects.py
815    """
816
817    _nonvar_fields = {
818        "proxy",
819        "inference",
820        "saved_tensors",
821        *UserDefinedObjectVariable._nonvar_fields,
822    }
823
824    def __init__(
825        self,
826        value,
827        value_type=None,
828        inference=False,
829        proxy=None,
830        saved_tensors=None,
831        needs_input_grad=None,
832        non_differentiable=None,
833        **kwargs,
834    ) -> None:
835        super().__init__(value=value, value_type=value_type, **kwargs)
836        self.inference = inference
837        self.proxy = proxy
838        self.saved_tensors = saved_tensors
839        self.needs_input_grad = needs_input_grad
840        self.non_differentiable = non_differentiable
841
842    @staticmethod
843    def create(tx: "InstructionTranslator", args=None, kwargs=None):
844        needs_input_grad = None
845        if args and not kwargs:
846            needs_input_grad = tuple(
847                isinstance(x, variables.TensorVariable) and x.requires_grad
848                for x in args
849            )
850        proxy = tx.output.create_proxy(
851            "call_function", torch.autograd.function.FunctionCtx, (), {}
852        )
853        out = tx.output.side_effects.track_object_new(
854            None,
855            torch.autograd.function.FunctionCtx,
856            functools.partial(
857                AutogradFunctionContextVariable,
858                inference=True,
859                proxy=proxy,
860                saved_tensors=SavedTensorBox(),
861                needs_input_grad=needs_input_grad,
862            ),
863            {},
864        )
865        set_example_value(proxy.node, out.value)
866
867        return out
868
869    def as_proxy(self):
870        if self.proxy is None:
871            unimplemented("proxy not set")
872        return self.proxy
873
874    def call_method(
875        self,
876        tx,
877        name,
878        args: "List[VariableTracker]",
879        kwargs: "Dict[str, VariableTracker]",
880    ) -> "VariableTracker":
881        if name == "__setattr__":
882            return super().call_method(tx, name, args, kwargs)
883        elif name == "mark_non_differentiable":
884            assert len(kwargs) == 0
885            self.non_differentiable = proxy_args_kwargs(args, {})[0]
886            return variables.ConstantVariable.create(None)
887
888        if name != "save_for_backward":
889            unimplemented(f"autograd.Function context method: {name}")
890        if self.saved_tensors is None:
891            unimplemented(
892                "save_for_backward only supported on a newly constructed FunctionCtx"
893            )
894
895        if not self.inference:
896            assert self.source and not kwargs
897            tx.output.side_effects.track_save_for_backward(self, args)
898
899        # In eager mode, multiple calls to .save_for_backward() will overwrite previous calls.
900        if len(self.saved_tensors.tensors) > 0:
901            self.saved_tensors.tensors = []
902        for arg in args:
903            self.saved_tensors.tensors.append(arg)
904        return variables.ConstantVariable.create(None)
905
906    def var_getattr(self, tx: "InstructionTranslator", name):
907        if name in ["save_for_backward", "mark_non_differentiable"]:
908            return LambdaVariable(
909                lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
910            )
911        if name == "saved_tensors" and self.saved_tensors is not None:
912            return variables.TupleVariable(list(self.saved_tensors.tensors))
913        if name == "needs_input_grad":
914            if self.needs_input_grad is not None:
915                return variables.ConstantVariable.create(self.needs_input_grad)
916            if self.source:
917                from .builder import VariableBuilder
918
919                return VariableBuilder(tx, AttrSource(self.source, "needs_input_grad"))(
920                    self.value.needs_input_grad
921                )
922        return super().var_getattr(tx, name)
923
924
925class AutogradEngineVariable(UserDefinedObjectVariable):
926    """
927    Represents a torch._C._ImperativeEngine instance.
928    """
929
930    def __init__(
931        self,
932        value,
933        value_type=None,
934        **kwargs,
935    ) -> None:
936        super().__init__(value=value, value_type=value_type, **kwargs)
937
938    def call_method(
939        self,
940        tx,
941        name,
942        args: "List[VariableTracker]",
943        kwargs: "Dict[str, VariableTracker]",
944    ) -> "VariableTracker":
945        if name == "queue_callback":
946            if torch._dynamo.compiled_autograd.compiled_autograd_enabled:
947                assert (
948                    tx.one_graph
949                ), "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
950                return variables.UserFunctionVariable(
951                    torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback,
952                    source=self.source,
953                ).call_function(
954                    tx,
955                    (tx.output.side_effects.get_ca_final_callbacks_var(), *args),
956                    kwargs,
957                )
958            else:
959                unimplemented(
960                    "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
961                )
962        else:
963            unimplemented(f"torch._C._ImperativeEngine method: {name}")
964
965
966class LambdaVariable(VariableTracker):
967    def __init__(self, fn, **kwargs) -> None:
968        super().__init__(**kwargs)
969        self.fn = fn
970
971    def call_function(
972        self,
973        tx: "InstructionTranslator",
974        args: "List[VariableTracker]",
975        kwargs: "Dict[str, VariableTracker]",
976    ) -> "VariableTracker":
977        return self.fn(*args, **kwargs)
978
979
980class GetAttrVariable(VariableTracker):
981    _nonvar_fields = {
982        "name",
983        *VariableTracker._nonvar_fields,
984    }
985
986    def __init__(self, obj, name, **kwargs) -> None:
987        super().__init__(**kwargs)
988        assert isinstance(obj, VariableTracker)
989        assert isinstance(name, str)
990        self.obj = obj
991        self.name = name
992
993    def __str__(self) -> str:
994        return f"{self.__class__.__name__}({self.obj}, {self.name})"
995
996    @staticmethod
997    def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr):
998        return getattr(base_proxy, attr)
999
1000    def as_proxy(self):
1001        return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)
1002
1003    def const_getattr(self, tx: "InstructionTranslator", name):
1004        if not isinstance(self.obj, variables.NNModuleVariable):
1005            raise NotImplementedError
1006        step1 = tx.output.get_submodule(self.obj.module_key)
1007        if self.name not in step1.__dict__:
1008            raise NotImplementedError
1009        step2 = inspect.getattr_static(step1, self.name)
1010        if name not in step2.__dict__:
1011            raise NotImplementedError
1012        return inspect.getattr_static(step2, name)
1013
1014    def reconstruct(self, codegen):
1015        codegen(self.obj)
1016        codegen.extend_output(codegen.create_load_attrs(self.name))
1017
1018    def call_function(
1019        self,
1020        tx: "InstructionTranslator",
1021        args: "List[VariableTracker]",
1022        kwargs: "Dict[str, VariableTracker]",
1023    ) -> "VariableTracker":
1024        return self.obj.call_method(tx, self.name, args, kwargs)
1025
1026    def call_method(
1027        self,
1028        tx,
1029        name,
1030        args: List[VariableTracker],
1031        kwargs: Dict[str, VariableTracker],
1032    ) -> VariableTracker:
1033        if (
1034            name in ("__getitem__", "get")
1035            and self.name == "__dict__"
1036            and not kwargs
1037            and args[0].is_python_constant()
1038            and isinstance(
1039                self.obj,
1040                (
1041                    variables.UserDefinedObjectVariable,
1042                    variables.NNModuleVariable,
1043                    variables.UserDefinedClassVariable,
1044                ),
1045            )
1046        ):
1047            obj = self.obj
1048            key = args[0].as_python_constant()
1049            if obj.has_key_in_generic_dict(tx, key):
1050                # redirect to var_getattr on the original obj
1051                return obj.var_getattr(tx, key)
1052
1053            # Return the default value for get
1054            if name == "get":
1055                if len(args) == 2:
1056                    return args[1]
1057                else:
1058                    return variables.ConstantVariable(None)
1059
1060        elif (
1061            name == "__contains__"
1062            and self.name == "__dict__"
1063            and len(args) == 1
1064            and args[0].is_python_constant()
1065            and not kwargs
1066            and isinstance(
1067                self.obj,
1068                (
1069                    variables.UserDefinedObjectVariable,
1070                    variables.NNModuleVariable,
1071                    variables.UserDefinedClassVariable,
1072                ),
1073            )
1074        ):
1075            obj = self.obj
1076            key = args[0].as_python_constant()
1077            if obj.has_key_in_generic_dict(tx, key):
1078                return variables.ConstantVariable(True)
1079            else:
1080                return variables.ConstantVariable(False)
1081
1082        return super().call_method(tx, name, args, kwargs)
1083
1084
1085class MethodWrapperVariable(VariableTracker):
1086    def __init__(self, method_wrapper, **kwargs) -> None:
1087        super().__init__(**kwargs)
1088        self.method_wrapper = method_wrapper
1089
1090    def call_function(
1091        self,
1092        tx: "InstructionTranslator",
1093        args: "List[VariableTracker]",
1094        kwargs: "Dict[str, VariableTracker]",
1095    ) -> "VariableTracker":
1096        if is_tensor_base_attr_getter(self.method_wrapper) and isinstance(
1097            args[0], variables.TensorVariable
1098        ):
1099            assert len(args) == 1 and len(kwargs) == 0
1100
1101            return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__)
1102
1103        super().call_function(tx, args, kwargs)
1104
1105    def is_python_constant(self):
1106        return True
1107
1108    def as_python_constant(self):
1109        return self.method_wrapper
1110
1111
1112class GetSetDescriptorVariable(VariableTracker):
1113    def __init__(self, desc, **kwargs) -> None:
1114        super().__init__(**kwargs)
1115        self.desc = desc
1116
1117    def var_getattr(self, tx: "InstructionTranslator", name):
1118        if name == "__get__" and self.source:
1119            from .builder import VariableBuilder
1120
1121            return VariableBuilder(tx, AttrSource(self.source, "__get__"))(
1122                self.desc.__get__
1123            )
1124        else:
1125            return super().var_getattr(tx, name)
1126
1127    def is_python_constant(self):
1128        return True
1129
1130    def as_python_constant(self):
1131        return self.desc
1132
1133
1134class PythonModuleVariable(VariableTracker):
1135    _nonvar_fields = {
1136        "value",
1137        "is_torch",
1138        *VariableTracker._nonvar_fields,
1139    }
1140
1141    def __init__(self, value: types.ModuleType, **kwargs) -> None:
1142        super().__init__(**kwargs)
1143        self.value = value
1144        self.is_torch = self.value is torch or self.value.__name__.startswith("torch.")
1145
1146    def python_type(self):
1147        return types.ModuleType
1148
1149    def as_python_constant(self):
1150        return self.value
1151
1152    def __repr__(self) -> str:
1153        return f"PythonModuleVariable({self.value})"
1154
1155    def call_hasattr(self, tx: "InstructionTranslator", name):
1156        result = hasattr(self.value, name)
1157        return variables.ConstantVariable.create(result)
1158
1159    def var_getattr(self, tx: "InstructionTranslator", name):
1160        if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
1161            return tx.output.side_effects.load_attr(self, name)
1162
1163        from .builder import SourcelessBuilder, VariableBuilder
1164
1165        if self.is_torch or name not in self.value.__dict__:
1166            attr_value = getattr(self.value, name)
1167        else:
1168            attr_value = self.value.__dict__[name]
1169
1170        if self.source:
1171            new_source = AttrSource(self.source, name)
1172            return VariableBuilder(tx, new_source)(attr_value)
1173        else:
1174            return SourcelessBuilder.create(tx, attr_value)
1175
1176
1177class TypingVariable(VariableTracker):
1178    def __init__(self, value, **kwargs) -> None:
1179        super().__init__(**kwargs)
1180        self.value = value
1181
1182    def call_method(
1183        self,
1184        tx,
1185        name,
1186        args: "List[VariableTracker]",
1187        kwargs: "Dict[str, VariableTracker]",
1188    ) -> "VariableTracker":
1189        if name == "__getitem__" and len(args) == 1:
1190            return variables.ConstantVariable.create(
1191                self.value[args[0].as_python_constant()],
1192            )
1193        unimplemented("typing")
1194
1195    def as_python_constant(self):
1196        return self.value
1197
1198
1199@functools.lru_cache(maxsize=1)
1200def get_np_to_tnp_map():
1201    from ..utils import NP_TO_TNP_MODULE
1202
1203    np_fn_to_tnp_fn = {}
1204
1205    for np_mod, tnp_mod in NP_TO_TNP_MODULE.items():
1206        for fn_name, tnp_fn in tnp_mod.__dict__.items():
1207            if callable(tnp_fn):
1208                # some internal details do leak from tnp
1209                # which are not part of numpy API.
1210                if np_fn := getattr(np_mod, fn_name, None):
1211                    np_fn_to_tnp_fn[np_fn] = tnp_fn
1212
1213    return np_fn_to_tnp_fn
1214
1215
1216class NumpyVariable(VariableTracker):
1217    """
1218    Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes.
1219    """
1220
1221    constant_fold_functions = (tnp.issubdtype,)
1222
1223    def __init__(self, value, **kwargs) -> None:
1224        super().__init__(**kwargs)
1225        self.value = value
1226
1227    @classmethod
1228    def can_constant_fold_through(cls, fn):
1229        mod = fn.__module__.split(".")
1230        assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
1231        return fn in cls.constant_fold_functions
1232
1233    @classmethod
1234    def get_constant_collection_for_func(cls, fn):
1235        mod = fn.__module__.split(".")
1236        assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
1237        return np_constant_collections_map.get(fn, None)
1238
1239    def call_function(
1240        self,
1241        tx: "InstructionTranslator",
1242        args: "List[VariableTracker]",
1243        kwargs: "Dict[str, VariableTracker]",
1244    ) -> "VariableTracker":
1245        if not config.trace_numpy:
1246            unimplemented(f"numpy.{self.value}()")
1247
1248        from ..utils import numpy_to_tensor_wrapper
1249        from .tensor import NumpyNdarrayVariable
1250
1251        func = get_np_to_tnp_map().get(self.value)
1252        if func is None:
1253            unimplemented(
1254                f"Can't find numpy function {self.value} in torch._numpy. "
1255                " Please file an issue to request support for this function."
1256            )
1257
1258        # We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo)
1259        if (
1260            collection_variable_typ := self.get_constant_collection_for_func(func)
1261        ) is not None:
1262            try:
1263                return collection_variable_typ(
1264                    self.value(
1265                        *[x.as_python_constant() for x in args],
1266                        **{k: v.as_python_constant() for k, v in kwargs.items()},
1267                    )
1268                )
1269            except NotImplementedError:
1270                unimplemented(
1271                    f"{self.value.__name__} with non-const args: {args} {kwargs}"
1272                )
1273        else:
1274            if (
1275                func.__module__ == "torch._numpy.random"
1276                and config.use_numpy_random_stream
1277            ):
1278                msg = f"delegate '{func.__qualname__}' to NumPy itself via "
1279                msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}"
1280                unimplemented(msg)
1281
1282            args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs)
1283
1284            if self.can_constant_fold_through(func) and (
1285                check_unspec_or_constant_args(args, kwargs)
1286            ):
1287                # constant fold
1288                return variables.ConstantVariable.create(
1289                    self.as_python_constant()(
1290                        *[x.as_python_constant() for x in args],
1291                        **{k: v.as_python_constant() for k, v in kwargs.items()},
1292                    ),
1293                )
1294
1295            # TODO Add all the functions that go from constants to constants to can_constant_fold_through
1296            proxy = tx.output.create_proxy(
1297                "call_function",
1298                numpy_to_tensor_wrapper(func),
1299                *proxy_args_kwargs(args, kwargs),
1300            )
1301            return NumpyNdarrayVariable.create(tx, proxy)
1302
1303    def call_method(
1304        self,
1305        tx,
1306        name,
1307        args: "List[VariableTracker]",
1308        kwargs: "Dict[str, VariableTracker]",
1309    ) -> "VariableTracker":
1310        unimplemented("numpy")
1311
1312    def as_python_constant(self):
1313        return self.value
1314
1315    def as_proxy(self):
1316        if config.trace_numpy and isinstance(self.value, type):
1317            # This handles numpy dtype attributes such as np.float32
1318            # We return a string as we don't want to serialize non-PyTorch objects in the output FX graph
1319            # In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does
1320            return self.value.__name__
1321
1322        return super().as_proxy()
1323
1324
1325# Used to keep track of NULLs pushed on the stack for Python 3.11 function calls
1326class NullVariable(VariableTracker):
1327    def __init__(self, **kwargs) -> None:
1328        super().__init__(**kwargs)
1329
1330    def __str__(self) -> str:
1331        return "NullVariable"
1332
1333    def reconstruct(self, codegen):
1334        if sys.version_info < (3, 11):
1335            unimplemented("cannot reconstruct NullVariable in < Python 3.11")
1336        codegen.append_output(create_instruction("PUSH_NULL"))
1337
1338
1339class DeletedVariable(VariableTracker):
1340    """Marker used to implement delattr()"""
1341
1342
1343class StringFormatVariable(VariableTracker):
1344    """
1345    Represents a call to str.format(), we delay calling format until after the graph.
1346    """
1347
1348    _nonvar_fields = {"format_string", *VariableTracker._nonvar_fields}
1349
1350    @classmethod
1351    def create(cls, format_string, sym_args, sym_kwargs):
1352        if all(
1353            x.is_python_constant()
1354            for x in itertools.chain(sym_args, sym_kwargs.values())
1355        ):
1356            return variables.ConstantVariable.create(
1357                format_string.format(
1358                    *[v.as_python_constant() for v in sym_args],
1359                    **{k: v.as_python_constant() for k, v in sym_kwargs.items()},
1360                )
1361            )
1362        return cls(format_string, list(sym_args), dict(sym_kwargs))
1363
1364    def __init__(self, format_string, sym_args, sym_kwargs, **kwargs) -> None:
1365        super().__init__(**kwargs)
1366        assert isinstance(format_string, str)
1367        self.format_string = format_string
1368        self.sym_args = sym_args
1369        self.sym_kwargs = sym_kwargs
1370
1371    def __repr__(self) -> str:
1372        return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})"
1373
1374    def reconstruct(self, codegen):
1375        codegen.add_push_null(
1376            lambda: codegen.extend_output(
1377                [
1378                    codegen.create_load_const(self.format_string),
1379                    codegen.create_load_attr("format"),
1380                ]
1381            ),
1382            call_function_ex=True,
1383        )
1384        codegen(variables.TupleVariable(self.sym_args))
1385        kwargs = {
1386            variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items()
1387        }
1388        codegen(variables.ConstDictVariable(kwargs))
1389        codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1))
1390
1391
1392class DebuggingVariable(VariableTracker):
1393    """
1394    Represents a call to a debugging function like print(), or something
1395    registered to config.reorderable_logging_functions.
1396    """
1397
1398    def __init__(self, value, **kwargs) -> None:
1399        super().__init__(**kwargs)
1400        self.value = value
1401
1402    @staticmethod
1403    def is_reorderable_logging_function(obj):
1404        return (
1405            callable(obj)
1406            and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType))
1407            and obj in torch._dynamo.config.reorderable_logging_functions
1408        )
1409
1410    def call_function(self, tx: "InstructionTranslator", args, kwargs):
1411        if tx.export:
1412            # For export cases, we can just make debugging functions no-ops
1413            return
1414
1415        if not self.can_reorder_logs(self.value, args, kwargs):
1416            unimplemented(
1417                f"Reordering debugging function {self.value} "
1418                f"with inputs {args} {kwargs} is not yet implemented."
1419            )
1420
1421        tx.debug_locals.append((self, list(args)))
1422
1423    def reconstruct(self, codegen):
1424        return self.source.reconstruct(codegen)
1425
1426    @staticmethod
1427    def can_reorder_logs(fn, args, kwargs) -> True:
1428        """
1429        Run some additional checks for what sort of function calls can we
1430        actually reorder.
1431        """
1432
1433        allowed_input_types = (
1434            variables.TensorVariable,
1435            variables.ConstantVariable,
1436            StringFormatVariable,
1437        )
1438
1439        flat_args = pytree.tree_leaves([args, kwargs])
1440        for arg in flat_args:
1441            if not isinstance(arg, allowed_input_types):
1442                return False
1443
1444        return True
1445
1446
1447class LoggingLoggerVariable(VariableTracker):
1448    """
1449    Represents a call to any of logging.Logger methods
1450    """
1451
1452    def __init__(self, value, **kwargs) -> None:
1453        super().__init__(**kwargs)
1454
1455    def call_method(
1456        self,
1457        tx,
1458        name,
1459        args: "List[VariableTracker]",
1460        kwargs: "Dict[str, VariableTracker]",
1461    ) -> "VariableTracker":
1462        if tx.export:
1463            # For export cases, we can just make debugging functions no-ops
1464            return
1465        unimplemented("Logger not supported for non-export cases")
1466
1467
1468class ConstantLikeVariable(VariableTracker):
1469    """self.value is a compile-time constant, but not a literal"""
1470
1471    _error_prefix = "ConstantLikeVariable"
1472    try:
1473        from numpy import (
1474            dtype as np_dtype,
1475            floating as np_floating,
1476            generic as np_generic,
1477        )
1478    except ImportError:
1479        np_floating = type("invalid_type", (), {})
1480        np_dtype = type("invalid_type", (), {})
1481
1482    def __init__(self, value, **kwargs) -> None:
1483        super().__init__(**kwargs)
1484        self.value = value
1485
1486    def as_python_constant(self):
1487        return self.value
1488
1489    def call_method(
1490        self,
1491        tx,
1492        name,
1493        args: List[VariableTracker],
1494        kwargs: Dict[str, VariableTracker],
1495    ) -> VariableTracker:
1496        try:
1497            # we only support constant propagation for methods
1498            cargs = [x.as_python_constant() for x in args]
1499            ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
1500        except NotImplementedError:
1501            unimplemented(f"{self._error_prefix}.{name}(*{args}, **{kwargs})")
1502
1503        result = getattr(self.value, name)(*cargs, **ckwargs)
1504
1505        if variables.ConstantVariable.is_literal(result):
1506            return variables.ConstantVariable.create(result)
1507        if isinstance(result, re.Match):
1508            return ConstantRegexMatchVariable(result)
1509
1510        unimplemented(f"{self._error_prefix}.{name}() -> {result}")
1511
1512    def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
1513        result = getattr(self.value, name)
1514        if isinstance(result, self.np_floating):
1515            result = float(result)
1516        if isinstance(result, self.np_dtype):
1517            return NumpyDTypeVariable(result)
1518        if isinstance(result, type) and issubclass(result, self.np_generic):
1519            # things like x.dtype.type
1520            return NumpyVariable(result)
1521        if variables.ConstantVariable.is_literal(result):
1522            return variables.ConstantVariable.create(result)
1523        return GetAttrVariable(self, name)
1524
1525
1526class RegexPatternVariable(ConstantLikeVariable):
1527    _error_prefix = "re.Pattern"
1528
1529
1530class ConstantRegexMatchVariable(ConstantLikeVariable):
1531    _error_prefix = "re.Match"
1532
1533
1534class TorchVersionVariable(ConstantLikeVariable):
1535    _error_prefix = "torch.__version__"
1536
1537    def __init__(self, **kwargs) -> None:
1538        kwargs.setdefault("value", torch.__version__)
1539        assert kwargs["value"] is torch.__version__
1540        super().__init__(**kwargs)
1541
1542
1543class NumpyTypeInfoVariable(ConstantLikeVariable):
1544    _error_prefix = "np.iinfo/np.finfo"
1545
1546
1547class NumpyDTypeVariable(ConstantLikeVariable):
1548    _error_prefix = "np.dtype[...]"
1549
1550    def as_proxy(self):
1551        """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable:
1552
1553        np.dtype() objects are serialized as strings, torch._numpy wrappers will normalize to the torch dtype.
1554        This also handles unsupported things nicely (i.e. structured arrays and object arrays).
1555        """
1556        return self.value.type.__name__
1557
1558
1559np_constant_collections_map = {
1560    tnp.finfo: NumpyTypeInfoVariable,
1561    tnp.iinfo: NumpyTypeInfoVariable,
1562    tnp.dtype: NumpyDTypeVariable,
1563}
1564
1565
1566class RandomClassVariable(VariableTracker):
1567    """random.Random"""
1568
1569    def __init__(self, **kwargs) -> None:
1570        super().__init__(**kwargs)
1571
1572    def call_function(self, tx: "InstructionTranslator", args, kwargs):
1573        if len(args) > 1:
1574            unimplemented("random.Random() with > 1 arg")
1575        elif kwargs:
1576            unimplemented("random.Random() with kwargs")
1577        seed = variables.ConstantVariable.create(None) if len(args) == 0 else args[0]
1578        return RandomVariable(seed=seed, mutable_local=variables.base.MutableLocal())
1579
1580
1581class RandomVariable(VariableTracker):
1582    """random.Random()
1583
1584    Implemented by wrapping a VariableTracker around a random.Random object.
1585    The supported methods for the random.Random object cannot be overriden.
1586    Assumes that random objects behave the same given a set seed or state.
1587    """
1588
1589    _nonvar_fields = {
1590        "random",
1591        *VariableTracker._nonvar_fields,
1592    }
1593
1594    _supported_fn_names = {
1595        "random",
1596        "randint",
1597        "randrange",
1598        "uniform",
1599    }
1600
1601    def __init__(
1602        self,
1603        rand: Optional[random.Random] = None,
1604        seed: Optional[VariableTracker] = None,
1605        **kwargs,
1606    ) -> None:
1607        super().__init__(**kwargs)
1608        if rand is not None:
1609            assert self.is_supported_random_obj(rand)
1610            self.random = random.Random()
1611            self.random.setstate(rand.getstate())
1612        else:
1613            seed = seed.as_python_constant() if seed is not None else None
1614            self.random = random.Random(seed)
1615
1616    def python_type(self):
1617        return random.Random
1618
1619    def as_python_constant(self):
1620        return self.random
1621
1622    @staticmethod
1623    def is_supported_random_obj(val):
1624        if type(val) is not random.Random:
1625            return False
1626        for name in itertools.chain(
1627            RandomVariable._supported_fn_names, ("seed", "getstate", "setstate")
1628        ):
1629            if not hasattr(val, name):
1630                return False
1631            meth = getattr(val, name)
1632            if inspect.isbuiltin(meth):
1633                # e.g. random.Random.random
1634                if meth != getattr(random.Random, name).__get__(val):
1635                    return False
1636            else:
1637                if getattr(meth, "__func__", None) is not getattr(random.Random, name):
1638                    return False
1639        return True
1640
1641    @staticmethod
1642    def check_state(state):
1643        assert type(state) is tuple
1644        assert type(state[0]) is int
1645        assert type(state[1]) is tuple
1646        assert all(type(x) is int for x in state[1])
1647        assert state[2] is None or type(state[2]) is float
1648
1649    @staticmethod
1650    def wrap_state(state):
1651        RandomVariable.check_state(state)
1652        return variables.TupleVariable(
1653            [
1654                variables.ConstantVariable.create(state[0]),
1655                variables.TupleVariable(
1656                    [variables.ConstantVariable.create(x) for x in state[1]]
1657                ),
1658                variables.ConstantVariable.create(state[2]),
1659            ]
1660        )
1661
1662    @staticmethod
1663    def unwrap_state(state):
1664        state_obj = state.as_python_constant()
1665        RandomVariable.check_state(state_obj)
1666        return state_obj
1667
1668    def call_method(
1669        self,
1670        tx,
1671        name,
1672        args: List[VariableTracker],
1673        kwargs: Dict[str, VariableTracker],
1674    ) -> VariableTracker:
1675        if name == "seed":
1676            tx.output.side_effects.mutation(self)
1677            self.random.seed(
1678                *[x.as_python_constant() for x in args],
1679                **{key: val.as_python_constant() for key, val in kwargs.items()},
1680            )
1681            return variables.ConstantVariable.create(None)
1682        elif name == "getstate":
1683            return self.wrap_state(self.random.getstate())
1684        elif name == "setstate":
1685            tx.output.side_effects.mutation(self)
1686            self.random.setstate(self.unwrap_state(args[0]))
1687            return variables.ConstantVariable.create(None)
1688        elif name in self._supported_fn_names:
1689            tx.output.side_effects.mutation(self)
1690            state = self.random.getstate()
1691
1692            def call_random_meth(*args, **kwargs):
1693                r = random.Random()
1694                r.setstate(state)
1695                return getattr(r, name)(*args, **kwargs)
1696
1697            # self.random state not actually updated by call_random_meth, so update here
1698            # by calling the method
1699            getattr(self.random, name)(
1700                *[x.as_python_constant() for x in args],
1701                **{k: v.as_python_constant() for k, v in kwargs.items()},
1702            )
1703
1704            return call_random_fn(tx, call_random_meth, args, kwargs)
1705        return super().call_method(tx, name, args, kwargs)
1706
1707    def reconstruct(self, codegen):
1708        codegen.add_push_null(
1709            lambda: codegen.extend_output(
1710                [
1711                    codegen.create_load_python_module(random),
1712                    codegen.create_load_attr("Random"),
1713                ]
1714            )
1715        )
1716        codegen.call_function(0, False)
1717        # NOTE using add_push_null may result in NULL being duplicated
1718        # so defer the push_null to call_function
1719        codegen.dup_top()
1720        codegen.load_attr("setstate")
1721        codegen(self.wrap_state(self.random.getstate()))
1722        codegen.call_function(1, True)
1723        codegen.pop_top()
1724