1# mypy: ignore-errors 2 3import functools 4import inspect 5import logging 6import operator 7import textwrap 8import traceback 9import types 10import unittest 11from typing import Dict, List, TYPE_CHECKING 12 13import sympy 14 15import torch._numpy as tnp 16import torch.fx 17import torch.random 18from torch._dynamo import compiled_autograd 19from torch._subclasses.meta_utils import is_sparse_any 20from torch.fx.experimental.symbolic_shapes import ( 21 guard_scalar, 22 GuardOnDataDependentSymNode, 23 has_free_symbols, 24 is_symbolic, 25 SymTypes, 26) 27from torch.utils._python_dispatch import is_traceable_wrapper_subclass 28 29from .. import config, variables 30from .._trace_wrapped_higher_order_op import trace_wrapped 31from ..exc import unimplemented, UserError, UserErrorType 32from ..external_utils import call_hook_from_backward_state 33from ..guards import GuardBuilder, install_guard 34from ..source import AttrSource 35from ..utils import ( 36 fqn, 37 get_custom_getattr, 38 get_fake_value, 39 get_real_value, 40 guard_if_dyn, 41 object_has_getattribute, 42 product, 43 proxy_args_kwargs, 44 set_example_value, 45 tensortype_to_dtype, 46) 47from .base import VariableTracker 48from .constant import ConstantVariable 49from .lists import SizeVariable 50 51 52try: 53 import numpy as np 54except ModuleNotFoundError: 55 np = None 56 57 58if TYPE_CHECKING: 59 from torch._dynamo.symbolic_convert import InstructionTranslator 60 61 62log = logging.getLogger(__name__) 63 64# Ops that allow tensor <op> tensor 65supported_tensor_comparison_ops = { 66 ">": operator.gt, 67 "<": operator.lt, 68 ">=": operator.ge, 69 "<=": operator.le, 70 "==": operator.eq, 71 "!=": operator.ne, 72} 73# Ops that allow tensor <op> None 74supported_const_comparison_ops = { 75 "is": operator.is_, 76 "is not": operator.is_not, 77 "==": operator.eq, 78 "!=": operator.ne, 79} 80supported_comparison_ops = { 81 **supported_tensor_comparison_ops, 82 **supported_const_comparison_ops, 83} 84supported_tensor_comparison_op_values = dict.fromkeys( 85 supported_tensor_comparison_ops.values() 86) 87supported_const_comparison_op_values = dict.fromkeys( 88 supported_const_comparison_ops.values() 89) 90 91 92class TensorVariable(VariableTracker): 93 """A torch.Tensor input or an intermediate value in the FX graph""" 94 95 _nonvar_fields = { 96 "proxy", 97 "dtype", 98 "device", 99 "layout", 100 "ndim", 101 "size", 102 "stride", 103 "requires_grad", 104 "is_quantized", 105 "is_contiguous", 106 "is_sparse", 107 "class_type", 108 "specialized_value", 109 "_is_name_set", 110 *VariableTracker._nonvar_fields, 111 } 112 113 def get_real_value(self): 114 """ 115 Get the actual value represented by this variable if computation is run 116 using the user-provided inputs. 117 NOTE: this runs actual tensor computation and may be 118 slow and memory-intensive. 119 """ 120 return get_real_value(self.proxy.node, self.proxy.tracer) 121 122 def __init__( 123 self, 124 proxy: torch.fx.Proxy, 125 *, 126 dtype, 127 device, 128 layout, 129 ndim, 130 requires_grad, 131 is_quantized, 132 is_sparse, 133 class_type, 134 has_grad_fn, 135 size=None, 136 stride=None, 137 is_contiguous=None, 138 _is_name_set=None, 139 **kwargs, 140 ) -> None: 141 super().__init__(**kwargs) 142 self.proxy = proxy 143 self.dtype = dtype 144 self.device = device 145 self.layout = layout 146 self.ndim = ndim 147 self.size = size 148 self.stride = stride 149 self.requires_grad = requires_grad 150 self.is_quantized = is_quantized 151 self.is_contiguous = is_contiguous 152 self.is_sparse = is_sparse 153 self.class_type = class_type 154 self.has_grad_fn = has_grad_fn 155 if _is_name_set is None: 156 # no need to rename inputs 157 _is_name_set = self.proxy.node.op == "placeholder" 158 self._is_name_set: bool = _is_name_set 159 160 def debug_repr(self): 161 # TODO: strip off fake tensor from repr here 162 return repr(self.proxy.node.meta["example_value"]) 163 164 def as_proxy(self): 165 return self.proxy 166 167 def python_type(self): 168 return self.class_type 169 170 @staticmethod 171 def specialize(value: torch.Tensor): 172 props = { 173 "dtype": value.dtype, 174 "device": value.device, 175 "layout": value.layout, 176 "ndim": int(value.ndim), 177 "requires_grad": value.requires_grad, 178 "is_quantized": value.is_quantized, 179 "is_sparse": value.is_sparse, 180 "class_type": type(value), 181 } 182 try: 183 props["has_grad_fn"] = value.grad_fn is not None 184 except Exception: 185 # Workaround for issues with create_parameter_op in Dynamo. Reading 186 # grad_fn should never cause an issue. 187 props["has_grad_fn"] = False 188 189 if is_sparse_any(value) and not has_free_symbols(value): 190 props["size"] = tuple( 191 [int(s) if is_symbolic(s) else s for s in value.size()] 192 ) 193 elif not has_free_symbols(value): 194 # this is a fully static shape, and the keys on props here inform specialization. 195 # We have to cast to int here, because these might get accessed as ConstantVariable, which has 196 # a strict no-symint policy. If we got here due to not having free symbols, this is a known constant 197 # already. We could remove the discrepancy here, by having ConstantVariable be more permissive for 198 # constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and 199 # I'd like to keep it around for now. 200 props["size"] = tuple( 201 # the non is_symbolic case applies to the jagged layout 202 # NestedTensor case as singleton ints are not symbolic 203 [int(s) if is_symbolic(s) else s for s in value.size()] 204 ) 205 props["stride"] = tuple(value.stride()) 206 if torch._C._functorch.is_batchedtensor(value): 207 # Batched tensors does not support contiguity patterns, so 208 # we refrain from computing the `is_contiguous` property 209 props["is_contiguous"] = None 210 else: 211 props["is_contiguous"] = tuple( 212 [ 213 x 214 for x in torch._prims_common._memory_formats 215 if value.is_contiguous(memory_format=x) 216 ] 217 ) 218 return props 219 220 def dynamic_getattr(self, tx: "InstructionTranslator", name): 221 fake_val = self.proxy.node.meta["example_value"] 222 # For getattrs on tensors without sources, 223 # we can do better than the default (creating a GetAttrVariable) 224 # if: 225 # (1) the tensor is a traceable tensor subclass 226 # (2) We are getattr'ing an inner tensor from that subclass 227 if not self.source and is_traceable_wrapper_subclass(fake_val): 228 fake_val = self.proxy.node.meta["example_value"] 229 attrs, ctx = fake_val.__tensor_flatten__() 230 proxy = getattr(self.as_proxy(), name) 231 example_value = getattr(fake_val, name) 232 if name in attrs: 233 # attrs returned from tensor_flatten are always tensors 234 assert isinstance(example_value, torch.Tensor) 235 from .builder import wrap_fx_proxy 236 237 return wrap_fx_proxy(tx=tx, proxy=proxy, example_value=example_value) 238 # any other attributes on the subclass (that are not methods) 239 # are assumed to be constant metadata. 240 elif not callable(example_value): 241 from .builder import SourcelessBuilder 242 243 return SourcelessBuilder.create(tx, example_value) 244 245 if not (self.source and self.source.subguards_allowed()): 246 raise NotImplementedError 247 248 # For local source, we associate the real value. We use this real value 249 # for implementing getattr fallthrough on the variable tracker base class. 250 251 # Note - this scope construction is mirrored in guards 252 # A subsequent PR will introduce a util. 253 scope = {"L": tx.output.local_scope, "G": tx.output.global_scope} 254 try: 255 # We raise in case we get a typerror bug w/ SuperSource. 256 # SuperSource has bugs in it atm, and can produce code like 257 # eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__, 258 # L['mod'].model.model.encoder.embed_positions)", scope) 259 # Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope. 260 _input_associated_real_value = eval(self.source.name(), scope) 261 except Exception as exc: 262 raise NotImplementedError from exc 263 264 if _input_associated_real_value is None: 265 raise NotImplementedError 266 267 if object_has_getattribute(_input_associated_real_value): 268 raise NotImplementedError 269 270 if get_custom_getattr(_input_associated_real_value): 271 raise NotImplementedError 272 273 real_value = getattr(_input_associated_real_value, name) 274 if callable(real_value): 275 # Callables have more nuanced handling, and we should let the existing system delegate here. 276 # Raising was past behavior and so should always be sound to fall back. 277 # Note - at a certain point we may want to handle 278 raise NotImplementedError 279 280 from ..guards import GuardBuilder 281 from .builder import VariableBuilder 282 283 attr_source = AttrSource(self.source, name) 284 install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) 285 return VariableBuilder(tx, attr_source)(real_value) 286 287 def method_attr_ndim(self, tx): 288 if self.ndim is not None: 289 return ConstantVariable.create(self.ndim) 290 else: 291 return self.call_method(tx, "dim", [], {}) 292 293 def method_attr_dtype(self, tx): 294 if self.dtype is not None: 295 return ConstantVariable.create(self.dtype) 296 297 def method_attr_device(self, tx): 298 if self.device is not None: 299 return ConstantVariable.create(self.device) 300 301 def method_attr_layout(self, tx): 302 if self.layout is not None: 303 return ConstantVariable.create(self.layout) 304 305 def method_attr_is_cuda(self, tx): 306 if self.device is not None: 307 return ConstantVariable.create(self.device.type == "cuda") 308 309 def method_attr_shape(self, tx): 310 if self.size is not None: 311 sizes = [variables.ConstantVariable.create(x) for x in self.size] 312 return SizeVariable(sizes) 313 else: 314 return self.call_method(tx, "size", [], {}) 315 316 def method_attr_requires_grad(self, tx): 317 if self.requires_grad is not None: 318 return ConstantVariable.create(self.requires_grad) 319 320 def method_attr_is_quantized(self, tx): 321 if self.is_quantized is not None: 322 return ConstantVariable.create(self.is_quantized) 323 324 def method_attr_is_sparse(self, tx): 325 if self.is_sparse is not None: 326 return ConstantVariable.create(self.is_sparse) 327 328 def method_attr_data(self, tx): 329 return variables.TorchInGraphFunctionVariable( 330 torch._C._autograd._get_data_attr 331 ).call_function(tx, [self], {}) 332 333 def method_attr_grad_fn(self, tx): 334 if self.has_grad_fn: 335 unimplemented("TensorVariable has a grad_fn") 336 else: 337 return variables.ConstantVariable(None) 338 339 def method_attr__version(self, tx): 340 from ..tensor_version_op import _tensor_version 341 342 return variables.TorchInGraphFunctionVariable(_tensor_version).call_function( 343 tx, [self], {} 344 ) 345 346 def call_hasattr(self, tx: "InstructionTranslator", name): 347 from . import GetAttrVariable 348 from .builtin import BuiltinVariable 349 350 try: 351 var = BuiltinVariable(getattr).call_function( 352 tx, [self, ConstantVariable(name)], {} 353 ) 354 # in the event that TensorVariable returns NotImplemented 355 # BuiltinVariable.call_getattr returns GetAttrVariable 356 ret_val = not isinstance(var, GetAttrVariable) 357 except AttributeError: 358 ret_val = False 359 360 if self.source: 361 install_guard( 362 AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) 363 ) 364 365 return ConstantVariable(ret_val) 366 367 def var_getattr(self, tx: "InstructionTranslator", name): 368 from . import UserDefinedClassVariable 369 370 if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops(): 371 unimplemented(f"Illegal getattr invocation {name} in strict mode") 372 373 if name == "__class__": 374 return UserDefinedClassVariable(self.python_type()) 375 376 handler = getattr(self, f"method_attr_{name}", None) 377 result = handler(tx) if handler is not None else None 378 379 # Add a guard for type matching, these guards are checked before tensor guards 380 # In some cases, a <tensor>.<attr> guard can be evaluated first, and break if 381 # <tensor> is later changed to another type 382 if ( 383 result is not None 384 and self.source 385 and self.source.subguards_allowed() 386 and not ( 387 name not in ("grad", "requires_grad") and result.is_python_constant() 388 ) 389 ): 390 install_guard(self.make_guard(GuardBuilder.TYPE_MATCH)) 391 result.source = AttrSource(self.source, name) 392 393 # It's hard to get inplace view (metadata mutation) on graph input work properly across 394 # dynamo/aot/inductor, just fall back. 395 if self.source is not None and hasattr(torch.ops.aten, name): 396 fn = getattr(torch.ops.aten, name) 397 if ( 398 hasattr(fn, "overloads") 399 and hasattr(fn, fn.overloads()[0]) 400 and torch.Tag.inplace_view in getattr(fn, fn.overloads()[0]).tags 401 ): 402 # Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc. 403 return variables.misc.DelayGraphBreakVariable( 404 source=AttrSource(self.source, name) 405 ) 406 407 # For attributes (not methods) that were not caught in the special handling above, 408 # (e.g. tensor.real), we handle these generically, assuming that the output type is 409 # a tensor. 410 if result is None and name != "grad": 411 412 def try_generic_attr_handling(): 413 from .builder import wrap_fx_proxy 414 from .misc import GetAttrVariable 415 416 try: 417 static_attr = inspect.getattr_static(torch.Tensor, name) 418 except AttributeError: 419 return None 420 421 # Make sure this is an attribute, not a method. 422 # type(torch.Tensor.H) should be "getset_descriptor" 423 # This is a because of CPython implementation, see THPVariableType: 424 # these attributes are implemented under tp_getset, which appear 425 # as `getset_descriptor`s, (compared to, say, methods which appear 426 # as `method_descriptor`s) 427 if type(static_attr) != types.GetSetDescriptorType: 428 return None 429 430 proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) 431 if self.source is not None: 432 return wrap_fx_proxy( 433 tx=tx, proxy=proxy, source=AttrSource(self.source, name) 434 ) 435 else: 436 return wrap_fx_proxy(tx=tx, proxy=proxy) 437 438 result = try_generic_attr_handling() 439 440 if result is None: 441 result = self.dynamic_getattr(tx, name) 442 443 if result is None: 444 raise NotImplementedError 445 return result 446 447 def call_id(self, tx): 448 if not self.source: 449 unimplemented("call_id not supported for sourceless TensorVariable") 450 451 # For local source, we associate the real value. We use this real value 452 scope = {"L": tx.output.local_scope, "G": tx.output.global_scope} 453 try: 454 _input_associated_real_value = eval(self.source.name(), scope) 455 except Exception as exc: 456 unimplemented(f"error getting associated real value: {exc}") 457 458 if _input_associated_real_value is None: 459 unimplemented("call_id without associated real value") 460 461 install_guard(self.source.make_guard(GuardBuilder.ID_MATCH)) 462 id_value = id(_input_associated_real_value) 463 return ConstantVariable.create(id_value) 464 465 def has_unpack_var_sequence(self, tx): 466 return self.ndim > 0 467 468 def unpack_var_sequence(self, tx: "InstructionTranslator", idxes=None): 469 from .builder import wrap_fx_proxy_cls 470 471 if self.size: 472 size_len = len(self.size) 473 else: 474 size_var = self.call_method(tx, "size", [], {}) 475 assert isinstance(size_var, SizeVariable) 476 size_len = len(size_var.items) 477 # Ensure we don't unpack a scalar tensor. 478 assert size_len != 0, "Can't unpack scalar tensors." 479 480 if self.size: 481 length = self.size[0] 482 else: 483 dyn_length = self.call_method(tx, "size", [ConstantVariable.create(0)], {}) 484 # SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values produced through 485 # symbolic_shapes, but that end up as int/sympy.Integer 486 assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable)) 487 if isinstance(dyn_length, SymNodeVariable): 488 length = dyn_length.evaluate_expr(tx.output) 489 else: 490 length = dyn_length.value 491 492 if idxes is None: 493 idxes = range(length) 494 else: 495 assert ( 496 len(idxes) == length 497 ), f"Can't unpack a tensor of {length} rows into a tuple of {len(idxes)} elements." 498 return [ 499 wrap_fx_proxy_cls(target_cls=type(self), tx=tx, proxy=self.as_proxy()[i]) 500 for i in idxes 501 ] 502 503 def _strict_mode_banned_ops(self): 504 return torch._dynamo.config._autograd_backward_strict_mode_banned_ops 505 506 def call_method( 507 self, 508 tx, 509 name, 510 args: "List[VariableTracker]", 511 kwargs: "Dict[str, VariableTracker]", 512 ) -> "VariableTracker": 513 if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops(): 514 unimplemented(f"Illegal method invocation {name} in strict mode") 515 516 """ 517 Dispatch to a method-specific handler defined below. If the 518 handler returns None (or doesn't exist) we put the method call 519 in the graph. 520 """ 521 try: 522 handler_method = getattr(self, f"method_{name}") 523 except AttributeError: 524 pass 525 else: 526 try: 527 result = handler_method(*args, **kwargs) 528 if result: 529 return result 530 except TypeError as e: 531 unimplemented(f"unhandled args for {name}: {e}") 532 533 from .builder import wrap_fx_proxy 534 535 return wrap_fx_proxy( 536 tx, 537 tx.output.create_proxy( 538 "call_method", 539 name, 540 *proxy_args_kwargs([self, *args], kwargs), 541 ), 542 ) 543 544 def method_size(self, *args, **kwargs): 545 return self._method_size_stride("size", *args, **kwargs) 546 547 def method_stride(self, *args, **kwargs): 548 return self._method_size_stride("stride", *args, **kwargs) 549 550 def _method_size_stride(self, name, dim=None): 551 dim = guard_if_dyn(dim) 552 553 def make_const_size_variable(x, **options): 554 return SizeVariable( 555 [ConstantVariable.create(y, **options) for y in x], **options 556 ) 557 558 RetVariable = ( 559 make_const_size_variable if name == "size" else ConstantVariable.create 560 ) 561 562 # Technically, this should not be necessary, but I'm including it 563 # for enhanced BC, in case example_value is sometimes not set 564 # (it really should always be set though!) 565 if (r := getattr(self, name)) is not None: 566 if dim is None: 567 return RetVariable(r) 568 else: 569 return ConstantVariable.create(r[dim]) 570 571 # It might still be constant! Consult the fake tensor and see 572 if (fake := self.proxy.node.meta.get("example_value")) is not None: 573 if dim is None: 574 fake_r = getattr(fake, name)() 575 if not has_free_symbols(fake_r): 576 # int conversion for safety, in case a SymInt refined 577 # to constant 578 return RetVariable(tuple(int(r) for r in fake_r)) 579 else: 580 fake_r = getattr(fake, name)(dim) 581 if not has_free_symbols(fake_r): 582 return ConstantVariable.create(int(fake_r)) 583 584 def method_numel(self): 585 if self.size is not None: 586 return ConstantVariable.create(product(self.size)) 587 588 # It might still be constant! Consult the fake tensor and see 589 if (fake := self.proxy.node.meta.get("example_value")) is not None: 590 fake_r = fake.numel() 591 if not has_free_symbols(fake_r): 592 return ConstantVariable.create(int(fake_r)) 593 594 method_nelement = method_numel 595 596 def method_dim(self): 597 if self.ndim is not None: 598 return ConstantVariable.create(self.ndim) 599 600 method_ndimension = method_dim 601 602 def method_is_floating_point(self): 603 if self.dtype is not None: 604 return ConstantVariable.create(self.dtype.is_floating_point) 605 606 def method_is_complex(self): 607 if self.dtype is not None: 608 return ConstantVariable.create(self.dtype.is_complex) 609 610 def method_is_contiguous(self, memory_format=None): 611 memory_format = ( 612 memory_format.as_python_constant() 613 if memory_format is not None 614 else torch.contiguous_format 615 ) 616 if self.is_contiguous is not None: 617 return ConstantVariable.create(memory_format in self.is_contiguous) 618 elif (fake := self.proxy.node.meta.get("example_value")) is not None: 619 return ConstantVariable.create( 620 fake.is_contiguous(memory_format=memory_format) 621 ) 622 623 def method_type(self, dtype=None, non_blocking=False, **kwargs): 624 if ( 625 dtype is None 626 and self.dtype is not None 627 and isinstance(self.device, torch.device) 628 ): 629 tensortype = next( 630 k for k, v in tensortype_to_dtype.items() if self.dtype in v 631 ) 632 if self.device.type == "cuda": 633 return ConstantVariable.create(f"torch.cuda.{tensortype.__name__}") 634 else: 635 return ConstantVariable.create(f"torch.{tensortype.__name__}") 636 elif ( 637 dtype is not None 638 and fqn(type(dtype.as_python_constant())) == "torch.tensortype" 639 ): 640 # torch.FloatTensor, etc. are all of type "torch.tensortype". 641 # torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type. 642 # So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args) 643 tensor_type = dtype.as_python_constant() 644 tensor_type_const = ConstantVariable.create(fqn(tensor_type)) 645 646 from ..symbolic_convert import InstructionTranslator 647 from .builder import wrap_fx_proxy 648 649 tx = InstructionTranslator.current_tx() 650 651 if non_blocking: 652 kwargs = {"non_blocking": non_blocking, **kwargs} 653 654 return wrap_fx_proxy( 655 tx, 656 tx.output.create_proxy( 657 "call_method", 658 "type", 659 *proxy_args_kwargs([self, tensor_type_const], kwargs), 660 ), 661 ) 662 663 def method_as_subclass(self, cls): 664 if isinstance(cls, TensorSubclassVariable) and cls.source: 665 from ..symbolic_convert import InstructionTranslator 666 from .builder import VariableBuilder 667 from .torch_function import TensorWithTFOverrideVariable 668 669 tx = InstructionTranslator.current_tx() 670 671 # [Note: __torch_function__] coerce this tensor variable into a TensorWithTFOverrideVariable 672 # in eager, this is just a type change. This isn't sound if a __torch_function__ tensor subclass 673 # defines a constructor, but if only a __torch_function__ impl is defined, this is okay to call. 674 # It is up to the user whether this is correct behavior or not. 675 py_cls = cls.as_python_constant() 676 torch_fn = VariableBuilder( 677 tx, 678 AttrSource(AttrSource(cls.source, "__torch_function__"), "__func__"), 679 )(py_cls.__torch_function__.__func__) 680 681 return TensorWithTFOverrideVariable.from_tensor_var( 682 tx, self, py_cls, torch_fn 683 ) 684 685 def method_get_device(self): 686 if isinstance(self.device, torch.device): 687 index = self.device.index if self.device.type != "cpu" else -1 688 return ConstantVariable.create(index) 689 690 def method_element_size(self): 691 return ConstantVariable.create(self.dtype.itemsize) 692 693 def method_numpy(self, *, force=False): 694 if not config.trace_numpy: 695 unimplemented("Tensor.numpy(). config.trace_numpy is False") 696 if not np: 697 unimplemented("Tensor.numpy(). NumPy is not available") 698 if self.layout != torch.strided: 699 raise TypeError( 700 f"can't convert {self.layout} layout tensor to numpy. Use Tensor.dense() first" 701 ) 702 from ..symbolic_convert import InstructionTranslator 703 704 tx = InstructionTranslator.current_tx() 705 706 # We don't check that the tensor is on CPU when force is False, as this 707 # allows us to execute NumPy code on CUDA. Same for requires_grad=True 708 if force and force.as_python_constant(): 709 # If the user set force=True we try to preserve the semantics (no gradients, move to CPU...) 710 t = self.call_method(tx, "detach", [], {}) 711 proxy = tx.output.create_proxy("call_method", "cpu", (t.as_proxy(),), {}) 712 else: 713 # Hacky way to create a view of self that will be marked as NumpyNdarrayVariable 714 proxy = tx.output.create_proxy( 715 "call_method", "view_as", *proxy_args_kwargs([self, self], {}) 716 ) 717 return NumpyNdarrayVariable.create(tx, proxy) 718 719 def method_tolist(self): 720 from ..symbolic_convert import InstructionTranslator 721 from .builder import SourcelessBuilder 722 723 tx = InstructionTranslator.current_tx() 724 725 def tolist(tensor, sub_proxy): 726 def wrap(i, sub_proxy): 727 # Sigh, we forgot to gate this, so this data dependent is on 728 # by default and is load bearing in CI 729 with unittest.mock.patch.object( 730 tx.fake_mode, "allow_scalar_outputs", True 731 ): 732 return SymNodeVariable.create( 733 tx, 734 sub_proxy.item(), 735 ) 736 737 if tensor.dtype not in [ 738 torch.int8, 739 torch.int16, 740 torch.int32, 741 torch.int64, 742 ]: 743 unimplemented("Input tensor for tolist must be an integer tensor") 744 745 if tensor.dim() == 0: 746 return wrap(tensor, sub_proxy) 747 748 if tensor.dim() == 1: 749 return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)] 750 751 return [ 752 tolist(sub_tensor, sub_proxy=sub_proxy[i]) 753 for i, sub_tensor in enumerate(tensor) 754 ] 755 756 tensor = self.as_proxy().node.meta["example_value"] 757 out = tolist(tensor, self.as_proxy()) 758 return SourcelessBuilder.create(tx, out) 759 760 def method_backward(self, *args, **kwargs): 761 unimplemented("Tensor.backward") 762 763 def method_data_ptr(self, *args, **kwargs): 764 unimplemented("Tensor.data_ptr") 765 766 def method_item(self, *args, **kwargs): 767 if not config.capture_scalar_outputs: 768 self._warn_capture_scalar_outputs() 769 unimplemented("Tensor.item") 770 771 @staticmethod 772 @functools.lru_cache(None) 773 def _warn_capture_scalar_outputs(): 774 user_stack = torch._guards.TracingContext.extract_stack() 775 user_stack_formatted = "".join(traceback.format_list(user_stack)) 776 log.warning( 777 textwrap.dedent( 778 """\ 779 Graph break from `Tensor.item()`, consider setting: 780 torch._dynamo.config.capture_scalar_outputs = True 781 or: 782 env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 783 to include these operations in the captured graph. 784 785 Graph break: from user code at: 786 %s 787 """ 788 ), 789 user_stack_formatted, 790 ) 791 792 def method___len__(self): 793 from ..symbolic_convert import InstructionTranslator 794 795 tx = InstructionTranslator.current_tx() 796 return self.call_method(tx, "size", [ConstantVariable.create(0)], {}) 797 798 def method_addcmul_(self, tensor1, tensor2, *, value=None): 799 from ..symbolic_convert import InstructionTranslator 800 801 tx = InstructionTranslator.current_tx() 802 if value is not None: 803 from .. import polyfills 804 from .builder import SourcelessBuilder 805 806 return tx.inline_user_function_return( 807 SourcelessBuilder.create(tx, polyfills.addcmul_inplace), 808 [self, tensor1, tensor2, value], 809 {}, 810 ) 811 812 def method___setitem__(self, key, value): 813 def has_bool_key(v): 814 if isinstance(v, TensorVariable): 815 return v.dtype in (torch.bool, torch.int8) 816 elif isinstance(v, variables.TupleVariable): 817 return any(has_bool_key(item) for item in v.items) 818 else: 819 return False 820 821 if ( 822 has_bool_key(key) 823 and isinstance(value, TensorVariable) 824 and value.requires_grad 825 and torch.is_grad_enabled() 826 ): 827 unimplemented( 828 "boolean masking setitem backwards, see https://github.com/pytorch/pytorch/issues/114123" 829 ) 830 from ..symbolic_convert import InstructionTranslator 831 832 tx = InstructionTranslator.current_tx() 833 tx.output.create_proxy( 834 "call_function", 835 operator.setitem, 836 *proxy_args_kwargs([self, key, value], {}), 837 ) 838 return ConstantVariable.create(None) 839 840 def method_resize_(self, *args, **kwargs): 841 unimplemented("Tensor.resize_") 842 843 def method_resize_as_(self, *args, **kwargs): 844 unimplemented("Tensor.resize_as_") 845 846 def method_sparse_resize_(self, *args, **kwargs): 847 unimplemented("Tensor.sparse_resize_") 848 849 def method_sparse_resize_and_clear_(self, *args, **kwargs): 850 unimplemented("Tensor.sparse_resize_and_clear_") 851 852 def method_set_(self, *args, **kwargs): 853 if len(args) > 1: 854 # torch.Tensor.set_() has several overloads. 855 # aten::set_.source_Tensor(Tensor) gets special handling 856 # in AOTAutograd and functionalization, because it is the most common 857 # overload and is used by FSDP. 858 # graph-breaking on aten::set_source_Tensor_storage_offset for now, 859 # unless we find that we need to make it work. 860 unimplemented("Tensor.set_.source_Tensor_storage_offset") 861 862 def method_add_(self, other, *, alpha=None): 863 if alpha is not None: 864 from ..symbolic_convert import InstructionTranslator 865 866 tx = InstructionTranslator.current_tx() 867 result = variables.TorchInGraphFunctionVariable(torch.mul).call_function( 868 tx, [other, alpha], {} 869 ) 870 return self.call_method(tx, "add_", [result], {}) 871 872 def method_addcdiv_(self, tensor1, tensor2, *, value=None): 873 from ..symbolic_convert import InstructionTranslator 874 875 tx = InstructionTranslator.current_tx() 876 if value is not None: 877 result = variables.TorchInGraphFunctionVariable(torch.div).call_function( 878 tx, [tensor1, tensor2], {} 879 ) 880 result = variables.TorchInGraphFunctionVariable(torch.mul).call_function( 881 tx, [result, value], {} 882 ) 883 return self.call_method(tx, "add_", [result], {}) 884 885 def method___contains__(self, arg): 886 from ..symbolic_convert import InstructionTranslator 887 888 tx = InstructionTranslator.current_tx() 889 890 # Rewrite __contains__ here so that downstream passes can trace through 891 # without dealing with unbacked symbool. Roughly the code we translate is: 892 # def __contains__(self, x): 893 # return (x == self).any().item() 894 result = variables.TorchInGraphFunctionVariable(torch.eq).call_function( 895 tx, [self, arg], {} 896 ) 897 result = variables.TorchInGraphFunctionVariable(torch.any).call_function( 898 tx, [result], {} 899 ) 900 return result.call_method(tx, "item", [], {}) 901 902 def method_redistribute(self, *args, **kwargs): 903 from ..symbolic_convert import InstructionTranslator 904 905 tx = InstructionTranslator.current_tx() 906 # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function 907 # and rewrite args to have only proxyable args, then insert call_function 908 args_as_value = [x.as_python_constant() for x in args] 909 kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} 910 911 def redistribute_fn_with_prim_types(x): 912 return x.redistribute(*args_as_value, **kwargs_as_value) 913 914 # attach the same function name for better debugging 915 redistribute_fn_with_prim_types.__name__ = "prim_redistribute" 916 917 from .builder import wrap_fx_proxy 918 919 return wrap_fx_proxy( 920 tx=tx, 921 proxy=tx.output.create_proxy( 922 "call_function", 923 redistribute_fn_with_prim_types, 924 *proxy_args_kwargs([self], {}), 925 ), 926 ) 927 928 def method_to_local(self, *args, **kwargs): 929 from ..symbolic_convert import InstructionTranslator 930 931 tx = InstructionTranslator.current_tx() 932 # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function 933 # and rewrite args to have only proxyable args, then insert call_function 934 args_as_value = [x.as_python_constant() for x in args] 935 kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} 936 937 def to_local_fn_with_prim_types(x): 938 return x.to_local(*args_as_value, **kwargs_as_value) 939 940 # attach the same function name for better debugging 941 to_local_fn_with_prim_types.__name__ = "prim_to_local" 942 943 from .builder import wrap_fx_proxy 944 945 return wrap_fx_proxy( 946 tx=tx, 947 proxy=tx.output.create_proxy( 948 "call_function", 949 to_local_fn_with_prim_types, 950 *proxy_args_kwargs([self], {}), 951 ), 952 ) 953 954 def method_register_hook(self, *args, **kwargs): 955 return self._method_register_hook("register_hook", *args, **kwargs) 956 957 def method_register_post_accumulate_grad_hook(self, *args, **kwargs): 958 return self._method_register_hook( 959 "register_post_accumulate_grad_hook", *args, **kwargs 960 ) 961 962 def _method_register_hook(self, name: str, hook: VariableTracker): 963 # Note - do not arbitrarily add hooks here - make sure they match the same contract 964 # see [On tensor.register_hook] 965 from ..symbolic_convert import InstructionTranslator 966 967 tx = InstructionTranslator.current_tx() 968 969 if not self.source: 970 if not compiled_autograd.compiled_autograd_enabled: 971 # TODO(voz): 972 # We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary 973 # python state. 974 # We *Must* be in compiled_autograd here because backward hooks can contain anything, and it is unsafe to run 975 # them in a compiled bwd without re-entering dynamo as compiled_autograd does. 976 # 977 # Discussion point 1 - Should we bypass this if nopython/fullgraph = True? 978 # No. Because this was going to be a graph break anyway - this check does not 979 # introduce new graph breaks where there were none. 980 # 981 # Discussion point 2 - Should we defer this check to backwards? 982 # No. Because compiled autograd is not yet ready for prime time. As such, if we defer, a user 983 # would have no recourse - their forward traces just fine, but will fail at backwards unless 984 # compiled_autograd is enabled. If compiled_autograd fails (there are a lot of failures today) 985 # then they have nothing they can do except disable compile. 986 unimplemented( 987 "Compilation of intermediate hooks requires compiled autograd" 988 ) 989 990 hook_name, bw_state_proxy = tx.output.add_backward_state_hook(hook) 991 992 def _register_hook_trampoline(tensor, bw_state): 993 register_hook = getattr(tensor, name) 994 register_hook( 995 functools.partial( 996 trace_wrapped, 997 fn=call_hook_from_backward_state, 998 bw_state=bw_state, 999 hook_name=hook_name, 1000 ) 1001 ) 1002 # TODO(jansel): returning None here is wrong, it should be 1003 # RemovableHandle, but we need some extra work to support 1004 # this properly. 1005 return None 1006 1007 from .builder import wrap_fx_proxy 1008 1009 return wrap_fx_proxy( 1010 tx, 1011 tx.output.create_proxy( 1012 "call_function", 1013 _register_hook_trampoline, 1014 (self.as_proxy(), bw_state_proxy), 1015 {}, 1016 ), 1017 ) 1018 1019 handle_variable = variables.RemovableHandleVariable( 1020 mutable_local=variables.base.MutableLocal(), 1021 ) 1022 tx.output.side_effects.register_hook(self, hook, handle_variable, name) 1023 return handle_variable 1024 1025 def method_requires_grad_(self, requires_grad=True): 1026 if requires_grad is not True: 1027 requires_grad = requires_grad.as_python_constant() 1028 1029 if self.as_proxy().node.meta["example_value"].requires_grad != requires_grad: 1030 unimplemented("Tensor.requires_grad_") 1031 else: 1032 return self 1033 1034 def method_new(self, *args, **kwargs): 1035 # Convert x.new(torch.Size) into x.new_empty(torch.Size), 1036 # as Tensor.new acts differently with a Size input versus a tuple input. 1037 if (len(args) == 1 and isinstance(args[0], SizeVariable)) or ( 1038 len(args) >= 1 1039 and all( 1040 isinstance(a, ConstantVariable) and a.python_type() == int for a in args 1041 ) 1042 ): 1043 from ..symbolic_convert import InstructionTranslator 1044 1045 return self.call_method( 1046 InstructionTranslator.current_tx(), "new_empty", args, kwargs 1047 ) 1048 1049 def method_untyped_storage(self): 1050 return UntypedStorageVariable( 1051 self, self.as_proxy().node.meta["example_value"].untyped_storage() 1052 ) 1053 1054 def set_name_hint(self, name: str): 1055 if not self._is_name_set: 1056 self.proxy.node._rename(name) 1057 self._is_name_set = True 1058 1059 1060class SymNodeVariable(VariableTracker): 1061 """ 1062 Represents a symbolic scalar, either int, float or bool. This is most commonly used to 1063 handle symbolic size computation, e.g., tensor.size(0), but it is also used to 1064 handle logic like float_tensor.item() or unspecialized float inputs. 1065 """ 1066 1067 _nonvar_fields = { 1068 "proxy", 1069 "sym_num", 1070 *VariableTracker._nonvar_fields, 1071 } 1072 1073 def debug_repr(self): 1074 return repr(self.sym_num) 1075 1076 @classmethod 1077 def create(cls, tx, proxy, sym_num=None, **options): 1078 if sym_num is None: 1079 sym_num = get_fake_value(proxy.node, tx) 1080 if "example_value" in proxy.node.meta: 1081 assert proxy.node.meta["example_value"] == sym_num 1082 set_example_value(proxy.node, sym_num) 1083 1084 if isinstance(sym_num, (sympy.Integer, int, bool)): 1085 sym_num = int(sym_num) if isinstance(sym_num, sympy.Integer) else sym_num 1086 return ConstantVariable.create(sym_num) 1087 1088 return SymNodeVariable(proxy, sym_num, **options) 1089 1090 def __init__(self, proxy, sym_num, **kwargs) -> None: 1091 super().__init__(**kwargs) 1092 self.proxy = proxy 1093 # TODO: Should we allow non SymTypes here? Today it is allowed 1094 self.sym_num = sym_num 1095 self._tensor_var = None 1096 1097 def python_type(self): 1098 if isinstance(self.sym_num, SymTypes): 1099 return self.sym_num.node.pytype 1100 else: 1101 return type(self.sym_num) 1102 1103 def as_proxy(self): 1104 return self.proxy 1105 1106 def as_tensor(self, tx): 1107 if self._tensor_var is None: 1108 from .builder import SourcelessBuilder 1109 1110 self._tensor_var = SourcelessBuilder.create( 1111 tx, torch.scalar_tensor 1112 ).call_function(tx, [self], {}) 1113 return self._tensor_var 1114 1115 def evaluate_expr(self, output_graph=None): 1116 try: 1117 return guard_scalar(self.sym_num) 1118 except GuardOnDataDependentSymNode as e: 1119 raise UserError( # noqa: B904 1120 UserErrorType.ANTI_PATTERN, 1121 f"Consider annotating your code using torch._check*(). {str(e)}", 1122 case_name="constrain_as_size_example", 1123 ) 1124 1125 def call_method( 1126 self, 1127 tx, 1128 name, 1129 args: "List[VariableTracker]", 1130 kwargs: "Dict[str, VariableTracker]", 1131 ) -> "VariableTracker": 1132 from .builder import wrap_fx_proxy 1133 1134 return wrap_fx_proxy( 1135 tx, 1136 tx.output.create_proxy( 1137 "call_method", 1138 name, 1139 *proxy_args_kwargs([self, *args], kwargs), 1140 ), 1141 ) 1142 1143 1144class NumpyNdarrayVariable(TensorVariable): 1145 """ 1146 Represents a np.ndarray, but backed by torch Tensor via torch._numpy.ndarray. 1147 Use this for Tensor.numpy() call. 1148 """ 1149 1150 @staticmethod 1151 def create(tx: "InstructionTranslator", proxy, **options): 1152 from .builder import wrap_fx_proxy_cls 1153 1154 return wrap_fx_proxy_cls( 1155 target_cls=NumpyNdarrayVariable, 1156 tx=tx, 1157 proxy=proxy, 1158 **options, 1159 ) 1160 1161 def var_getattr(self, tx: "InstructionTranslator", name): 1162 # NB: This INTENTIONALLY does not call super(), because there is 1163 # no intrinsic reason ndarray properties are related to Tensor 1164 # properties. The inheritance here is for implementation sharing. 1165 1166 from ..utils import numpy_attr_wrapper 1167 from .builder import wrap_fx_proxy 1168 1169 result = None 1170 1171 example_value = self.as_proxy().node.meta["example_value"] 1172 example_ndarray = tnp.ndarray(example_value) 1173 1174 def insert_into_graph(): 1175 return wrap_fx_proxy( 1176 tx, 1177 tx.output.create_proxy( 1178 "call_function", numpy_attr_wrapper, (self.as_proxy(), name), {} 1179 ), 1180 ) 1181 1182 if name in ["T", "real", "imag"]: 1183 proxy = tx.output.create_proxy( 1184 "call_function", 1185 numpy_attr_wrapper, 1186 (self.as_proxy(), name), 1187 {}, 1188 ) 1189 result = NumpyNdarrayVariable.create(tx, proxy) 1190 1191 # These are awkward to implement. The standard playbook for torch._numpy 1192 # interop is to trace a call into the torch._numpy wrapper which works for 1193 # Tensor operations. However, we don't want to do this for calls 1194 # that don't return Tensors, because in those cases we may not want 1195 # to trace the attribute access into the graph at all (it is sort 1196 # of harmless to do so, because AOTAutograd will eliminate them, 1197 # but it's best not to trace them in to begin with.) But in any 1198 # case, tracing these into the graph is like trying to fit a square 1199 # peg into a round hole; best not to do it. So instead we 1200 # painstakingly implement these by hand 1201 # 1202 # NB: only ALWAYS specialized attributes can go here; notably, 1203 # size/shape not allowed! 1204 elif name in ("ndim", "itemsize"): 1205 return ConstantVariable.create(getattr(example_ndarray, name)) 1206 elif name in ("shape", "stride"): 1207 if not has_free_symbols(r := getattr(example_ndarray, name)): 1208 return ConstantVariable.create(tuple(int(r) for r in r)) 1209 return insert_into_graph() 1210 elif name == "size": 1211 if not has_free_symbols(r := example_ndarray.size): 1212 return ConstantVariable.create(int(r)) 1213 return insert_into_graph() 1214 elif name in ["base", "flags", "dtype"]: 1215 unimplemented(f"TODO: add support for ndarray.{name}") 1216 elif name in ["__version__"]: 1217 unimplemented("delegate np.__version__ to NumPy") 1218 if result is None: 1219 raise NotImplementedError 1220 return result 1221 1222 @staticmethod 1223 def patch_args(name, args, kwargs): 1224 if name == "clip": 1225 kwargs_rename = {"a_min": "min", "a_max": "max"} 1226 kwargs = {kwargs_rename.get(k, k): v for k, v in kwargs.items()} 1227 return args, kwargs 1228 1229 def call_method( 1230 self, 1231 tx, 1232 name, 1233 args: "List[VariableTracker]", 1234 kwargs: "Dict[str, VariableTracker]", 1235 ) -> "VariableTracker": 1236 from ..utils import numpy_method_wrapper 1237 1238 args, kwargs = self.patch_args(name, args, kwargs) 1239 1240 if name in ["__len__", "size", "tolist"]: 1241 # delegate back to TensorVariable 1242 return super().call_method(tx, name, args, kwargs) 1243 if name in ("tostring", "tobytes"): 1244 unimplemented(f"{name} is not modelled in torch._numpy") 1245 proxy = tx.output.create_proxy( 1246 "call_function", 1247 numpy_method_wrapper(name), 1248 *proxy_args_kwargs([self] + list(args), kwargs), 1249 ) 1250 return NumpyNdarrayVariable.create(tx, proxy) 1251 1252 def python_type(self): 1253 return np.ndarray 1254 1255 1256class UnspecializedPythonVariable(TensorVariable): 1257 """ 1258 This is a 1-element tensor represents unspecialized python float/int. 1259 """ 1260 1261 _nonvar_fields = { 1262 "raw_value", 1263 "need_unwrap", 1264 *TensorVariable._nonvar_fields, 1265 } 1266 1267 def __init__( 1268 self, proxy: torch.fx.Proxy, *, raw_value=None, need_unwrap=True, **kwargs 1269 ) -> None: 1270 super().__init__(proxy, **kwargs) 1271 self.raw_value = raw_value 1272 self.need_unwrap = need_unwrap 1273 1274 @classmethod 1275 def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True): 1276 # Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance. 1277 return UnspecializedPythonVariable( 1278 **dict(tensor_variable.__dict__), 1279 raw_value=raw_value, 1280 need_unwrap=need_unwrap, 1281 ) 1282 1283 1284class FakeItemVariable(TensorVariable): 1285 """An unspecialized python variable which prevents access to the underlying raw value. 1286 This is needed if item is called on a FakeTensor.""" 1287 1288 _nonvar_fields = { 1289 "need_unwrap", 1290 *TensorVariable._nonvar_fields, 1291 } 1292 1293 def __init__(self, proxy: torch.fx.Proxy, **kwargs) -> None: 1294 need_unwrap = kwargs.pop("need_unwrap", False) 1295 super().__init__(proxy, **kwargs) 1296 self.need_unwrap = need_unwrap 1297 1298 @classmethod 1299 def from_tensor_variable(cls, tensor_variable): 1300 return FakeItemVariable(**dict(tensor_variable.__dict__)) 1301 1302 1303class TensorSubclassVariable(VariableTracker): 1304 def __init__(self, value, *args, **kwargs) -> None: 1305 self.value = value 1306 super().__init__(*args, **kwargs) 1307 1308 def call_function( 1309 self, 1310 tx: "InstructionTranslator", 1311 args: List[VariableTracker], 1312 kwargs: Dict[str, VariableTracker], 1313 ) -> VariableTracker: 1314 if len(args) == 1 and isinstance(args[0], TensorVariable): 1315 from .builder import VariableBuilder 1316 from .torch_function import TensorWithTFOverrideVariable 1317 1318 torch_fn = VariableBuilder( 1319 tx, AttrSource(self.source, "__torch_function__") 1320 )(self.value.__torch_function__) 1321 1322 return TensorWithTFOverrideVariable.from_tensor_var( 1323 tx, args[0], self.value, torch_fn 1324 ) 1325 1326 return super().call_function(tx, args, kwargs) 1327 1328 def as_python_constant(self): 1329 return self.value 1330 1331 1332class UntypedStorageVariable(VariableTracker): 1333 _nonvar_fields = { 1334 "example_value", 1335 *VariableTracker._nonvar_fields, 1336 } 1337 1338 def __init__( 1339 self, 1340 from_tensor: TensorVariable, 1341 example_value: torch.UntypedStorage, 1342 **kwargs, 1343 ) -> None: 1344 super().__init__(**kwargs), 1345 self.from_tensor = from_tensor 1346 # Example_value will always have device="meta" 1347 self.example_value = example_value 1348 1349 def call_method( 1350 self, 1351 tx, 1352 name, 1353 args: List[VariableTracker], 1354 kwargs: Dict[str, VariableTracker], 1355 ) -> VariableTracker: 1356 if name == "size": 1357 assert not args 1358 assert not kwargs 1359 result = self.example_value.size() 1360 if not has_free_symbols(result): 1361 # avoid creating a node in the graph 1362 return ConstantVariable.create(int(result)) 1363 else: 1364 from ..external_utils import untyped_storage_size 1365 from .builder import wrap_fx_proxy 1366 1367 return wrap_fx_proxy( 1368 tx, 1369 tx.output.create_proxy( 1370 "call_function", 1371 untyped_storage_size, 1372 (self.from_tensor.as_proxy(),), 1373 {}, 1374 ), 1375 ) 1376 if name == "resize_" and len(args) == 1: 1377 assert not kwargs 1378 tx.output.create_proxy( 1379 "call_function", 1380 torch.ops.inductor.resize_storage_bytes_, 1381 (self.from_tensor.as_proxy(), args[0].as_proxy()), 1382 {}, 1383 ) 1384 return self 1385 1386 return super().call_method(tx, name, args, kwargs) 1387 1388 def reconstruct(self, codegen): 1389 codegen(self.from_tensor) 1390 codegen.load_method("untyped_storage") 1391 codegen.call_method(0) 1392