xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/builtin.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import contextlib
4import functools
5import inspect
6import itertools
7import logging
8import math
9import operator
10import types
11from collections import defaultdict, OrderedDict
12from collections.abc import KeysView
13from typing import Dict, List, TYPE_CHECKING
14
15import torch
16from torch import sym_float, sym_int
17from torch.utils._python_dispatch import is_traceable_wrapper_subclass
18
19from .. import config, variables
20from ..exc import (
21    AttributeMutationError,
22    unimplemented,
23    Unsupported,
24    UserError,
25    UserErrorType,
26)
27from ..guards import GuardBuilder, install_guard
28from ..replay_record import DummyModule
29from ..source import AttrSource, GetItemSource, is_constant_source, TypeSource
30from ..utils import (
31    check_constant_args,
32    check_numpy_ndarray_args,
33    check_unspec_or_constant_args,
34    check_unspec_python_args,
35    does_not_override_dict_iter_methods,
36    extract_fake_example_value,
37    get_fake_value,
38    guard_if_dyn,
39    is_wrapper_or_member_descriptor,
40    istype,
41    numpy_operator_wrapper,
42    proxy_args_kwargs,
43    tensortype_to_dtype,
44)
45from .base import MutableLocal, VariableTracker
46from .constant import ConstantVariable
47from .ctx_manager import EventVariable, StreamVariable
48from .dicts import (
49    ConstDictVariable,
50    DefaultDictVariable,
51    DictView,
52    FrozensetVariable,
53    is_hashable,
54    SetVariable,
55)
56from .lists import (
57    BaseListVariable,
58    ListIteratorVariable,
59    ListVariable,
60    SizeVariable,
61    TupleIteratorVariable,
62    TupleVariable,
63)
64from .tensor import (
65    FakeItemVariable,
66    supported_comparison_ops,
67    SymNodeVariable,
68    TensorVariable,
69    UnspecializedPythonVariable,
70)
71from .user_defined import UserDefinedObjectVariable, UserDefinedVariable
72
73
74if TYPE_CHECKING:
75    from torch._dynamo.symbolic_convert import InstructionTranslator
76
77
78log = logging.getLogger(__name__)
79
80
81IN_PLACE_DESUGARING_MAP = {
82    operator.iadd: operator.add,
83    operator.isub: operator.sub,
84    operator.imul: operator.mul,
85    operator.ifloordiv: operator.floordiv,
86    operator.itruediv: operator.truediv,
87    operator.imod: operator.mod,
88    operator.imatmul: operator.imatmul,
89    operator.ilshift: operator.lshift,
90    operator.irshift: operator.rshift,
91    operator.ipow: operator.pow,
92    operator.iand: operator.and_,
93    operator.ior: operator.or_,
94    operator.ixor: operator.xor,
95}
96
97
98class BuiltinVariable(VariableTracker):
99    _SENTINEL = object()
100    _nonvar_fields = {
101        "fn",
102        *VariableTracker._nonvar_fields,
103    }
104
105    @classmethod
106    def create_with_source(cls, value, source):
107        install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH))
108        return cls(value, source=source)
109
110    @staticmethod
111    @functools.lru_cache(None)
112    def _constant_fold_functions():
113        fns = {
114            abs,
115            all,
116            any,
117            bool,
118            callable,
119            chr,
120            divmod,
121            float,
122            getattr,
123            int,
124            len,
125            max,
126            min,
127            ord,
128            pow,
129            repr,
130            round,
131            str,
132            str.format,
133            sum,
134            type,
135            operator.abs,
136            operator.pos,
137            operator.neg,
138            operator.not_,
139            operator.truth,
140            operator.invert,
141            operator.pow,
142            operator.mul,
143            operator.matmul,
144            operator.floordiv,
145            operator.truediv,
146            operator.mod,
147            operator.add,
148            operator.sub,
149            operator.getitem,
150            operator.length_hint,
151            operator.lshift,
152            operator.rshift,
153            operator.and_,
154            operator.or_,
155            operator.xor,
156            operator.ipow,
157            operator.imul,
158            operator.imatmul,
159            operator.ifloordiv,
160            operator.itruediv,
161            operator.imod,
162            operator.iadd,
163            operator.isub,
164            operator.ilshift,
165            operator.irshift,
166            operator.iand,
167            operator.ixor,
168            operator.ior,
169            operator.index,
170        }
171        from .tensor import supported_comparison_ops
172
173        fns.update(supported_comparison_ops.values())
174        fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
175        return fns
176
177    def can_constant_fold_through(self):
178        return self.fn in self._constant_fold_functions()
179
180    @staticmethod
181    @functools.lru_cache(None)
182    def _fx_graph_functions():
183        fns = {
184            operator.abs,
185            operator.pos,
186            operator.neg,
187            operator.not_,
188            operator.invert,
189            operator.pow,
190            operator.mul,
191            operator.matmul,
192            operator.floordiv,
193            operator.truediv,
194            operator.mod,
195            operator.add,
196            operator.lt,
197            operator.gt,
198            operator.ge,
199            operator.le,
200            operator.ne,
201            operator.eq,
202            operator.sub,
203            operator.getitem,
204            operator.length_hint,
205            operator.lshift,
206            operator.rshift,
207            operator.and_,
208            operator.or_,
209            operator.xor,
210            operator.ipow,
211            operator.imul,
212            operator.imatmul,
213            operator.ifloordiv,
214            operator.itruediv,
215            operator.imod,
216            operator.iadd,
217            operator.isub,
218            operator.ilshift,
219            operator.irshift,
220            operator.iand,
221            operator.ixor,
222            operator.ior,
223        }
224        return fns
225
226    @staticmethod
227    @functools.lru_cache(None)
228    def _binops():
229        # function -> ([forward name, reverse name, in-place name], in-place op)
230        fns = {
231            operator.add: (["__add__", "__radd__", "__iadd__"], operator.iadd),
232            operator.sub: (["__sub__", "__rsub__", "__isub__"], operator.isub),
233            operator.mul: (["__mul__", "__rmul__", "__imul__"], operator.imul),
234            operator.truediv: (
235                ["__truediv__", "__rtruediv__", "__itruediv__"],
236                operator.itruediv,
237            ),
238            operator.floordiv: (
239                ["__floordiv__", "__rfloordiv__", "__ifloordiv__"],
240                operator.ifloordiv,
241            ),
242            operator.mod: (["__mod__", "__rmod__", "__imod__"], operator.imod),
243            pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow),
244            operator.pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow),
245            operator.lshift: (
246                ["__lshift__", "__rlshift__", "__ilshift__"],
247                operator.ilshift,
248            ),
249            operator.rshift: (
250                ["__rshift__", "__rrshift__", "__irshift__"],
251                operator.irshift,
252            ),
253            # NB: The follow binary operators are not supported for now, since the
254            # corresponding magic methods aren't defined on SymInt / SymFloat:
255            # operator.matmul
256            # divmod
257            # operator.and_
258            # operator.or_
259            # operator.xor
260        }
261        return fns
262
263    @staticmethod
264    @functools.lru_cache(None)
265    def _binop_handlers():
266        # Multiple dispatch mechanism defining custom binop behavior for certain type
267        # combinations. Handlers are attempted in order, and will be used if the type checks
268        # match. They are expected to have the signature:
269        # fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker
270        from .dicts import DictKeys, SetVariable
271        from .functions import BaseUserFunctionVariable, UserFunctionVariable
272        from .nn_module import NNModuleVariable
273        from .tensor import supported_const_comparison_ops
274        from .torch import BaseTorchVariable
275        from .user_defined import (
276            UserDefinedClassVariable,
277            UserDefinedObjectVariable,
278            UserDefinedVariable,
279        )
280
281        # Override table contains: op_fn -> [list of handlers]
282        op_handlers = {}
283        for (
284            op,
285            (magic_method_names, in_place_op),
286        ) in BuiltinVariable._binops().items():
287            op_handlers[op] = []
288            op_handlers[in_place_op] = []
289
290            forward_name, reverse_name, inplace_name = magic_method_names
291
292            # User-defined args (highest precedence)
293            def user_defined_handler(
294                tx,
295                a,
296                b,
297                *,
298                forward_name=forward_name,
299                reverse_name=reverse_name,
300            ):
301                # Manually handle reversing logic if needed (e.g. call __radd__)
302
303                # TODO: If we expand this to handle tensor args, we need to manually
304                # handle cases like this:
305                #
306                # class A(int):
307                #     def __radd__(self, other):
308                #         print("woof")
309                # torch.randn(3) + A(3)
310                #
311                # In this example, A.__radd__() is not called -> nothing is printed, because
312                # Tensor.__add__ only does a subtype test against int, ignoring the subclass.
313                # To be fully correct, we should not call A.__radd__() here, and there may be
314                # other cases to reason about and add exceptions for.
315                if isinstance(a, UserDefinedVariable):
316                    return a.call_method(tx, forward_name, [b], {})
317                else:
318                    return b.call_method(tx, reverse_name, [a], {})
319
320            op_handlers[op].append(
321                ((UserDefinedVariable, VariableTracker), user_defined_handler)
322            )
323            op_handlers[op].append(
324                ((VariableTracker, UserDefinedVariable), user_defined_handler)
325            )
326
327            def user_defined_inplace_handler(
328                tx: "InstructionTranslator", a, b, *, forward_name=inplace_name
329            ):
330                return a.call_method(tx, forward_name, [b], {})
331
332            op_handlers[in_place_op].append(
333                ((UserDefinedVariable, VariableTracker), user_defined_inplace_handler)
334            )
335            op_handlers[in_place_op].append(
336                ((VariableTracker, UserDefinedVariable), user_defined_inplace_handler)
337            )
338
339            # Dynamic shape args
340            def dynamic_handler(tx: "InstructionTranslator", a, b, *, fn=op):
341                from .builder import wrap_fx_proxy
342
343                return wrap_fx_proxy(
344                    tx,
345                    tx.output.create_proxy(
346                        "call_function", fn, *proxy_args_kwargs([a, b], {})
347                    ),
348                )
349
350            op_handlers[op].append(
351                ((SymNodeVariable, VariableTracker), dynamic_handler)
352            )
353            op_handlers[op].append(
354                ((VariableTracker, SymNodeVariable), dynamic_handler)
355            )
356
357            # NB: Prefer out-of-place op when calling in-place op to generate valid graph
358            op_handlers[in_place_op].append(
359                ((SymNodeVariable, VariableTracker), dynamic_handler)
360            )
361            op_handlers[in_place_op].append(
362                ((VariableTracker, SymNodeVariable), dynamic_handler)
363            )
364
365        # Special cases - lower precedence but still prefer these over constant folding
366
367        # List-like addition (e.g. [1, 2] + [3, 4])
368        def tuple_add_handler(tx: "InstructionTranslator", a, b):
369            return TupleVariable([*a.items, *b.unpack_var_sequence(tx)])
370
371        def size_add_handler(tx: "InstructionTranslator", a, b):
372            return SizeVariable([*a.items, *b.unpack_var_sequence(tx)])
373
374        list_like_addition_handlers = [
375            # NB: Prefer the tuple-specific logic over base logic because of
376            # some SizeVariable weirdness. Specifically, the tuple-specific logic
377            # drops the subclass type (e.g. SizeVariable) and returns TupleVariables.
378            (
379                (SizeVariable, SizeVariable),
380                size_add_handler,
381            ),
382            (
383                (TupleVariable, TupleVariable),
384                tuple_add_handler,
385            ),
386            (
387                (TupleVariable, ConstantVariable),
388                tuple_add_handler,
389            ),
390            (
391                (ConstantVariable, TupleVariable),
392                lambda tx, a, b: TupleVariable(
393                    [*a.unpack_var_sequence(tx), *b.items],
394                ),
395            ),
396            (
397                (
398                    ListVariable,
399                    (BaseListVariable, ConstantVariable, ListIteratorVariable),
400                ),
401                lambda tx, a, b: ListVariable(
402                    [*a.items, *b.unpack_var_sequence(tx)], mutable_local=MutableLocal()
403                ),
404            ),
405            (
406                (BaseListVariable, BaseListVariable),
407                lambda tx, a, b: type(a)([*a.items, *b.items]),
408            ),
409        ]
410        op_handlers[operator.add].extend(list_like_addition_handlers)
411
412        def list_iadd_handler(tx: "InstructionTranslator", a, b):
413            if not a.mutable_local or not b.has_unpack_var_sequence(tx):
414                # Handler doesn't apply
415                return None
416
417            seq = b.unpack_var_sequence(tx)
418            tx.output.side_effects.mutation(a)
419            a.items.extend(seq)
420            return a
421
422        list_like_iadd_handlers = [
423            (
424                (ListVariable, VariableTracker),
425                list_iadd_handler,
426            ),
427            (
428                (TupleVariable, TupleVariable),
429                tuple_add_handler,
430            ),
431            (
432                (TupleVariable, ConstantVariable),
433                tuple_add_handler,
434            ),
435        ]
436        op_handlers[operator.iadd].extend(list_like_iadd_handlers)
437
438        # List-like expansion (e.g. [1, 2, 3] * 3)
439        def expand_list_like(tx: "InstructionTranslator", lst, const):
440            if isinstance(lst, ConstantVariable):
441                lst, const = const, lst
442            return lst.__class__(
443                items=lst.items * const.as_python_constant(),
444                mutable_local=MutableLocal(),
445            )
446
447        list_like_expansion_handlers = [
448            ((ListVariable, ConstantVariable), expand_list_like),
449            ((TupleVariable, ConstantVariable), expand_list_like),
450            ((ConstantVariable, ListVariable), expand_list_like),
451            ((ConstantVariable, TupleVariable), expand_list_like),
452        ]
453        op_handlers[operator.mul].extend(list_like_expansion_handlers)
454
455        size_or_tuple = (SizeVariable, TupleVariable)
456        has_set_items = (SetVariable, DictKeys)
457
458        def create_cmp_op_handlers(op):
459            def compare_by_value(tx: "InstructionTranslator", a, b):
460                return ConstantVariable(op(a.value, b.value))
461
462            result = [((ConstantVariable, ConstantVariable), compare_by_value)]
463
464            if op in supported_const_comparison_ops.values():
465                # Tensor is None, List is not None, etc
466                none_result = op(object(), None)
467                if op.__name__.startswith("is_"):
468
469                    def never(tx: "InstructionTranslator", a, b):
470                        return ConstantVariable(none_result)
471
472                    obj_op_none = never
473                    none_op_obj = never
474                else:
475
476                    def obj_op_none(
477                        tx: "InstructionTranslator", a, b: ConstantVariable
478                    ):
479                        if b.value is None or b.value is True or b.value is False:
480                            return ConstantVariable(none_result)
481
482                    def none_op_obj(
483                        tx: "InstructionTranslator", a: ConstantVariable, b
484                    ):
485                        if a.value is None or a.value is True or a.value is False:
486                            return ConstantVariable(none_result)
487
488                types_that_are_never_none = (
489                    TensorVariable,
490                    SymNodeVariable,
491                    NNModuleVariable,
492                    BaseListVariable,
493                    UserDefinedVariable,
494                    BaseUserFunctionVariable,
495                    ConstDictVariable,
496                    BaseTorchVariable,
497                )
498                result.extend(
499                    [
500                        (
501                            (types_that_are_never_none, ConstantVariable),
502                            obj_op_none,
503                        ),
504                        (
505                            (ConstantVariable, types_that_are_never_none),
506                            none_op_obj,
507                        ),
508                    ]
509                )
510
511            def list_compare_nocheck(tx: "InstructionTranslator", left, right):
512                return BaseListVariable.list_compare(tx, op, left, right)
513
514            def list_compare_check(tx: "InstructionTranslator", left, right):
515                if type(left) is not type(
516                    right
517                ):  # Mismatch in BaseListVariable subclasses
518                    unimplemented(f"{op.__name__}({left}, {right})")
519                return BaseListVariable.list_compare(tx, op, left, right)
520
521            def compare_set_items(tx: "InstructionTranslator", left, right):
522                return ConstantVariable(op(left.set_items, right.set_items))
523
524            def compare_via_method(tx: "InstructionTranslator", left, right):
525                return left.call_method(tx, f"__{op.__name__}__", [right], {})
526
527            if op.__name__.startswith("is_"):
528                compare_user_defined = compare_by_value
529            else:
530                compare_user_defined = compare_via_method
531
532            op_var = BuiltinVariable(op)
533            result.extend(
534                [
535                    (
536                        (
537                            (UserFunctionVariable, BuiltinVariable),
538                            (UserFunctionVariable, BuiltinVariable),
539                        ),
540                        lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)),
541                    ),
542                    (
543                        (
544                            NNModuleVariable,
545                            NNModuleVariable,
546                        ),
547                        lambda tx, a, b: ConstantVariable(
548                            op(
549                                tx.output.get_submodule(a.module_key),
550                                tx.output.get_submodule(b.module_key),
551                            )
552                        ),
553                    ),
554                    ((size_or_tuple, size_or_tuple), list_compare_nocheck),
555                    (
556                        (variables.BaseListVariable, variables.BaseListVariable),
557                        list_compare_check,
558                    ),
559                    ((has_set_items, has_set_items), compare_set_items),
560                    (
561                        (UserDefinedObjectVariable, UserDefinedObjectVariable),
562                        compare_user_defined,
563                    ),
564                    (
565                        (UserDefinedClassVariable, UserDefinedClassVariable),
566                        compare_user_defined,
567                    ),
568                    (
569                        (
570                            (StreamVariable, EventVariable, ConstantVariable),
571                            (StreamVariable, EventVariable, ConstantVariable),
572                        ),
573                        compare_by_value,
574                    ),
575                    (
576                        (TensorVariable, VariableTracker),
577                        op_var._comparison_with_tensor,
578                    ),
579                    (
580                        (VariableTracker, TensorVariable),
581                        op_var._comparison_with_tensor,
582                    ),
583                    (
584                        (SymNodeVariable, VariableTracker),
585                        op_var._comparison_with_symnode,
586                    ),
587                    (
588                        (VariableTracker, SymNodeVariable),
589                        op_var._comparison_with_symnode,
590                    ),
591                ]
592            )
593
594            if op.__name__.startswith("is_"):
595
596                def handle_is(tx: "InstructionTranslator", left, right):
597                    # If the two objects are of different type, we can safely return False
598                    # and True for `is` and `is not`, respectively
599                    if type(left) is not type(right):
600                        return ConstantVariable.create(op.__name__ != "is_")
601
602                result.append(((VariableTracker, VariableTracker), handle_is))
603
604            return result
605
606        for op in supported_comparison_ops.values():
607            assert callable(op)
608            assert op not in op_handlers
609            op_handlers[op] = create_cmp_op_handlers(op)
610
611        return op_handlers
612
613    @staticmethod
614    def _find_binop_handler(op, a_type, b_type):
615        handlers = BuiltinVariable._binop_handlers().get(op)
616        if handlers is None:
617            return None
618
619        matches = []
620        for (type1, type2), handler in handlers:
621            if issubclass(a_type, type1) and issubclass(b_type, type2):
622                matches.append(handler)
623        return matches
624
625    def can_insert_in_graph(self):
626        return self.fn in self._fx_graph_functions()
627
628    def __init__(self, fn, **kwargs) -> None:
629        super().__init__(**kwargs)
630        self.fn = fn
631
632    def __str__(self) -> str:
633        if self.fn is None:
634            name = "None"
635        else:
636            name = self.fn.__name__
637
638        return f"{self.__class__.__name__}({name})"
639
640    def as_python_constant(self):
641        return self.fn
642
643    def as_proxy(self):
644        DTYPE = {
645            bool: torch.bool,
646            int: torch.int64,
647            float: torch.float64,
648        }
649        if self.fn in DTYPE:
650            return DTYPE[self.fn]
651        return super().as_proxy()
652
653    def reconstruct(self, codegen):
654        name = self.fn.__name__
655        assert self.fn.__module__ == "builtins"
656        assert name not in codegen.tx.f_globals, "shadowed global"
657        codegen.append_output(codegen.create_load_global(name, False, add=True))
658
659    def constant_args(self, *args, **kwargs):
660        return check_constant_args(args, kwargs)
661
662    def tensor_args(self, *args):
663        any_tensor = False
664        for arg in args:
665            if isinstance(arg, variables.GetAttrVariable):
666                return False
667            any_tensor = any_tensor or isinstance(arg, variables.TensorVariable)
668        return any_tensor
669
670    def tensor_args_type(self, arg_types):
671        any_tensor = False
672        for arg_type in arg_types:
673            if issubclass(arg_type, variables.GetAttrVariable):
674                return False
675            any_tensor = any_tensor or issubclass(arg_type, variables.TensorVariable)
676        return any_tensor
677
678    def python_and_tensor_constant_only(self, *args, **kwargs):
679        tensor_args = []
680        non_tensor_args = []
681        for i in itertools.chain(args, kwargs.values()):
682            if isinstance(i, variables.TensorVariable):
683                tensor_args.append(i)
684            else:
685                non_tensor_args.append(i)
686        return all(
687            is_constant_source(t.source) if t.source is not None else False
688            for t in tensor_args
689        ) and self.constant_args(*non_tensor_args)
690
691    @staticmethod
692    def unwrap_unspec_args_kwargs(args, kwargs):
693        return [x.as_python_constant() for x in args], {
694            k: v.as_python_constant() for k, v in kwargs.items()
695        }
696
697    def has_constant_handler(self, args, kwargs):
698        return self.can_constant_fold_through() and check_unspec_or_constant_args(
699            args, kwargs
700        )
701
702    @staticmethod
703    def _make_handler(fn, arg_types: List[type], has_kwargs: bool):
704        from .builder import SourcelessBuilder
705        from .lazy import LazyVariableTracker
706
707        obj = BuiltinVariable(fn)
708        handlers = []
709
710        if any(issubclass(t, LazyVariableTracker) for t in arg_types):
711            return lambda tx, args, kwargs: obj.call_function(
712                tx, [v.realize() for v in args], kwargs
713            )
714
715        if inspect.isclass(fn) and issubclass(fn, Exception):
716
717            def create_exception_class_object(
718                tx: "InstructionTranslator", args, kwargs
719            ):
720                if fn is AssertionError and not all(
721                    isinstance(x, variables.ConstantVariable)
722                    and isinstance(x.value, str)
723                    for x in args
724                ):
725                    unimplemented("assert with non-string message")
726
727                return variables.ExceptionVariable(fn, args, **kwargs)
728
729            return create_exception_class_object
730
731        if obj.can_insert_in_graph() and not (
732            fn is operator.getitem
733            and not issubclass(arg_types[0], variables.TensorVariable)
734        ):
735            if obj.tensor_args_type(arg_types):
736                return obj._handle_insert_op_in_graph
737            elif has_kwargs:
738                # need runtime check for kwargs
739                handlers.append(obj._handle_insert_op_in_graph)
740
741        # Handle binary ops (e.g. __add__ / __radd__, __iadd__, etc.)
742        # NB: Tensor args are handled above and not here
743        if len(arg_types) == 2 and not has_kwargs:
744            # Try to find a handler for the arg types; otherwise, fall through to constant handler
745            binop_handlers = BuiltinVariable._find_binop_handler(fn, *arg_types)
746            if not binop_handlers:
747                pass
748            elif len(binop_handlers) == 1:
749                (binop_handler,) = binop_handlers
750                handlers.append(lambda tx, args, _: binop_handler(tx, *args))
751            else:
752
753                def call_binop_handlers(tx: "InstructionTranslator", args, _):
754                    for fn in binop_handlers:
755                        rv = fn(tx, *args)
756                        if rv:
757                            return rv
758
759                handlers.append(call_binop_handlers)
760
761        self_handler = getattr(obj, f"call_{fn.__name__}", None)
762        if self_handler:
763
764            def call_self_handler(tx: "InstructionTranslator", args, kwargs):
765                try:
766                    result = self_handler(tx, *args, **kwargs)
767                    if result is not None:
768                        return result
769                except TypeError:
770                    # Check if binding is bad. inspect signature bind is expensive.
771                    # So check only when handler call fails.
772                    try:
773                        inspect.signature(self_handler).bind(tx, *args, **kwargs)
774                    except TypeError as e:
775                        has_constant_handler = obj.has_constant_handler(args, kwargs)
776                        if not has_constant_handler:
777                            log.warning(
778                                "incorrect arg count %s %s and no constant handler",
779                                self_handler,
780                                e,
781                            )
782                            unimplemented(
783                                f"invalid handler args {self_handler} {args} {kwargs}"
784                            )
785                    else:
786                        raise
787                except Unsupported as exc:
788                    has_constant_handler = obj.has_constant_handler(args, kwargs)
789                    if not has_constant_handler:
790                        raise
791                    # Actually, we will handle this just fine
792                    exc.remove_from_stats()
793
794            handlers.append(call_self_handler)
795
796        if obj.can_constant_fold_through():
797            builder = SourcelessBuilder.create
798
799            if (
800                all(issubclass(x, ConstantVariable) for x in arg_types)
801                and not has_kwargs
802            ):
803
804                def constant_fold_handler(tx: "InstructionTranslator", args, kwargs):
805                    # fast path
806                    try:
807                        res = fn(
808                            *[x.as_python_constant() for x in args],
809                        )
810                    except Exception as exc:
811                        unimplemented(f"constant fold exception: {repr(exc)}")
812                    return builder(tx, res)
813
814            else:
815
816                def constant_fold_handler(tx: "InstructionTranslator", args, kwargs):
817                    # path with a runtime check
818                    if check_unspec_or_constant_args(args, kwargs):
819                        try:
820                            res = fn(
821                                *[x.as_python_constant() for x in args],
822                                **{
823                                    k: v.as_python_constant() for k, v in kwargs.items()
824                                },
825                            )
826                        except Exception as exc:
827                            unimplemented(f"constant fold exception: {repr(exc)}")
828                        return builder(tx, res)
829
830            handlers.append(constant_fold_handler)
831
832        error_msg = f"builtin: {fn.__name__} {arg_types} {has_kwargs}"
833        if len(handlers) == 0:
834            return lambda *args: unimplemented(error_msg)
835        elif len(handlers) == 1:
836            (handler,) = handlers
837
838            def builtin_dispatch(tx: "InstructionTranslator", args, kwargs):
839                rv = handler(tx, args, kwargs)
840                if rv:
841                    return rv
842                unimplemented(error_msg)
843
844        else:
845
846            def builtin_dispatch(tx: "InstructionTranslator", args, kwargs):
847                for fn in handlers:
848                    rv = fn(tx, args, kwargs)
849                    if rv:
850                        return rv
851                unimplemented(error_msg)
852
853        return builtin_dispatch
854
855    def _handle_insert_op_in_graph(self, tx: "InstructionTranslator", args, kwargs):
856        from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
857
858        if kwargs and not self.tensor_args(*args, *kwargs.values()):
859            return
860
861        fn = self.fn
862        try:
863            # Constant fold for constant tensor and python constants
864            if self.python_and_tensor_constant_only(*args, **kwargs):
865                from ..bytecode_transformation import unique_id
866                from .functions import invoke_and_store_as_constant
867
868                return invoke_and_store_as_constant(
869                    tx, fn, unique_id(fn.__name__), args, kwargs
870                )
871
872            if fn in IN_PLACE_DESUGARING_MAP and isinstance(
873                args[0], variables.ConstantVariable
874            ):
875                # In-place operators like += usually mustate tensor
876                # values, but in the edge case of immutable values they
877                # re-bind the variable.
878                #
879                # The easiest way to keep the graph consistent in this
880                # scenario is to de-sugar eagerly.
881                fn, args = IN_PLACE_DESUGARING_MAP[fn], [args[0], args[1]]
882
883            if fn is operator.getitem and isinstance(args[1], SymNodeVariable):
884                # Standard indexing will force specialization due to
885                # __index__.  Rewrite as a regular torch op which will
886                # trace fine
887                fn, args = torch.select, [
888                    args[0],
889                    variables.ConstantVariable.create(0),
890                    args[1],
891                ]
892
893            # Interaction between ndarray and tensors:
894            #   We prefer the tensor op whenever there are tensors involved
895            if check_numpy_ndarray_args(args, kwargs) and not any(
896                type(arg) == variables.TensorVariable for arg in args
897            ):
898                proxy = tx.output.create_proxy(
899                    "call_function",
900                    numpy_operator_wrapper(fn),
901                    *proxy_args_kwargs(args, kwargs),
902                )
903
904                return wrap_fx_proxy_cls(variables.NumpyNdarrayVariable, tx, proxy)
905
906            proxy = tx.output.create_proxy(
907                "call_function",
908                fn,
909                *proxy_args_kwargs(args, kwargs),
910            )
911            if any(isinstance(arg, FakeItemVariable) for arg in args):
912                return wrap_fx_proxy_cls(
913                    FakeItemVariable,
914                    tx,
915                    proxy,
916                )
917            elif check_unspec_python_args(args, kwargs):
918                _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs)
919                raw_value = fn(*_args, **_kwargs)
920
921                need_unwrap = any(
922                    x.need_unwrap
923                    for x in itertools.chain(args, kwargs.values())
924                    if isinstance(x, variables.UnspecializedPythonVariable)
925                )
926
927                return wrap_fx_proxy_cls(
928                    UnspecializedPythonVariable,
929                    tx,
930                    proxy,
931                    raw_value=raw_value,
932                    need_unwrap=need_unwrap,
933                )
934            elif all(isinstance(x, SymNodeVariable) for x in args):
935                return SymNodeVariable.create(tx, proxy, None)
936            else:
937                # Work around for vision_maskrcnn due to precision difference
938                # specialize the dividend when float divide by tensor
939                if fn is operator.truediv and isinstance(
940                    args[0], variables.UnspecializedPythonVariable
941                ):
942                    args[0] = args[0].convert_to_constant(tx)
943                return wrap_fx_proxy(tx, proxy)
944
945        except NotImplementedError:
946            unimplemented(f"partial tensor op: {self} {args} {kwargs}")
947
948    call_function_handler_cache = {}
949
950    def call_function(
951        self,
952        tx: "InstructionTranslator",
953        args: "List[VariableTracker]",
954        kwargs: "Dict[str, VariableTracker]",
955    ) -> "VariableTracker":
956        if kwargs:
957            kwargs = {k: v.realize() for k, v in kwargs.items()}
958            key = (self.fn, *(type(x) for x in args), True)
959        else:
960            key = (self.fn, *(type(x) for x in args))
961
962        handler = self.call_function_handler_cache.get(key)
963        if not handler:
964            self.call_function_handler_cache[key] = handler = self._make_handler(
965                self.fn, [type(x) for x in args], bool(kwargs)
966            )
967        return handler(tx, args, kwargs)
968
969    def call_method(
970        self,
971        tx,
972        name,
973        args: "List[VariableTracker]",
974        kwargs: "Dict[str, VariableTracker]",
975    ) -> "VariableTracker":
976        if self.fn is object and name == "__setattr__":
977            assert len(args) == 3
978            assert len(kwargs) == 0
979            obj, name_var, val = args
980            obj = obj.realize()
981            if (
982                isinstance(obj, UserDefinedObjectVariable)
983                and tx.output.side_effects.is_attribute_mutation(obj)
984                and name_var.is_python_constant()
985            ):
986                return obj.method_setattr_standard(tx, name_var, val)
987        if self.fn is object and name == "__new__":
988            assert len(args) == 1
989            assert len(kwargs) == 0
990            return tx.output.side_effects.track_object_new_from_user_defined_class(
991                args[0]
992            )
993        if self.fn is dict and name == "fromkeys":
994            return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
995        return super().call_method(tx, name, args, kwargs)
996
997    def _call_int_float(self, tx: "InstructionTranslator", arg):
998        # Handle cases like int(torch.seed())
999        # Also handle sym_float to sym_int cases
1000        if isinstance(arg, (SymNodeVariable, variables.TensorVariable)):
1001            if isinstance(arg, variables.TensorVariable):
1002                item = arg.call_method(tx, "item", [], {})
1003            else:
1004                item = arg
1005            fn_ = sym_int if self.fn is int else sym_float
1006            from torch._dynamo.variables.builder import wrap_fx_proxy
1007
1008            return wrap_fx_proxy(
1009                tx=tx,
1010                proxy=tx.output.create_proxy(
1011                    "call_function",
1012                    fn_,
1013                    (item.as_proxy(),),
1014                    {},
1015                ),
1016            )
1017
1018    call_int = _call_int_float
1019    call_float = _call_int_float
1020
1021    def call_str(self, tx: "InstructionTranslator", arg):
1022        # Handle `str` on a user defined function or object
1023        if isinstance(arg, (variables.UserFunctionVariable)):
1024            return variables.ConstantVariable.create(value=str(arg.fn))
1025        elif isinstance(arg, (variables.UserDefinedObjectVariable)):
1026            # Check if object has __str__ method
1027            if hasattr(arg.value, "__str__"):
1028                str_method = arg.value.__str__
1029            elif hasattr(arg.value, "__repr__"):
1030                # account for __repr__ functions when __str__ is absent
1031                str_method = arg.value.__repr__
1032            else:
1033                unimplemented("user defined object has no __str__ or __repr__ method")
1034
1035            if type(arg.value).__str__ is object.__str__:
1036                # Rely on the object str method
1037                try:
1038                    return variables.ConstantVariable.create(value=str_method())
1039                except AttributeError:
1040                    # Graph break
1041                    return
1042            elif is_wrapper_or_member_descriptor(str_method):
1043                unimplemented(f"{type(arg.value)} has a C/C++ based str method")
1044            else:
1045                # Overrides for custom str method
1046                # Pass method as function to call tx.inline_user_function_return
1047                bound_method = str_method.__func__
1048
1049                try:
1050                    # Only supports certain function types
1051                    user_func_variable = variables.UserFunctionVariable(bound_method)
1052                except AssertionError as e:
1053                    # Won't be able to do inline the str method, return to avoid graph break
1054                    log.warning("Failed to create UserFunctionVariable: %s", e)
1055                    return
1056
1057                # Inline the user function
1058                return tx.inline_user_function_return(user_func_variable, [arg], {})
1059
1060    def _call_min_max(self, tx: "InstructionTranslator", *args):
1061        if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
1062            items = args[0].force_unpack_var_sequence(tx)
1063            return self._call_min_max_seq(tx, items)
1064        elif len(args) == 2:
1065            return self._call_min_max_binary(tx, args[0], args[1])
1066        elif len(args) > 2:
1067            return self._call_min_max_seq(tx, args)
1068
1069    def _call_min_max_seq(self, tx: "InstructionTranslator", items):
1070        assert len(items) > 0
1071        if len(items) == 1:
1072            return items[0]
1073
1074        return functools.reduce(functools.partial(self._call_min_max_binary, tx), items)
1075
1076    def _call_min_max_binary(self, tx: "InstructionTranslator", a, b):
1077        if a is None or b is None:
1078            # a or b could be none if we reduce and _call_min_max_binary failed
1079            # to return something
1080            return
1081        if self.tensor_args(a, b):
1082            if not isinstance(a, variables.TensorVariable):
1083                a, b = b, a
1084            assert isinstance(a, variables.TensorVariable)
1085
1086            # result of an item call is a scalar convert to a tensor
1087            if isinstance(a, FakeItemVariable):
1088                a = variables.TorchInGraphFunctionVariable(torch.tensor).call_function(
1089                    tx, [a], {}
1090                )
1091
1092            # Dynamic input does not get resolved, rather, gets stored as call_function
1093            if isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable):
1094                from .builder import wrap_fx_proxy_cls
1095
1096                return wrap_fx_proxy_cls(
1097                    type(a),
1098                    tx=tx,
1099                    proxy=tx.output.create_proxy(
1100                        "call_function",
1101                        self.fn,
1102                        *proxy_args_kwargs([a, b], {}),
1103                    ),
1104                )
1105
1106            # convert min/max to torch ops
1107            if b.is_python_constant():
1108                if isinstance(a, variables.NumpyNdarrayVariable):
1109                    import numpy as np
1110
1111                    fn = variables.NumpyVariable(np.clip)
1112                else:
1113                    fn = variables.TorchInGraphFunctionVariable(torch.clamp)
1114                kwargs = {"min": b} if (self.fn is max) else {"max": b}
1115                result = fn.call_function(tx, [a], kwargs)
1116            else:
1117                if isinstance(a, variables.NumpyNdarrayVariable):
1118                    import numpy as np
1119
1120                    fn = {max: np.maximum, min: np.minimum}[self.fn]
1121                    fn = variables.NumpyVariable(fn)
1122                else:
1123                    fn = {max: torch.maximum, min: torch.minimum}[self.fn]
1124                    fn = variables.TorchInGraphFunctionVariable(fn)
1125                result = fn.call_function(tx, [a, b], {})
1126
1127            # return unspec if both a, b are unspec or const
1128            if all(
1129                isinstance(
1130                    i,
1131                    (
1132                        variables.UnspecializedPythonVariable,
1133                        variables.ConstantVariable,
1134                    ),
1135                )
1136                for i in [a, b]
1137            ):
1138                if any(isinstance(val, FakeItemVariable) for val in [a, b]):
1139                    return variables.FakeItemVariable.from_tensor_variable(result)
1140
1141                if b.is_python_constant():
1142                    raw_b = b.as_python_constant()
1143                else:
1144                    raw_b = b.raw_value
1145                if self.fn is max:
1146                    raw_res = max(a.raw_value, raw_b)
1147                else:
1148                    raw_res = min(a.raw_value, raw_b)
1149
1150                need_unwrap = any(
1151                    x.need_unwrap
1152                    for x in [a, b]
1153                    if isinstance(x, variables.UnspecializedPythonVariable)
1154                )
1155                return variables.UnspecializedPythonVariable.from_tensor_variable(
1156                    result, raw_res, need_unwrap
1157                )
1158            # otherwise return tensor
1159            else:
1160                return result
1161        elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable):
1162            fn = torch.sym_max if self.fn is max else torch.sym_min
1163            proxy = tx.output.create_proxy(
1164                "call_function", fn, *proxy_args_kwargs([a, b], {})
1165            )
1166            return SymNodeVariable.create(tx, proxy, None)
1167
1168    call_min = _call_min_max
1169    call_max = _call_min_max
1170
1171    def call_abs(self, tx: "InstructionTranslator", arg: "VariableTracker"):
1172        # Call arg.__abs__()
1173        abs_method = BuiltinVariable(getattr).call_function(
1174            tx, [arg, ConstantVariable.create("__abs__")], {}
1175        )
1176        return abs_method.call_function(tx, [], {})
1177
1178    def call_pos(self, tx: "InstructionTranslator", arg: "VariableTracker"):
1179        # Call arg.__pos__()
1180        pos_method = BuiltinVariable(getattr).call_function(
1181            tx, [arg, ConstantVariable.create("__pos__")], {}
1182        )
1183        return pos_method.call_function(tx, [], {})
1184
1185    def call_index(self, tx: "InstructionTranslator", arg: "VariableTracker"):
1186        if isinstance(arg, variables.TensorVariable):
1187            unimplemented("unsupported index(tensor)")
1188
1189        arg = guard_if_dyn(arg)
1190        constant_value = operator.index(arg)
1191        return variables.ConstantVariable.create(constant_value)
1192
1193    def call_round(self, tx: "InstructionTranslator", arg, *args, **kwargs):
1194        # Call arg.__round__()
1195        round_method = BuiltinVariable(getattr).call_function(
1196            tx, [arg, ConstantVariable.create("__round__")], {}
1197        )
1198        return round_method.call_function(tx, args, kwargs)
1199
1200    def call_range(self, tx: "InstructionTranslator", *args):
1201        if check_unspec_or_constant_args(args, {}):
1202            return variables.RangeVariable(args)
1203        elif self._dynamic_args(*args):
1204            args = [
1205                variables.ConstantVariable.create(guard_if_dyn(arg)) for arg in args
1206            ]
1207            return variables.RangeVariable(args)
1208        # None no-ops this handler and lets the driving function proceed
1209        return None
1210
1211    def _dynamic_args(self, *args, **kwargs):
1212        return any(isinstance(x, SymNodeVariable) for x in args) or any(
1213            isinstance(x, SymNodeVariable) for x in kwargs.values()
1214        )
1215
1216    def call_slice(self, tx: "InstructionTranslator", *args):
1217        return variables.SliceVariable(args)
1218
1219    def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs):
1220        from .builder import wrap_fx_proxy
1221
1222        return wrap_fx_proxy(
1223            tx,
1224            tx.output.create_proxy(
1225                "call_function", self.fn, *proxy_args_kwargs(args, kwargs)
1226            ),
1227        )
1228
1229    # NOTE must handle IteratorVariable separately!
1230    def _call_iter_tuple_list(
1231        self, tx: "InstructionTranslator", obj=None, *args, **kwargs
1232    ):
1233        assert not isinstance(obj, variables.IteratorVariable)
1234
1235        if self._dynamic_args(*args, **kwargs):
1236            return self._dyn_proxy(tx, *args, **kwargs)
1237
1238        cls = variables.BaseListVariable.cls_for(self.fn)
1239        if obj is None:
1240            return cls(
1241                [],
1242                mutable_local=MutableLocal(),
1243            )
1244        elif obj.has_unpack_var_sequence(tx):
1245            if obj.source and not is_constant_source(obj.source):
1246                if isinstance(obj, TupleIteratorVariable):
1247                    install_guard(
1248                        obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN)
1249                    )
1250                else:
1251                    if (
1252                        getattr(obj, "source", False)
1253                        and isinstance(obj, ConstDictVariable)
1254                        and not istype(obj, SetVariable)
1255                    ):
1256                        tx.output.guard_on_key_order.add(obj.source.name())
1257
1258                    install_guard(obj.source.make_guard(GuardBuilder.SEQUENCE_LENGTH))
1259
1260            return cls(
1261                list(obj.unpack_var_sequence(tx)),
1262                mutable_local=MutableLocal(),
1263            )
1264
1265    def _call_tuple_list(self, tx, obj=None, *args, **kwargs):
1266        if isinstance(obj, variables.IteratorVariable):
1267            cls = variables.BaseListVariable.cls_for(self.fn)
1268            return cls(
1269                list(obj.force_unpack_var_sequence(tx)),
1270                mutable_local=MutableLocal(),
1271            )
1272        else:
1273            return self._call_iter_tuple_list(tx, obj, *args, **kwargs)
1274
1275    def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs):
1276        if isinstance(obj, variables.IteratorVariable):
1277            ret = obj
1278        else:
1279            # Handle the case where we are iterating over a tuple, list or iterator
1280            ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
1281
1282        if ret is None:
1283            # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway.
1284            # If the object implements a __iter__ method, inlining effectively forwards the call to another iter call
1285            # (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator.
1286            return obj.call_method(tx, "__iter__", args, kwargs)
1287        return ret
1288
1289    call_tuple = _call_tuple_list
1290    call_list = _call_tuple_list
1291
1292    def call_callable(self, tx: "InstructionTranslator", arg):
1293        from .functions import BaseUserFunctionVariable
1294        from .nn_module import NNModuleVariable
1295
1296        if isinstance(
1297            arg,
1298            (
1299                variables.UserDefinedClassVariable,
1300                BaseUserFunctionVariable,
1301                NNModuleVariable,
1302            ),
1303        ):
1304            return variables.ConstantVariable.create(True)
1305        elif isinstance(arg, UserDefinedVariable):
1306            return variables.ConstantVariable.create(callable(arg.value))
1307        elif isinstance(
1308            arg,
1309            (
1310                ConstantVariable,
1311                SymNodeVariable,
1312                TensorVariable,
1313                ListVariable,
1314                TupleVariable,
1315                ListIteratorVariable,
1316            ),
1317        ):
1318            return variables.ConstantVariable.create(False)
1319
1320    def call_cast(self, _, *args, **kwargs):
1321        if len(args) == 2:
1322            return args[1]
1323
1324        unimplemented(f"unsupported args to builtin cast(): {args} {kwargs}")
1325
1326    def call_dict(self, tx: "InstructionTranslator", *args, **kwargs):
1327        return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs)
1328
1329    @staticmethod
1330    def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs):
1331        from .builder import SourcelessBuilder
1332
1333        if not kwargs:
1334            if not args:
1335                args = ({},)
1336            assert len(args) == 1
1337            arg = args[0]
1338            if isinstance(arg, dict):
1339                return ConstDictVariable(arg, user_cls, mutable_local=MutableLocal())
1340            elif isinstance(arg, variables.ConstDictVariable):
1341                return arg.clone(user_cls=user_cls, mutable_local=MutableLocal())
1342            elif isinstance(
1343                arg,
1344                (
1345                    ListVariable,
1346                    TupleVariable,
1347                    ListIteratorVariable,
1348                    variables.IteratorVariable,
1349                ),
1350            ):
1351                items = dict(
1352                    x.force_unpack_var_sequence(tx)
1353                    for x in arg.force_unpack_var_sequence(tx)
1354                )
1355                return ConstDictVariable(items, user_cls, mutable_local=MutableLocal())
1356            elif isinstance(arg, variables.MutableMappingVariable):
1357                # This is applicable for user defined objects which seem like dict, but are not really dicts. For
1358                # example, TensorDict derives from MutableMapping. For such cases, we can directly inline the .items
1359                # method and create a new dict.
1360                if does_not_override_dict_iter_methods(type(arg.value)):
1361                    # These are implemeted in C, so we will have to manually construct the items
1362
1363                    if tx.output.side_effects.has_pending_mutation(arg):
1364                        unimplemented(
1365                            f"{user_cls.__name__}.items(): {args} {kwargs} - object is mutated"
1366                        )
1367
1368                    new_dict = dict(arg.value.items())
1369                    return SourcelessBuilder.create(tx, new_dict)
1370                else:
1371                    func_var = arg.var_getattr(tx, "items")
1372                    if not isinstance(func_var, variables.UserFunctionVariable):
1373                        unimplemented(f"{user_cls.__name__}.items(): {args} {kwargs}")
1374                    out = tx.inline_user_function_return(func_var, args, kwargs)
1375                    if isinstance(out, ConstDictVariable):
1376                        return out
1377                    return BuiltinVariable(user_cls).call_custom_dict(tx, user_cls, out)
1378        elif not args and kwargs:
1379            items = {ConstantVariable.create(k): v for k, v in kwargs.items()}
1380            return variables.ConstDictVariable(
1381                items, user_cls=user_cls, mutable_local=MutableLocal()
1382            )
1383        unimplemented(f"{user_cls.__name__}(): {args} {kwargs}")
1384
1385    @staticmethod
1386    def call_custom_dict_fromkeys(
1387        tx: "InstructionTranslator", user_cls, *args, **kwargs
1388    ):
1389        assert user_cls in {dict, OrderedDict, defaultdict}
1390        if kwargs:
1391            # Only `OrderedDict.fromkeys` accepts `value` passed by keyword
1392            assert user_cls is OrderedDict
1393            assert len(args) == 1 and len(kwargs) == 1 and "value" in kwargs
1394            args = (*args, kwargs.pop("value"))
1395        if len(args) == 0:
1396            raise UserError(TypeError, "fromkeys expected at least 1 argument, got 0")
1397        if len(args) == 1:
1398            args = (*args, ConstantVariable.create(None))
1399        assert len(args) == 2
1400        arg, value = args
1401        DictVariableType = (
1402            ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable
1403        )
1404
1405        if isinstance(arg, dict):
1406            arg = [ConstantVariable.create(k) for k in arg.keys()]
1407            return DictVariableType(
1408                dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal()
1409            )
1410        elif arg.has_force_unpack_var_sequence(tx):
1411            keys = arg.force_unpack_var_sequence(tx)
1412            if all(is_hashable(v) for v in keys):
1413                return DictVariableType(
1414                    dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal()
1415                )
1416        unimplemented(f"{user_cls.__name__}.fromkeys(): {args} {kwargs}")
1417
1418    def call_set(self, tx: "InstructionTranslator", *args, **kwargs):
1419        # Can we merge this implementation and call_dict's one?
1420        assert not kwargs
1421        if not args:
1422            return SetVariable([], mutable_local=MutableLocal())
1423        assert len(args) == 1
1424        arg = args[0]
1425        if isinstance(arg, variables.SetVariable):
1426            return arg.clone(mutable_local=MutableLocal())
1427        elif arg.has_force_unpack_var_sequence(tx):
1428            items = arg.force_unpack_var_sequence(tx)
1429            return SetVariable(items, mutable_local=MutableLocal())
1430        elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance(
1431            arg.value, KeysView
1432        ):
1433            iter_fn = arg.var_getattr(tx, "__iter__")
1434            if isinstance(iter_fn, variables.UserMethodVariable):
1435                out = tx.inline_user_function_return(iter_fn, args, kwargs)
1436                if isinstance(out, SetVariable):
1437                    return out
1438                return BuiltinVariable(set).call_set(tx, out)
1439            else:
1440                unimplemented(f"set(): {args} {kwargs}")
1441        else:
1442            unimplemented(f"set(): {args} {kwargs}")
1443
1444    def call_frozenset(self, tx: "InstructionTranslator", *args, **kwargs):
1445        assert not kwargs
1446        if not args:
1447            return FrozensetVariable([])
1448        assert len(args) == 1
1449        arg = args[0]
1450        if isinstance(arg, variables.FrozensetVariable):
1451            return FrozensetVariable([x.vt for x in arg.set_items])
1452        elif arg.has_unpack_var_sequence(tx):
1453            items = arg.unpack_var_sequence(tx)
1454            return FrozensetVariable(items)
1455        else:
1456            unimplemented(f"frozenset(): {args} {kwargs}")
1457
1458    def call_zip(self, tx: "InstructionTranslator", *args, **kwargs):
1459        if kwargs:
1460            assert len(kwargs) == 1 and "strict" in kwargs
1461        strict = kwargs.pop("strict", False)
1462        args = [
1463            arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg
1464            for arg in args
1465        ]
1466        return variables.ZipVariable(args, strict=strict, mutable_local=MutableLocal())
1467
1468    def call_len(self, tx: "InstructionTranslator", *args, **kwargs):
1469        return args[0].call_method(tx, "__len__", args[1:], kwargs)
1470
1471    def call_getitem(self, tx: "InstructionTranslator", *args, **kwargs):
1472        return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
1473
1474    def call_isinstance(self, tx: "InstructionTranslator", arg, isinstance_type):
1475        try:
1476            arg_type = arg.python_type()
1477        except NotImplementedError:
1478            unimplemented(
1479                f"isinstance({arg}, {isinstance_type}): can't determine type of {arg}"
1480            )
1481
1482        isinstance_type = isinstance_type.as_python_constant()
1483
1484        if isinstance(arg, variables.TensorVariable) and arg.dtype is not None:
1485
1486            def _tensor_isinstance(tensor_var, tensor_type):
1487                def check_type(ty):
1488                    if ty not in tensortype_to_dtype:
1489                        example_val = arg.as_proxy().node.meta["example_value"]
1490                        if (
1491                            is_traceable_wrapper_subclass(example_val)
1492                            and ty is torch.nn.parameter.Parameter
1493                        ):
1494                            # N.B: we are calling isinstance directly on the example value.
1495                            # torch.nn.Parameter has a meta-class that overrides __isinstance__,
1496                            # the isinstance check here allows us to invoke that logic.
1497                            return isinstance(example_val, ty)
1498                        else:
1499                            return issubclass(arg.python_type(), ty)
1500
1501                    dtypes = tensortype_to_dtype[ty]
1502                    return arg.dtype in dtypes
1503
1504                if type(tensor_type) is tuple:
1505                    return any(check_type(ty) for ty in tensor_type)
1506                else:
1507                    return check_type(tensor_type)
1508
1509            return variables.ConstantVariable.create(
1510                _tensor_isinstance(arg, isinstance_type)
1511            )
1512        # UserDefinedObject with C extensions can have torch.Tensor attributes,
1513        # so break graph.
1514        if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance(
1515            arg.value, types.MemberDescriptorType
1516        ):
1517            unimplemented(
1518                f"isinstance called on UserDefinedClass {arg} {isinstance_type}"
1519            )
1520        # handle __instancecheck__ defined in user class
1521        if (
1522            isinstance(arg, variables.UserDefinedObjectVariable)
1523            and "__instancecheck__" in isinstance_type.__class__.__dict__
1524        ):
1525            return variables.ConstantVariable.create(
1526                isinstance_type.__class__.__instancecheck__(isinstance_type, arg.value)
1527            )
1528
1529        try:
1530            val = issubclass(arg_type, isinstance_type)
1531        except TypeError:
1532            val = arg_type is isinstance_type
1533        return variables.ConstantVariable.create(val)
1534
1535    def call_issubclass(self, tx: "InstructionTranslator", left_ty, right_ty):
1536        """Checks if first arg is subclass of right arg"""
1537        try:
1538            left_ty_py = left_ty.as_python_constant()
1539            right_ty_py = right_ty.as_python_constant()
1540        except NotImplementedError:
1541            unimplemented(
1542                f"call_issubclass args not constant left_ty: {left_ty}, right_ty: {right_ty}"
1543            )
1544
1545        return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))
1546
1547    def call_super(self, tx: "InstructionTranslator", a, b):
1548        return variables.SuperVariable(a, b)
1549
1550    def call_next(self, tx: "InstructionTranslator", arg: VariableTracker):
1551        try:
1552            return arg.next_variable(tx)
1553        except Unsupported as ex:
1554            if isinstance(arg, variables.BaseListVariable):
1555                ex.remove_from_stats()
1556                return arg.items[0]
1557            raise
1558
1559    def call_hasattr(self, tx: "InstructionTranslator", obj, attr):
1560        if attr.is_python_constant():
1561            name = attr.as_python_constant()
1562            if isinstance(obj, variables.BuiltinVariable):
1563                return variables.ConstantVariable(hasattr(obj.fn, name))
1564            return obj.call_hasattr(tx, name)
1565
1566    def call_map(self, tx: "InstructionTranslator", fn, *seqs):
1567        seqs = [
1568            seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
1569            for seq in seqs
1570        ]
1571        return variables.MapVariable(fn, seqs, mutable_local=MutableLocal())
1572
1573    def call_filter(self, tx: "InstructionTranslator", fn, seq):
1574        if seq.has_unpack_var_sequence(tx):
1575            seq_unpacked = seq.unpack_var_sequence(tx)
1576            try:
1577                items = list(
1578                    filter(
1579                        lambda x: fn.call_function(tx, [x], {}).as_python_constant(),
1580                        seq_unpacked,
1581                    )
1582                )
1583                return variables.TupleVariable(items)
1584            except NotImplementedError:
1585                return
1586
1587    def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL):
1588        # Special case for sum on tuple of floats and ints
1589        if isinstance(seq, (variables.ListVariable, variables.TupleVariable)) and all(
1590            isinstance(x, variables.ConstantVariable)
1591            and isinstance(x.value, (int, float))
1592            for x in seq.items
1593        ):
1594            if start is self._SENTINEL:
1595                return variables.ConstantVariable.create(
1596                    sum(x.value for x in seq.items),
1597                )
1598            if isinstance(start, variables.ConstantVariable) and isinstance(
1599                start.value, (int, float)
1600            ):
1601                return variables.ConstantVariable.create(
1602                    sum((x.value for x in seq.items), start=start.value),
1603                )
1604        if seq.has_force_unpack_var_sequence(tx):
1605            if start is self._SENTINEL:
1606                start = variables.ConstantVariable.create(0)
1607            items = seq.force_unpack_var_sequence(tx)
1608            return BuiltinVariable(functools.reduce).call_function(
1609                tx,
1610                [
1611                    BuiltinVariable(operator.add),
1612                    variables.TupleVariable(items),
1613                    start,
1614                ],
1615                {},
1616            )
1617
1618    def call_reduce(
1619        self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL
1620    ):
1621        if iterable.has_force_unpack_var_sequence(tx):
1622            items = iterable.force_unpack_var_sequence(tx)
1623            if initial is self._SENTINEL:
1624                value, items = items[0], items[1:]
1625            else:
1626                value = initial
1627            for element in items:
1628                value = function.call_function(tx, [value, element], {})
1629            return value
1630
1631    def call_getattr(
1632        self,
1633        tx: "InstructionTranslator",
1634        obj: VariableTracker,
1635        name_var: VariableTracker,
1636        default=None,
1637    ):
1638        from .. import trace_rules
1639        from . import (
1640            ConstantVariable,
1641            GetAttrVariable,
1642            TorchInGraphFunctionVariable,
1643            UserFunctionVariable,
1644        )
1645        from .builder import SourcelessBuilder, VariableBuilder
1646
1647        name = name_var.as_python_constant()
1648
1649        if not name_var.is_python_constant():
1650            unimplemented("non-const getattr() name")
1651
1652        if tx.output.side_effects.is_attribute_mutation(obj):
1653            if isinstance(obj, variables.UnspecializedNNModuleVariable):
1654                if (
1655                    name
1656                    in (
1657                        "named_parameters",
1658                        "parameters",
1659                        "named_buffers",
1660                        "buffers",
1661                        "named_modules",
1662                        "modules",
1663                    )
1664                    and obj.is_state_mutated
1665                    and tx.output.side_effects.has_pending_mutation(obj)
1666                ):
1667                    unimplemented(
1668                        f"pending mutation on nn module, so graph breaking at {name!r} call"
1669                    )
1670
1671        if tx.output.side_effects.has_pending_mutation_of_attr(obj, name):
1672            return tx.output.side_effects.load_attr(obj, name)
1673
1674        if default is not None:
1675            hasattr_var = self.call_hasattr(tx, obj, name_var)
1676            assert hasattr_var.as_python_constant() in (True, False)
1677            if not hasattr_var.as_python_constant():
1678                return default
1679
1680        options = {}
1681        if obj.source:
1682            source = AttrSource(obj.source, name)
1683            options["source"] = source
1684        else:
1685            source = None
1686
1687        if name in {"__bases__", "__base__", "__flags__"}:
1688            try:
1689                value = obj.as_python_constant()
1690                if isinstance(value, type):
1691                    if name == "__bases__":
1692                        bases = value.__bases__
1693                        if source is not None:
1694                            tuple_args = [
1695                                VariableBuilder(tx, GetItemSource(source, i))(b)
1696                                for i, b in enumerate(bases)
1697                            ]
1698                        else:
1699                            tuple_args = [
1700                                SourcelessBuilder.create(tx, b) for b in bases
1701                            ]
1702                        return variables.TupleVariable(tuple_args, **options)
1703                    if name == "__base__":
1704                        base = value.__base__
1705                        if source is not None:
1706                            return VariableBuilder(tx, source)(base)
1707                        return SourcelessBuilder.create(tx, base)
1708                    if name == "__flags__":
1709                        return ConstantVariable.create(value.__flags__)
1710            except NotImplementedError:
1711                pass
1712
1713        if isinstance(obj, variables.NNModuleVariable):
1714            return obj.var_getattr(tx, name)
1715        elif isinstance(
1716            obj,
1717            (
1718                variables.TensorVariable,
1719                variables.NamedTupleVariable,
1720                variables.ConstantVariable,
1721                variables.DistributedVariable,
1722                variables.UserDefinedClassVariable,
1723                variables.UserDefinedObjectVariable,
1724            ),
1725        ):
1726            try:
1727                return obj.var_getattr(tx, name)
1728            except NotImplementedError:
1729                return GetAttrVariable(obj, name, **options)
1730        elif isinstance(obj, TorchInGraphFunctionVariable):
1731            # Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default.
1732            member = getattr(obj.value, name)
1733            if isinstance(
1734                member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
1735            ) and trace_rules.is_aten_op_or_tensor_method(member):
1736                return TorchInGraphFunctionVariable(member, **options)
1737        elif isinstance(obj, DummyModule):
1738            # TODO(mlazos) - Do we need this?
1739            if obj.is_torch or name not in obj.value.__dict__:
1740                member = getattr(obj.value, name)
1741            else:
1742                member = obj.value.__dict__[name]
1743
1744            if config.replay_record_enabled:
1745                tx.exec_recorder.record_module_access(obj.value, name, member)
1746
1747            if source is not None:
1748                return VariableBuilder(tx, source)(member)
1749            else:
1750                return SourcelessBuilder.create(tx, member)
1751        elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
1752            return ConstantVariable.create(getattr(obj.fn, name))
1753        else:
1754            try:
1755                return obj.var_getattr(tx, name)
1756            except NotImplementedError:
1757                return GetAttrVariable(obj, name, **options)
1758
1759    def call_setattr(
1760        self,
1761        tx: "InstructionTranslator",
1762        obj: VariableTracker,
1763        name_var: VariableTracker,
1764        val: VariableTracker,
1765    ):
1766        if isinstance(
1767            obj,
1768            (
1769                variables.CustomizedDictVariable,
1770                variables.PlacementVariable,
1771                variables.UserDefinedObjectVariable,
1772            ),
1773        ):
1774            return obj.call_method(tx, "__setattr__", [name_var, val], {})
1775        elif (
1776            tx.output.side_effects.is_attribute_mutation(obj)
1777            and name_var.is_python_constant()
1778        ):
1779            name = name_var.as_python_constant()
1780            if isinstance(obj, variables.TensorVariable):
1781                from .builder import wrap_fx_proxy
1782
1783                if name == "requires_grad":
1784                    # TODO(voz): Make it work properly
1785                    unimplemented(
1786                        "mutating requires_grad can introduce a new leaf from non-leaf or vice versa in "
1787                        "the middle of the graph, which aot_autograd does not currently know how to handle. "
1788                    )
1789                if name == "data":
1790                    # Remove the old reference in tracked fakes - if we don't do this
1791                    # new .data value size and shape differences will cause
1792                    # tracked fakes to produce incorrect guards. This is sound because the TensorVariable
1793                    # coming out of set_() below will be a new one, and get
1794                    # installed in tracked fakes.
1795                    to_remove = []
1796                    for tf in tx.output.tracked_fakes:
1797                        if tf.source == obj.source:
1798                            to_remove.append(tf)
1799                    for tf in to_remove:
1800                        tx.output.tracked_fakes.remove(tf)
1801
1802                    # Step 1 - disable grads
1803                    with dynamo_disable_grad(tx), torch.no_grad():
1804                        # Step 2 - call `set_`
1805                        out = wrap_fx_proxy(
1806                            tx,
1807                            tx.output.create_proxy(
1808                                "call_function",
1809                                torch.Tensor.set_,
1810                                *proxy_args_kwargs([obj, val], {}),
1811                            ),
1812                        )
1813
1814                    # Step 3 - drop the version counter - this is a step required to get
1815                    # .data setting to play correctly with the autograd engine.
1816                    # Essentially, dynamo is trying to faithfully preserve the (absurd)
1817                    # behavior of .data= from eager mode
1818                    def _lower_version_count_by_1(x):
1819                        version = x._version
1820                        if version > 0:
1821                            version = version - 1
1822                        torch._C._autograd._unsafe_set_version_counter(x, version)
1823                        return x
1824
1825                    tx.output.create_proxy(
1826                        "call_function",
1827                        _lower_version_count_by_1,
1828                        (out.as_proxy(),),
1829                        {},
1830                    )
1831                    _lower_version_count_by_1(obj.as_proxy().node.meta["example_value"])
1832                    # This handles options prop, guards and ends with a clone
1833                    # Step 4 - replace all reference to the current object with the new one
1834                    return out
1835
1836            tx.output.side_effects.store_attr(obj, name, val)
1837            if name == "_grad":
1838                tx.output.side_effects.store_attr(obj, "grad", val)
1839
1840            return val
1841        elif isinstance(obj, variables.UserDefinedObjectVariable):
1842            unimplemented(
1843                f"setattr(UserDefinedObjectVariable) {type(obj.value).__setattr__}"
1844            )
1845        elif isinstance(obj, variables.NNModuleVariable):
1846            if not tx.output.is_root_tracer():
1847                raise AttributeMutationError(
1848                    "Can't inplace modify module params/buffers inside HigherOrderOp"
1849                )
1850            if name_var.is_python_constant() and isinstance(
1851                val, variables.TensorVariable
1852            ):
1853                assigning_fake_val = get_fake_value(val.as_proxy().node, tx)
1854
1855                try:
1856                    getattr_var = obj.var_getattr(tx, name_var.as_python_constant())
1857                except AttributeError:
1858                    getattr_var = None
1859
1860                if isinstance(getattr_var, variables.TensorVariable):
1861                    # get_fake_val will get the same fake tensor
1862                    existing_fake_attr = get_fake_value(getattr_var.as_proxy().node, tx)
1863
1864                    # same tensor identiy, setattr is a no-op
1865                    mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__")
1866                    if (
1867                        existing_fake_attr is assigning_fake_val
1868                        and mod_setattr is torch.nn.Module.__setattr__
1869                    ):
1870                        return getattr_var
1871
1872            obj.convert_to_unspecialized(tx)
1873        # FIXME (tmanlaibaatar) this is utter hack to unblock HuggingFace export
1874        # Export generally doesn't want to allow mutations on objects directly,
1875        # but we don't have good way to do this rn. For now, we make it an undefined
1876        # behaviour and just set attributes directly on the PretrainedConfig object
1877        # for now.
1878        elif isinstance(obj, variables.dicts.HFPretrainedConfigVariable) and tx.export:
1879            if name_var.is_python_constant() and isinstance(
1880                val, variables.ConstantVariable
1881            ):
1882                setattr(
1883                    obj.obj, name_var.as_python_constant(), val.as_python_constant()
1884                )
1885                return ConstantVariable(None)
1886
1887    def call_delattr(
1888        self,
1889        tx: "InstructionTranslator",
1890        obj: VariableTracker,
1891        name_var: VariableTracker,
1892    ):
1893        return self.call_setattr(tx, obj, name_var, variables.DeletedVariable())
1894
1895    def call_type(self, tx: "InstructionTranslator", obj: VariableTracker):
1896        from .builder import SourcelessBuilder, VariableBuilder
1897
1898        try:
1899            py_type = obj.python_type()
1900        except NotImplementedError as error:
1901            raise UserError(
1902                UserErrorType.INVALID_INPUT,
1903                str(error),
1904                case_name="unknown_python_type",
1905            ) from None
1906
1907        if obj.source is None:
1908            return SourcelessBuilder.create(tx, py_type)
1909        else:
1910            return VariableBuilder(tx, TypeSource(obj.source))(py_type)
1911
1912    def call_reversed(self, tx: "InstructionTranslator", obj: VariableTracker):
1913        if obj.has_unpack_var_sequence(tx):
1914            items = list(reversed(obj.unpack_var_sequence(tx)))
1915            return variables.TupleVariable(items)
1916
1917    def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwargs):
1918        if obj.has_force_unpack_var_sequence(tx) and not isinstance(
1919            obj, variables.TensorVariable
1920        ):
1921            unpacked = obj.force_unpack_var_sequence(tx)
1922            if not all(x.is_python_constant() for x in unpacked):
1923                return
1924            function = kwargs.pop("key", None)
1925            reverse = kwargs.pop(
1926                "reverse", ConstantVariable.create(False)
1927            ).as_python_constant()
1928            assert len(kwargs) == 0
1929            if function:
1930                items = sorted(
1931                    unpacked,
1932                    key=lambda x: function.call_function(
1933                        tx, [x], {}
1934                    ).as_python_constant(),
1935                    reverse=reverse,
1936                )
1937            else:
1938                items = sorted(
1939                    unpacked,
1940                    key=lambda x: x.as_python_constant(),
1941                    reverse=reverse,
1942                )
1943            return variables.ListVariable(items)
1944
1945    # neg is a constant fold function, so we only get here if constant fold is not valid
1946    def call_neg(self, tx: "InstructionTranslator", a):
1947        if isinstance(a, SymNodeVariable):
1948            return SymNodeVariable.create(
1949                tx,
1950                (operator.neg)(a.as_proxy()),
1951                sym_num=None,
1952            )
1953        # None no-ops this handler and lets the driving function proceed
1954        return None
1955
1956    def call_format(self, tx: "InstructionTranslator", _format_string, *args, **kwargs):
1957        format_string = _format_string.as_python_constant()
1958        return variables.StringFormatVariable.create(format_string, args, kwargs)
1959
1960    def call_id(self, tx: "InstructionTranslator", *args):
1961        if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable):
1962            nn_mod_variable = args[0]
1963            mod = tx.output.get_submodule(nn_mod_variable.module_key)
1964            return variables.ConstantVariable.create(id(mod))
1965        elif len(args) == 1 and isinstance(
1966            args[0], variables.UserDefinedObjectVariable
1967        ):
1968            install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH))
1969            constant_result = id(args[0].value)
1970            return variables.ConstantVariable.create(constant_result)
1971        elif len(args) == 1 and isinstance(args[0], TensorVariable):
1972            tensor_variable = args[0]
1973            return tensor_variable.call_id(tx)
1974        else:
1975            unimplemented(f"call_id with args {args}")
1976
1977    def call_deepcopy(self, tx: "InstructionTranslator", x):
1978        unimplemented(f"copy.deepcopy {repr(x)}")
1979
1980    def _comparison_with_tensor(self, tx: "InstructionTranslator", left, right):
1981        from .builder import wrap_fx_proxy_cls
1982        from .tensor import supported_tensor_comparison_op_values
1983
1984        op = self.fn
1985
1986        if op in [operator.is_, operator.is_not]:
1987            is_result = (
1988                isinstance(left, TensorVariable)
1989                and isinstance(right, TensorVariable)
1990                and id(extract_fake_example_value(left.as_proxy().node))
1991                == id(extract_fake_example_value(right.as_proxy().node))
1992            )
1993            if op is operator.is_:
1994                return ConstantVariable.create(is_result)
1995            else:
1996                return ConstantVariable.create(not is_result)
1997
1998        if op not in supported_tensor_comparison_op_values:
1999            unimplemented(f"{op.__name__}({left}, {right})")
2000        if (
2001            isinstance(left, TensorVariable)
2002            and isinstance(right, TensorVariable)
2003            and (left.size and right.size) is not None
2004            and left.size != right.size
2005        ):
2006            try:
2007                torch.broadcast_shapes(left.size, right.size)
2008            except RuntimeError:
2009                # not broadcastable, can't be compared
2010                unimplemented(f"{op.__name__}({left}, {right})")
2011        tensor_cls = left if isinstance(left, TensorVariable) else right
2012        proxy = tx.output.create_proxy(
2013            "call_function", op, (left.as_proxy(), right.as_proxy()), {}
2014        )
2015        return wrap_fx_proxy_cls(
2016            type(tensor_cls),  # handle Ndarrays and Tensors
2017            tx,
2018            proxy,
2019        )
2020
2021    def _comparison_with_symnode(self, tx: "InstructionTranslator", left, right):
2022        from .tensor import supported_tensor_comparison_op_values
2023
2024        op = self.fn
2025
2026        if op not in supported_tensor_comparison_op_values:
2027            unimplemented(f"{op.__name__}({left}, {right})")
2028
2029        proxy = tx.output.create_proxy(
2030            "call_function", op, (left.as_proxy(), right.as_proxy()), {}
2031        )
2032        return SymNodeVariable.create(
2033            tx,
2034            proxy,
2035            sym_num=None,
2036        )
2037
2038    def call_and_(self, tx: "InstructionTranslator", a, b):
2039        # Rely on constant_handler
2040        if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
2041            return None
2042        if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance(
2043            b, (SymNodeVariable, ConstantVariable)
2044        ):
2045            return SymNodeVariable.create(
2046                tx,
2047                tx.output.create_proxy(
2048                    "call_function", operator.and_, *proxy_args_kwargs([a, b], {})
2049                ),
2050                sym_num=None,
2051            )
2052        if hasattr(a, "set_items") and hasattr(b, "set_items"):
2053            return SetVariable(list(a.set_items & b.set_items))
2054        # None no-ops this handler and lets the driving function proceed
2055
2056    def call_or_(self, tx: "InstructionTranslator", a, b):
2057        # Rely on constant_handler
2058        if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
2059            return None
2060        if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance(
2061            b, (SymNodeVariable, ConstantVariable)
2062        ):
2063            return SymNodeVariable.create(
2064                tx,
2065                tx.output.create_proxy(
2066                    "call_function", operator.or_, *proxy_args_kwargs([a, b], {})
2067                ),
2068                sym_num=None,
2069            )
2070        if hasattr(a, "set_items") and hasattr(b, "set_items"):
2071            return SetVariable(list(a.set_items | b.set_items))
2072        # None no-ops this handler and lets the driving function proceed
2073        return None
2074
2075    def call_not_(self, tx: "InstructionTranslator", a):
2076        if isinstance(a, SymNodeVariable):
2077            return SymNodeVariable.create(
2078                tx,
2079                tx.output.create_proxy(
2080                    "call_function", operator.not_, *proxy_args_kwargs([a], {})
2081                ),
2082                sym_num=None,
2083            )
2084
2085        # Unwrap the underlying ConstDictVariable
2086        if isinstance(a, DictView):
2087            a = a.dv_dict
2088        if isinstance(a, (ListVariable, ConstDictVariable)):
2089            return ConstantVariable.create(len(a.items) == 0)
2090
2091        return None
2092
2093    def call_contains(
2094        self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
2095    ):
2096        return a.call_method(tx, "__contains__", [b], {})
2097
2098
2099@contextlib.contextmanager
2100def dynamo_disable_grad(tx):
2101    from . import GradModeVariable
2102
2103    org_value = torch.is_grad_enabled()
2104    gmv = GradModeVariable.create(tx, False)
2105    try:
2106        gmv.enter(tx)
2107        yield
2108    finally:
2109        gmv.exit(tx)
2110