1# mypy: ignore-errors 2import collections 3import dataclasses 4import functools 5import inspect 6import itertools 7import random 8import re 9import sys 10import types 11from typing import Dict, List, Optional, TYPE_CHECKING 12 13import torch._C 14import torch._numpy as tnp 15import torch.utils._pytree as pytree 16 17from .. import config, variables 18from ..bytecode_transformation import create_call_function, create_instruction 19from ..create_parameter_op import do_not_convert_to_tracable_parameter 20from ..exc import unimplemented 21from ..guards import GuardBuilder, install_guard 22from ..mutation_guard import unpatched_nn_module_init 23from ..source import ( 24 AttrSource, 25 DefaultsSource, 26 GetItemSource, 27 ODictGetItemSource, 28 TypeSource, 29) 30from ..utils import ( 31 check_unspec_or_constant_args, 32 identity, 33 is_tensor_base_attr_getter, 34 proxy_args_kwargs, 35 set_example_value, 36) 37from .base import VariableTracker 38from .functions import ( 39 NestedUserFunctionVariable, 40 UserFunctionVariable, 41 UserMethodVariable, 42 wrap_bound_arg, 43) 44from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable 45 46 47if TYPE_CHECKING: 48 from torch._dynamo.symbolic_convert import InstructionTranslator 49 50 51class NO_SUCH_SUBOBJ: 52 pass 53 54 55class SuperVariable(VariableTracker): 56 _nonvar_fields = { 57 "specialized", 58 *VariableTracker._nonvar_fields, 59 } 60 61 def __init__(self, typevar, objvar=None, specialized=False, **kwargs) -> None: 62 super().__init__(**kwargs) 63 # typevar is the fist argument to super(). In the case where no argument 64 # is provided to super(), it is the __class__ object where 65 # the super() function is being called 66 self.typevar = typevar 67 # objvar here must be an instance or subtype of typevar. 68 # In the case where super() is called without arguments, it is the first argument 69 # to the current function where super() is called from (self for regular method, 70 # cls for a classmethod) 71 self.objvar = objvar 72 self.specialized = specialized # directly get attr from self.typevar if true 73 74 def reconstruct(self, codegen): 75 codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super))) 76 codegen(self.typevar) 77 if self.objvar is not None: 78 codegen(self.objvar) 79 codegen.extend_output(create_call_function(2, False)) 80 else: 81 codegen.extend_output(create_call_function(1, False)) 82 83 def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name): 84 assert self.objvar, "1-arg super not implemented" 85 if self.specialized: 86 return getattr(self.typevar.as_python_constant(), name) 87 search_type = self.typevar.as_python_constant() 88 89 # The rest of this function does two things: 90 # - Walk the mro to find where the attribute comes from to be 91 # able to provide accurate source 92 # - Call the getattr to get the object 93 94 # Find the class object, where the function lives. 95 # When objvar is "self", use type(self), when objvar is "cls", use it as-is 96 type_to_use = self.objvar.python_type() 97 type_to_use_source = ( 98 TypeSource(self.objvar.source) if self.objvar.source else None 99 ) 100 if issubclass(type_to_use, type): 101 type_to_use = self.objvar.value 102 type_to_use_source = self.objvar.source 103 104 source = None 105 resolved_class = None 106 resolved_attr = None 107 search_mro = type_to_use.__mro__ 108 109 try: 110 start_index = search_mro.index(search_type) + 1 111 except ValueError: 112 # Corner case where the typevar is not in the mro of the objvar 113 # https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8843-L8844 114 return getattr(super(search_type, type_to_use), name), None 115 # Implemented based on https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8812 116 # super has its getattro implementation. The key point is that instead of calling getattr, it checks the 117 # attribute in the class __dict__ 118 for index in range(start_index, len(search_mro)): 119 # Dont call getattr, just check the __dict__ of the class 120 if resolved_getattr := search_mro[index].__dict__.get(name, NO_SUCH_SUBOBJ): 121 if resolved_getattr is not NO_SUCH_SUBOBJ: 122 # Equivalent of something like type(L['self']).__mro__[1].attr_name 123 if type_to_use_source: 124 source = AttrSource( 125 GetItemSource( 126 AttrSource(type_to_use_source, "__mro__"), index 127 ), 128 name, 129 ) 130 return resolved_getattr, source 131 132 unimplemented("Unable to resolve super getattr") 133 134 def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 135 # Check if getattr is a constant. If not, delay the actual work by 136 # wrapping the result in GetAttrVariable. Mostly super is called with a 137 # method, so most of the work is delayed to call_function. 138 # 139 # We could have just implemented a const_getattr. However, super is 140 # special when it comes to finding sources. Compared to other VTs, super 141 # requires the attr name to walk the mro and find the actual source (and 142 # not just AttrSource). 143 value, source = self._resolved_getattr_and_source(self, name) 144 if not variables.ConstantVariable.is_literal(value): 145 return GetAttrVariable(self, name) 146 if source: 147 install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) 148 return variables.ConstantVariable.create(value, source=source) 149 return variables.ConstantVariable.create(value) 150 151 def call_method( 152 self, 153 tx, 154 name, 155 args: "List[VariableTracker]", 156 kwargs: "Dict[str, VariableTracker]", 157 ) -> "VariableTracker": 158 inner_fn, source = self._resolved_getattr_and_source(self, name) 159 if inner_fn is object.__init__: 160 return LambdaVariable(identity) 161 elif inner_fn is torch.nn.Module.__init__: 162 objvar = self.objvar 163 from ..side_effects import AttributeMutationNew 164 165 if ( 166 isinstance(objvar, variables.UserDefinedObjectVariable) 167 and isinstance(objvar.mutable_local, AttributeMutationNew) 168 and not (args or kwargs) 169 ): 170 with do_not_convert_to_tracable_parameter(): 171 return variables.UserFunctionVariable( 172 unpatched_nn_module_init, source=source 173 ).call_function(tx, [self.objvar] + args, kwargs) 174 else: 175 unimplemented("super() nn.Module.__init__") 176 elif self.objvar.source and inner_fn is object.__new__: 177 return tx.output.side_effects.track_object_new_from_user_defined_class( 178 self.objvar 179 ) 180 elif isinstance(inner_fn, staticmethod) and isinstance( 181 inner_fn.__func__, types.FunctionType 182 ): 183 return variables.UserFunctionVariable( 184 inner_fn.__func__, source=source 185 ).call_function(tx, args, kwargs) 186 elif isinstance(inner_fn, classmethod) and isinstance( 187 inner_fn.__func__, types.FunctionType 188 ): 189 return variables.UserMethodVariable( 190 inner_fn.__func__, self.objvar, source=source 191 ).call_function(tx, args, kwargs) 192 elif isinstance(inner_fn, types.FunctionType): 193 return variables.UserFunctionVariable( 194 inner_fn, source=source 195 ).call_function(tx, [self.objvar] + args, kwargs) 196 elif isinstance(inner_fn, types.MethodType): 197 return variables.UserMethodVariable( 198 inner_fn.__func__, self.objvar, source=source 199 ).call_function(tx, args, kwargs) 200 elif ( 201 inner_fn is collections.OrderedDict.__getitem__ 202 and isinstance(self.objvar, variables.UserDefinedObjectVariable) 203 and self.objvar.source 204 and len(args) == 1 205 and len(kwargs) == 0 206 and args[0].is_python_constant() 207 ): 208 from .builder import VariableBuilder 209 210 key = args[0].as_python_constant() 211 return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))( 212 collections.OrderedDict.__getitem__(self.objvar.value, key) 213 ) 214 elif inner_fn in ( 215 collections.OrderedDict.__setitem__, 216 object.__setattr__, 217 ) and isinstance(self.objvar, variables.CustomizedDictVariable): 218 assert not kwargs and len(args) == 2 219 return super(variables.CustomizedDictVariable, self.objvar).call_method( 220 tx, "__setitem__", args, kwargs 221 ) 222 elif inner_fn is collections.OrderedDict.__getitem__ and isinstance( 223 self.objvar, variables.CustomizedDictVariable 224 ): 225 return super(variables.CustomizedDictVariable, self.objvar).call_method( 226 tx, "__getitem__", args, kwargs 227 ) 228 elif is_standard_setattr(inner_fn) and isinstance( 229 self.objvar, UserDefinedObjectVariable 230 ): 231 return self.objvar.method_setattr_standard(tx, *args, **kwargs) 232 elif inner_fn is object.__delattr__: 233 attr = args[0] 234 try: 235 attr = attr.as_python_constant() 236 except NotImplementedError: 237 unimplemented(f"non-const delattr attr: {attr}") 238 if not tx.output.side_effects.is_attribute_mutation(self.objvar): 239 unimplemented(f"delattr({self.objvar}, {attr}, ...)") 240 241 tx.output.side_effects.store_attr( 242 self.objvar, attr, variables.DeletedVariable() 243 ) 244 return variables.ConstantVariable(None) 245 246 unimplemented(f"non-function or method super: {inner_fn}") 247 248 249class ExceptionVariable(VariableTracker): 250 def __init__(self, exc_type, args, **kwargs) -> None: 251 super().__init__(**kwargs) 252 self.exc_type = exc_type 253 self.args = args 254 255 def reconstruct(self, codegen): 256 codegen.add_push_null( 257 lambda: codegen.load_import_from("builtins", self.exc_type.__name__) 258 ) 259 codegen.foreach(self.args) 260 codegen.call_function(len(self.args), False) 261 262 263class UnknownVariable(VariableTracker): 264 """ 265 It could be anything! 266 """ 267 268 269class DelayGraphBreakVariable(UnknownVariable): 270 """ 271 Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION. 272 """ 273 274 275class ComptimeVariable(VariableTracker): 276 """ 277 This variable is special, it lets you execute arbitrary code at 278 Dynamo compile time 279 """ 280 281 def reconstruct(self, codegen): 282 raise NotImplementedError("comptime is special form") 283 284 def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 285 from ..comptime import comptime 286 287 # To support the comptime.print_graph convenience accessors 288 from .functions import UserFunctionVariable 289 290 return UserFunctionVariable( 291 getattr(comptime, name), source=AttrSource(self.source, name) 292 ) 293 294 def call_function( 295 self, 296 tx: "InstructionTranslator", 297 args: "List[VariableTracker]", 298 kwargs: "Dict[str, VariableTracker]", 299 ) -> "VariableTracker": 300 from ..comptime import ComptimeContext 301 302 # TODO: support an expression form as well 303 304 assert not kwargs 305 # Second argument is runtime lambda, ignored 306 assert len(args) <= 2 307 fn = args[0] 308 if isinstance(fn, UserFunctionVariable): 309 fn.get_function()(ComptimeContext(tx)) 310 elif isinstance(fn, NestedUserFunctionVariable): 311 # We have to manually bind the freevars ourselves 312 code = fn.get_code() 313 assert not fn.closure, ( 314 "comptime function must not have free variables, " 315 f"but these variables were free: {code.co_freevars}" 316 ) 317 func = types.FunctionType( 318 code, 319 fn.f_globals, 320 fn.fn_name.as_python_constant(), 321 tuple(fn.defaults.items) if fn.defaults else None, 322 # We could automatically promote free variables into 323 # ComptimeVar but this is confusing if you access 324 # a free variable that we actually DO have the runtime 325 # value for 326 # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items) 327 (), 328 ) 329 func(ComptimeContext(tx)) 330 else: 331 raise RuntimeError(f"unsupported argument to comptime: {type(fn)}") 332 333 return variables.ConstantVariable.create(None) 334 335 336class ClosureVariable(UnknownVariable): 337 _nonvar_fields = { 338 "name", 339 *UnknownVariable._nonvar_fields, 340 } 341 342 def __init__(self, name, **kwargs) -> None: 343 super().__init__(**kwargs) 344 self.name = name 345 346 def reconstruct(self, codegen): 347 codegen.append_output(codegen.create_load_closure(self.name)) 348 349 350# closure variable created by an inlined function 351class InlinedClosureVariable(UnknownVariable): 352 _nonvar_fields = { 353 "name", 354 *UnknownVariable._nonvar_fields, 355 } 356 357 def __init__(self, name, **kwargs) -> None: 358 super().__init__(**kwargs) 359 self.name = name 360 361 def reconstruct(self, codegen): 362 codegen.append_output(codegen.create_load_closure(self.name)) 363 364 365class NewCellVariable(VariableTracker): 366 def __init__(self, **kwargs) -> None: 367 super().__init__(**kwargs) 368 369 370class NewGlobalVariable(VariableTracker): 371 def __init__(self, **kwargs) -> None: 372 super().__init__(**kwargs) 373 374 375class InspectSignatureVariable(VariableTracker): 376 """represents inspect.signature(...)""" 377 378 _nonvar_fields = { 379 "signature", 380 "parameters", 381 *VariableTracker._nonvar_fields, 382 } 383 384 @staticmethod 385 def create(callable, **kwargs): 386 if kwargs: 387 unimplemented(f"inspect.signature with {kwargs}") 388 return InspectSignatureVariable( 389 callable, mutable_local=variables.base.MutableLocal() 390 ) 391 392 def __init__(self, inspected: VariableTracker, **kwargs) -> None: 393 super().__init__(**kwargs) 394 self.inspected = inspected 395 396 if isinstance(self.inspected, UserMethodVariable): 397 self.fn = self.inspected.get_function() 398 self.signature = inspect.signature(self.fn) 399 self.parameters = list(self.signature.parameters.items())[1:] 400 elif isinstance(self.inspected, UserFunctionVariable): 401 self.fn = self.inspected.get_function() 402 self.signature = inspect.signature(self.fn) 403 self.parameters = list(self.signature.parameters.items()) 404 else: 405 self.fn = self.inspected.as_python_constant() 406 self.signature = inspect.signature(self.fn) 407 self.parameters = list(self.signature.parameters.items()) 408 409 def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 410 if name == "parameters": 411 return variables.ConstDictVariable( 412 { 413 variables.ConstantVariable.create( 414 param[0] 415 ): InspectParameterVariable(param[1]) 416 for param in self.parameters 417 }, 418 user_cls=dict, 419 ) 420 return super().var_getattr(tx, name) 421 422 def call_method( 423 self, 424 tx, 425 name, 426 args: "List[VariableTracker]", 427 kwargs: "Dict[str, VariableTracker]", 428 ) -> "VariableTracker": 429 if name == "bind": 430 if not hasattr(self.fn, "__kwdefaults__"): 431 unimplemented( 432 f"inspect.signature.bind with {self.fn} without __kwdefaults__" 433 ) 434 obj = self.signature.bind(*args, **kwargs) 435 436 # wrap function defaults in VTs 437 defaults = {} 438 if self.fn.__kwdefaults__: 439 wrap = functools.partial(wrap_bound_arg, tx=tx) 440 kwdefaults_sources = { 441 k: None 442 if self.source is None 443 else DefaultsSource(self.source, k, is_kw=True) 444 for k in self.fn.__kwdefaults__ 445 } 446 defaults = { 447 k: wrap(val=v, source=kwdefaults_sources[k]) 448 for k, v in self.fn.__kwdefaults__.items() 449 } 450 451 return InspectBoundArgumentsVariable( 452 obj, 453 defaults, 454 self, 455 ) 456 return super().call_method(tx, name, args, kwargs) 457 458 def reconstruct(self, codegen): 459 codegen.add_push_null( 460 lambda: codegen.extend_output( 461 [ 462 codegen.create_load_python_module(inspect), 463 codegen.create_load_attr("signature"), 464 ] 465 ) 466 ) 467 codegen(self.inspected) 468 codegen.extend_output(create_call_function(1, False)) 469 470 471class InspectParameterVariable(VariableTracker): 472 """represents inspect.Parameter(...)""" 473 474 def __init__(self, value, **kwargs) -> None: 475 super().__init__(**kwargs) 476 self.value = value 477 478 def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 479 from .builder import SourcelessBuilder, VariableBuilder 480 481 try: 482 attr_value = getattr(self.value, name) 483 if self.source: 484 attr_source = AttrSource(self.source, name) 485 return VariableBuilder(tx, attr_source)(attr_value) 486 else: 487 return SourcelessBuilder.create(tx, attr_value) 488 except AttributeError: 489 unimplemented(f"getattr({self.value}, {name})") 490 491 492class InspectBoundArgumentsVariable(VariableTracker): 493 """represents inspect.signature(...).bind(...)""" 494 495 _nonvar_fields = { 496 "bound_arguments", 497 "packed_vars", 498 *VariableTracker._nonvar_fields, 499 } 500 501 # NOTE: we keep track of changes to arguments via bound_arguments_var, 502 # but we still keep a copy of the inspect.BoundArguments object in order 503 # to get the correct args/kwargs. 504 def __init__( 505 self, 506 bound_arguments: inspect.BoundArguments, 507 defaults: Dict[str, VariableTracker], 508 signature: InspectSignatureVariable, 509 **kwargs, 510 ): 511 super().__init__(**kwargs) 512 self.bound_arguments = bound_arguments 513 self.defaults = defaults 514 # used to convert from VT to tuple/dict when updating bound_arguments 515 self.packed_vars = set() 516 517 arguments_dict = {} 518 for key, val in bound_arguments.arguments.items(): 519 key_var = variables.ConstantVariable(key) 520 # convert val to VT 521 if isinstance(val, tuple): 522 arguments_dict[key_var] = variables.TupleVariable(list(val)) 523 self.packed_vars.add(key) 524 elif isinstance(val, dict): 525 self.packed_vars.add(key) 526 arguments_dict[key_var] = variables.ConstDictVariable( 527 {variables.ConstantVariable(k): v for k, v in val.items()} 528 ) 529 elif isinstance(val, VariableTracker): 530 arguments_dict[key_var] = val 531 else: 532 unimplemented( 533 "inspect.signature(...).bind(...).arguments contains non-variable/tuple/dict" 534 ) 535 536 self.bound_arguments_var = variables.ConstDictVariable( 537 arguments_dict, 538 type(bound_arguments.arguments), 539 mutable_local=variables.base.MutableLocal(), 540 ) 541 self.signature = signature 542 543 def _update_bound_arguments(self): 544 for key, val in self.bound_arguments_var.items.items(): 545 true_val = val 546 if key.underlying_value in self.packed_vars: 547 if isinstance(val, variables.TupleVariable): 548 true_val = tuple(val.items) 549 elif isinstance(val, variables.ConstDictVariable): 550 true_val = {k.underlying_value: v for k, v in val.items.items()} 551 else: 552 unimplemented( 553 "inspect.signature(...).bind(...) cannot update bound arguments" 554 ) 555 self.bound_arguments.arguments[key.underlying_value] = true_val 556 557 def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 558 if name == "arguments": 559 return self.bound_arguments_var 560 elif name == "args": 561 self._update_bound_arguments() 562 return variables.TupleVariable(list(self.bound_arguments.args)) 563 elif name == "kwargs": 564 self._update_bound_arguments() 565 kw = { 566 variables.ConstantVariable(key): val 567 for key, val in self.bound_arguments.kwargs.items() 568 } 569 return variables.ConstDictVariable(kw) 570 elif name == "signature": 571 return self.signature 572 return super().var_getattr(tx, name) 573 574 def call_method( 575 self, 576 tx, 577 name, 578 args: "List[VariableTracker]", 579 kwargs: "Dict[str, VariableTracker]", 580 ) -> "VariableTracker": 581 if name == "apply_defaults": 582 # mimic calling apply_defaults 583 for key, val in self.defaults.items(): 584 key_var = variables.ConstantVariable(key) 585 if key_var not in self.bound_arguments_var: 586 self.bound_arguments_var.call_method( 587 tx, "__setitem__", [key_var, val], {} 588 ) 589 590 # actually apply the changes 591 self._update_bound_arguments() 592 593 return variables.ConstantVariable(None) 594 return super().call_method(tx, name, args, kwargs) 595 596 def reconstruct(self, codegen): 597 # reconstruct inspect.signature(...).bind(*bound_arguments.args, **bound_arguments.kwargs) 598 # NOTE the reconstructed inspect.signature(...) object might not be the same object 599 # as the Signature object that originally created the BoundArguments object. 600 self._update_bound_arguments() 601 602 def gen_fn(): 603 codegen(self.signature) 604 codegen.append_output(codegen.create_load_attr("bind")) 605 606 codegen.add_push_null(gen_fn, call_function_ex=True) 607 608 codegen.foreach(self.bound_arguments.args) 609 codegen.append_output( 610 create_instruction("BUILD_TUPLE", arg=len(self.bound_arguments.args)) 611 ) 612 613 for key, val in self.bound_arguments.kwargs.items(): 614 codegen.append_output(codegen.create_load_const(key)) 615 codegen(val) 616 codegen.extend_output( 617 [ 618 create_instruction("BUILD_MAP", arg=len(self.bound_arguments.kwargs)), 619 create_instruction("CALL_FUNCTION_EX", arg=1), 620 ] 621 ) 622 623 624def produce_trampoline_autograd_apply(fn_cls): 625 def trampoline_autograd_apply(*args, **kwargs): 626 return fn_cls.apply(*args, **kwargs) 627 628 trampoline_autograd_apply._origin = produce_trampoline_autograd_apply 629 return trampoline_autograd_apply 630 631 632class AutogradFunctionVariable(VariableTracker): 633 """represents a torch.autograd.Function subclass""" 634 635 _nonvar_fields = { 636 "fn_cls", 637 *VariableTracker._nonvar_fields, 638 } 639 640 def __init__(self, fn_cls, **kwargs) -> None: 641 super().__init__(**kwargs) 642 self.fn_cls = fn_cls 643 644 def call_apply(self, tx: "InstructionTranslator", args, kwargs): 645 requires_grad = False 646 647 def visit(node): 648 nonlocal requires_grad 649 if isinstance(node, variables.TensorVariable): 650 if node.requires_grad is not False: 651 requires_grad = True 652 if isinstance(node, variables.NNModuleVariable): 653 if node.is_training(tx): 654 requires_grad = True 655 656 VariableTracker.visit(visit, (args, kwargs)) 657 658 if ( 659 requires_grad 660 and torch.is_grad_enabled() 661 and config.capture_autograd_function 662 ): 663 from torch._functorch.autograd_function import ( 664 autograd_function_forward_rewritten, 665 ) 666 from torch.autograd.function import _is_setup_context_defined 667 668 forward_fn = self.fn_cls.forward 669 670 is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context) 671 if is_setup_ctx_defined: 672 # If setup_context is defined, we generate a new forward function which includes 673 # the original forward and setup_context function, and trace the new forward function. 674 forward_fn = autograd_function_forward_rewritten( 675 self.fn_cls.forward, self.fn_cls.setup_context 676 ) 677 678 vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined] 679 if vjp_fn is not torch.autograd.Function.vjp: 680 unimplemented("NYI - User defind vjp") 681 682 jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined] 683 if jvp_fn is not torch.autograd.Function.jvp: 684 unimplemented("NYI - User defind jvp") 685 686 from .higher_order_ops import AutogradFunctionApplyVariable 687 688 source = self.source 689 if source is None: 690 source = AttrSource( 691 tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__ 692 ) 693 694 val = AutogradFunctionApplyVariable( 695 forward_fn, 696 self.fn_cls.backward, 697 source, 698 source=AttrSource(source, member="apply"), 699 ).call_function(tx, args, kwargs) 700 # Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping 701 # the forward function, as we don't want to generate guards for new_forward.__closure__ 702 # if forward is rewritten by autograd_function_forward_rewritten. 703 # But we still need to generate correct guards for the original forward and setup_context 704 # functions, so we have to add guards manually. 705 if self.source: 706 fwd_src = AttrSource(self.source, "forward") 707 install_guard(fwd_src.make_guard(GuardBuilder.FUNCTION_MATCH)) 708 if is_setup_ctx_defined: 709 setup_ctx_src = AttrSource(self.source, "setup_context") 710 install_guard(setup_ctx_src.make_guard(GuardBuilder.FUNCTION_MATCH)) 711 712 return val 713 714 if self.source: 715 source = AttrSource(self.source, "forward") 716 else: 717 source = None 718 719 fn = self.fn_cls.forward 720 ctx = AutogradFunctionContextVariable.create(tx, args, kwargs) 721 args = [ctx, *args] 722 if isinstance(fn, types.FunctionType): 723 return variables.UserFunctionVariable(fn, source=source).call_function( 724 tx, args, kwargs 725 ) 726 elif isinstance(fn, types.MethodType): 727 return variables.UserMethodVariable( 728 fn.__func__, 729 variables.UserDefinedClassVariable(self.fn_cls), 730 source=source, 731 ).call_function(tx, args, kwargs) 732 else: 733 unimplemented( 734 f"non-function or method in subclass of torch.autograd.Function: {fn}" 735 ) 736 737 def call_backward(self, tx: "InstructionTranslator", args, kwargs): 738 fn = self.fn_cls.backward 739 self.source = AttrSource(self.source, "backward") 740 assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction 741 assert isinstance(fn, types.FunctionType) 742 743 return variables.UserFunctionVariable(fn, source=self.source).call_function( 744 tx, args, kwargs 745 ) 746 747 def call_function(self, tx: "InstructionTranslator", args, kwargs): 748 return AutogradFunctionVariable(self.fn_cls) 749 750 def call_method( 751 self, 752 tx, 753 name, 754 args: "List[VariableTracker]", 755 kwargs: "Dict[str, VariableTracker]", 756 ): 757 from ..trace_rules import is_callable_allowed 758 from .builder import wrap_fx_proxy 759 760 if name == "apply": 761 if is_callable_allowed(self.fn_cls): 762 trampoline_autograd_apply = produce_trampoline_autograd_apply( 763 self.fn_cls 764 ) 765 return wrap_fx_proxy( 766 tx=tx, 767 proxy=tx.output.create_proxy( 768 "call_function", 769 trampoline_autograd_apply, 770 *proxy_args_kwargs(args, kwargs), 771 ), 772 ) 773 else: 774 return self.call_apply(tx, args, kwargs) 775 776 elif name == "backward": 777 return self.call_backward(tx, args, kwargs) 778 else: 779 from .. import trace_rules 780 781 source = AttrSource(self.source, name) if self.source is not None else None 782 try: 783 obj = inspect.getattr_static(self.fn_cls, name) 784 except AttributeError: 785 obj = None 786 787 if isinstance(obj, staticmethod): 788 func = obj.__get__(self.fn_cls) 789 if source is not None: 790 return ( 791 trace_rules.lookup(func) 792 .create_with_source(func, source=source) 793 .call_function(tx, args, kwargs) 794 ) 795 else: 796 return trace_rules.lookup(func)(func).call_function( 797 tx, args, kwargs 798 ) 799 elif isinstance(obj, classmethod): 800 return variables.UserMethodVariable( 801 obj.__func__, self, source=source 802 ).call_function(tx, args, kwargs) 803 else: 804 unimplemented(f"Unsupported method: {name}") 805 806 807@dataclasses.dataclass 808class SavedTensorBox: 809 tensors: List[VariableTracker] = dataclasses.field(default_factory=list) 810 811 812class AutogradFunctionContextVariable(UserDefinedObjectVariable): 813 """ 814 Tracks an autograd.Function() context using mutation tracking in side_effects.py 815 """ 816 817 _nonvar_fields = { 818 "proxy", 819 "inference", 820 "saved_tensors", 821 *UserDefinedObjectVariable._nonvar_fields, 822 } 823 824 def __init__( 825 self, 826 value, 827 value_type=None, 828 inference=False, 829 proxy=None, 830 saved_tensors=None, 831 needs_input_grad=None, 832 non_differentiable=None, 833 **kwargs, 834 ) -> None: 835 super().__init__(value=value, value_type=value_type, **kwargs) 836 self.inference = inference 837 self.proxy = proxy 838 self.saved_tensors = saved_tensors 839 self.needs_input_grad = needs_input_grad 840 self.non_differentiable = non_differentiable 841 842 @staticmethod 843 def create(tx: "InstructionTranslator", args=None, kwargs=None): 844 needs_input_grad = None 845 if args and not kwargs: 846 needs_input_grad = tuple( 847 isinstance(x, variables.TensorVariable) and x.requires_grad 848 for x in args 849 ) 850 proxy = tx.output.create_proxy( 851 "call_function", torch.autograd.function.FunctionCtx, (), {} 852 ) 853 out = tx.output.side_effects.track_object_new( 854 None, 855 torch.autograd.function.FunctionCtx, 856 functools.partial( 857 AutogradFunctionContextVariable, 858 inference=True, 859 proxy=proxy, 860 saved_tensors=SavedTensorBox(), 861 needs_input_grad=needs_input_grad, 862 ), 863 {}, 864 ) 865 set_example_value(proxy.node, out.value) 866 867 return out 868 869 def as_proxy(self): 870 if self.proxy is None: 871 unimplemented("proxy not set") 872 return self.proxy 873 874 def call_method( 875 self, 876 tx, 877 name, 878 args: "List[VariableTracker]", 879 kwargs: "Dict[str, VariableTracker]", 880 ) -> "VariableTracker": 881 if name == "__setattr__": 882 return super().call_method(tx, name, args, kwargs) 883 elif name == "mark_non_differentiable": 884 assert len(kwargs) == 0 885 self.non_differentiable = proxy_args_kwargs(args, {})[0] 886 return variables.ConstantVariable.create(None) 887 888 if name != "save_for_backward": 889 unimplemented(f"autograd.Function context method: {name}") 890 if self.saved_tensors is None: 891 unimplemented( 892 "save_for_backward only supported on a newly constructed FunctionCtx" 893 ) 894 895 if not self.inference: 896 assert self.source and not kwargs 897 tx.output.side_effects.track_save_for_backward(self, args) 898 899 # In eager mode, multiple calls to .save_for_backward() will overwrite previous calls. 900 if len(self.saved_tensors.tensors) > 0: 901 self.saved_tensors.tensors = [] 902 for arg in args: 903 self.saved_tensors.tensors.append(arg) 904 return variables.ConstantVariable.create(None) 905 906 def var_getattr(self, tx: "InstructionTranslator", name): 907 if name in ["save_for_backward", "mark_non_differentiable"]: 908 return LambdaVariable( 909 lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) 910 ) 911 if name == "saved_tensors" and self.saved_tensors is not None: 912 return variables.TupleVariable(list(self.saved_tensors.tensors)) 913 if name == "needs_input_grad": 914 if self.needs_input_grad is not None: 915 return variables.ConstantVariable.create(self.needs_input_grad) 916 if self.source: 917 from .builder import VariableBuilder 918 919 return VariableBuilder(tx, AttrSource(self.source, "needs_input_grad"))( 920 self.value.needs_input_grad 921 ) 922 return super().var_getattr(tx, name) 923 924 925class AutogradEngineVariable(UserDefinedObjectVariable): 926 """ 927 Represents a torch._C._ImperativeEngine instance. 928 """ 929 930 def __init__( 931 self, 932 value, 933 value_type=None, 934 **kwargs, 935 ) -> None: 936 super().__init__(value=value, value_type=value_type, **kwargs) 937 938 def call_method( 939 self, 940 tx, 941 name, 942 args: "List[VariableTracker]", 943 kwargs: "Dict[str, VariableTracker]", 944 ) -> "VariableTracker": 945 if name == "queue_callback": 946 if torch._dynamo.compiled_autograd.compiled_autograd_enabled: 947 assert ( 948 tx.one_graph 949 ), "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" 950 return variables.UserFunctionVariable( 951 torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback, 952 source=self.source, 953 ).call_function( 954 tx, 955 (tx.output.side_effects.get_ca_final_callbacks_var(), *args), 956 kwargs, 957 ) 958 else: 959 unimplemented( 960 "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" 961 ) 962 else: 963 unimplemented(f"torch._C._ImperativeEngine method: {name}") 964 965 966class LambdaVariable(VariableTracker): 967 def __init__(self, fn, **kwargs) -> None: 968 super().__init__(**kwargs) 969 self.fn = fn 970 971 def call_function( 972 self, 973 tx: "InstructionTranslator", 974 args: "List[VariableTracker]", 975 kwargs: "Dict[str, VariableTracker]", 976 ) -> "VariableTracker": 977 return self.fn(*args, **kwargs) 978 979 980class GetAttrVariable(VariableTracker): 981 _nonvar_fields = { 982 "name", 983 *VariableTracker._nonvar_fields, 984 } 985 986 def __init__(self, obj, name, **kwargs) -> None: 987 super().__init__(**kwargs) 988 assert isinstance(obj, VariableTracker) 989 assert isinstance(name, str) 990 self.obj = obj 991 self.name = name 992 993 def __str__(self) -> str: 994 return f"{self.__class__.__name__}({self.obj}, {self.name})" 995 996 @staticmethod 997 def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr): 998 return getattr(base_proxy, attr) 999 1000 def as_proxy(self): 1001 return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name) 1002 1003 def const_getattr(self, tx: "InstructionTranslator", name): 1004 if not isinstance(self.obj, variables.NNModuleVariable): 1005 raise NotImplementedError 1006 step1 = tx.output.get_submodule(self.obj.module_key) 1007 if self.name not in step1.__dict__: 1008 raise NotImplementedError 1009 step2 = inspect.getattr_static(step1, self.name) 1010 if name not in step2.__dict__: 1011 raise NotImplementedError 1012 return inspect.getattr_static(step2, name) 1013 1014 def reconstruct(self, codegen): 1015 codegen(self.obj) 1016 codegen.extend_output(codegen.create_load_attrs(self.name)) 1017 1018 def call_function( 1019 self, 1020 tx: "InstructionTranslator", 1021 args: "List[VariableTracker]", 1022 kwargs: "Dict[str, VariableTracker]", 1023 ) -> "VariableTracker": 1024 return self.obj.call_method(tx, self.name, args, kwargs) 1025 1026 def call_method( 1027 self, 1028 tx, 1029 name, 1030 args: List[VariableTracker], 1031 kwargs: Dict[str, VariableTracker], 1032 ) -> VariableTracker: 1033 if ( 1034 name in ("__getitem__", "get") 1035 and self.name == "__dict__" 1036 and not kwargs 1037 and args[0].is_python_constant() 1038 and isinstance( 1039 self.obj, 1040 ( 1041 variables.UserDefinedObjectVariable, 1042 variables.NNModuleVariable, 1043 variables.UserDefinedClassVariable, 1044 ), 1045 ) 1046 ): 1047 obj = self.obj 1048 key = args[0].as_python_constant() 1049 if obj.has_key_in_generic_dict(tx, key): 1050 # redirect to var_getattr on the original obj 1051 return obj.var_getattr(tx, key) 1052 1053 # Return the default value for get 1054 if name == "get": 1055 if len(args) == 2: 1056 return args[1] 1057 else: 1058 return variables.ConstantVariable(None) 1059 1060 elif ( 1061 name == "__contains__" 1062 and self.name == "__dict__" 1063 and len(args) == 1 1064 and args[0].is_python_constant() 1065 and not kwargs 1066 and isinstance( 1067 self.obj, 1068 ( 1069 variables.UserDefinedObjectVariable, 1070 variables.NNModuleVariable, 1071 variables.UserDefinedClassVariable, 1072 ), 1073 ) 1074 ): 1075 obj = self.obj 1076 key = args[0].as_python_constant() 1077 if obj.has_key_in_generic_dict(tx, key): 1078 return variables.ConstantVariable(True) 1079 else: 1080 return variables.ConstantVariable(False) 1081 1082 return super().call_method(tx, name, args, kwargs) 1083 1084 1085class MethodWrapperVariable(VariableTracker): 1086 def __init__(self, method_wrapper, **kwargs) -> None: 1087 super().__init__(**kwargs) 1088 self.method_wrapper = method_wrapper 1089 1090 def call_function( 1091 self, 1092 tx: "InstructionTranslator", 1093 args: "List[VariableTracker]", 1094 kwargs: "Dict[str, VariableTracker]", 1095 ) -> "VariableTracker": 1096 if is_tensor_base_attr_getter(self.method_wrapper) and isinstance( 1097 args[0], variables.TensorVariable 1098 ): 1099 assert len(args) == 1 and len(kwargs) == 0 1100 1101 return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__) 1102 1103 super().call_function(tx, args, kwargs) 1104 1105 def is_python_constant(self): 1106 return True 1107 1108 def as_python_constant(self): 1109 return self.method_wrapper 1110 1111 1112class GetSetDescriptorVariable(VariableTracker): 1113 def __init__(self, desc, **kwargs) -> None: 1114 super().__init__(**kwargs) 1115 self.desc = desc 1116 1117 def var_getattr(self, tx: "InstructionTranslator", name): 1118 if name == "__get__" and self.source: 1119 from .builder import VariableBuilder 1120 1121 return VariableBuilder(tx, AttrSource(self.source, "__get__"))( 1122 self.desc.__get__ 1123 ) 1124 else: 1125 return super().var_getattr(tx, name) 1126 1127 def is_python_constant(self): 1128 return True 1129 1130 def as_python_constant(self): 1131 return self.desc 1132 1133 1134class PythonModuleVariable(VariableTracker): 1135 _nonvar_fields = { 1136 "value", 1137 "is_torch", 1138 *VariableTracker._nonvar_fields, 1139 } 1140 1141 def __init__(self, value: types.ModuleType, **kwargs) -> None: 1142 super().__init__(**kwargs) 1143 self.value = value 1144 self.is_torch = self.value is torch or self.value.__name__.startswith("torch.") 1145 1146 def python_type(self): 1147 return types.ModuleType 1148 1149 def as_python_constant(self): 1150 return self.value 1151 1152 def __repr__(self) -> str: 1153 return f"PythonModuleVariable({self.value})" 1154 1155 def call_hasattr(self, tx: "InstructionTranslator", name): 1156 result = hasattr(self.value, name) 1157 return variables.ConstantVariable.create(result) 1158 1159 def var_getattr(self, tx: "InstructionTranslator", name): 1160 if tx.output.side_effects.has_pending_mutation_of_attr(self, name): 1161 return tx.output.side_effects.load_attr(self, name) 1162 1163 from .builder import SourcelessBuilder, VariableBuilder 1164 1165 if self.is_torch or name not in self.value.__dict__: 1166 attr_value = getattr(self.value, name) 1167 else: 1168 attr_value = self.value.__dict__[name] 1169 1170 if self.source: 1171 new_source = AttrSource(self.source, name) 1172 return VariableBuilder(tx, new_source)(attr_value) 1173 else: 1174 return SourcelessBuilder.create(tx, attr_value) 1175 1176 1177class TypingVariable(VariableTracker): 1178 def __init__(self, value, **kwargs) -> None: 1179 super().__init__(**kwargs) 1180 self.value = value 1181 1182 def call_method( 1183 self, 1184 tx, 1185 name, 1186 args: "List[VariableTracker]", 1187 kwargs: "Dict[str, VariableTracker]", 1188 ) -> "VariableTracker": 1189 if name == "__getitem__" and len(args) == 1: 1190 return variables.ConstantVariable.create( 1191 self.value[args[0].as_python_constant()], 1192 ) 1193 unimplemented("typing") 1194 1195 def as_python_constant(self): 1196 return self.value 1197 1198 1199@functools.lru_cache(maxsize=1) 1200def get_np_to_tnp_map(): 1201 from ..utils import NP_TO_TNP_MODULE 1202 1203 np_fn_to_tnp_fn = {} 1204 1205 for np_mod, tnp_mod in NP_TO_TNP_MODULE.items(): 1206 for fn_name, tnp_fn in tnp_mod.__dict__.items(): 1207 if callable(tnp_fn): 1208 # some internal details do leak from tnp 1209 # which are not part of numpy API. 1210 if np_fn := getattr(np_mod, fn_name, None): 1211 np_fn_to_tnp_fn[np_fn] = tnp_fn 1212 1213 return np_fn_to_tnp_fn 1214 1215 1216class NumpyVariable(VariableTracker): 1217 """ 1218 Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes. 1219 """ 1220 1221 constant_fold_functions = (tnp.issubdtype,) 1222 1223 def __init__(self, value, **kwargs) -> None: 1224 super().__init__(**kwargs) 1225 self.value = value 1226 1227 @classmethod 1228 def can_constant_fold_through(cls, fn): 1229 mod = fn.__module__.split(".") 1230 assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] 1231 return fn in cls.constant_fold_functions 1232 1233 @classmethod 1234 def get_constant_collection_for_func(cls, fn): 1235 mod = fn.__module__.split(".") 1236 assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] 1237 return np_constant_collections_map.get(fn, None) 1238 1239 def call_function( 1240 self, 1241 tx: "InstructionTranslator", 1242 args: "List[VariableTracker]", 1243 kwargs: "Dict[str, VariableTracker]", 1244 ) -> "VariableTracker": 1245 if not config.trace_numpy: 1246 unimplemented(f"numpy.{self.value}()") 1247 1248 from ..utils import numpy_to_tensor_wrapper 1249 from .tensor import NumpyNdarrayVariable 1250 1251 func = get_np_to_tnp_map().get(self.value) 1252 if func is None: 1253 unimplemented( 1254 f"Can't find numpy function {self.value} in torch._numpy. " 1255 " Please file an issue to request support for this function." 1256 ) 1257 1258 # We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo) 1259 if ( 1260 collection_variable_typ := self.get_constant_collection_for_func(func) 1261 ) is not None: 1262 try: 1263 return collection_variable_typ( 1264 self.value( 1265 *[x.as_python_constant() for x in args], 1266 **{k: v.as_python_constant() for k, v in kwargs.items()}, 1267 ) 1268 ) 1269 except NotImplementedError: 1270 unimplemented( 1271 f"{self.value.__name__} with non-const args: {args} {kwargs}" 1272 ) 1273 else: 1274 if ( 1275 func.__module__ == "torch._numpy.random" 1276 and config.use_numpy_random_stream 1277 ): 1278 msg = f"delegate '{func.__qualname__}' to NumPy itself via " 1279 msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}" 1280 unimplemented(msg) 1281 1282 args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs) 1283 1284 if self.can_constant_fold_through(func) and ( 1285 check_unspec_or_constant_args(args, kwargs) 1286 ): 1287 # constant fold 1288 return variables.ConstantVariable.create( 1289 self.as_python_constant()( 1290 *[x.as_python_constant() for x in args], 1291 **{k: v.as_python_constant() for k, v in kwargs.items()}, 1292 ), 1293 ) 1294 1295 # TODO Add all the functions that go from constants to constants to can_constant_fold_through 1296 proxy = tx.output.create_proxy( 1297 "call_function", 1298 numpy_to_tensor_wrapper(func), 1299 *proxy_args_kwargs(args, kwargs), 1300 ) 1301 return NumpyNdarrayVariable.create(tx, proxy) 1302 1303 def call_method( 1304 self, 1305 tx, 1306 name, 1307 args: "List[VariableTracker]", 1308 kwargs: "Dict[str, VariableTracker]", 1309 ) -> "VariableTracker": 1310 unimplemented("numpy") 1311 1312 def as_python_constant(self): 1313 return self.value 1314 1315 def as_proxy(self): 1316 if config.trace_numpy and isinstance(self.value, type): 1317 # This handles numpy dtype attributes such as np.float32 1318 # We return a string as we don't want to serialize non-PyTorch objects in the output FX graph 1319 # In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does 1320 return self.value.__name__ 1321 1322 return super().as_proxy() 1323 1324 1325# Used to keep track of NULLs pushed on the stack for Python 3.11 function calls 1326class NullVariable(VariableTracker): 1327 def __init__(self, **kwargs) -> None: 1328 super().__init__(**kwargs) 1329 1330 def __str__(self) -> str: 1331 return "NullVariable" 1332 1333 def reconstruct(self, codegen): 1334 if sys.version_info < (3, 11): 1335 unimplemented("cannot reconstruct NullVariable in < Python 3.11") 1336 codegen.append_output(create_instruction("PUSH_NULL")) 1337 1338 1339class DeletedVariable(VariableTracker): 1340 """Marker used to implement delattr()""" 1341 1342 1343class StringFormatVariable(VariableTracker): 1344 """ 1345 Represents a call to str.format(), we delay calling format until after the graph. 1346 """ 1347 1348 _nonvar_fields = {"format_string", *VariableTracker._nonvar_fields} 1349 1350 @classmethod 1351 def create(cls, format_string, sym_args, sym_kwargs): 1352 if all( 1353 x.is_python_constant() 1354 for x in itertools.chain(sym_args, sym_kwargs.values()) 1355 ): 1356 return variables.ConstantVariable.create( 1357 format_string.format( 1358 *[v.as_python_constant() for v in sym_args], 1359 **{k: v.as_python_constant() for k, v in sym_kwargs.items()}, 1360 ) 1361 ) 1362 return cls(format_string, list(sym_args), dict(sym_kwargs)) 1363 1364 def __init__(self, format_string, sym_args, sym_kwargs, **kwargs) -> None: 1365 super().__init__(**kwargs) 1366 assert isinstance(format_string, str) 1367 self.format_string = format_string 1368 self.sym_args = sym_args 1369 self.sym_kwargs = sym_kwargs 1370 1371 def __repr__(self) -> str: 1372 return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})" 1373 1374 def reconstruct(self, codegen): 1375 codegen.add_push_null( 1376 lambda: codegen.extend_output( 1377 [ 1378 codegen.create_load_const(self.format_string), 1379 codegen.create_load_attr("format"), 1380 ] 1381 ), 1382 call_function_ex=True, 1383 ) 1384 codegen(variables.TupleVariable(self.sym_args)) 1385 kwargs = { 1386 variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items() 1387 } 1388 codegen(variables.ConstDictVariable(kwargs)) 1389 codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1)) 1390 1391 1392class DebuggingVariable(VariableTracker): 1393 """ 1394 Represents a call to a debugging function like print(), or something 1395 registered to config.reorderable_logging_functions. 1396 """ 1397 1398 def __init__(self, value, **kwargs) -> None: 1399 super().__init__(**kwargs) 1400 self.value = value 1401 1402 @staticmethod 1403 def is_reorderable_logging_function(obj): 1404 return ( 1405 callable(obj) 1406 and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)) 1407 and obj in torch._dynamo.config.reorderable_logging_functions 1408 ) 1409 1410 def call_function(self, tx: "InstructionTranslator", args, kwargs): 1411 if tx.export: 1412 # For export cases, we can just make debugging functions no-ops 1413 return 1414 1415 if not self.can_reorder_logs(self.value, args, kwargs): 1416 unimplemented( 1417 f"Reordering debugging function {self.value} " 1418 f"with inputs {args} {kwargs} is not yet implemented." 1419 ) 1420 1421 tx.debug_locals.append((self, list(args))) 1422 1423 def reconstruct(self, codegen): 1424 return self.source.reconstruct(codegen) 1425 1426 @staticmethod 1427 def can_reorder_logs(fn, args, kwargs) -> True: 1428 """ 1429 Run some additional checks for what sort of function calls can we 1430 actually reorder. 1431 """ 1432 1433 allowed_input_types = ( 1434 variables.TensorVariable, 1435 variables.ConstantVariable, 1436 StringFormatVariable, 1437 ) 1438 1439 flat_args = pytree.tree_leaves([args, kwargs]) 1440 for arg in flat_args: 1441 if not isinstance(arg, allowed_input_types): 1442 return False 1443 1444 return True 1445 1446 1447class LoggingLoggerVariable(VariableTracker): 1448 """ 1449 Represents a call to any of logging.Logger methods 1450 """ 1451 1452 def __init__(self, value, **kwargs) -> None: 1453 super().__init__(**kwargs) 1454 1455 def call_method( 1456 self, 1457 tx, 1458 name, 1459 args: "List[VariableTracker]", 1460 kwargs: "Dict[str, VariableTracker]", 1461 ) -> "VariableTracker": 1462 if tx.export: 1463 # For export cases, we can just make debugging functions no-ops 1464 return 1465 unimplemented("Logger not supported for non-export cases") 1466 1467 1468class ConstantLikeVariable(VariableTracker): 1469 """self.value is a compile-time constant, but not a literal""" 1470 1471 _error_prefix = "ConstantLikeVariable" 1472 try: 1473 from numpy import ( 1474 dtype as np_dtype, 1475 floating as np_floating, 1476 generic as np_generic, 1477 ) 1478 except ImportError: 1479 np_floating = type("invalid_type", (), {}) 1480 np_dtype = type("invalid_type", (), {}) 1481 1482 def __init__(self, value, **kwargs) -> None: 1483 super().__init__(**kwargs) 1484 self.value = value 1485 1486 def as_python_constant(self): 1487 return self.value 1488 1489 def call_method( 1490 self, 1491 tx, 1492 name, 1493 args: List[VariableTracker], 1494 kwargs: Dict[str, VariableTracker], 1495 ) -> VariableTracker: 1496 try: 1497 # we only support constant propagation for methods 1498 cargs = [x.as_python_constant() for x in args] 1499 ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()} 1500 except NotImplementedError: 1501 unimplemented(f"{self._error_prefix}.{name}(*{args}, **{kwargs})") 1502 1503 result = getattr(self.value, name)(*cargs, **ckwargs) 1504 1505 if variables.ConstantVariable.is_literal(result): 1506 return variables.ConstantVariable.create(result) 1507 if isinstance(result, re.Match): 1508 return ConstantRegexMatchVariable(result) 1509 1510 unimplemented(f"{self._error_prefix}.{name}() -> {result}") 1511 1512 def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: 1513 result = getattr(self.value, name) 1514 if isinstance(result, self.np_floating): 1515 result = float(result) 1516 if isinstance(result, self.np_dtype): 1517 return NumpyDTypeVariable(result) 1518 if isinstance(result, type) and issubclass(result, self.np_generic): 1519 # things like x.dtype.type 1520 return NumpyVariable(result) 1521 if variables.ConstantVariable.is_literal(result): 1522 return variables.ConstantVariable.create(result) 1523 return GetAttrVariable(self, name) 1524 1525 1526class RegexPatternVariable(ConstantLikeVariable): 1527 _error_prefix = "re.Pattern" 1528 1529 1530class ConstantRegexMatchVariable(ConstantLikeVariable): 1531 _error_prefix = "re.Match" 1532 1533 1534class TorchVersionVariable(ConstantLikeVariable): 1535 _error_prefix = "torch.__version__" 1536 1537 def __init__(self, **kwargs) -> None: 1538 kwargs.setdefault("value", torch.__version__) 1539 assert kwargs["value"] is torch.__version__ 1540 super().__init__(**kwargs) 1541 1542 1543class NumpyTypeInfoVariable(ConstantLikeVariable): 1544 _error_prefix = "np.iinfo/np.finfo" 1545 1546 1547class NumpyDTypeVariable(ConstantLikeVariable): 1548 _error_prefix = "np.dtype[...]" 1549 1550 def as_proxy(self): 1551 """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable: 1552 1553 np.dtype() objects are serialized as strings, torch._numpy wrappers will normalize to the torch dtype. 1554 This also handles unsupported things nicely (i.e. structured arrays and object arrays). 1555 """ 1556 return self.value.type.__name__ 1557 1558 1559np_constant_collections_map = { 1560 tnp.finfo: NumpyTypeInfoVariable, 1561 tnp.iinfo: NumpyTypeInfoVariable, 1562 tnp.dtype: NumpyDTypeVariable, 1563} 1564 1565 1566class RandomClassVariable(VariableTracker): 1567 """random.Random""" 1568 1569 def __init__(self, **kwargs) -> None: 1570 super().__init__(**kwargs) 1571 1572 def call_function(self, tx: "InstructionTranslator", args, kwargs): 1573 if len(args) > 1: 1574 unimplemented("random.Random() with > 1 arg") 1575 elif kwargs: 1576 unimplemented("random.Random() with kwargs") 1577 seed = variables.ConstantVariable.create(None) if len(args) == 0 else args[0] 1578 return RandomVariable(seed=seed, mutable_local=variables.base.MutableLocal()) 1579 1580 1581class RandomVariable(VariableTracker): 1582 """random.Random() 1583 1584 Implemented by wrapping a VariableTracker around a random.Random object. 1585 The supported methods for the random.Random object cannot be overriden. 1586 Assumes that random objects behave the same given a set seed or state. 1587 """ 1588 1589 _nonvar_fields = { 1590 "random", 1591 *VariableTracker._nonvar_fields, 1592 } 1593 1594 _supported_fn_names = { 1595 "random", 1596 "randint", 1597 "randrange", 1598 "uniform", 1599 } 1600 1601 def __init__( 1602 self, 1603 rand: Optional[random.Random] = None, 1604 seed: Optional[VariableTracker] = None, 1605 **kwargs, 1606 ) -> None: 1607 super().__init__(**kwargs) 1608 if rand is not None: 1609 assert self.is_supported_random_obj(rand) 1610 self.random = random.Random() 1611 self.random.setstate(rand.getstate()) 1612 else: 1613 seed = seed.as_python_constant() if seed is not None else None 1614 self.random = random.Random(seed) 1615 1616 def python_type(self): 1617 return random.Random 1618 1619 def as_python_constant(self): 1620 return self.random 1621 1622 @staticmethod 1623 def is_supported_random_obj(val): 1624 if type(val) is not random.Random: 1625 return False 1626 for name in itertools.chain( 1627 RandomVariable._supported_fn_names, ("seed", "getstate", "setstate") 1628 ): 1629 if not hasattr(val, name): 1630 return False 1631 meth = getattr(val, name) 1632 if inspect.isbuiltin(meth): 1633 # e.g. random.Random.random 1634 if meth != getattr(random.Random, name).__get__(val): 1635 return False 1636 else: 1637 if getattr(meth, "__func__", None) is not getattr(random.Random, name): 1638 return False 1639 return True 1640 1641 @staticmethod 1642 def check_state(state): 1643 assert type(state) is tuple 1644 assert type(state[0]) is int 1645 assert type(state[1]) is tuple 1646 assert all(type(x) is int for x in state[1]) 1647 assert state[2] is None or type(state[2]) is float 1648 1649 @staticmethod 1650 def wrap_state(state): 1651 RandomVariable.check_state(state) 1652 return variables.TupleVariable( 1653 [ 1654 variables.ConstantVariable.create(state[0]), 1655 variables.TupleVariable( 1656 [variables.ConstantVariable.create(x) for x in state[1]] 1657 ), 1658 variables.ConstantVariable.create(state[2]), 1659 ] 1660 ) 1661 1662 @staticmethod 1663 def unwrap_state(state): 1664 state_obj = state.as_python_constant() 1665 RandomVariable.check_state(state_obj) 1666 return state_obj 1667 1668 def call_method( 1669 self, 1670 tx, 1671 name, 1672 args: List[VariableTracker], 1673 kwargs: Dict[str, VariableTracker], 1674 ) -> VariableTracker: 1675 if name == "seed": 1676 tx.output.side_effects.mutation(self) 1677 self.random.seed( 1678 *[x.as_python_constant() for x in args], 1679 **{key: val.as_python_constant() for key, val in kwargs.items()}, 1680 ) 1681 return variables.ConstantVariable.create(None) 1682 elif name == "getstate": 1683 return self.wrap_state(self.random.getstate()) 1684 elif name == "setstate": 1685 tx.output.side_effects.mutation(self) 1686 self.random.setstate(self.unwrap_state(args[0])) 1687 return variables.ConstantVariable.create(None) 1688 elif name in self._supported_fn_names: 1689 tx.output.side_effects.mutation(self) 1690 state = self.random.getstate() 1691 1692 def call_random_meth(*args, **kwargs): 1693 r = random.Random() 1694 r.setstate(state) 1695 return getattr(r, name)(*args, **kwargs) 1696 1697 # self.random state not actually updated by call_random_meth, so update here 1698 # by calling the method 1699 getattr(self.random, name)( 1700 *[x.as_python_constant() for x in args], 1701 **{k: v.as_python_constant() for k, v in kwargs.items()}, 1702 ) 1703 1704 return call_random_fn(tx, call_random_meth, args, kwargs) 1705 return super().call_method(tx, name, args, kwargs) 1706 1707 def reconstruct(self, codegen): 1708 codegen.add_push_null( 1709 lambda: codegen.extend_output( 1710 [ 1711 codegen.create_load_python_module(random), 1712 codegen.create_load_attr("Random"), 1713 ] 1714 ) 1715 ) 1716 codegen.call_function(0, False) 1717 # NOTE using add_push_null may result in NULL being duplicated 1718 # so defer the push_null to call_function 1719 codegen.dup_top() 1720 codegen.load_attr("setstate") 1721 codegen(self.wrap_state(self.random.getstate())) 1722 codegen.call_function(1, True) 1723 codegen.pop_top() 1724