1# mypy: ignore-errors 2 3import contextlib 4import functools 5import inspect 6import itertools 7import logging 8import math 9import operator 10import types 11from collections import defaultdict, OrderedDict 12from collections.abc import KeysView 13from typing import Dict, List, TYPE_CHECKING 14 15import torch 16from torch import sym_float, sym_int 17from torch.utils._python_dispatch import is_traceable_wrapper_subclass 18 19from .. import config, variables 20from ..exc import ( 21 AttributeMutationError, 22 unimplemented, 23 Unsupported, 24 UserError, 25 UserErrorType, 26) 27from ..guards import GuardBuilder, install_guard 28from ..replay_record import DummyModule 29from ..source import AttrSource, GetItemSource, is_constant_source, TypeSource 30from ..utils import ( 31 check_constant_args, 32 check_numpy_ndarray_args, 33 check_unspec_or_constant_args, 34 check_unspec_python_args, 35 does_not_override_dict_iter_methods, 36 extract_fake_example_value, 37 get_fake_value, 38 guard_if_dyn, 39 is_wrapper_or_member_descriptor, 40 istype, 41 numpy_operator_wrapper, 42 proxy_args_kwargs, 43 tensortype_to_dtype, 44) 45from .base import MutableLocal, VariableTracker 46from .constant import ConstantVariable 47from .ctx_manager import EventVariable, StreamVariable 48from .dicts import ( 49 ConstDictVariable, 50 DefaultDictVariable, 51 DictView, 52 FrozensetVariable, 53 is_hashable, 54 SetVariable, 55) 56from .lists import ( 57 BaseListVariable, 58 ListIteratorVariable, 59 ListVariable, 60 SizeVariable, 61 TupleIteratorVariable, 62 TupleVariable, 63) 64from .tensor import ( 65 FakeItemVariable, 66 supported_comparison_ops, 67 SymNodeVariable, 68 TensorVariable, 69 UnspecializedPythonVariable, 70) 71from .user_defined import UserDefinedObjectVariable, UserDefinedVariable 72 73 74if TYPE_CHECKING: 75 from torch._dynamo.symbolic_convert import InstructionTranslator 76 77 78log = logging.getLogger(__name__) 79 80 81IN_PLACE_DESUGARING_MAP = { 82 operator.iadd: operator.add, 83 operator.isub: operator.sub, 84 operator.imul: operator.mul, 85 operator.ifloordiv: operator.floordiv, 86 operator.itruediv: operator.truediv, 87 operator.imod: operator.mod, 88 operator.imatmul: operator.imatmul, 89 operator.ilshift: operator.lshift, 90 operator.irshift: operator.rshift, 91 operator.ipow: operator.pow, 92 operator.iand: operator.and_, 93 operator.ior: operator.or_, 94 operator.ixor: operator.xor, 95} 96 97 98class BuiltinVariable(VariableTracker): 99 _SENTINEL = object() 100 _nonvar_fields = { 101 "fn", 102 *VariableTracker._nonvar_fields, 103 } 104 105 @classmethod 106 def create_with_source(cls, value, source): 107 install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) 108 return cls(value, source=source) 109 110 @staticmethod 111 @functools.lru_cache(None) 112 def _constant_fold_functions(): 113 fns = { 114 abs, 115 all, 116 any, 117 bool, 118 callable, 119 chr, 120 divmod, 121 float, 122 getattr, 123 int, 124 len, 125 max, 126 min, 127 ord, 128 pow, 129 repr, 130 round, 131 str, 132 str.format, 133 sum, 134 type, 135 operator.abs, 136 operator.pos, 137 operator.neg, 138 operator.not_, 139 operator.truth, 140 operator.invert, 141 operator.pow, 142 operator.mul, 143 operator.matmul, 144 operator.floordiv, 145 operator.truediv, 146 operator.mod, 147 operator.add, 148 operator.sub, 149 operator.getitem, 150 operator.length_hint, 151 operator.lshift, 152 operator.rshift, 153 operator.and_, 154 operator.or_, 155 operator.xor, 156 operator.ipow, 157 operator.imul, 158 operator.imatmul, 159 operator.ifloordiv, 160 operator.itruediv, 161 operator.imod, 162 operator.iadd, 163 operator.isub, 164 operator.ilshift, 165 operator.irshift, 166 operator.iand, 167 operator.ixor, 168 operator.ior, 169 operator.index, 170 } 171 from .tensor import supported_comparison_ops 172 173 fns.update(supported_comparison_ops.values()) 174 fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt))) 175 return fns 176 177 def can_constant_fold_through(self): 178 return self.fn in self._constant_fold_functions() 179 180 @staticmethod 181 @functools.lru_cache(None) 182 def _fx_graph_functions(): 183 fns = { 184 operator.abs, 185 operator.pos, 186 operator.neg, 187 operator.not_, 188 operator.invert, 189 operator.pow, 190 operator.mul, 191 operator.matmul, 192 operator.floordiv, 193 operator.truediv, 194 operator.mod, 195 operator.add, 196 operator.lt, 197 operator.gt, 198 operator.ge, 199 operator.le, 200 operator.ne, 201 operator.eq, 202 operator.sub, 203 operator.getitem, 204 operator.length_hint, 205 operator.lshift, 206 operator.rshift, 207 operator.and_, 208 operator.or_, 209 operator.xor, 210 operator.ipow, 211 operator.imul, 212 operator.imatmul, 213 operator.ifloordiv, 214 operator.itruediv, 215 operator.imod, 216 operator.iadd, 217 operator.isub, 218 operator.ilshift, 219 operator.irshift, 220 operator.iand, 221 operator.ixor, 222 operator.ior, 223 } 224 return fns 225 226 @staticmethod 227 @functools.lru_cache(None) 228 def _binops(): 229 # function -> ([forward name, reverse name, in-place name], in-place op) 230 fns = { 231 operator.add: (["__add__", "__radd__", "__iadd__"], operator.iadd), 232 operator.sub: (["__sub__", "__rsub__", "__isub__"], operator.isub), 233 operator.mul: (["__mul__", "__rmul__", "__imul__"], operator.imul), 234 operator.truediv: ( 235 ["__truediv__", "__rtruediv__", "__itruediv__"], 236 operator.itruediv, 237 ), 238 operator.floordiv: ( 239 ["__floordiv__", "__rfloordiv__", "__ifloordiv__"], 240 operator.ifloordiv, 241 ), 242 operator.mod: (["__mod__", "__rmod__", "__imod__"], operator.imod), 243 pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow), 244 operator.pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow), 245 operator.lshift: ( 246 ["__lshift__", "__rlshift__", "__ilshift__"], 247 operator.ilshift, 248 ), 249 operator.rshift: ( 250 ["__rshift__", "__rrshift__", "__irshift__"], 251 operator.irshift, 252 ), 253 # NB: The follow binary operators are not supported for now, since the 254 # corresponding magic methods aren't defined on SymInt / SymFloat: 255 # operator.matmul 256 # divmod 257 # operator.and_ 258 # operator.or_ 259 # operator.xor 260 } 261 return fns 262 263 @staticmethod 264 @functools.lru_cache(None) 265 def _binop_handlers(): 266 # Multiple dispatch mechanism defining custom binop behavior for certain type 267 # combinations. Handlers are attempted in order, and will be used if the type checks 268 # match. They are expected to have the signature: 269 # fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker 270 from .dicts import DictKeys, SetVariable 271 from .functions import BaseUserFunctionVariable, UserFunctionVariable 272 from .nn_module import NNModuleVariable 273 from .tensor import supported_const_comparison_ops 274 from .torch import BaseTorchVariable 275 from .user_defined import ( 276 UserDefinedClassVariable, 277 UserDefinedObjectVariable, 278 UserDefinedVariable, 279 ) 280 281 # Override table contains: op_fn -> [list of handlers] 282 op_handlers = {} 283 for ( 284 op, 285 (magic_method_names, in_place_op), 286 ) in BuiltinVariable._binops().items(): 287 op_handlers[op] = [] 288 op_handlers[in_place_op] = [] 289 290 forward_name, reverse_name, inplace_name = magic_method_names 291 292 # User-defined args (highest precedence) 293 def user_defined_handler( 294 tx, 295 a, 296 b, 297 *, 298 forward_name=forward_name, 299 reverse_name=reverse_name, 300 ): 301 # Manually handle reversing logic if needed (e.g. call __radd__) 302 303 # TODO: If we expand this to handle tensor args, we need to manually 304 # handle cases like this: 305 # 306 # class A(int): 307 # def __radd__(self, other): 308 # print("woof") 309 # torch.randn(3) + A(3) 310 # 311 # In this example, A.__radd__() is not called -> nothing is printed, because 312 # Tensor.__add__ only does a subtype test against int, ignoring the subclass. 313 # To be fully correct, we should not call A.__radd__() here, and there may be 314 # other cases to reason about and add exceptions for. 315 if isinstance(a, UserDefinedVariable): 316 return a.call_method(tx, forward_name, [b], {}) 317 else: 318 return b.call_method(tx, reverse_name, [a], {}) 319 320 op_handlers[op].append( 321 ((UserDefinedVariable, VariableTracker), user_defined_handler) 322 ) 323 op_handlers[op].append( 324 ((VariableTracker, UserDefinedVariable), user_defined_handler) 325 ) 326 327 def user_defined_inplace_handler( 328 tx: "InstructionTranslator", a, b, *, forward_name=inplace_name 329 ): 330 return a.call_method(tx, forward_name, [b], {}) 331 332 op_handlers[in_place_op].append( 333 ((UserDefinedVariable, VariableTracker), user_defined_inplace_handler) 334 ) 335 op_handlers[in_place_op].append( 336 ((VariableTracker, UserDefinedVariable), user_defined_inplace_handler) 337 ) 338 339 # Dynamic shape args 340 def dynamic_handler(tx: "InstructionTranslator", a, b, *, fn=op): 341 from .builder import wrap_fx_proxy 342 343 return wrap_fx_proxy( 344 tx, 345 tx.output.create_proxy( 346 "call_function", fn, *proxy_args_kwargs([a, b], {}) 347 ), 348 ) 349 350 op_handlers[op].append( 351 ((SymNodeVariable, VariableTracker), dynamic_handler) 352 ) 353 op_handlers[op].append( 354 ((VariableTracker, SymNodeVariable), dynamic_handler) 355 ) 356 357 # NB: Prefer out-of-place op when calling in-place op to generate valid graph 358 op_handlers[in_place_op].append( 359 ((SymNodeVariable, VariableTracker), dynamic_handler) 360 ) 361 op_handlers[in_place_op].append( 362 ((VariableTracker, SymNodeVariable), dynamic_handler) 363 ) 364 365 # Special cases - lower precedence but still prefer these over constant folding 366 367 # List-like addition (e.g. [1, 2] + [3, 4]) 368 def tuple_add_handler(tx: "InstructionTranslator", a, b): 369 return TupleVariable([*a.items, *b.unpack_var_sequence(tx)]) 370 371 def size_add_handler(tx: "InstructionTranslator", a, b): 372 return SizeVariable([*a.items, *b.unpack_var_sequence(tx)]) 373 374 list_like_addition_handlers = [ 375 # NB: Prefer the tuple-specific logic over base logic because of 376 # some SizeVariable weirdness. Specifically, the tuple-specific logic 377 # drops the subclass type (e.g. SizeVariable) and returns TupleVariables. 378 ( 379 (SizeVariable, SizeVariable), 380 size_add_handler, 381 ), 382 ( 383 (TupleVariable, TupleVariable), 384 tuple_add_handler, 385 ), 386 ( 387 (TupleVariable, ConstantVariable), 388 tuple_add_handler, 389 ), 390 ( 391 (ConstantVariable, TupleVariable), 392 lambda tx, a, b: TupleVariable( 393 [*a.unpack_var_sequence(tx), *b.items], 394 ), 395 ), 396 ( 397 ( 398 ListVariable, 399 (BaseListVariable, ConstantVariable, ListIteratorVariable), 400 ), 401 lambda tx, a, b: ListVariable( 402 [*a.items, *b.unpack_var_sequence(tx)], mutable_local=MutableLocal() 403 ), 404 ), 405 ( 406 (BaseListVariable, BaseListVariable), 407 lambda tx, a, b: type(a)([*a.items, *b.items]), 408 ), 409 ] 410 op_handlers[operator.add].extend(list_like_addition_handlers) 411 412 def list_iadd_handler(tx: "InstructionTranslator", a, b): 413 if not a.mutable_local or not b.has_unpack_var_sequence(tx): 414 # Handler doesn't apply 415 return None 416 417 seq = b.unpack_var_sequence(tx) 418 tx.output.side_effects.mutation(a) 419 a.items.extend(seq) 420 return a 421 422 list_like_iadd_handlers = [ 423 ( 424 (ListVariable, VariableTracker), 425 list_iadd_handler, 426 ), 427 ( 428 (TupleVariable, TupleVariable), 429 tuple_add_handler, 430 ), 431 ( 432 (TupleVariable, ConstantVariable), 433 tuple_add_handler, 434 ), 435 ] 436 op_handlers[operator.iadd].extend(list_like_iadd_handlers) 437 438 # List-like expansion (e.g. [1, 2, 3] * 3) 439 def expand_list_like(tx: "InstructionTranslator", lst, const): 440 if isinstance(lst, ConstantVariable): 441 lst, const = const, lst 442 return lst.__class__( 443 items=lst.items * const.as_python_constant(), 444 mutable_local=MutableLocal(), 445 ) 446 447 list_like_expansion_handlers = [ 448 ((ListVariable, ConstantVariable), expand_list_like), 449 ((TupleVariable, ConstantVariable), expand_list_like), 450 ((ConstantVariable, ListVariable), expand_list_like), 451 ((ConstantVariable, TupleVariable), expand_list_like), 452 ] 453 op_handlers[operator.mul].extend(list_like_expansion_handlers) 454 455 size_or_tuple = (SizeVariable, TupleVariable) 456 has_set_items = (SetVariable, DictKeys) 457 458 def create_cmp_op_handlers(op): 459 def compare_by_value(tx: "InstructionTranslator", a, b): 460 return ConstantVariable(op(a.value, b.value)) 461 462 result = [((ConstantVariable, ConstantVariable), compare_by_value)] 463 464 if op in supported_const_comparison_ops.values(): 465 # Tensor is None, List is not None, etc 466 none_result = op(object(), None) 467 if op.__name__.startswith("is_"): 468 469 def never(tx: "InstructionTranslator", a, b): 470 return ConstantVariable(none_result) 471 472 obj_op_none = never 473 none_op_obj = never 474 else: 475 476 def obj_op_none( 477 tx: "InstructionTranslator", a, b: ConstantVariable 478 ): 479 if b.value is None or b.value is True or b.value is False: 480 return ConstantVariable(none_result) 481 482 def none_op_obj( 483 tx: "InstructionTranslator", a: ConstantVariable, b 484 ): 485 if a.value is None or a.value is True or a.value is False: 486 return ConstantVariable(none_result) 487 488 types_that_are_never_none = ( 489 TensorVariable, 490 SymNodeVariable, 491 NNModuleVariable, 492 BaseListVariable, 493 UserDefinedVariable, 494 BaseUserFunctionVariable, 495 ConstDictVariable, 496 BaseTorchVariable, 497 ) 498 result.extend( 499 [ 500 ( 501 (types_that_are_never_none, ConstantVariable), 502 obj_op_none, 503 ), 504 ( 505 (ConstantVariable, types_that_are_never_none), 506 none_op_obj, 507 ), 508 ] 509 ) 510 511 def list_compare_nocheck(tx: "InstructionTranslator", left, right): 512 return BaseListVariable.list_compare(tx, op, left, right) 513 514 def list_compare_check(tx: "InstructionTranslator", left, right): 515 if type(left) is not type( 516 right 517 ): # Mismatch in BaseListVariable subclasses 518 unimplemented(f"{op.__name__}({left}, {right})") 519 return BaseListVariable.list_compare(tx, op, left, right) 520 521 def compare_set_items(tx: "InstructionTranslator", left, right): 522 return ConstantVariable(op(left.set_items, right.set_items)) 523 524 def compare_via_method(tx: "InstructionTranslator", left, right): 525 return left.call_method(tx, f"__{op.__name__}__", [right], {}) 526 527 if op.__name__.startswith("is_"): 528 compare_user_defined = compare_by_value 529 else: 530 compare_user_defined = compare_via_method 531 532 op_var = BuiltinVariable(op) 533 result.extend( 534 [ 535 ( 536 ( 537 (UserFunctionVariable, BuiltinVariable), 538 (UserFunctionVariable, BuiltinVariable), 539 ), 540 lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)), 541 ), 542 ( 543 ( 544 NNModuleVariable, 545 NNModuleVariable, 546 ), 547 lambda tx, a, b: ConstantVariable( 548 op( 549 tx.output.get_submodule(a.module_key), 550 tx.output.get_submodule(b.module_key), 551 ) 552 ), 553 ), 554 ((size_or_tuple, size_or_tuple), list_compare_nocheck), 555 ( 556 (variables.BaseListVariable, variables.BaseListVariable), 557 list_compare_check, 558 ), 559 ((has_set_items, has_set_items), compare_set_items), 560 ( 561 (UserDefinedObjectVariable, UserDefinedObjectVariable), 562 compare_user_defined, 563 ), 564 ( 565 (UserDefinedClassVariable, UserDefinedClassVariable), 566 compare_user_defined, 567 ), 568 ( 569 ( 570 (StreamVariable, EventVariable, ConstantVariable), 571 (StreamVariable, EventVariable, ConstantVariable), 572 ), 573 compare_by_value, 574 ), 575 ( 576 (TensorVariable, VariableTracker), 577 op_var._comparison_with_tensor, 578 ), 579 ( 580 (VariableTracker, TensorVariable), 581 op_var._comparison_with_tensor, 582 ), 583 ( 584 (SymNodeVariable, VariableTracker), 585 op_var._comparison_with_symnode, 586 ), 587 ( 588 (VariableTracker, SymNodeVariable), 589 op_var._comparison_with_symnode, 590 ), 591 ] 592 ) 593 594 if op.__name__.startswith("is_"): 595 596 def handle_is(tx: "InstructionTranslator", left, right): 597 # If the two objects are of different type, we can safely return False 598 # and True for `is` and `is not`, respectively 599 if type(left) is not type(right): 600 return ConstantVariable.create(op.__name__ != "is_") 601 602 result.append(((VariableTracker, VariableTracker), handle_is)) 603 604 return result 605 606 for op in supported_comparison_ops.values(): 607 assert callable(op) 608 assert op not in op_handlers 609 op_handlers[op] = create_cmp_op_handlers(op) 610 611 return op_handlers 612 613 @staticmethod 614 def _find_binop_handler(op, a_type, b_type): 615 handlers = BuiltinVariable._binop_handlers().get(op) 616 if handlers is None: 617 return None 618 619 matches = [] 620 for (type1, type2), handler in handlers: 621 if issubclass(a_type, type1) and issubclass(b_type, type2): 622 matches.append(handler) 623 return matches 624 625 def can_insert_in_graph(self): 626 return self.fn in self._fx_graph_functions() 627 628 def __init__(self, fn, **kwargs) -> None: 629 super().__init__(**kwargs) 630 self.fn = fn 631 632 def __str__(self) -> str: 633 if self.fn is None: 634 name = "None" 635 else: 636 name = self.fn.__name__ 637 638 return f"{self.__class__.__name__}({name})" 639 640 def as_python_constant(self): 641 return self.fn 642 643 def as_proxy(self): 644 DTYPE = { 645 bool: torch.bool, 646 int: torch.int64, 647 float: torch.float64, 648 } 649 if self.fn in DTYPE: 650 return DTYPE[self.fn] 651 return super().as_proxy() 652 653 def reconstruct(self, codegen): 654 name = self.fn.__name__ 655 assert self.fn.__module__ == "builtins" 656 assert name not in codegen.tx.f_globals, "shadowed global" 657 codegen.append_output(codegen.create_load_global(name, False, add=True)) 658 659 def constant_args(self, *args, **kwargs): 660 return check_constant_args(args, kwargs) 661 662 def tensor_args(self, *args): 663 any_tensor = False 664 for arg in args: 665 if isinstance(arg, variables.GetAttrVariable): 666 return False 667 any_tensor = any_tensor or isinstance(arg, variables.TensorVariable) 668 return any_tensor 669 670 def tensor_args_type(self, arg_types): 671 any_tensor = False 672 for arg_type in arg_types: 673 if issubclass(arg_type, variables.GetAttrVariable): 674 return False 675 any_tensor = any_tensor or issubclass(arg_type, variables.TensorVariable) 676 return any_tensor 677 678 def python_and_tensor_constant_only(self, *args, **kwargs): 679 tensor_args = [] 680 non_tensor_args = [] 681 for i in itertools.chain(args, kwargs.values()): 682 if isinstance(i, variables.TensorVariable): 683 tensor_args.append(i) 684 else: 685 non_tensor_args.append(i) 686 return all( 687 is_constant_source(t.source) if t.source is not None else False 688 for t in tensor_args 689 ) and self.constant_args(*non_tensor_args) 690 691 @staticmethod 692 def unwrap_unspec_args_kwargs(args, kwargs): 693 return [x.as_python_constant() for x in args], { 694 k: v.as_python_constant() for k, v in kwargs.items() 695 } 696 697 def has_constant_handler(self, args, kwargs): 698 return self.can_constant_fold_through() and check_unspec_or_constant_args( 699 args, kwargs 700 ) 701 702 @staticmethod 703 def _make_handler(fn, arg_types: List[type], has_kwargs: bool): 704 from .builder import SourcelessBuilder 705 from .lazy import LazyVariableTracker 706 707 obj = BuiltinVariable(fn) 708 handlers = [] 709 710 if any(issubclass(t, LazyVariableTracker) for t in arg_types): 711 return lambda tx, args, kwargs: obj.call_function( 712 tx, [v.realize() for v in args], kwargs 713 ) 714 715 if inspect.isclass(fn) and issubclass(fn, Exception): 716 717 def create_exception_class_object( 718 tx: "InstructionTranslator", args, kwargs 719 ): 720 if fn is AssertionError and not all( 721 isinstance(x, variables.ConstantVariable) 722 and isinstance(x.value, str) 723 for x in args 724 ): 725 unimplemented("assert with non-string message") 726 727 return variables.ExceptionVariable(fn, args, **kwargs) 728 729 return create_exception_class_object 730 731 if obj.can_insert_in_graph() and not ( 732 fn is operator.getitem 733 and not issubclass(arg_types[0], variables.TensorVariable) 734 ): 735 if obj.tensor_args_type(arg_types): 736 return obj._handle_insert_op_in_graph 737 elif has_kwargs: 738 # need runtime check for kwargs 739 handlers.append(obj._handle_insert_op_in_graph) 740 741 # Handle binary ops (e.g. __add__ / __radd__, __iadd__, etc.) 742 # NB: Tensor args are handled above and not here 743 if len(arg_types) == 2 and not has_kwargs: 744 # Try to find a handler for the arg types; otherwise, fall through to constant handler 745 binop_handlers = BuiltinVariable._find_binop_handler(fn, *arg_types) 746 if not binop_handlers: 747 pass 748 elif len(binop_handlers) == 1: 749 (binop_handler,) = binop_handlers 750 handlers.append(lambda tx, args, _: binop_handler(tx, *args)) 751 else: 752 753 def call_binop_handlers(tx: "InstructionTranslator", args, _): 754 for fn in binop_handlers: 755 rv = fn(tx, *args) 756 if rv: 757 return rv 758 759 handlers.append(call_binop_handlers) 760 761 self_handler = getattr(obj, f"call_{fn.__name__}", None) 762 if self_handler: 763 764 def call_self_handler(tx: "InstructionTranslator", args, kwargs): 765 try: 766 result = self_handler(tx, *args, **kwargs) 767 if result is not None: 768 return result 769 except TypeError: 770 # Check if binding is bad. inspect signature bind is expensive. 771 # So check only when handler call fails. 772 try: 773 inspect.signature(self_handler).bind(tx, *args, **kwargs) 774 except TypeError as e: 775 has_constant_handler = obj.has_constant_handler(args, kwargs) 776 if not has_constant_handler: 777 log.warning( 778 "incorrect arg count %s %s and no constant handler", 779 self_handler, 780 e, 781 ) 782 unimplemented( 783 f"invalid handler args {self_handler} {args} {kwargs}" 784 ) 785 else: 786 raise 787 except Unsupported as exc: 788 has_constant_handler = obj.has_constant_handler(args, kwargs) 789 if not has_constant_handler: 790 raise 791 # Actually, we will handle this just fine 792 exc.remove_from_stats() 793 794 handlers.append(call_self_handler) 795 796 if obj.can_constant_fold_through(): 797 builder = SourcelessBuilder.create 798 799 if ( 800 all(issubclass(x, ConstantVariable) for x in arg_types) 801 and not has_kwargs 802 ): 803 804 def constant_fold_handler(tx: "InstructionTranslator", args, kwargs): 805 # fast path 806 try: 807 res = fn( 808 *[x.as_python_constant() for x in args], 809 ) 810 except Exception as exc: 811 unimplemented(f"constant fold exception: {repr(exc)}") 812 return builder(tx, res) 813 814 else: 815 816 def constant_fold_handler(tx: "InstructionTranslator", args, kwargs): 817 # path with a runtime check 818 if check_unspec_or_constant_args(args, kwargs): 819 try: 820 res = fn( 821 *[x.as_python_constant() for x in args], 822 **{ 823 k: v.as_python_constant() for k, v in kwargs.items() 824 }, 825 ) 826 except Exception as exc: 827 unimplemented(f"constant fold exception: {repr(exc)}") 828 return builder(tx, res) 829 830 handlers.append(constant_fold_handler) 831 832 error_msg = f"builtin: {fn.__name__} {arg_types} {has_kwargs}" 833 if len(handlers) == 0: 834 return lambda *args: unimplemented(error_msg) 835 elif len(handlers) == 1: 836 (handler,) = handlers 837 838 def builtin_dispatch(tx: "InstructionTranslator", args, kwargs): 839 rv = handler(tx, args, kwargs) 840 if rv: 841 return rv 842 unimplemented(error_msg) 843 844 else: 845 846 def builtin_dispatch(tx: "InstructionTranslator", args, kwargs): 847 for fn in handlers: 848 rv = fn(tx, args, kwargs) 849 if rv: 850 return rv 851 unimplemented(error_msg) 852 853 return builtin_dispatch 854 855 def _handle_insert_op_in_graph(self, tx: "InstructionTranslator", args, kwargs): 856 from .builder import wrap_fx_proxy, wrap_fx_proxy_cls 857 858 if kwargs and not self.tensor_args(*args, *kwargs.values()): 859 return 860 861 fn = self.fn 862 try: 863 # Constant fold for constant tensor and python constants 864 if self.python_and_tensor_constant_only(*args, **kwargs): 865 from ..bytecode_transformation import unique_id 866 from .functions import invoke_and_store_as_constant 867 868 return invoke_and_store_as_constant( 869 tx, fn, unique_id(fn.__name__), args, kwargs 870 ) 871 872 if fn in IN_PLACE_DESUGARING_MAP and isinstance( 873 args[0], variables.ConstantVariable 874 ): 875 # In-place operators like += usually mustate tensor 876 # values, but in the edge case of immutable values they 877 # re-bind the variable. 878 # 879 # The easiest way to keep the graph consistent in this 880 # scenario is to de-sugar eagerly. 881 fn, args = IN_PLACE_DESUGARING_MAP[fn], [args[0], args[1]] 882 883 if fn is operator.getitem and isinstance(args[1], SymNodeVariable): 884 # Standard indexing will force specialization due to 885 # __index__. Rewrite as a regular torch op which will 886 # trace fine 887 fn, args = torch.select, [ 888 args[0], 889 variables.ConstantVariable.create(0), 890 args[1], 891 ] 892 893 # Interaction between ndarray and tensors: 894 # We prefer the tensor op whenever there are tensors involved 895 if check_numpy_ndarray_args(args, kwargs) and not any( 896 type(arg) == variables.TensorVariable for arg in args 897 ): 898 proxy = tx.output.create_proxy( 899 "call_function", 900 numpy_operator_wrapper(fn), 901 *proxy_args_kwargs(args, kwargs), 902 ) 903 904 return wrap_fx_proxy_cls(variables.NumpyNdarrayVariable, tx, proxy) 905 906 proxy = tx.output.create_proxy( 907 "call_function", 908 fn, 909 *proxy_args_kwargs(args, kwargs), 910 ) 911 if any(isinstance(arg, FakeItemVariable) for arg in args): 912 return wrap_fx_proxy_cls( 913 FakeItemVariable, 914 tx, 915 proxy, 916 ) 917 elif check_unspec_python_args(args, kwargs): 918 _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs) 919 raw_value = fn(*_args, **_kwargs) 920 921 need_unwrap = any( 922 x.need_unwrap 923 for x in itertools.chain(args, kwargs.values()) 924 if isinstance(x, variables.UnspecializedPythonVariable) 925 ) 926 927 return wrap_fx_proxy_cls( 928 UnspecializedPythonVariable, 929 tx, 930 proxy, 931 raw_value=raw_value, 932 need_unwrap=need_unwrap, 933 ) 934 elif all(isinstance(x, SymNodeVariable) for x in args): 935 return SymNodeVariable.create(tx, proxy, None) 936 else: 937 # Work around for vision_maskrcnn due to precision difference 938 # specialize the dividend when float divide by tensor 939 if fn is operator.truediv and isinstance( 940 args[0], variables.UnspecializedPythonVariable 941 ): 942 args[0] = args[0].convert_to_constant(tx) 943 return wrap_fx_proxy(tx, proxy) 944 945 except NotImplementedError: 946 unimplemented(f"partial tensor op: {self} {args} {kwargs}") 947 948 call_function_handler_cache = {} 949 950 def call_function( 951 self, 952 tx: "InstructionTranslator", 953 args: "List[VariableTracker]", 954 kwargs: "Dict[str, VariableTracker]", 955 ) -> "VariableTracker": 956 if kwargs: 957 kwargs = {k: v.realize() for k, v in kwargs.items()} 958 key = (self.fn, *(type(x) for x in args), True) 959 else: 960 key = (self.fn, *(type(x) for x in args)) 961 962 handler = self.call_function_handler_cache.get(key) 963 if not handler: 964 self.call_function_handler_cache[key] = handler = self._make_handler( 965 self.fn, [type(x) for x in args], bool(kwargs) 966 ) 967 return handler(tx, args, kwargs) 968 969 def call_method( 970 self, 971 tx, 972 name, 973 args: "List[VariableTracker]", 974 kwargs: "Dict[str, VariableTracker]", 975 ) -> "VariableTracker": 976 if self.fn is object and name == "__setattr__": 977 assert len(args) == 3 978 assert len(kwargs) == 0 979 obj, name_var, val = args 980 obj = obj.realize() 981 if ( 982 isinstance(obj, UserDefinedObjectVariable) 983 and tx.output.side_effects.is_attribute_mutation(obj) 984 and name_var.is_python_constant() 985 ): 986 return obj.method_setattr_standard(tx, name_var, val) 987 if self.fn is object and name == "__new__": 988 assert len(args) == 1 989 assert len(kwargs) == 0 990 return tx.output.side_effects.track_object_new_from_user_defined_class( 991 args[0] 992 ) 993 if self.fn is dict and name == "fromkeys": 994 return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs) 995 return super().call_method(tx, name, args, kwargs) 996 997 def _call_int_float(self, tx: "InstructionTranslator", arg): 998 # Handle cases like int(torch.seed()) 999 # Also handle sym_float to sym_int cases 1000 if isinstance(arg, (SymNodeVariable, variables.TensorVariable)): 1001 if isinstance(arg, variables.TensorVariable): 1002 item = arg.call_method(tx, "item", [], {}) 1003 else: 1004 item = arg 1005 fn_ = sym_int if self.fn is int else sym_float 1006 from torch._dynamo.variables.builder import wrap_fx_proxy 1007 1008 return wrap_fx_proxy( 1009 tx=tx, 1010 proxy=tx.output.create_proxy( 1011 "call_function", 1012 fn_, 1013 (item.as_proxy(),), 1014 {}, 1015 ), 1016 ) 1017 1018 call_int = _call_int_float 1019 call_float = _call_int_float 1020 1021 def call_str(self, tx: "InstructionTranslator", arg): 1022 # Handle `str` on a user defined function or object 1023 if isinstance(arg, (variables.UserFunctionVariable)): 1024 return variables.ConstantVariable.create(value=str(arg.fn)) 1025 elif isinstance(arg, (variables.UserDefinedObjectVariable)): 1026 # Check if object has __str__ method 1027 if hasattr(arg.value, "__str__"): 1028 str_method = arg.value.__str__ 1029 elif hasattr(arg.value, "__repr__"): 1030 # account for __repr__ functions when __str__ is absent 1031 str_method = arg.value.__repr__ 1032 else: 1033 unimplemented("user defined object has no __str__ or __repr__ method") 1034 1035 if type(arg.value).__str__ is object.__str__: 1036 # Rely on the object str method 1037 try: 1038 return variables.ConstantVariable.create(value=str_method()) 1039 except AttributeError: 1040 # Graph break 1041 return 1042 elif is_wrapper_or_member_descriptor(str_method): 1043 unimplemented(f"{type(arg.value)} has a C/C++ based str method") 1044 else: 1045 # Overrides for custom str method 1046 # Pass method as function to call tx.inline_user_function_return 1047 bound_method = str_method.__func__ 1048 1049 try: 1050 # Only supports certain function types 1051 user_func_variable = variables.UserFunctionVariable(bound_method) 1052 except AssertionError as e: 1053 # Won't be able to do inline the str method, return to avoid graph break 1054 log.warning("Failed to create UserFunctionVariable: %s", e) 1055 return 1056 1057 # Inline the user function 1058 return tx.inline_user_function_return(user_func_variable, [arg], {}) 1059 1060 def _call_min_max(self, tx: "InstructionTranslator", *args): 1061 if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): 1062 items = args[0].force_unpack_var_sequence(tx) 1063 return self._call_min_max_seq(tx, items) 1064 elif len(args) == 2: 1065 return self._call_min_max_binary(tx, args[0], args[1]) 1066 elif len(args) > 2: 1067 return self._call_min_max_seq(tx, args) 1068 1069 def _call_min_max_seq(self, tx: "InstructionTranslator", items): 1070 assert len(items) > 0 1071 if len(items) == 1: 1072 return items[0] 1073 1074 return functools.reduce(functools.partial(self._call_min_max_binary, tx), items) 1075 1076 def _call_min_max_binary(self, tx: "InstructionTranslator", a, b): 1077 if a is None or b is None: 1078 # a or b could be none if we reduce and _call_min_max_binary failed 1079 # to return something 1080 return 1081 if self.tensor_args(a, b): 1082 if not isinstance(a, variables.TensorVariable): 1083 a, b = b, a 1084 assert isinstance(a, variables.TensorVariable) 1085 1086 # result of an item call is a scalar convert to a tensor 1087 if isinstance(a, FakeItemVariable): 1088 a = variables.TorchInGraphFunctionVariable(torch.tensor).call_function( 1089 tx, [a], {} 1090 ) 1091 1092 # Dynamic input does not get resolved, rather, gets stored as call_function 1093 if isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): 1094 from .builder import wrap_fx_proxy_cls 1095 1096 return wrap_fx_proxy_cls( 1097 type(a), 1098 tx=tx, 1099 proxy=tx.output.create_proxy( 1100 "call_function", 1101 self.fn, 1102 *proxy_args_kwargs([a, b], {}), 1103 ), 1104 ) 1105 1106 # convert min/max to torch ops 1107 if b.is_python_constant(): 1108 if isinstance(a, variables.NumpyNdarrayVariable): 1109 import numpy as np 1110 1111 fn = variables.NumpyVariable(np.clip) 1112 else: 1113 fn = variables.TorchInGraphFunctionVariable(torch.clamp) 1114 kwargs = {"min": b} if (self.fn is max) else {"max": b} 1115 result = fn.call_function(tx, [a], kwargs) 1116 else: 1117 if isinstance(a, variables.NumpyNdarrayVariable): 1118 import numpy as np 1119 1120 fn = {max: np.maximum, min: np.minimum}[self.fn] 1121 fn = variables.NumpyVariable(fn) 1122 else: 1123 fn = {max: torch.maximum, min: torch.minimum}[self.fn] 1124 fn = variables.TorchInGraphFunctionVariable(fn) 1125 result = fn.call_function(tx, [a, b], {}) 1126 1127 # return unspec if both a, b are unspec or const 1128 if all( 1129 isinstance( 1130 i, 1131 ( 1132 variables.UnspecializedPythonVariable, 1133 variables.ConstantVariable, 1134 ), 1135 ) 1136 for i in [a, b] 1137 ): 1138 if any(isinstance(val, FakeItemVariable) for val in [a, b]): 1139 return variables.FakeItemVariable.from_tensor_variable(result) 1140 1141 if b.is_python_constant(): 1142 raw_b = b.as_python_constant() 1143 else: 1144 raw_b = b.raw_value 1145 if self.fn is max: 1146 raw_res = max(a.raw_value, raw_b) 1147 else: 1148 raw_res = min(a.raw_value, raw_b) 1149 1150 need_unwrap = any( 1151 x.need_unwrap 1152 for x in [a, b] 1153 if isinstance(x, variables.UnspecializedPythonVariable) 1154 ) 1155 return variables.UnspecializedPythonVariable.from_tensor_variable( 1156 result, raw_res, need_unwrap 1157 ) 1158 # otherwise return tensor 1159 else: 1160 return result 1161 elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): 1162 fn = torch.sym_max if self.fn is max else torch.sym_min 1163 proxy = tx.output.create_proxy( 1164 "call_function", fn, *proxy_args_kwargs([a, b], {}) 1165 ) 1166 return SymNodeVariable.create(tx, proxy, None) 1167 1168 call_min = _call_min_max 1169 call_max = _call_min_max 1170 1171 def call_abs(self, tx: "InstructionTranslator", arg: "VariableTracker"): 1172 # Call arg.__abs__() 1173 abs_method = BuiltinVariable(getattr).call_function( 1174 tx, [arg, ConstantVariable.create("__abs__")], {} 1175 ) 1176 return abs_method.call_function(tx, [], {}) 1177 1178 def call_pos(self, tx: "InstructionTranslator", arg: "VariableTracker"): 1179 # Call arg.__pos__() 1180 pos_method = BuiltinVariable(getattr).call_function( 1181 tx, [arg, ConstantVariable.create("__pos__")], {} 1182 ) 1183 return pos_method.call_function(tx, [], {}) 1184 1185 def call_index(self, tx: "InstructionTranslator", arg: "VariableTracker"): 1186 if isinstance(arg, variables.TensorVariable): 1187 unimplemented("unsupported index(tensor)") 1188 1189 arg = guard_if_dyn(arg) 1190 constant_value = operator.index(arg) 1191 return variables.ConstantVariable.create(constant_value) 1192 1193 def call_round(self, tx: "InstructionTranslator", arg, *args, **kwargs): 1194 # Call arg.__round__() 1195 round_method = BuiltinVariable(getattr).call_function( 1196 tx, [arg, ConstantVariable.create("__round__")], {} 1197 ) 1198 return round_method.call_function(tx, args, kwargs) 1199 1200 def call_range(self, tx: "InstructionTranslator", *args): 1201 if check_unspec_or_constant_args(args, {}): 1202 return variables.RangeVariable(args) 1203 elif self._dynamic_args(*args): 1204 args = [ 1205 variables.ConstantVariable.create(guard_if_dyn(arg)) for arg in args 1206 ] 1207 return variables.RangeVariable(args) 1208 # None no-ops this handler and lets the driving function proceed 1209 return None 1210 1211 def _dynamic_args(self, *args, **kwargs): 1212 return any(isinstance(x, SymNodeVariable) for x in args) or any( 1213 isinstance(x, SymNodeVariable) for x in kwargs.values() 1214 ) 1215 1216 def call_slice(self, tx: "InstructionTranslator", *args): 1217 return variables.SliceVariable(args) 1218 1219 def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs): 1220 from .builder import wrap_fx_proxy 1221 1222 return wrap_fx_proxy( 1223 tx, 1224 tx.output.create_proxy( 1225 "call_function", self.fn, *proxy_args_kwargs(args, kwargs) 1226 ), 1227 ) 1228 1229 # NOTE must handle IteratorVariable separately! 1230 def _call_iter_tuple_list( 1231 self, tx: "InstructionTranslator", obj=None, *args, **kwargs 1232 ): 1233 assert not isinstance(obj, variables.IteratorVariable) 1234 1235 if self._dynamic_args(*args, **kwargs): 1236 return self._dyn_proxy(tx, *args, **kwargs) 1237 1238 cls = variables.BaseListVariable.cls_for(self.fn) 1239 if obj is None: 1240 return cls( 1241 [], 1242 mutable_local=MutableLocal(), 1243 ) 1244 elif obj.has_unpack_var_sequence(tx): 1245 if obj.source and not is_constant_source(obj.source): 1246 if isinstance(obj, TupleIteratorVariable): 1247 install_guard( 1248 obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN) 1249 ) 1250 else: 1251 if ( 1252 getattr(obj, "source", False) 1253 and isinstance(obj, ConstDictVariable) 1254 and not istype(obj, SetVariable) 1255 ): 1256 tx.output.guard_on_key_order.add(obj.source.name()) 1257 1258 install_guard(obj.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) 1259 1260 return cls( 1261 list(obj.unpack_var_sequence(tx)), 1262 mutable_local=MutableLocal(), 1263 ) 1264 1265 def _call_tuple_list(self, tx, obj=None, *args, **kwargs): 1266 if isinstance(obj, variables.IteratorVariable): 1267 cls = variables.BaseListVariable.cls_for(self.fn) 1268 return cls( 1269 list(obj.force_unpack_var_sequence(tx)), 1270 mutable_local=MutableLocal(), 1271 ) 1272 else: 1273 return self._call_iter_tuple_list(tx, obj, *args, **kwargs) 1274 1275 def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs): 1276 if isinstance(obj, variables.IteratorVariable): 1277 ret = obj 1278 else: 1279 # Handle the case where we are iterating over a tuple, list or iterator 1280 ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) 1281 1282 if ret is None: 1283 # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway. 1284 # If the object implements a __iter__ method, inlining effectively forwards the call to another iter call 1285 # (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator. 1286 return obj.call_method(tx, "__iter__", args, kwargs) 1287 return ret 1288 1289 call_tuple = _call_tuple_list 1290 call_list = _call_tuple_list 1291 1292 def call_callable(self, tx: "InstructionTranslator", arg): 1293 from .functions import BaseUserFunctionVariable 1294 from .nn_module import NNModuleVariable 1295 1296 if isinstance( 1297 arg, 1298 ( 1299 variables.UserDefinedClassVariable, 1300 BaseUserFunctionVariable, 1301 NNModuleVariable, 1302 ), 1303 ): 1304 return variables.ConstantVariable.create(True) 1305 elif isinstance(arg, UserDefinedVariable): 1306 return variables.ConstantVariable.create(callable(arg.value)) 1307 elif isinstance( 1308 arg, 1309 ( 1310 ConstantVariable, 1311 SymNodeVariable, 1312 TensorVariable, 1313 ListVariable, 1314 TupleVariable, 1315 ListIteratorVariable, 1316 ), 1317 ): 1318 return variables.ConstantVariable.create(False) 1319 1320 def call_cast(self, _, *args, **kwargs): 1321 if len(args) == 2: 1322 return args[1] 1323 1324 unimplemented(f"unsupported args to builtin cast(): {args} {kwargs}") 1325 1326 def call_dict(self, tx: "InstructionTranslator", *args, **kwargs): 1327 return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs) 1328 1329 @staticmethod 1330 def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs): 1331 from .builder import SourcelessBuilder 1332 1333 if not kwargs: 1334 if not args: 1335 args = ({},) 1336 assert len(args) == 1 1337 arg = args[0] 1338 if isinstance(arg, dict): 1339 return ConstDictVariable(arg, user_cls, mutable_local=MutableLocal()) 1340 elif isinstance(arg, variables.ConstDictVariable): 1341 return arg.clone(user_cls=user_cls, mutable_local=MutableLocal()) 1342 elif isinstance( 1343 arg, 1344 ( 1345 ListVariable, 1346 TupleVariable, 1347 ListIteratorVariable, 1348 variables.IteratorVariable, 1349 ), 1350 ): 1351 items = dict( 1352 x.force_unpack_var_sequence(tx) 1353 for x in arg.force_unpack_var_sequence(tx) 1354 ) 1355 return ConstDictVariable(items, user_cls, mutable_local=MutableLocal()) 1356 elif isinstance(arg, variables.MutableMappingVariable): 1357 # This is applicable for user defined objects which seem like dict, but are not really dicts. For 1358 # example, TensorDict derives from MutableMapping. For such cases, we can directly inline the .items 1359 # method and create a new dict. 1360 if does_not_override_dict_iter_methods(type(arg.value)): 1361 # These are implemeted in C, so we will have to manually construct the items 1362 1363 if tx.output.side_effects.has_pending_mutation(arg): 1364 unimplemented( 1365 f"{user_cls.__name__}.items(): {args} {kwargs} - object is mutated" 1366 ) 1367 1368 new_dict = dict(arg.value.items()) 1369 return SourcelessBuilder.create(tx, new_dict) 1370 else: 1371 func_var = arg.var_getattr(tx, "items") 1372 if not isinstance(func_var, variables.UserFunctionVariable): 1373 unimplemented(f"{user_cls.__name__}.items(): {args} {kwargs}") 1374 out = tx.inline_user_function_return(func_var, args, kwargs) 1375 if isinstance(out, ConstDictVariable): 1376 return out 1377 return BuiltinVariable(user_cls).call_custom_dict(tx, user_cls, out) 1378 elif not args and kwargs: 1379 items = {ConstantVariable.create(k): v for k, v in kwargs.items()} 1380 return variables.ConstDictVariable( 1381 items, user_cls=user_cls, mutable_local=MutableLocal() 1382 ) 1383 unimplemented(f"{user_cls.__name__}(): {args} {kwargs}") 1384 1385 @staticmethod 1386 def call_custom_dict_fromkeys( 1387 tx: "InstructionTranslator", user_cls, *args, **kwargs 1388 ): 1389 assert user_cls in {dict, OrderedDict, defaultdict} 1390 if kwargs: 1391 # Only `OrderedDict.fromkeys` accepts `value` passed by keyword 1392 assert user_cls is OrderedDict 1393 assert len(args) == 1 and len(kwargs) == 1 and "value" in kwargs 1394 args = (*args, kwargs.pop("value")) 1395 if len(args) == 0: 1396 raise UserError(TypeError, "fromkeys expected at least 1 argument, got 0") 1397 if len(args) == 1: 1398 args = (*args, ConstantVariable.create(None)) 1399 assert len(args) == 2 1400 arg, value = args 1401 DictVariableType = ( 1402 ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable 1403 ) 1404 1405 if isinstance(arg, dict): 1406 arg = [ConstantVariable.create(k) for k in arg.keys()] 1407 return DictVariableType( 1408 dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal() 1409 ) 1410 elif arg.has_force_unpack_var_sequence(tx): 1411 keys = arg.force_unpack_var_sequence(tx) 1412 if all(is_hashable(v) for v in keys): 1413 return DictVariableType( 1414 dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal() 1415 ) 1416 unimplemented(f"{user_cls.__name__}.fromkeys(): {args} {kwargs}") 1417 1418 def call_set(self, tx: "InstructionTranslator", *args, **kwargs): 1419 # Can we merge this implementation and call_dict's one? 1420 assert not kwargs 1421 if not args: 1422 return SetVariable([], mutable_local=MutableLocal()) 1423 assert len(args) == 1 1424 arg = args[0] 1425 if isinstance(arg, variables.SetVariable): 1426 return arg.clone(mutable_local=MutableLocal()) 1427 elif arg.has_force_unpack_var_sequence(tx): 1428 items = arg.force_unpack_var_sequence(tx) 1429 return SetVariable(items, mutable_local=MutableLocal()) 1430 elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( 1431 arg.value, KeysView 1432 ): 1433 iter_fn = arg.var_getattr(tx, "__iter__") 1434 if isinstance(iter_fn, variables.UserMethodVariable): 1435 out = tx.inline_user_function_return(iter_fn, args, kwargs) 1436 if isinstance(out, SetVariable): 1437 return out 1438 return BuiltinVariable(set).call_set(tx, out) 1439 else: 1440 unimplemented(f"set(): {args} {kwargs}") 1441 else: 1442 unimplemented(f"set(): {args} {kwargs}") 1443 1444 def call_frozenset(self, tx: "InstructionTranslator", *args, **kwargs): 1445 assert not kwargs 1446 if not args: 1447 return FrozensetVariable([]) 1448 assert len(args) == 1 1449 arg = args[0] 1450 if isinstance(arg, variables.FrozensetVariable): 1451 return FrozensetVariable([x.vt for x in arg.set_items]) 1452 elif arg.has_unpack_var_sequence(tx): 1453 items = arg.unpack_var_sequence(tx) 1454 return FrozensetVariable(items) 1455 else: 1456 unimplemented(f"frozenset(): {args} {kwargs}") 1457 1458 def call_zip(self, tx: "InstructionTranslator", *args, **kwargs): 1459 if kwargs: 1460 assert len(kwargs) == 1 and "strict" in kwargs 1461 strict = kwargs.pop("strict", False) 1462 args = [ 1463 arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg 1464 for arg in args 1465 ] 1466 return variables.ZipVariable(args, strict=strict, mutable_local=MutableLocal()) 1467 1468 def call_len(self, tx: "InstructionTranslator", *args, **kwargs): 1469 return args[0].call_method(tx, "__len__", args[1:], kwargs) 1470 1471 def call_getitem(self, tx: "InstructionTranslator", *args, **kwargs): 1472 return args[0].call_method(tx, "__getitem__", args[1:], kwargs) 1473 1474 def call_isinstance(self, tx: "InstructionTranslator", arg, isinstance_type): 1475 try: 1476 arg_type = arg.python_type() 1477 except NotImplementedError: 1478 unimplemented( 1479 f"isinstance({arg}, {isinstance_type}): can't determine type of {arg}" 1480 ) 1481 1482 isinstance_type = isinstance_type.as_python_constant() 1483 1484 if isinstance(arg, variables.TensorVariable) and arg.dtype is not None: 1485 1486 def _tensor_isinstance(tensor_var, tensor_type): 1487 def check_type(ty): 1488 if ty not in tensortype_to_dtype: 1489 example_val = arg.as_proxy().node.meta["example_value"] 1490 if ( 1491 is_traceable_wrapper_subclass(example_val) 1492 and ty is torch.nn.parameter.Parameter 1493 ): 1494 # N.B: we are calling isinstance directly on the example value. 1495 # torch.nn.Parameter has a meta-class that overrides __isinstance__, 1496 # the isinstance check here allows us to invoke that logic. 1497 return isinstance(example_val, ty) 1498 else: 1499 return issubclass(arg.python_type(), ty) 1500 1501 dtypes = tensortype_to_dtype[ty] 1502 return arg.dtype in dtypes 1503 1504 if type(tensor_type) is tuple: 1505 return any(check_type(ty) for ty in tensor_type) 1506 else: 1507 return check_type(tensor_type) 1508 1509 return variables.ConstantVariable.create( 1510 _tensor_isinstance(arg, isinstance_type) 1511 ) 1512 # UserDefinedObject with C extensions can have torch.Tensor attributes, 1513 # so break graph. 1514 if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( 1515 arg.value, types.MemberDescriptorType 1516 ): 1517 unimplemented( 1518 f"isinstance called on UserDefinedClass {arg} {isinstance_type}" 1519 ) 1520 # handle __instancecheck__ defined in user class 1521 if ( 1522 isinstance(arg, variables.UserDefinedObjectVariable) 1523 and "__instancecheck__" in isinstance_type.__class__.__dict__ 1524 ): 1525 return variables.ConstantVariable.create( 1526 isinstance_type.__class__.__instancecheck__(isinstance_type, arg.value) 1527 ) 1528 1529 try: 1530 val = issubclass(arg_type, isinstance_type) 1531 except TypeError: 1532 val = arg_type is isinstance_type 1533 return variables.ConstantVariable.create(val) 1534 1535 def call_issubclass(self, tx: "InstructionTranslator", left_ty, right_ty): 1536 """Checks if first arg is subclass of right arg""" 1537 try: 1538 left_ty_py = left_ty.as_python_constant() 1539 right_ty_py = right_ty.as_python_constant() 1540 except NotImplementedError: 1541 unimplemented( 1542 f"call_issubclass args not constant left_ty: {left_ty}, right_ty: {right_ty}" 1543 ) 1544 1545 return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py)) 1546 1547 def call_super(self, tx: "InstructionTranslator", a, b): 1548 return variables.SuperVariable(a, b) 1549 1550 def call_next(self, tx: "InstructionTranslator", arg: VariableTracker): 1551 try: 1552 return arg.next_variable(tx) 1553 except Unsupported as ex: 1554 if isinstance(arg, variables.BaseListVariable): 1555 ex.remove_from_stats() 1556 return arg.items[0] 1557 raise 1558 1559 def call_hasattr(self, tx: "InstructionTranslator", obj, attr): 1560 if attr.is_python_constant(): 1561 name = attr.as_python_constant() 1562 if isinstance(obj, variables.BuiltinVariable): 1563 return variables.ConstantVariable(hasattr(obj.fn, name)) 1564 return obj.call_hasattr(tx, name) 1565 1566 def call_map(self, tx: "InstructionTranslator", fn, *seqs): 1567 seqs = [ 1568 seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq 1569 for seq in seqs 1570 ] 1571 return variables.MapVariable(fn, seqs, mutable_local=MutableLocal()) 1572 1573 def call_filter(self, tx: "InstructionTranslator", fn, seq): 1574 if seq.has_unpack_var_sequence(tx): 1575 seq_unpacked = seq.unpack_var_sequence(tx) 1576 try: 1577 items = list( 1578 filter( 1579 lambda x: fn.call_function(tx, [x], {}).as_python_constant(), 1580 seq_unpacked, 1581 ) 1582 ) 1583 return variables.TupleVariable(items) 1584 except NotImplementedError: 1585 return 1586 1587 def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): 1588 # Special case for sum on tuple of floats and ints 1589 if isinstance(seq, (variables.ListVariable, variables.TupleVariable)) and all( 1590 isinstance(x, variables.ConstantVariable) 1591 and isinstance(x.value, (int, float)) 1592 for x in seq.items 1593 ): 1594 if start is self._SENTINEL: 1595 return variables.ConstantVariable.create( 1596 sum(x.value for x in seq.items), 1597 ) 1598 if isinstance(start, variables.ConstantVariable) and isinstance( 1599 start.value, (int, float) 1600 ): 1601 return variables.ConstantVariable.create( 1602 sum((x.value for x in seq.items), start=start.value), 1603 ) 1604 if seq.has_force_unpack_var_sequence(tx): 1605 if start is self._SENTINEL: 1606 start = variables.ConstantVariable.create(0) 1607 items = seq.force_unpack_var_sequence(tx) 1608 return BuiltinVariable(functools.reduce).call_function( 1609 tx, 1610 [ 1611 BuiltinVariable(operator.add), 1612 variables.TupleVariable(items), 1613 start, 1614 ], 1615 {}, 1616 ) 1617 1618 def call_reduce( 1619 self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL 1620 ): 1621 if iterable.has_force_unpack_var_sequence(tx): 1622 items = iterable.force_unpack_var_sequence(tx) 1623 if initial is self._SENTINEL: 1624 value, items = items[0], items[1:] 1625 else: 1626 value = initial 1627 for element in items: 1628 value = function.call_function(tx, [value, element], {}) 1629 return value 1630 1631 def call_getattr( 1632 self, 1633 tx: "InstructionTranslator", 1634 obj: VariableTracker, 1635 name_var: VariableTracker, 1636 default=None, 1637 ): 1638 from .. import trace_rules 1639 from . import ( 1640 ConstantVariable, 1641 GetAttrVariable, 1642 TorchInGraphFunctionVariable, 1643 UserFunctionVariable, 1644 ) 1645 from .builder import SourcelessBuilder, VariableBuilder 1646 1647 name = name_var.as_python_constant() 1648 1649 if not name_var.is_python_constant(): 1650 unimplemented("non-const getattr() name") 1651 1652 if tx.output.side_effects.is_attribute_mutation(obj): 1653 if isinstance(obj, variables.UnspecializedNNModuleVariable): 1654 if ( 1655 name 1656 in ( 1657 "named_parameters", 1658 "parameters", 1659 "named_buffers", 1660 "buffers", 1661 "named_modules", 1662 "modules", 1663 ) 1664 and obj.is_state_mutated 1665 and tx.output.side_effects.has_pending_mutation(obj) 1666 ): 1667 unimplemented( 1668 f"pending mutation on nn module, so graph breaking at {name!r} call" 1669 ) 1670 1671 if tx.output.side_effects.has_pending_mutation_of_attr(obj, name): 1672 return tx.output.side_effects.load_attr(obj, name) 1673 1674 if default is not None: 1675 hasattr_var = self.call_hasattr(tx, obj, name_var) 1676 assert hasattr_var.as_python_constant() in (True, False) 1677 if not hasattr_var.as_python_constant(): 1678 return default 1679 1680 options = {} 1681 if obj.source: 1682 source = AttrSource(obj.source, name) 1683 options["source"] = source 1684 else: 1685 source = None 1686 1687 if name in {"__bases__", "__base__", "__flags__"}: 1688 try: 1689 value = obj.as_python_constant() 1690 if isinstance(value, type): 1691 if name == "__bases__": 1692 bases = value.__bases__ 1693 if source is not None: 1694 tuple_args = [ 1695 VariableBuilder(tx, GetItemSource(source, i))(b) 1696 for i, b in enumerate(bases) 1697 ] 1698 else: 1699 tuple_args = [ 1700 SourcelessBuilder.create(tx, b) for b in bases 1701 ] 1702 return variables.TupleVariable(tuple_args, **options) 1703 if name == "__base__": 1704 base = value.__base__ 1705 if source is not None: 1706 return VariableBuilder(tx, source)(base) 1707 return SourcelessBuilder.create(tx, base) 1708 if name == "__flags__": 1709 return ConstantVariable.create(value.__flags__) 1710 except NotImplementedError: 1711 pass 1712 1713 if isinstance(obj, variables.NNModuleVariable): 1714 return obj.var_getattr(tx, name) 1715 elif isinstance( 1716 obj, 1717 ( 1718 variables.TensorVariable, 1719 variables.NamedTupleVariable, 1720 variables.ConstantVariable, 1721 variables.DistributedVariable, 1722 variables.UserDefinedClassVariable, 1723 variables.UserDefinedObjectVariable, 1724 ), 1725 ): 1726 try: 1727 return obj.var_getattr(tx, name) 1728 except NotImplementedError: 1729 return GetAttrVariable(obj, name, **options) 1730 elif isinstance(obj, TorchInGraphFunctionVariable): 1731 # Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default. 1732 member = getattr(obj.value, name) 1733 if isinstance( 1734 member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) 1735 ) and trace_rules.is_aten_op_or_tensor_method(member): 1736 return TorchInGraphFunctionVariable(member, **options) 1737 elif isinstance(obj, DummyModule): 1738 # TODO(mlazos) - Do we need this? 1739 if obj.is_torch or name not in obj.value.__dict__: 1740 member = getattr(obj.value, name) 1741 else: 1742 member = obj.value.__dict__[name] 1743 1744 if config.replay_record_enabled: 1745 tx.exec_recorder.record_module_access(obj.value, name, member) 1746 1747 if source is not None: 1748 return VariableBuilder(tx, source)(member) 1749 else: 1750 return SourcelessBuilder.create(tx, member) 1751 elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"): 1752 return ConstantVariable.create(getattr(obj.fn, name)) 1753 else: 1754 try: 1755 return obj.var_getattr(tx, name) 1756 except NotImplementedError: 1757 return GetAttrVariable(obj, name, **options) 1758 1759 def call_setattr( 1760 self, 1761 tx: "InstructionTranslator", 1762 obj: VariableTracker, 1763 name_var: VariableTracker, 1764 val: VariableTracker, 1765 ): 1766 if isinstance( 1767 obj, 1768 ( 1769 variables.CustomizedDictVariable, 1770 variables.PlacementVariable, 1771 variables.UserDefinedObjectVariable, 1772 ), 1773 ): 1774 return obj.call_method(tx, "__setattr__", [name_var, val], {}) 1775 elif ( 1776 tx.output.side_effects.is_attribute_mutation(obj) 1777 and name_var.is_python_constant() 1778 ): 1779 name = name_var.as_python_constant() 1780 if isinstance(obj, variables.TensorVariable): 1781 from .builder import wrap_fx_proxy 1782 1783 if name == "requires_grad": 1784 # TODO(voz): Make it work properly 1785 unimplemented( 1786 "mutating requires_grad can introduce a new leaf from non-leaf or vice versa in " 1787 "the middle of the graph, which aot_autograd does not currently know how to handle. " 1788 ) 1789 if name == "data": 1790 # Remove the old reference in tracked fakes - if we don't do this 1791 # new .data value size and shape differences will cause 1792 # tracked fakes to produce incorrect guards. This is sound because the TensorVariable 1793 # coming out of set_() below will be a new one, and get 1794 # installed in tracked fakes. 1795 to_remove = [] 1796 for tf in tx.output.tracked_fakes: 1797 if tf.source == obj.source: 1798 to_remove.append(tf) 1799 for tf in to_remove: 1800 tx.output.tracked_fakes.remove(tf) 1801 1802 # Step 1 - disable grads 1803 with dynamo_disable_grad(tx), torch.no_grad(): 1804 # Step 2 - call `set_` 1805 out = wrap_fx_proxy( 1806 tx, 1807 tx.output.create_proxy( 1808 "call_function", 1809 torch.Tensor.set_, 1810 *proxy_args_kwargs([obj, val], {}), 1811 ), 1812 ) 1813 1814 # Step 3 - drop the version counter - this is a step required to get 1815 # .data setting to play correctly with the autograd engine. 1816 # Essentially, dynamo is trying to faithfully preserve the (absurd) 1817 # behavior of .data= from eager mode 1818 def _lower_version_count_by_1(x): 1819 version = x._version 1820 if version > 0: 1821 version = version - 1 1822 torch._C._autograd._unsafe_set_version_counter(x, version) 1823 return x 1824 1825 tx.output.create_proxy( 1826 "call_function", 1827 _lower_version_count_by_1, 1828 (out.as_proxy(),), 1829 {}, 1830 ) 1831 _lower_version_count_by_1(obj.as_proxy().node.meta["example_value"]) 1832 # This handles options prop, guards and ends with a clone 1833 # Step 4 - replace all reference to the current object with the new one 1834 return out 1835 1836 tx.output.side_effects.store_attr(obj, name, val) 1837 if name == "_grad": 1838 tx.output.side_effects.store_attr(obj, "grad", val) 1839 1840 return val 1841 elif isinstance(obj, variables.UserDefinedObjectVariable): 1842 unimplemented( 1843 f"setattr(UserDefinedObjectVariable) {type(obj.value).__setattr__}" 1844 ) 1845 elif isinstance(obj, variables.NNModuleVariable): 1846 if not tx.output.is_root_tracer(): 1847 raise AttributeMutationError( 1848 "Can't inplace modify module params/buffers inside HigherOrderOp" 1849 ) 1850 if name_var.is_python_constant() and isinstance( 1851 val, variables.TensorVariable 1852 ): 1853 assigning_fake_val = get_fake_value(val.as_proxy().node, tx) 1854 1855 try: 1856 getattr_var = obj.var_getattr(tx, name_var.as_python_constant()) 1857 except AttributeError: 1858 getattr_var = None 1859 1860 if isinstance(getattr_var, variables.TensorVariable): 1861 # get_fake_val will get the same fake tensor 1862 existing_fake_attr = get_fake_value(getattr_var.as_proxy().node, tx) 1863 1864 # same tensor identiy, setattr is a no-op 1865 mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__") 1866 if ( 1867 existing_fake_attr is assigning_fake_val 1868 and mod_setattr is torch.nn.Module.__setattr__ 1869 ): 1870 return getattr_var 1871 1872 obj.convert_to_unspecialized(tx) 1873 # FIXME (tmanlaibaatar) this is utter hack to unblock HuggingFace export 1874 # Export generally doesn't want to allow mutations on objects directly, 1875 # but we don't have good way to do this rn. For now, we make it an undefined 1876 # behaviour and just set attributes directly on the PretrainedConfig object 1877 # for now. 1878 elif isinstance(obj, variables.dicts.HFPretrainedConfigVariable) and tx.export: 1879 if name_var.is_python_constant() and isinstance( 1880 val, variables.ConstantVariable 1881 ): 1882 setattr( 1883 obj.obj, name_var.as_python_constant(), val.as_python_constant() 1884 ) 1885 return ConstantVariable(None) 1886 1887 def call_delattr( 1888 self, 1889 tx: "InstructionTranslator", 1890 obj: VariableTracker, 1891 name_var: VariableTracker, 1892 ): 1893 return self.call_setattr(tx, obj, name_var, variables.DeletedVariable()) 1894 1895 def call_type(self, tx: "InstructionTranslator", obj: VariableTracker): 1896 from .builder import SourcelessBuilder, VariableBuilder 1897 1898 try: 1899 py_type = obj.python_type() 1900 except NotImplementedError as error: 1901 raise UserError( 1902 UserErrorType.INVALID_INPUT, 1903 str(error), 1904 case_name="unknown_python_type", 1905 ) from None 1906 1907 if obj.source is None: 1908 return SourcelessBuilder.create(tx, py_type) 1909 else: 1910 return VariableBuilder(tx, TypeSource(obj.source))(py_type) 1911 1912 def call_reversed(self, tx: "InstructionTranslator", obj: VariableTracker): 1913 if obj.has_unpack_var_sequence(tx): 1914 items = list(reversed(obj.unpack_var_sequence(tx))) 1915 return variables.TupleVariable(items) 1916 1917 def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwargs): 1918 if obj.has_force_unpack_var_sequence(tx) and not isinstance( 1919 obj, variables.TensorVariable 1920 ): 1921 unpacked = obj.force_unpack_var_sequence(tx) 1922 if not all(x.is_python_constant() for x in unpacked): 1923 return 1924 function = kwargs.pop("key", None) 1925 reverse = kwargs.pop( 1926 "reverse", ConstantVariable.create(False) 1927 ).as_python_constant() 1928 assert len(kwargs) == 0 1929 if function: 1930 items = sorted( 1931 unpacked, 1932 key=lambda x: function.call_function( 1933 tx, [x], {} 1934 ).as_python_constant(), 1935 reverse=reverse, 1936 ) 1937 else: 1938 items = sorted( 1939 unpacked, 1940 key=lambda x: x.as_python_constant(), 1941 reverse=reverse, 1942 ) 1943 return variables.ListVariable(items) 1944 1945 # neg is a constant fold function, so we only get here if constant fold is not valid 1946 def call_neg(self, tx: "InstructionTranslator", a): 1947 if isinstance(a, SymNodeVariable): 1948 return SymNodeVariable.create( 1949 tx, 1950 (operator.neg)(a.as_proxy()), 1951 sym_num=None, 1952 ) 1953 # None no-ops this handler and lets the driving function proceed 1954 return None 1955 1956 def call_format(self, tx: "InstructionTranslator", _format_string, *args, **kwargs): 1957 format_string = _format_string.as_python_constant() 1958 return variables.StringFormatVariable.create(format_string, args, kwargs) 1959 1960 def call_id(self, tx: "InstructionTranslator", *args): 1961 if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable): 1962 nn_mod_variable = args[0] 1963 mod = tx.output.get_submodule(nn_mod_variable.module_key) 1964 return variables.ConstantVariable.create(id(mod)) 1965 elif len(args) == 1 and isinstance( 1966 args[0], variables.UserDefinedObjectVariable 1967 ): 1968 install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH)) 1969 constant_result = id(args[0].value) 1970 return variables.ConstantVariable.create(constant_result) 1971 elif len(args) == 1 and isinstance(args[0], TensorVariable): 1972 tensor_variable = args[0] 1973 return tensor_variable.call_id(tx) 1974 else: 1975 unimplemented(f"call_id with args {args}") 1976 1977 def call_deepcopy(self, tx: "InstructionTranslator", x): 1978 unimplemented(f"copy.deepcopy {repr(x)}") 1979 1980 def _comparison_with_tensor(self, tx: "InstructionTranslator", left, right): 1981 from .builder import wrap_fx_proxy_cls 1982 from .tensor import supported_tensor_comparison_op_values 1983 1984 op = self.fn 1985 1986 if op in [operator.is_, operator.is_not]: 1987 is_result = ( 1988 isinstance(left, TensorVariable) 1989 and isinstance(right, TensorVariable) 1990 and id(extract_fake_example_value(left.as_proxy().node)) 1991 == id(extract_fake_example_value(right.as_proxy().node)) 1992 ) 1993 if op is operator.is_: 1994 return ConstantVariable.create(is_result) 1995 else: 1996 return ConstantVariable.create(not is_result) 1997 1998 if op not in supported_tensor_comparison_op_values: 1999 unimplemented(f"{op.__name__}({left}, {right})") 2000 if ( 2001 isinstance(left, TensorVariable) 2002 and isinstance(right, TensorVariable) 2003 and (left.size and right.size) is not None 2004 and left.size != right.size 2005 ): 2006 try: 2007 torch.broadcast_shapes(left.size, right.size) 2008 except RuntimeError: 2009 # not broadcastable, can't be compared 2010 unimplemented(f"{op.__name__}({left}, {right})") 2011 tensor_cls = left if isinstance(left, TensorVariable) else right 2012 proxy = tx.output.create_proxy( 2013 "call_function", op, (left.as_proxy(), right.as_proxy()), {} 2014 ) 2015 return wrap_fx_proxy_cls( 2016 type(tensor_cls), # handle Ndarrays and Tensors 2017 tx, 2018 proxy, 2019 ) 2020 2021 def _comparison_with_symnode(self, tx: "InstructionTranslator", left, right): 2022 from .tensor import supported_tensor_comparison_op_values 2023 2024 op = self.fn 2025 2026 if op not in supported_tensor_comparison_op_values: 2027 unimplemented(f"{op.__name__}({left}, {right})") 2028 2029 proxy = tx.output.create_proxy( 2030 "call_function", op, (left.as_proxy(), right.as_proxy()), {} 2031 ) 2032 return SymNodeVariable.create( 2033 tx, 2034 proxy, 2035 sym_num=None, 2036 ) 2037 2038 def call_and_(self, tx: "InstructionTranslator", a, b): 2039 # Rely on constant_handler 2040 if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): 2041 return None 2042 if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( 2043 b, (SymNodeVariable, ConstantVariable) 2044 ): 2045 return SymNodeVariable.create( 2046 tx, 2047 tx.output.create_proxy( 2048 "call_function", operator.and_, *proxy_args_kwargs([a, b], {}) 2049 ), 2050 sym_num=None, 2051 ) 2052 if hasattr(a, "set_items") and hasattr(b, "set_items"): 2053 return SetVariable(list(a.set_items & b.set_items)) 2054 # None no-ops this handler and lets the driving function proceed 2055 2056 def call_or_(self, tx: "InstructionTranslator", a, b): 2057 # Rely on constant_handler 2058 if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): 2059 return None 2060 if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( 2061 b, (SymNodeVariable, ConstantVariable) 2062 ): 2063 return SymNodeVariable.create( 2064 tx, 2065 tx.output.create_proxy( 2066 "call_function", operator.or_, *proxy_args_kwargs([a, b], {}) 2067 ), 2068 sym_num=None, 2069 ) 2070 if hasattr(a, "set_items") and hasattr(b, "set_items"): 2071 return SetVariable(list(a.set_items | b.set_items)) 2072 # None no-ops this handler and lets the driving function proceed 2073 return None 2074 2075 def call_not_(self, tx: "InstructionTranslator", a): 2076 if isinstance(a, SymNodeVariable): 2077 return SymNodeVariable.create( 2078 tx, 2079 tx.output.create_proxy( 2080 "call_function", operator.not_, *proxy_args_kwargs([a], {}) 2081 ), 2082 sym_num=None, 2083 ) 2084 2085 # Unwrap the underlying ConstDictVariable 2086 if isinstance(a, DictView): 2087 a = a.dv_dict 2088 if isinstance(a, (ListVariable, ConstDictVariable)): 2089 return ConstantVariable.create(len(a.items) == 0) 2090 2091 return None 2092 2093 def call_contains( 2094 self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker 2095 ): 2096 return a.call_method(tx, "__contains__", [b], {}) 2097 2098 2099@contextlib.contextmanager 2100def dynamo_disable_grad(tx): 2101 from . import GradModeVariable 2102 2103 org_value = torch.is_grad_enabled() 2104 gmv = GradModeVariable.create(tx, False) 2105 try: 2106 gmv.enter(tx) 2107 yield 2108 finally: 2109 gmv.exit(tx) 2110