xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/dicts.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import collections
4import dataclasses
5import functools
6import inspect
7import sys
8from typing import Dict, List, Optional, TYPE_CHECKING
9
10from torch._subclasses.fake_tensor import is_fake
11
12from .. import polyfills, variables
13from ..bytecode_transformation import create_call_function, create_instruction
14from ..eval_frame import skip_code
15from ..exc import raise_observed_exception, unimplemented
16from ..guards import GuardBuilder, install_guard
17from ..source import AttrSource, GetItemSource
18from ..utils import dict_keys, dict_values, istype, specialize_symnode
19from .base import MutableLocal, VariableTracker
20from .constant import ConstantVariable
21
22
23if TYPE_CHECKING:
24    from torch._dynamo.symbolic_convert import InstructionTranslator
25
26
27# [Adding a new supported class within the keys of ConstDictVarialble]
28# - Add its tracker type to is_hashable
29# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
30
31
32def is_hashable(x):
33    if isinstance(x, variables.TensorVariable):
34        # Tensors are hashable if they have an example_value (a fake tensor)
35        # Most VT's should have one.
36        # It'd be nice if at some point we could assert that they all have one
37        return x.as_proxy().node.meta.get("example_value") is not None
38    elif isinstance(x, variables.TupleVariable):
39        return all(is_hashable(e) for e in x.items)
40    else:
41        return isinstance(
42            x,
43            (
44                variables.BuiltinVariable,
45                variables.SymNodeVariable,
46                variables.ConstantVariable,
47                variables.EnumVariable,
48                variables.user_defined.UserDefinedClassVariable,
49                variables.UserFunctionVariable,
50                variables.SkipFunctionVariable,
51                variables.misc.NumpyVariable,
52                variables.NNModuleVariable,
53                variables.UnspecializedNNModuleVariable,
54                variables.MethodWrapperVariable,
55                variables.TorchInGraphFunctionVariable,
56                variables.TypingVariable,
57                variables.FunctoolsPartialVariable,
58            ),
59        )
60
61
62class ConstDictVariable(VariableTracker):
63    _nonvar_fields = {
64        "user_cls",
65        *VariableTracker._nonvar_fields,
66    }
67
68    class _HashableTracker:
69        """
70        Auxiliary opaque internal class that wraps a VariableTracker and makes it hashable
71        This should not be seen or touched by anything outside of ConstDictVariable and its children
72        Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
73        """
74
75        def __init__(self, vt) -> None:
76            # We specialize SymNodes
77            vt = specialize_symnode(vt)
78            # TODO Temorarily remove to figure out what keys are we breaking on
79            # and add proper support for them
80            if not is_hashable(vt):
81                unimplemented(f"Dict key of type {type(vt)}. Key: {vt}")
82            self.vt = vt
83
84        @property
85        def underlying_value(self):
86            if isinstance(self.vt, variables.TensorVariable):
87                x = self.vt.as_proxy().node.meta["example_value"]
88            elif isinstance(self.vt, variables.TupleVariable):
89                Hashable = ConstDictVariable._HashableTracker
90                x = tuple(Hashable(e).underlying_value for e in self.vt.items)
91            elif isinstance(self.vt, variables.NNModuleVariable):
92                return self.vt.module
93            elif isinstance(self.vt, variables.UnspecializedNNModuleVariable):
94                return self.vt.value
95            elif isinstance(self.vt, variables.UserFunctionVariable):
96                return self.vt.get_function()
97            else:
98                x = self.vt.as_python_constant()
99            return x
100
101        def __hash__(self):
102            return hash(self.underlying_value)
103
104        @staticmethod
105        def _eq_impl(a, b):
106            # TODO: Put this in utils and share it between variables/builtin.py and here
107            if type(a) != type(b):
108                return False
109            elif isinstance(a, tuple):
110                Hashable = ConstDictVariable._HashableTracker
111                return len(a) == len(b) and all(
112                    Hashable._eq_impl(u, v) for u, v in zip(a, b)
113                )
114            elif is_fake(a):
115                return a is b
116            else:
117                return a == b
118
119        def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
120            Hashable = ConstDictVariable._HashableTracker
121            assert isinstance(other, Hashable) or ConstantVariable.is_literal(
122                other
123            ), type(other)
124            if isinstance(other, Hashable):
125                return Hashable._eq_impl(self.underlying_value, other.underlying_value)
126
127            # constant
128            return Hashable._eq_impl(self.underlying_value, other)
129
130    def __init__(
131        self, items: Dict[VariableTracker, VariableTracker], user_cls=dict, **kwargs
132    ) -> None:
133        super().__init__(**kwargs)
134
135        Hashable = ConstDictVariable._HashableTracker
136
137        # Keys will just be HashableTrackers when cloning, in any other case they'll be VariableTrackers
138        assert all(
139            isinstance(x, (VariableTracker, Hashable))
140            and isinstance(v, VariableTracker)
141            for x, v in items.items()
142        )
143
144        def make_hashable(key):
145            return key if isinstance(key, Hashable) else Hashable(key)
146
147        self.items = {make_hashable(x): v for x, v in items.items()}
148        self.user_cls = user_cls
149
150    def as_proxy(self):
151        return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
152
153    def debug_repr(self):
154        return (
155            "{"
156            + ", ".join(
157                f"{k.vt.debug_repr()}: {v.debug_repr()}" for k, v in self.items.items()
158            )
159            + "}"
160        )
161
162    def as_python_constant(self):
163        return {
164            k.vt.as_python_constant(): v.as_python_constant()
165            for k, v in self.items.items()
166        }
167
168    def keys_as_python_constant(self):
169        return {k.vt.as_python_constant(): v for k, v in self.items.items()}
170
171    def python_type(self):
172        return self.user_cls
173
174    def __contains__(self, vt) -> bool:
175        assert isinstance(vt, VariableTracker)
176        Hashable = ConstDictVariable._HashableTracker
177        return (
178            is_hashable(vt)
179            and Hashable(vt) in self.items
180            and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable)
181        )
182
183    def len(self):
184        return len(
185            [
186                x
187                for x in self.items.values()
188                if not isinstance(x, variables.DeletedVariable)
189            ]
190        )
191
192    def reconstruct(self, codegen):
193        # instructions to load collections.OrderedDict if necessary
194        if self.user_cls is collections.OrderedDict:
195            codegen.add_push_null(
196                lambda: codegen.extend_output(
197                    [
198                        codegen.create_load_python_module(collections),
199                        codegen.create_load_attr("OrderedDict"),
200                    ]
201                )
202            )
203        # instructions to build the dict keys and values
204        for key, value in self.items.items():
205            codegen(key.vt)
206            codegen(value)
207        # BUILD_MAP and calling collections.OrderedDict if necessary
208        if self.user_cls is collections.OrderedDict:
209            codegen.extend_output(
210                [
211                    create_instruction("BUILD_MAP", arg=len(self.items)),
212                    *create_call_function(1, False),
213                ]
214            )
215        # BUILD_MAP only if user_cls is dict
216        else:
217            codegen.append_output(create_instruction("BUILD_MAP", arg=len(self.items)))
218
219    def getitem_const_raise_exception_if_absent(
220        self, tx: "InstructionTranslator", arg: VariableTracker
221    ):
222        key = ConstDictVariable._HashableTracker(arg)
223        if key not in self.items:
224            raise_observed_exception(KeyError, tx, self)
225        return self.items[key]
226
227    def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
228        key = ConstDictVariable._HashableTracker(arg)
229        if key not in self.items:
230            unimplemented(f"dict KeyError: {arg.value}")
231        return self.items[key]
232
233    def maybe_getitem_const(self, arg: VariableTracker):
234        key = ConstDictVariable._HashableTracker(arg)
235        if key not in self.items:
236            return None
237        return self.items[key]
238
239    def call_method(
240        self,
241        tx,
242        name,
243        args: "List[VariableTracker]",
244        kwargs: "Dict[str, VariableTracker]",
245    ) -> "VariableTracker":
246        from . import (
247            BuiltinVariable,
248            ConstantVariable,
249            ListIteratorVariable,
250            ListVariable,
251            TupleVariable,
252            UserDefinedObjectVariable,
253        )
254
255        Hashable = ConstDictVariable._HashableTracker
256
257        arg_hashable = args and is_hashable(args[0])
258
259        if name == "__getitem__":
260            assert len(args) == 1
261            return self.getitem_const_raise_exception_if_absent(tx, args[0])
262        elif name == "items":
263            assert not (args or kwargs)
264            if self.source:
265                tx.output.guard_on_key_order.add(self.source.name())
266            return TupleVariable(
267                [TupleVariable([k.vt, v]) for k, v in self.items.items()]
268            )
269        elif name == "keys":
270            if self.source:
271                tx.output.guard_on_key_order.add(self.source.name())
272            assert not (args or kwargs)
273            return DictKeys(self)
274        elif name == "values":
275            if self.source:
276                tx.output.guard_on_key_order.add(self.source.name())
277            assert not (args or kwargs)
278            return DictValues(self)
279        elif name == "copy":
280            assert not (args or kwargs)
281            return self.clone(items=self.items.copy(), mutable_local=MutableLocal())
282        elif name == "__len__":
283            assert not (args or kwargs)
284            return ConstantVariable.create(len(self.items))
285        elif name == "__setitem__" and arg_hashable and self.mutable_local:
286            assert not kwargs and len(args) == 2
287            tx.output.side_effects.mutation(self)
288            self.items[Hashable(args[0])] = args[1]
289            return ConstantVariable.create(None)
290        elif name == "__delitem__" and arg_hashable and self.mutable_local:
291            tx.output.side_effects.mutation(self)
292            self.items.__delitem__(Hashable(args[0]))
293            return ConstantVariable.create(None)
294        elif name in ("pop", "get") and len(args) in (1, 2) and args[0] not in self:
295            # missing item, return the default value
296            if len(args) == 1:
297                return ConstantVariable(None)
298            else:
299                return args[1]
300        elif name == "pop" and arg_hashable and self.mutable_local:
301            tx.output.side_effects.mutation(self)
302            return self.items.pop(Hashable(args[0]))
303        elif name == "clear":
304            tx.output.side_effects.mutation(self)
305            self.items.clear()
306            return ConstantVariable.create(None)
307        elif (
308            name == "update"
309            and len(args) == 1
310            and isinstance(
311                args[0],
312                (
313                    ConstDictVariable,
314                    ListVariable,
315                    TupleVariable,
316                    ListIteratorVariable,
317                    variables.IteratorVariable,
318                    UserDefinedObjectVariable,
319                ),
320            )
321            and self.mutable_local
322        ):
323            tx.output.side_effects.mutation(self)
324            if isinstance(args[0], ConstDictVariable):
325                dict_vt = args[0]
326            else:
327                dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
328            self.items.update(dict_vt.items)
329            # Wrap strings
330            kwargs = {
331                Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items()
332            }
333            self.items.update(kwargs)
334            return ConstantVariable.create(None)
335        elif name in ("get", "__getattr__") and args[0] in self:
336            return self.getitem_const(tx, args[0])
337        elif name == "__contains__" and len(args) == 1:
338            return ConstantVariable.create(args[0] in self)
339        elif name == "setdefault" and arg_hashable and self.mutable_local:
340            assert not kwargs
341            assert len(args) <= 2
342            value = self.maybe_getitem_const(args[0])
343            if value is not None:
344                return value
345            else:
346                if len(args) == 1:
347                    x = ConstantVariable.create(None)
348                else:
349                    x = args[1]
350                tx.output.side_effects.mutation(self)
351                self.items[Hashable(args[0])] = x
352                return x
353        else:
354            return super().call_method(tx, name, args, kwargs)
355
356    def unpack_var_sequence(self, tx):
357        return [x.vt for x in self.items.keys()]
358
359    def call_hasattr(self, tx, name):
360        # dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict.
361        # OrderedDict though requires side effects tracking because it supports arbitrary setattr.
362        if self.user_cls is dict:
363            if name in self.user_cls.__dict__:
364                return ConstantVariable.create(True)
365            return ConstantVariable.create(False)
366        unimplemented(f"hasattr on {self.user_cls} is not supported")
367
368
369class DefaultDictVariable(ConstDictVariable):
370    def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
371        super().__init__(items, user_cls, **kwargs)
372        assert user_cls is collections.defaultdict
373        self.default_factory = default_factory
374
375    def is_python_constant(self):
376        # Return false for unsupported defaults. This ensures that a bad handler
377        # path is not taken in BuiltinVariable for getitem.
378        if self.default_factory not in [list, tuple, dict] and not self.items:
379            return False
380        return super().is_python_constant()
381
382    def debug_repr(self):
383        return (
384            f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
385        )
386
387    @staticmethod
388    def is_supported_arg(arg):
389        if isinstance(arg, variables.BuiltinVariable):
390            return arg.fn in (list, tuple, dict, set)
391        else:
392            return isinstance(arg, variables.functions.BaseUserFunctionVariable)
393
394    def call_method(
395        self,
396        tx,
397        name,
398        args: "List[VariableTracker]",
399        kwargs: "Dict[str, VariableTracker]",
400    ) -> "VariableTracker":
401        if name == "__getitem__":
402            assert len(args) == 1
403
404            if args[0] in self:
405                return self.getitem_const(tx, args[0])
406            else:
407                if self.default_factory is None:
408                    raise KeyError(f"{args[0]}")
409                else:
410                    default_var = self.default_factory.call_function(tx, [], {})
411                    super().call_method(
412                        tx, "__setitem__", (args[0], default_var), kwargs
413                    )
414                    return default_var
415        else:
416            return super().call_method(tx, name, args, kwargs)
417
418
419# TODO: Implementing this via inheritance rather than composition is a
420# footgun, because self method calls in dict will route back to the set
421# implementation, which is almost assuredly wrong
422class SetVariable(ConstDictVariable):
423    """We model a sets as dictonary with None values"""
424
425    def __init__(
426        self,
427        items: List[VariableTracker],
428        **kwargs,
429    ) -> None:
430        items = dict.fromkeys(items, SetVariable._default_value())
431        super().__init__(items, **kwargs)
432
433    def debug_repr(self):
434        if not self.items:
435            return "set()"
436        else:
437            return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
438
439    @property
440    def set_items(self):
441        return set(self.items.keys())
442
443    @staticmethod
444    def _default_value():
445        # Variable to fill in he keys of the dictinary
446        return ConstantVariable.create(None)
447
448    def as_proxy(self):
449        return {k.vt.as_proxy() for k in self.set_items}
450
451    def python_type(self):
452        return set
453
454    def as_python_constant(self):
455        return {k.vt.as_python_constant() for k in self.set_items}
456
457    def reconstruct(self, codegen):
458        codegen.foreach([x.vt for x in self.set_items])
459        codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
460
461    def call_method(
462        self,
463        tx,
464        name,
465        args: List[VariableTracker],
466        kwargs: Dict[str, VariableTracker],
467    ) -> "VariableTracker":
468        from . import ListVariable, TupleVariable
469
470        # We foward the calls to the dictionary model
471        if name == "add":
472            assert not kwargs
473            assert len(args) == 1
474            name = "__setitem__"
475            args = (args[0], SetVariable._default_value())
476        elif name == "pop":
477            assert not kwargs
478            assert not args
479            # Choose an item at random and pop it via the Dict.pop method
480            result = self.set_items.pop().vt
481            super().call_method(tx, name, (result,), kwargs)
482            return result
483        elif name == "isdisjoint":
484            assert not kwargs
485            assert len(args) == 1
486            return variables.UserFunctionVariable(
487                polyfills.set_isdisjoint
488            ).call_function(tx, [self, args[0]], {})
489        elif name == "intersection":
490            assert not kwargs
491            assert len(args) == 1
492            return variables.UserFunctionVariable(
493                polyfills.set_intersection
494            ).call_function(tx, [self, args[0]], {})
495        elif name == "union":
496            assert not kwargs
497            assert len(args) == 1
498            return variables.UserFunctionVariable(polyfills.set_union).call_function(
499                tx, [self, args[0]], {}
500            )
501        elif name == "difference":
502            assert not kwargs
503            assert len(args) == 1
504            return variables.UserFunctionVariable(
505                polyfills.set_difference
506            ).call_function(tx, [self, args[0]], {})
507        elif (
508            name == "update"
509            and len(args) == 1
510            and isinstance(
511                args[0],
512                (
513                    SetVariable,
514                    ListVariable,
515                    TupleVariable,
516                ),
517            )
518            and self.mutable_local
519        ):
520            if isinstance(args[0], (ListVariable, TupleVariable)):
521                arg = SetVariable(args[0].unpack_var_sequence(tx))
522            else:
523                arg = args[0]
524            return super().call_method(tx, "update", (arg,), kwargs)
525        elif name == "remove":
526            assert not kwargs
527            assert len(args) == 1
528            if args[0] not in self:
529                unimplemented("key does not exist")
530            return super().call_method(tx, "pop", args, kwargs)
531        elif name == "discard":
532            assert not kwargs
533            assert len(args) == 1
534            if args[0] in self:
535                return super().call_method(tx, "pop", args, kwargs)
536            else:
537                return ConstantVariable.create(value=None)
538        return super().call_method(tx, name, args, kwargs)
539
540    def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
541        raise RuntimeError("Illegal to getitem on a set")
542
543
544class FrozensetVariable(SetVariable):
545    def __init__(
546        self,
547        items: List[VariableTracker],
548        **kwargs,
549    ) -> None:
550        super().__init__(items, **kwargs)
551
552    def debug_repr(self):
553        if not self.items:
554            return "frozenset()"
555        else:
556            return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
557
558    @property
559    def set_items(self):
560        return self.items.keys()
561
562    def python_type(self):
563        return frozenset
564
565    def as_python_constant(self):
566        return {k.vt.as_python_constant() for k in self.set_items}
567
568    def reconstruct(self, codegen):
569        codegen.foreach([x.vt for x in self.set_items])
570        codegen.add_push_null(
571            lambda: codegen.extend_output(
572                [
573                    codegen.create_load_global("frozenset"),
574                ]
575            )
576        )
577        codegen.extend_output(create_call_function(0, False))
578
579    def call_method(
580        self,
581        tx,
582        name,
583        args: List[VariableTracker],
584        kwargs: Dict[str, VariableTracker],
585    ) -> "VariableTracker":
586        if name in ["add", "pop", "update", "remove", "discard", "clear"]:
587            raise RuntimeError(f"Illegal call_method {name} on a frozenset")
588        return super().call_method(tx, name, args, kwargs)
589
590
591class DictView(VariableTracker):
592    """
593    Models _PyDictViewObject
594
595    This is an "abstract" class. Subclasses will override kv and the items method
596    """
597
598    kv: Optional[str] = None
599
600    def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
601        super().__init__(**kwargs)
602        assert self.kv in ("keys", "values")
603        assert isinstance(dv_dict, ConstDictVariable)
604        self.dv_dict = dv_dict
605
606    @property
607    def view_items(self):
608        return getattr(self.dv_dict.items, self.kv)()
609
610    @property
611    def view_items_vt(self):
612        # Returns an iterable of the unpacked items
613        # Implement in the subclasses
614        raise NotImplementedError
615
616    def unpack_var_sequence(self, tx):
617        def unwrap(x):
618            return x.vt if self.kv == "keys" else x
619
620        return [unwrap(x) for x in self.view_items]
621
622    def reconstruct(self, codegen):
623        codegen(self.dv_dict)
624        codegen.load_method(self.kv)
625        codegen.call_method(0)
626
627    def call_method(
628        self,
629        tx,
630        name,
631        args: List["VariableTracker"],
632        kwargs: Dict[str, "VariableTracker"],
633    ) -> "VariableTracker":
634        if name == "__len__":
635            return self.dv_dict.call_method(tx, name, args, kwargs)
636        return super().call_method(tx, name, args, kwargs)
637
638
639class DictKeys(DictView):
640    kv = "keys"
641
642    @property
643    def set_items(self):
644        return set(self.view_items)
645
646    @property
647    def view_items_vt(self):
648        # Returns an iterable of the unpacked items
649        return [x.vt for x in self.view_items]
650
651    def python_type(self):
652        return dict_keys
653
654    def call_method(
655        self,
656        tx,
657        name,
658        args: List["VariableTracker"],
659        kwargs: Dict[str, "VariableTracker"],
660    ) -> "VariableTracker":
661        if name == "__contains__":
662            return self.dv_dict.call_method(tx, name, args, kwargs)
663        return super().call_method(tx, name, args, kwargs)
664
665
666class DictValues(DictView):
667    # DictValues is an iterable but cannot be compared.
668    kv = "values"
669
670    @property
671    def view_items_vt(self):
672        return list(self.view_items)
673
674    def python_type(self):
675        return dict_values
676
677
678def _is_matching_transformers_cls(cls) -> bool:
679    mod = sys.modules.get("transformers.file_utils")
680    if mod is None:
681        mod = sys.modules.get("transformers.utils.generic")
682    return mod is not None and issubclass(cls, mod.ModelOutput)
683
684
685def _is_matching_diffusers_cls(cls) -> bool:
686    mod = sys.modules.get("diffusers.utils")
687    return mod is not None and issubclass(cls, mod.BaseOutput)
688
689
690def _call_hasattr_customobj(
691    self, tx: "InstructionTranslator", name: str
692) -> "VariableTracker":
693    """Shared method between DataClassVariable and CustomizedDictVariable where items are attrs"""
694    if tx.output.side_effects.is_attribute_mutation(self):
695        try:
696            result = tx.output.side_effects.load_attr(self, name, deleted_ok=True)
697            return variables.ConstantVariable.create(
698                not isinstance(result, variables.DeletedVariable)
699            )
700        except KeyError:
701            pass
702    if name in self.items or hasattr(self.user_cls, name):
703        return ConstantVariable(True)
704    elif istype(self.mutable_local, MutableLocal) and self.source is None:
705        # Something created locally can't have any extra fields on it
706        return ConstantVariable(False)
707    elif self.source:
708        # Maybe add a guard
709        try:
710            example = tx.output.root_tx.get_example_value(self.source)
711            install_guard(
712                AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
713            )
714            return ConstantVariable(hasattr(example, name))
715        except KeyError:
716            pass
717    unimplemented(
718        f"hasattr({self.__class__.__name__}, {name}) {self.mutable_local} {self.source}"
719    )
720
721
722class CustomizedDictVariable(ConstDictVariable):
723    @staticmethod
724    def is_matching_cls_hf(cls):
725        return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)
726
727    @staticmethod
728    def is_matching_cls(cls):
729        # True if using default OrderedDict.__init__ and did not implement __post_init__
730        if (
731            issubclass(cls, collections.OrderedDict)
732            and cls is not collections.OrderedDict
733            and cls.__init__ is collections.OrderedDict.__init__
734            and not hasattr(cls, "__post_init__")
735        ):
736            return True
737        # hack for HF usecase:
738        #   assume dataclass annotation for ModelOutput subclass
739        #   assume self.create is AA to ModelOutput.__post_init__
740        return CustomizedDictVariable.is_matching_cls_hf(cls)
741
742    @classmethod
743    def is_matching_object(cls, obj):
744        return cls.is_matching_cls(type(obj))
745
746    # called from user_defined.py
747    # when is_matching_cls(cls) is true
748    @classmethod
749    def create(cls, user_cls, args, kwargs, options):
750        # avoid tracing when returning ModelOutput from forward func
751        for attr_name in ("__init__", "__post_init__", "__setattr__", "__setitem__"):
752            if hasattr(user_cls, attr_name):
753                fn = getattr(user_cls, attr_name)
754                assert callable(fn), f"expect callable attr {attr_name}"
755                if hasattr(fn, "__code__"):
756                    skip_code(fn.__code__)
757
758        if dataclasses.is_dataclass(user_cls):
759            # @dataclass CustomDict(a=1, b=2)
760            bound = inspect.signature(user_cls).bind(*args, **kwargs)
761            bound.apply_defaults()
762
763            def make_var(x):
764                if isinstance(x, VariableTracker):
765                    return x
766                elif ConstantVariable.is_literal(x):
767                    return ConstantVariable.create(x)
768                else:
769                    unimplemented(
770                        "expect VariableTracker or ConstantVariable.is_literal"
771                    )
772
773            bound_args = {}
774            if cls.is_matching_cls_hf(user_cls):
775                # Skip none
776                for k, v in bound.arguments.items():
777                    if isinstance(v, ConstantVariable) and v.value is None or v is None:
778                        continue
779                    bound_args[k] = v
780            else:
781                bound_args = bound.arguments
782
783            items = {
784                ConstantVariable.create(k): make_var(v) for k, v in bound_args.items()
785            }
786        elif not args:
787            # CustomDict(a=1, b=2) in the general (non-dataclass) case.
788            items = {ConstantVariable.create(k): v for k, v in kwargs.items()}
789        elif len(args) == 1 and isinstance(args[0], ConstDictVariable) and not kwargs:
790            # CustomDict({'a': 1, 'b': 2})
791            items = args[0].items
792        else:
793            unimplemented("custom dict init with args/kwargs unimplemented")
794
795        return cls(items, user_cls, **options)
796
797    # called from builder.py
798    @classmethod
799    def wrap(cls, builder, obj):
800        user_cls = type(obj)
801
802        if not cls.is_matching_cls_hf(user_cls):
803            unimplemented("custom non-hf dict subclass wrap unimplemented")
804
805        items = builder.__class__(tx=builder.tx, source=builder.source)(
806            collections.OrderedDict(obj)
807        ).items
808
809        keys = [f.name for f in dataclasses.fields(user_cls)]
810        for key in keys:
811            # __init__ function of a dataclass might not have yet defined the key
812            if hasattr(obj, key):
813                val = getattr(obj, key)
814                var = builder.__class__(
815                    tx=builder.tx, source=AttrSource(builder.source, key)
816                )(val)
817                if val is not None:
818                    key = ConstantVariable.create(key)
819                    items[key] = var
820        return cls(items, user_cls)
821
822    def __init__(self, items, user_cls, **options) -> None:
823        super().__init__(items, user_cls, **options)
824        assert self.is_matching_cls(user_cls)
825
826    def as_proxy(self):
827        raise NotImplementedError
828
829    # 'RETURN_VALUE triggered compile'
830    # called from torch/_dynamo/codegen.py
831    def reconstruct(self, codegen):
832        is_hf_model_output = self.is_matching_cls_hf(self.user_cls)
833
834        def gen_fn1():
835            # If the user class is a ModelOutput, then wrap the instance creation in
836            # torch._dynamo.disable(). Even though we mark the __post_init__ as skip
837            # in `create` function, this is not enough. TorchDynamo can still get
838            # triggered on the child functions of __post_init__. This upsets export.
839            # Since, we know that ModelOutput __post_init__ is not worth optimizing,
840            # we just wrap the instance creation in torch._dynamo.disable(),
841            # regardless whether its export or not.
842            if is_hf_model_output:
843                # load torch._dynamo.disable
844                def gen_fn2():
845                    codegen.append_output(codegen.create_load_global("torch", add=True))
846                    codegen.append_output(codegen.create_load_attr("_dynamo"))
847                    codegen.append_output(codegen.create_load_attr("disable"))
848
849                codegen.add_push_null(gen_fn2)
850
851            codegen.extend_output([codegen._create_load_const(self.user_cls)])
852
853            if is_hf_model_output:
854                # Wrap user_cls with disable
855                codegen.extend_output(create_call_function(1, False))
856
857        codegen.add_push_null(gen_fn1)
858
859        # All the keys are just wrapped strings
860        d = self.keys_as_python_constant()
861        codegen.foreach(d.values())
862        keys = tuple(d.keys())
863        codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, False))
864
865    def call_method(
866        self,
867        tx,
868        name,
869        args: "List[VariableTracker]",
870        kwargs: "Dict[str, VariableTracker]",
871    ) -> "VariableTracker":
872        fn = getattr(self.user_cls, name)
873        source = None if self.source is None else AttrSource(self.source, name)
874
875        if hasattr(fn, "__objclass__") and fn.__objclass__ in (
876            dict,
877            collections.OrderedDict,
878        ):
879            # for python dict method without overridden
880            return super().call_method(tx, name, args, kwargs)
881        elif name in (
882            "__getitem__",
883            "to_tuple",
884            "__setitem__",
885            "__setattr__",
886            "__post_init__",
887        ):
888            # for user overridden method
889            return tx.inline_user_function_return(
890                variables.UserFunctionVariable(fn, source=source),
891                [self] + list(args),
892                kwargs,
893            )
894        elif fn is getattr(collections.OrderedDict, name, None):
895            return super().call_method(tx, name, args, kwargs)
896
897        unimplemented(f"custom dict: call_method unimplemented name={name}")
898
899    def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
900        name_vt = ConstantVariable.create(name)
901        if name_vt in self:
902            return self.call_method(tx, "__getitem__", [name_vt], {})
903        if dataclasses.is_dataclass(self.user_cls):
904            defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)}
905            if name in defaults:
906                assert variables.ConstantVariable.is_literal(defaults[name])
907                return variables.ConstantVariable.create(defaults[name])
908        return super().var_getattr(tx, name)
909
910    call_hasattr = _call_hasattr_customobj
911
912
913@functools.lru_cache(None)
914def _install_PretrainedConfig_patch():
915    import transformers
916
917    # We need to monkeypatch transformers here, sadly.
918    # TODO(voz): Upstream to transformers lib
919
920    def _dynamo_overriden_transformers_eq(self, other):
921        if not hasattr(other, "__dict__"):
922            return False
923        return self.__dict__ == other.__dict__
924
925    transformers.configuration_utils.PretrainedConfig.__eq__ = (
926        _dynamo_overriden_transformers_eq
927    )
928
929
930class HFPretrainedConfigVariable(VariableTracker):
931    """
932    Hack for HuggingFace PretrainedConfig
933    """
934
935    @staticmethod
936    def is_matching_cls(cls):
937        mod = sys.modules.get("transformers.configuration_utils")
938        is_match = mod is not None and issubclass(cls, mod.PretrainedConfig)
939
940        # Lazily install monkeypatch the first time we see it in dynamo
941        if is_match:
942            _install_PretrainedConfig_patch()
943        return is_match
944
945    @classmethod
946    def is_matching_object(cls, obj):
947        return cls.is_matching_cls(type(obj))
948
949    def __init__(self, obj, **kwargs) -> None:
950        super().__init__(**kwargs)
951        self.obj = obj
952        assert self.is_matching_cls(type(obj))
953
954    def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
955        from .builder import VariableBuilder
956
957        try:
958            attr_value = getattr(self.obj, name)
959            attr_source = AttrSource(self.source, name)
960            return VariableBuilder(tx, attr_source)(attr_value)
961
962        except AttributeError:
963            unimplemented(f"getattr({self.value}, {name})")
964
965    def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
966        return variables.ConstantVariable.create(hasattr(self.obj, name))
967
968
969class PythonSysModulesVariable(VariableTracker):
970    """Special case for sys.modules.
971
972    Without this we will guard on the exact set of modules imported in the
973    lifetime of the python program.
974    """
975
976    def python_type(self):
977        return dict
978
979    def reconstruct(self, codegen):
980        codegen.add_push_null(
981            lambda: codegen.extend_output(
982                [
983                    codegen.create_load_python_module(sys),
984                    codegen.create_load_attr("modules"),
985                ]
986            )
987        )
988
989    def call_method(
990        self,
991        tx: "InstructionTranslator",
992        name,
993        args: List[VariableTracker],
994        kwargs: Dict[str, VariableTracker],
995    ):
996        if name == "__getitem__":
997            return self.call_getitem(tx, *args, **kwargs)
998        elif name == "get":
999            return self.call_get(tx, *args, **kwargs)
1000        elif name == "__contains__":
1001            return self.call_contains(tx, *args, **kwargs)
1002        unimplemented(f"sys.modules.{name}(*{args}, **{kwargs})")
1003
1004    def _contains_helper(self, tx: "InstructionTranslator", key: VariableTracker):
1005        k = key.as_python_constant()
1006        has_key = k in sys.modules
1007        install_guard(
1008            self.make_guard(
1009                functools.partial(GuardBuilder.DICT_CONTAINS, key=k, invert=not has_key)
1010            )
1011        )
1012        return k, has_key
1013
1014    def call_contains(self, tx: "InstructionTranslator", key: VariableTracker):
1015        k, has_key = self._contains_helper(tx, key)
1016        return ConstantVariable.create(value=has_key)
1017
1018    def call_get(
1019        self,
1020        tx: "InstructionTranslator",
1021        key: VariableTracker,
1022        default: Optional[VariableTracker] = None,
1023    ):
1024        from .builder import VariableBuilder
1025
1026        k, has_key = self._contains_helper(tx, key)
1027
1028        if has_key:
1029            return VariableBuilder(
1030                tx,
1031                GetItemSource(self.source, k),
1032            )(sys.modules[k])
1033
1034        if default is not None:
1035            return default
1036
1037        return ConstantVariable.create(value=None)
1038
1039    def call_getitem(self, tx: "InstructionTranslator", key: VariableTracker):
1040        from .builder import VariableBuilder
1041
1042        k, has_key = self._contains_helper(tx, key)
1043        return VariableBuilder(
1044            tx,
1045            GetItemSource(self.source, k),
1046        )(sys.modules[k])
1047