1# mypy: ignore-errors 2 3import collections 4import dataclasses 5import functools 6import inspect 7import sys 8from typing import Dict, List, Optional, TYPE_CHECKING 9 10from torch._subclasses.fake_tensor import is_fake 11 12from .. import polyfills, variables 13from ..bytecode_transformation import create_call_function, create_instruction 14from ..eval_frame import skip_code 15from ..exc import raise_observed_exception, unimplemented 16from ..guards import GuardBuilder, install_guard 17from ..source import AttrSource, GetItemSource 18from ..utils import dict_keys, dict_values, istype, specialize_symnode 19from .base import MutableLocal, VariableTracker 20from .constant import ConstantVariable 21 22 23if TYPE_CHECKING: 24 from torch._dynamo.symbolic_convert import InstructionTranslator 25 26 27# [Adding a new supported class within the keys of ConstDictVarialble] 28# - Add its tracker type to is_hashable 29# - (perhaps) Define how it is compared in _HashableTracker._eq_impl 30 31 32def is_hashable(x): 33 if isinstance(x, variables.TensorVariable): 34 # Tensors are hashable if they have an example_value (a fake tensor) 35 # Most VT's should have one. 36 # It'd be nice if at some point we could assert that they all have one 37 return x.as_proxy().node.meta.get("example_value") is not None 38 elif isinstance(x, variables.TupleVariable): 39 return all(is_hashable(e) for e in x.items) 40 else: 41 return isinstance( 42 x, 43 ( 44 variables.BuiltinVariable, 45 variables.SymNodeVariable, 46 variables.ConstantVariable, 47 variables.EnumVariable, 48 variables.user_defined.UserDefinedClassVariable, 49 variables.UserFunctionVariable, 50 variables.SkipFunctionVariable, 51 variables.misc.NumpyVariable, 52 variables.NNModuleVariable, 53 variables.UnspecializedNNModuleVariable, 54 variables.MethodWrapperVariable, 55 variables.TorchInGraphFunctionVariable, 56 variables.TypingVariable, 57 variables.FunctoolsPartialVariable, 58 ), 59 ) 60 61 62class ConstDictVariable(VariableTracker): 63 _nonvar_fields = { 64 "user_cls", 65 *VariableTracker._nonvar_fields, 66 } 67 68 class _HashableTracker: 69 """ 70 Auxiliary opaque internal class that wraps a VariableTracker and makes it hashable 71 This should not be seen or touched by anything outside of ConstDictVariable and its children 72 Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing 73 """ 74 75 def __init__(self, vt) -> None: 76 # We specialize SymNodes 77 vt = specialize_symnode(vt) 78 # TODO Temorarily remove to figure out what keys are we breaking on 79 # and add proper support for them 80 if not is_hashable(vt): 81 unimplemented(f"Dict key of type {type(vt)}. Key: {vt}") 82 self.vt = vt 83 84 @property 85 def underlying_value(self): 86 if isinstance(self.vt, variables.TensorVariable): 87 x = self.vt.as_proxy().node.meta["example_value"] 88 elif isinstance(self.vt, variables.TupleVariable): 89 Hashable = ConstDictVariable._HashableTracker 90 x = tuple(Hashable(e).underlying_value for e in self.vt.items) 91 elif isinstance(self.vt, variables.NNModuleVariable): 92 return self.vt.module 93 elif isinstance(self.vt, variables.UnspecializedNNModuleVariable): 94 return self.vt.value 95 elif isinstance(self.vt, variables.UserFunctionVariable): 96 return self.vt.get_function() 97 else: 98 x = self.vt.as_python_constant() 99 return x 100 101 def __hash__(self): 102 return hash(self.underlying_value) 103 104 @staticmethod 105 def _eq_impl(a, b): 106 # TODO: Put this in utils and share it between variables/builtin.py and here 107 if type(a) != type(b): 108 return False 109 elif isinstance(a, tuple): 110 Hashable = ConstDictVariable._HashableTracker 111 return len(a) == len(b) and all( 112 Hashable._eq_impl(u, v) for u, v in zip(a, b) 113 ) 114 elif is_fake(a): 115 return a is b 116 else: 117 return a == b 118 119 def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool: 120 Hashable = ConstDictVariable._HashableTracker 121 assert isinstance(other, Hashable) or ConstantVariable.is_literal( 122 other 123 ), type(other) 124 if isinstance(other, Hashable): 125 return Hashable._eq_impl(self.underlying_value, other.underlying_value) 126 127 # constant 128 return Hashable._eq_impl(self.underlying_value, other) 129 130 def __init__( 131 self, items: Dict[VariableTracker, VariableTracker], user_cls=dict, **kwargs 132 ) -> None: 133 super().__init__(**kwargs) 134 135 Hashable = ConstDictVariable._HashableTracker 136 137 # Keys will just be HashableTrackers when cloning, in any other case they'll be VariableTrackers 138 assert all( 139 isinstance(x, (VariableTracker, Hashable)) 140 and isinstance(v, VariableTracker) 141 for x, v in items.items() 142 ) 143 144 def make_hashable(key): 145 return key if isinstance(key, Hashable) else Hashable(key) 146 147 self.items = {make_hashable(x): v for x, v in items.items()} 148 self.user_cls = user_cls 149 150 def as_proxy(self): 151 return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()} 152 153 def debug_repr(self): 154 return ( 155 "{" 156 + ", ".join( 157 f"{k.vt.debug_repr()}: {v.debug_repr()}" for k, v in self.items.items() 158 ) 159 + "}" 160 ) 161 162 def as_python_constant(self): 163 return { 164 k.vt.as_python_constant(): v.as_python_constant() 165 for k, v in self.items.items() 166 } 167 168 def keys_as_python_constant(self): 169 return {k.vt.as_python_constant(): v for k, v in self.items.items()} 170 171 def python_type(self): 172 return self.user_cls 173 174 def __contains__(self, vt) -> bool: 175 assert isinstance(vt, VariableTracker) 176 Hashable = ConstDictVariable._HashableTracker 177 return ( 178 is_hashable(vt) 179 and Hashable(vt) in self.items 180 and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) 181 ) 182 183 def len(self): 184 return len( 185 [ 186 x 187 for x in self.items.values() 188 if not isinstance(x, variables.DeletedVariable) 189 ] 190 ) 191 192 def reconstruct(self, codegen): 193 # instructions to load collections.OrderedDict if necessary 194 if self.user_cls is collections.OrderedDict: 195 codegen.add_push_null( 196 lambda: codegen.extend_output( 197 [ 198 codegen.create_load_python_module(collections), 199 codegen.create_load_attr("OrderedDict"), 200 ] 201 ) 202 ) 203 # instructions to build the dict keys and values 204 for key, value in self.items.items(): 205 codegen(key.vt) 206 codegen(value) 207 # BUILD_MAP and calling collections.OrderedDict if necessary 208 if self.user_cls is collections.OrderedDict: 209 codegen.extend_output( 210 [ 211 create_instruction("BUILD_MAP", arg=len(self.items)), 212 *create_call_function(1, False), 213 ] 214 ) 215 # BUILD_MAP only if user_cls is dict 216 else: 217 codegen.append_output(create_instruction("BUILD_MAP", arg=len(self.items))) 218 219 def getitem_const_raise_exception_if_absent( 220 self, tx: "InstructionTranslator", arg: VariableTracker 221 ): 222 key = ConstDictVariable._HashableTracker(arg) 223 if key not in self.items: 224 raise_observed_exception(KeyError, tx, self) 225 return self.items[key] 226 227 def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): 228 key = ConstDictVariable._HashableTracker(arg) 229 if key not in self.items: 230 unimplemented(f"dict KeyError: {arg.value}") 231 return self.items[key] 232 233 def maybe_getitem_const(self, arg: VariableTracker): 234 key = ConstDictVariable._HashableTracker(arg) 235 if key not in self.items: 236 return None 237 return self.items[key] 238 239 def call_method( 240 self, 241 tx, 242 name, 243 args: "List[VariableTracker]", 244 kwargs: "Dict[str, VariableTracker]", 245 ) -> "VariableTracker": 246 from . import ( 247 BuiltinVariable, 248 ConstantVariable, 249 ListIteratorVariable, 250 ListVariable, 251 TupleVariable, 252 UserDefinedObjectVariable, 253 ) 254 255 Hashable = ConstDictVariable._HashableTracker 256 257 arg_hashable = args and is_hashable(args[0]) 258 259 if name == "__getitem__": 260 assert len(args) == 1 261 return self.getitem_const_raise_exception_if_absent(tx, args[0]) 262 elif name == "items": 263 assert not (args or kwargs) 264 if self.source: 265 tx.output.guard_on_key_order.add(self.source.name()) 266 return TupleVariable( 267 [TupleVariable([k.vt, v]) for k, v in self.items.items()] 268 ) 269 elif name == "keys": 270 if self.source: 271 tx.output.guard_on_key_order.add(self.source.name()) 272 assert not (args or kwargs) 273 return DictKeys(self) 274 elif name == "values": 275 if self.source: 276 tx.output.guard_on_key_order.add(self.source.name()) 277 assert not (args or kwargs) 278 return DictValues(self) 279 elif name == "copy": 280 assert not (args or kwargs) 281 return self.clone(items=self.items.copy(), mutable_local=MutableLocal()) 282 elif name == "__len__": 283 assert not (args or kwargs) 284 return ConstantVariable.create(len(self.items)) 285 elif name == "__setitem__" and arg_hashable and self.mutable_local: 286 assert not kwargs and len(args) == 2 287 tx.output.side_effects.mutation(self) 288 self.items[Hashable(args[0])] = args[1] 289 return ConstantVariable.create(None) 290 elif name == "__delitem__" and arg_hashable and self.mutable_local: 291 tx.output.side_effects.mutation(self) 292 self.items.__delitem__(Hashable(args[0])) 293 return ConstantVariable.create(None) 294 elif name in ("pop", "get") and len(args) in (1, 2) and args[0] not in self: 295 # missing item, return the default value 296 if len(args) == 1: 297 return ConstantVariable(None) 298 else: 299 return args[1] 300 elif name == "pop" and arg_hashable and self.mutable_local: 301 tx.output.side_effects.mutation(self) 302 return self.items.pop(Hashable(args[0])) 303 elif name == "clear": 304 tx.output.side_effects.mutation(self) 305 self.items.clear() 306 return ConstantVariable.create(None) 307 elif ( 308 name == "update" 309 and len(args) == 1 310 and isinstance( 311 args[0], 312 ( 313 ConstDictVariable, 314 ListVariable, 315 TupleVariable, 316 ListIteratorVariable, 317 variables.IteratorVariable, 318 UserDefinedObjectVariable, 319 ), 320 ) 321 and self.mutable_local 322 ): 323 tx.output.side_effects.mutation(self) 324 if isinstance(args[0], ConstDictVariable): 325 dict_vt = args[0] 326 else: 327 dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) 328 self.items.update(dict_vt.items) 329 # Wrap strings 330 kwargs = { 331 Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items() 332 } 333 self.items.update(kwargs) 334 return ConstantVariable.create(None) 335 elif name in ("get", "__getattr__") and args[0] in self: 336 return self.getitem_const(tx, args[0]) 337 elif name == "__contains__" and len(args) == 1: 338 return ConstantVariable.create(args[0] in self) 339 elif name == "setdefault" and arg_hashable and self.mutable_local: 340 assert not kwargs 341 assert len(args) <= 2 342 value = self.maybe_getitem_const(args[0]) 343 if value is not None: 344 return value 345 else: 346 if len(args) == 1: 347 x = ConstantVariable.create(None) 348 else: 349 x = args[1] 350 tx.output.side_effects.mutation(self) 351 self.items[Hashable(args[0])] = x 352 return x 353 else: 354 return super().call_method(tx, name, args, kwargs) 355 356 def unpack_var_sequence(self, tx): 357 return [x.vt for x in self.items.keys()] 358 359 def call_hasattr(self, tx, name): 360 # dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict. 361 # OrderedDict though requires side effects tracking because it supports arbitrary setattr. 362 if self.user_cls is dict: 363 if name in self.user_cls.__dict__: 364 return ConstantVariable.create(True) 365 return ConstantVariable.create(False) 366 unimplemented(f"hasattr on {self.user_cls} is not supported") 367 368 369class DefaultDictVariable(ConstDictVariable): 370 def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: 371 super().__init__(items, user_cls, **kwargs) 372 assert user_cls is collections.defaultdict 373 self.default_factory = default_factory 374 375 def is_python_constant(self): 376 # Return false for unsupported defaults. This ensures that a bad handler 377 # path is not taken in BuiltinVariable for getitem. 378 if self.default_factory not in [list, tuple, dict] and not self.items: 379 return False 380 return super().is_python_constant() 381 382 def debug_repr(self): 383 return ( 384 f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})" 385 ) 386 387 @staticmethod 388 def is_supported_arg(arg): 389 if isinstance(arg, variables.BuiltinVariable): 390 return arg.fn in (list, tuple, dict, set) 391 else: 392 return isinstance(arg, variables.functions.BaseUserFunctionVariable) 393 394 def call_method( 395 self, 396 tx, 397 name, 398 args: "List[VariableTracker]", 399 kwargs: "Dict[str, VariableTracker]", 400 ) -> "VariableTracker": 401 if name == "__getitem__": 402 assert len(args) == 1 403 404 if args[0] in self: 405 return self.getitem_const(tx, args[0]) 406 else: 407 if self.default_factory is None: 408 raise KeyError(f"{args[0]}") 409 else: 410 default_var = self.default_factory.call_function(tx, [], {}) 411 super().call_method( 412 tx, "__setitem__", (args[0], default_var), kwargs 413 ) 414 return default_var 415 else: 416 return super().call_method(tx, name, args, kwargs) 417 418 419# TODO: Implementing this via inheritance rather than composition is a 420# footgun, because self method calls in dict will route back to the set 421# implementation, which is almost assuredly wrong 422class SetVariable(ConstDictVariable): 423 """We model a sets as dictonary with None values""" 424 425 def __init__( 426 self, 427 items: List[VariableTracker], 428 **kwargs, 429 ) -> None: 430 items = dict.fromkeys(items, SetVariable._default_value()) 431 super().__init__(items, **kwargs) 432 433 def debug_repr(self): 434 if not self.items: 435 return "set()" 436 else: 437 return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" 438 439 @property 440 def set_items(self): 441 return set(self.items.keys()) 442 443 @staticmethod 444 def _default_value(): 445 # Variable to fill in he keys of the dictinary 446 return ConstantVariable.create(None) 447 448 def as_proxy(self): 449 return {k.vt.as_proxy() for k in self.set_items} 450 451 def python_type(self): 452 return set 453 454 def as_python_constant(self): 455 return {k.vt.as_python_constant() for k in self.set_items} 456 457 def reconstruct(self, codegen): 458 codegen.foreach([x.vt for x in self.set_items]) 459 codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) 460 461 def call_method( 462 self, 463 tx, 464 name, 465 args: List[VariableTracker], 466 kwargs: Dict[str, VariableTracker], 467 ) -> "VariableTracker": 468 from . import ListVariable, TupleVariable 469 470 # We foward the calls to the dictionary model 471 if name == "add": 472 assert not kwargs 473 assert len(args) == 1 474 name = "__setitem__" 475 args = (args[0], SetVariable._default_value()) 476 elif name == "pop": 477 assert not kwargs 478 assert not args 479 # Choose an item at random and pop it via the Dict.pop method 480 result = self.set_items.pop().vt 481 super().call_method(tx, name, (result,), kwargs) 482 return result 483 elif name == "isdisjoint": 484 assert not kwargs 485 assert len(args) == 1 486 return variables.UserFunctionVariable( 487 polyfills.set_isdisjoint 488 ).call_function(tx, [self, args[0]], {}) 489 elif name == "intersection": 490 assert not kwargs 491 assert len(args) == 1 492 return variables.UserFunctionVariable( 493 polyfills.set_intersection 494 ).call_function(tx, [self, args[0]], {}) 495 elif name == "union": 496 assert not kwargs 497 assert len(args) == 1 498 return variables.UserFunctionVariable(polyfills.set_union).call_function( 499 tx, [self, args[0]], {} 500 ) 501 elif name == "difference": 502 assert not kwargs 503 assert len(args) == 1 504 return variables.UserFunctionVariable( 505 polyfills.set_difference 506 ).call_function(tx, [self, args[0]], {}) 507 elif ( 508 name == "update" 509 and len(args) == 1 510 and isinstance( 511 args[0], 512 ( 513 SetVariable, 514 ListVariable, 515 TupleVariable, 516 ), 517 ) 518 and self.mutable_local 519 ): 520 if isinstance(args[0], (ListVariable, TupleVariable)): 521 arg = SetVariable(args[0].unpack_var_sequence(tx)) 522 else: 523 arg = args[0] 524 return super().call_method(tx, "update", (arg,), kwargs) 525 elif name == "remove": 526 assert not kwargs 527 assert len(args) == 1 528 if args[0] not in self: 529 unimplemented("key does not exist") 530 return super().call_method(tx, "pop", args, kwargs) 531 elif name == "discard": 532 assert not kwargs 533 assert len(args) == 1 534 if args[0] in self: 535 return super().call_method(tx, "pop", args, kwargs) 536 else: 537 return ConstantVariable.create(value=None) 538 return super().call_method(tx, name, args, kwargs) 539 540 def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): 541 raise RuntimeError("Illegal to getitem on a set") 542 543 544class FrozensetVariable(SetVariable): 545 def __init__( 546 self, 547 items: List[VariableTracker], 548 **kwargs, 549 ) -> None: 550 super().__init__(items, **kwargs) 551 552 def debug_repr(self): 553 if not self.items: 554 return "frozenset()" 555 else: 556 return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" 557 558 @property 559 def set_items(self): 560 return self.items.keys() 561 562 def python_type(self): 563 return frozenset 564 565 def as_python_constant(self): 566 return {k.vt.as_python_constant() for k in self.set_items} 567 568 def reconstruct(self, codegen): 569 codegen.foreach([x.vt for x in self.set_items]) 570 codegen.add_push_null( 571 lambda: codegen.extend_output( 572 [ 573 codegen.create_load_global("frozenset"), 574 ] 575 ) 576 ) 577 codegen.extend_output(create_call_function(0, False)) 578 579 def call_method( 580 self, 581 tx, 582 name, 583 args: List[VariableTracker], 584 kwargs: Dict[str, VariableTracker], 585 ) -> "VariableTracker": 586 if name in ["add", "pop", "update", "remove", "discard", "clear"]: 587 raise RuntimeError(f"Illegal call_method {name} on a frozenset") 588 return super().call_method(tx, name, args, kwargs) 589 590 591class DictView(VariableTracker): 592 """ 593 Models _PyDictViewObject 594 595 This is an "abstract" class. Subclasses will override kv and the items method 596 """ 597 598 kv: Optional[str] = None 599 600 def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None: 601 super().__init__(**kwargs) 602 assert self.kv in ("keys", "values") 603 assert isinstance(dv_dict, ConstDictVariable) 604 self.dv_dict = dv_dict 605 606 @property 607 def view_items(self): 608 return getattr(self.dv_dict.items, self.kv)() 609 610 @property 611 def view_items_vt(self): 612 # Returns an iterable of the unpacked items 613 # Implement in the subclasses 614 raise NotImplementedError 615 616 def unpack_var_sequence(self, tx): 617 def unwrap(x): 618 return x.vt if self.kv == "keys" else x 619 620 return [unwrap(x) for x in self.view_items] 621 622 def reconstruct(self, codegen): 623 codegen(self.dv_dict) 624 codegen.load_method(self.kv) 625 codegen.call_method(0) 626 627 def call_method( 628 self, 629 tx, 630 name, 631 args: List["VariableTracker"], 632 kwargs: Dict[str, "VariableTracker"], 633 ) -> "VariableTracker": 634 if name == "__len__": 635 return self.dv_dict.call_method(tx, name, args, kwargs) 636 return super().call_method(tx, name, args, kwargs) 637 638 639class DictKeys(DictView): 640 kv = "keys" 641 642 @property 643 def set_items(self): 644 return set(self.view_items) 645 646 @property 647 def view_items_vt(self): 648 # Returns an iterable of the unpacked items 649 return [x.vt for x in self.view_items] 650 651 def python_type(self): 652 return dict_keys 653 654 def call_method( 655 self, 656 tx, 657 name, 658 args: List["VariableTracker"], 659 kwargs: Dict[str, "VariableTracker"], 660 ) -> "VariableTracker": 661 if name == "__contains__": 662 return self.dv_dict.call_method(tx, name, args, kwargs) 663 return super().call_method(tx, name, args, kwargs) 664 665 666class DictValues(DictView): 667 # DictValues is an iterable but cannot be compared. 668 kv = "values" 669 670 @property 671 def view_items_vt(self): 672 return list(self.view_items) 673 674 def python_type(self): 675 return dict_values 676 677 678def _is_matching_transformers_cls(cls) -> bool: 679 mod = sys.modules.get("transformers.file_utils") 680 if mod is None: 681 mod = sys.modules.get("transformers.utils.generic") 682 return mod is not None and issubclass(cls, mod.ModelOutput) 683 684 685def _is_matching_diffusers_cls(cls) -> bool: 686 mod = sys.modules.get("diffusers.utils") 687 return mod is not None and issubclass(cls, mod.BaseOutput) 688 689 690def _call_hasattr_customobj( 691 self, tx: "InstructionTranslator", name: str 692) -> "VariableTracker": 693 """Shared method between DataClassVariable and CustomizedDictVariable where items are attrs""" 694 if tx.output.side_effects.is_attribute_mutation(self): 695 try: 696 result = tx.output.side_effects.load_attr(self, name, deleted_ok=True) 697 return variables.ConstantVariable.create( 698 not isinstance(result, variables.DeletedVariable) 699 ) 700 except KeyError: 701 pass 702 if name in self.items or hasattr(self.user_cls, name): 703 return ConstantVariable(True) 704 elif istype(self.mutable_local, MutableLocal) and self.source is None: 705 # Something created locally can't have any extra fields on it 706 return ConstantVariable(False) 707 elif self.source: 708 # Maybe add a guard 709 try: 710 example = tx.output.root_tx.get_example_value(self.source) 711 install_guard( 712 AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) 713 ) 714 return ConstantVariable(hasattr(example, name)) 715 except KeyError: 716 pass 717 unimplemented( 718 f"hasattr({self.__class__.__name__}, {name}) {self.mutable_local} {self.source}" 719 ) 720 721 722class CustomizedDictVariable(ConstDictVariable): 723 @staticmethod 724 def is_matching_cls_hf(cls): 725 return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) 726 727 @staticmethod 728 def is_matching_cls(cls): 729 # True if using default OrderedDict.__init__ and did not implement __post_init__ 730 if ( 731 issubclass(cls, collections.OrderedDict) 732 and cls is not collections.OrderedDict 733 and cls.__init__ is collections.OrderedDict.__init__ 734 and not hasattr(cls, "__post_init__") 735 ): 736 return True 737 # hack for HF usecase: 738 # assume dataclass annotation for ModelOutput subclass 739 # assume self.create is AA to ModelOutput.__post_init__ 740 return CustomizedDictVariable.is_matching_cls_hf(cls) 741 742 @classmethod 743 def is_matching_object(cls, obj): 744 return cls.is_matching_cls(type(obj)) 745 746 # called from user_defined.py 747 # when is_matching_cls(cls) is true 748 @classmethod 749 def create(cls, user_cls, args, kwargs, options): 750 # avoid tracing when returning ModelOutput from forward func 751 for attr_name in ("__init__", "__post_init__", "__setattr__", "__setitem__"): 752 if hasattr(user_cls, attr_name): 753 fn = getattr(user_cls, attr_name) 754 assert callable(fn), f"expect callable attr {attr_name}" 755 if hasattr(fn, "__code__"): 756 skip_code(fn.__code__) 757 758 if dataclasses.is_dataclass(user_cls): 759 # @dataclass CustomDict(a=1, b=2) 760 bound = inspect.signature(user_cls).bind(*args, **kwargs) 761 bound.apply_defaults() 762 763 def make_var(x): 764 if isinstance(x, VariableTracker): 765 return x 766 elif ConstantVariable.is_literal(x): 767 return ConstantVariable.create(x) 768 else: 769 unimplemented( 770 "expect VariableTracker or ConstantVariable.is_literal" 771 ) 772 773 bound_args = {} 774 if cls.is_matching_cls_hf(user_cls): 775 # Skip none 776 for k, v in bound.arguments.items(): 777 if isinstance(v, ConstantVariable) and v.value is None or v is None: 778 continue 779 bound_args[k] = v 780 else: 781 bound_args = bound.arguments 782 783 items = { 784 ConstantVariable.create(k): make_var(v) for k, v in bound_args.items() 785 } 786 elif not args: 787 # CustomDict(a=1, b=2) in the general (non-dataclass) case. 788 items = {ConstantVariable.create(k): v for k, v in kwargs.items()} 789 elif len(args) == 1 and isinstance(args[0], ConstDictVariable) and not kwargs: 790 # CustomDict({'a': 1, 'b': 2}) 791 items = args[0].items 792 else: 793 unimplemented("custom dict init with args/kwargs unimplemented") 794 795 return cls(items, user_cls, **options) 796 797 # called from builder.py 798 @classmethod 799 def wrap(cls, builder, obj): 800 user_cls = type(obj) 801 802 if not cls.is_matching_cls_hf(user_cls): 803 unimplemented("custom non-hf dict subclass wrap unimplemented") 804 805 items = builder.__class__(tx=builder.tx, source=builder.source)( 806 collections.OrderedDict(obj) 807 ).items 808 809 keys = [f.name for f in dataclasses.fields(user_cls)] 810 for key in keys: 811 # __init__ function of a dataclass might not have yet defined the key 812 if hasattr(obj, key): 813 val = getattr(obj, key) 814 var = builder.__class__( 815 tx=builder.tx, source=AttrSource(builder.source, key) 816 )(val) 817 if val is not None: 818 key = ConstantVariable.create(key) 819 items[key] = var 820 return cls(items, user_cls) 821 822 def __init__(self, items, user_cls, **options) -> None: 823 super().__init__(items, user_cls, **options) 824 assert self.is_matching_cls(user_cls) 825 826 def as_proxy(self): 827 raise NotImplementedError 828 829 # 'RETURN_VALUE triggered compile' 830 # called from torch/_dynamo/codegen.py 831 def reconstruct(self, codegen): 832 is_hf_model_output = self.is_matching_cls_hf(self.user_cls) 833 834 def gen_fn1(): 835 # If the user class is a ModelOutput, then wrap the instance creation in 836 # torch._dynamo.disable(). Even though we mark the __post_init__ as skip 837 # in `create` function, this is not enough. TorchDynamo can still get 838 # triggered on the child functions of __post_init__. This upsets export. 839 # Since, we know that ModelOutput __post_init__ is not worth optimizing, 840 # we just wrap the instance creation in torch._dynamo.disable(), 841 # regardless whether its export or not. 842 if is_hf_model_output: 843 # load torch._dynamo.disable 844 def gen_fn2(): 845 codegen.append_output(codegen.create_load_global("torch", add=True)) 846 codegen.append_output(codegen.create_load_attr("_dynamo")) 847 codegen.append_output(codegen.create_load_attr("disable")) 848 849 codegen.add_push_null(gen_fn2) 850 851 codegen.extend_output([codegen._create_load_const(self.user_cls)]) 852 853 if is_hf_model_output: 854 # Wrap user_cls with disable 855 codegen.extend_output(create_call_function(1, False)) 856 857 codegen.add_push_null(gen_fn1) 858 859 # All the keys are just wrapped strings 860 d = self.keys_as_python_constant() 861 codegen.foreach(d.values()) 862 keys = tuple(d.keys()) 863 codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, False)) 864 865 def call_method( 866 self, 867 tx, 868 name, 869 args: "List[VariableTracker]", 870 kwargs: "Dict[str, VariableTracker]", 871 ) -> "VariableTracker": 872 fn = getattr(self.user_cls, name) 873 source = None if self.source is None else AttrSource(self.source, name) 874 875 if hasattr(fn, "__objclass__") and fn.__objclass__ in ( 876 dict, 877 collections.OrderedDict, 878 ): 879 # for python dict method without overridden 880 return super().call_method(tx, name, args, kwargs) 881 elif name in ( 882 "__getitem__", 883 "to_tuple", 884 "__setitem__", 885 "__setattr__", 886 "__post_init__", 887 ): 888 # for user overridden method 889 return tx.inline_user_function_return( 890 variables.UserFunctionVariable(fn, source=source), 891 [self] + list(args), 892 kwargs, 893 ) 894 elif fn is getattr(collections.OrderedDict, name, None): 895 return super().call_method(tx, name, args, kwargs) 896 897 unimplemented(f"custom dict: call_method unimplemented name={name}") 898 899 def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 900 name_vt = ConstantVariable.create(name) 901 if name_vt in self: 902 return self.call_method(tx, "__getitem__", [name_vt], {}) 903 if dataclasses.is_dataclass(self.user_cls): 904 defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)} 905 if name in defaults: 906 assert variables.ConstantVariable.is_literal(defaults[name]) 907 return variables.ConstantVariable.create(defaults[name]) 908 return super().var_getattr(tx, name) 909 910 call_hasattr = _call_hasattr_customobj 911 912 913@functools.lru_cache(None) 914def _install_PretrainedConfig_patch(): 915 import transformers 916 917 # We need to monkeypatch transformers here, sadly. 918 # TODO(voz): Upstream to transformers lib 919 920 def _dynamo_overriden_transformers_eq(self, other): 921 if not hasattr(other, "__dict__"): 922 return False 923 return self.__dict__ == other.__dict__ 924 925 transformers.configuration_utils.PretrainedConfig.__eq__ = ( 926 _dynamo_overriden_transformers_eq 927 ) 928 929 930class HFPretrainedConfigVariable(VariableTracker): 931 """ 932 Hack for HuggingFace PretrainedConfig 933 """ 934 935 @staticmethod 936 def is_matching_cls(cls): 937 mod = sys.modules.get("transformers.configuration_utils") 938 is_match = mod is not None and issubclass(cls, mod.PretrainedConfig) 939 940 # Lazily install monkeypatch the first time we see it in dynamo 941 if is_match: 942 _install_PretrainedConfig_patch() 943 return is_match 944 945 @classmethod 946 def is_matching_object(cls, obj): 947 return cls.is_matching_cls(type(obj)) 948 949 def __init__(self, obj, **kwargs) -> None: 950 super().__init__(**kwargs) 951 self.obj = obj 952 assert self.is_matching_cls(type(obj)) 953 954 def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 955 from .builder import VariableBuilder 956 957 try: 958 attr_value = getattr(self.obj, name) 959 attr_source = AttrSource(self.source, name) 960 return VariableBuilder(tx, attr_source)(attr_value) 961 962 except AttributeError: 963 unimplemented(f"getattr({self.value}, {name})") 964 965 def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 966 return variables.ConstantVariable.create(hasattr(self.obj, name)) 967 968 969class PythonSysModulesVariable(VariableTracker): 970 """Special case for sys.modules. 971 972 Without this we will guard on the exact set of modules imported in the 973 lifetime of the python program. 974 """ 975 976 def python_type(self): 977 return dict 978 979 def reconstruct(self, codegen): 980 codegen.add_push_null( 981 lambda: codegen.extend_output( 982 [ 983 codegen.create_load_python_module(sys), 984 codegen.create_load_attr("modules"), 985 ] 986 ) 987 ) 988 989 def call_method( 990 self, 991 tx: "InstructionTranslator", 992 name, 993 args: List[VariableTracker], 994 kwargs: Dict[str, VariableTracker], 995 ): 996 if name == "__getitem__": 997 return self.call_getitem(tx, *args, **kwargs) 998 elif name == "get": 999 return self.call_get(tx, *args, **kwargs) 1000 elif name == "__contains__": 1001 return self.call_contains(tx, *args, **kwargs) 1002 unimplemented(f"sys.modules.{name}(*{args}, **{kwargs})") 1003 1004 def _contains_helper(self, tx: "InstructionTranslator", key: VariableTracker): 1005 k = key.as_python_constant() 1006 has_key = k in sys.modules 1007 install_guard( 1008 self.make_guard( 1009 functools.partial(GuardBuilder.DICT_CONTAINS, key=k, invert=not has_key) 1010 ) 1011 ) 1012 return k, has_key 1013 1014 def call_contains(self, tx: "InstructionTranslator", key: VariableTracker): 1015 k, has_key = self._contains_helper(tx, key) 1016 return ConstantVariable.create(value=has_key) 1017 1018 def call_get( 1019 self, 1020 tx: "InstructionTranslator", 1021 key: VariableTracker, 1022 default: Optional[VariableTracker] = None, 1023 ): 1024 from .builder import VariableBuilder 1025 1026 k, has_key = self._contains_helper(tx, key) 1027 1028 if has_key: 1029 return VariableBuilder( 1030 tx, 1031 GetItemSource(self.source, k), 1032 )(sys.modules[k]) 1033 1034 if default is not None: 1035 return default 1036 1037 return ConstantVariable.create(value=None) 1038 1039 def call_getitem(self, tx: "InstructionTranslator", key: VariableTracker): 1040 from .builder import VariableBuilder 1041 1042 k, has_key = self._contains_helper(tx, key) 1043 return VariableBuilder( 1044 tx, 1045 GetItemSource(self.source, k), 1046 )(sys.modules[k]) 1047