xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import functools
4import inspect
5import logging
6import operator
7import textwrap
8import traceback
9import types
10import unittest
11from typing import Dict, List, TYPE_CHECKING
12
13import sympy
14
15import torch._numpy as tnp
16import torch.fx
17import torch.random
18from torch._dynamo import compiled_autograd
19from torch._subclasses.meta_utils import is_sparse_any
20from torch.fx.experimental.symbolic_shapes import (
21    guard_scalar,
22    GuardOnDataDependentSymNode,
23    has_free_symbols,
24    is_symbolic,
25    SymTypes,
26)
27from torch.utils._python_dispatch import is_traceable_wrapper_subclass
28
29from .. import config, variables
30from .._trace_wrapped_higher_order_op import trace_wrapped
31from ..exc import unimplemented, UserError, UserErrorType
32from ..external_utils import call_hook_from_backward_state
33from ..guards import GuardBuilder, install_guard
34from ..source import AttrSource
35from ..utils import (
36    fqn,
37    get_custom_getattr,
38    get_fake_value,
39    get_real_value,
40    guard_if_dyn,
41    object_has_getattribute,
42    product,
43    proxy_args_kwargs,
44    set_example_value,
45    tensortype_to_dtype,
46)
47from .base import VariableTracker
48from .constant import ConstantVariable
49from .lists import SizeVariable
50
51
52try:
53    import numpy as np
54except ModuleNotFoundError:
55    np = None
56
57
58if TYPE_CHECKING:
59    from torch._dynamo.symbolic_convert import InstructionTranslator
60
61
62log = logging.getLogger(__name__)
63
64# Ops that allow tensor <op> tensor
65supported_tensor_comparison_ops = {
66    ">": operator.gt,
67    "<": operator.lt,
68    ">=": operator.ge,
69    "<=": operator.le,
70    "==": operator.eq,
71    "!=": operator.ne,
72}
73# Ops that allow tensor <op> None
74supported_const_comparison_ops = {
75    "is": operator.is_,
76    "is not": operator.is_not,
77    "==": operator.eq,
78    "!=": operator.ne,
79}
80supported_comparison_ops = {
81    **supported_tensor_comparison_ops,
82    **supported_const_comparison_ops,
83}
84supported_tensor_comparison_op_values = dict.fromkeys(
85    supported_tensor_comparison_ops.values()
86)
87supported_const_comparison_op_values = dict.fromkeys(
88    supported_const_comparison_ops.values()
89)
90
91
92class TensorVariable(VariableTracker):
93    """A torch.Tensor input or an intermediate value in the FX graph"""
94
95    _nonvar_fields = {
96        "proxy",
97        "dtype",
98        "device",
99        "layout",
100        "ndim",
101        "size",
102        "stride",
103        "requires_grad",
104        "is_quantized",
105        "is_contiguous",
106        "is_sparse",
107        "class_type",
108        "specialized_value",
109        "_is_name_set",
110        *VariableTracker._nonvar_fields,
111    }
112
113    def get_real_value(self):
114        """
115        Get the actual value represented by this variable if computation is run
116        using the user-provided inputs.
117        NOTE: this runs actual tensor computation and may be
118        slow and memory-intensive.
119        """
120        return get_real_value(self.proxy.node, self.proxy.tracer)
121
122    def __init__(
123        self,
124        proxy: torch.fx.Proxy,
125        *,
126        dtype,
127        device,
128        layout,
129        ndim,
130        requires_grad,
131        is_quantized,
132        is_sparse,
133        class_type,
134        has_grad_fn,
135        size=None,
136        stride=None,
137        is_contiguous=None,
138        _is_name_set=None,
139        **kwargs,
140    ) -> None:
141        super().__init__(**kwargs)
142        self.proxy = proxy
143        self.dtype = dtype
144        self.device = device
145        self.layout = layout
146        self.ndim = ndim
147        self.size = size
148        self.stride = stride
149        self.requires_grad = requires_grad
150        self.is_quantized = is_quantized
151        self.is_contiguous = is_contiguous
152        self.is_sparse = is_sparse
153        self.class_type = class_type
154        self.has_grad_fn = has_grad_fn
155        if _is_name_set is None:
156            # no need to rename inputs
157            _is_name_set = self.proxy.node.op == "placeholder"
158        self._is_name_set: bool = _is_name_set
159
160    def debug_repr(self):
161        # TODO: strip off fake tensor from repr here
162        return repr(self.proxy.node.meta["example_value"])
163
164    def as_proxy(self):
165        return self.proxy
166
167    def python_type(self):
168        return self.class_type
169
170    @staticmethod
171    def specialize(value: torch.Tensor):
172        props = {
173            "dtype": value.dtype,
174            "device": value.device,
175            "layout": value.layout,
176            "ndim": int(value.ndim),
177            "requires_grad": value.requires_grad,
178            "is_quantized": value.is_quantized,
179            "is_sparse": value.is_sparse,
180            "class_type": type(value),
181        }
182        try:
183            props["has_grad_fn"] = value.grad_fn is not None
184        except Exception:
185            # Workaround for issues with create_parameter_op in Dynamo. Reading
186            # grad_fn should never cause an issue.
187            props["has_grad_fn"] = False
188
189        if is_sparse_any(value) and not has_free_symbols(value):
190            props["size"] = tuple(
191                [int(s) if is_symbolic(s) else s for s in value.size()]
192            )
193        elif not has_free_symbols(value):
194            # this is a fully static shape, and the keys on props here inform specialization.
195            # We have to cast to int here, because these might get accessed as ConstantVariable, which has
196            # a strict no-symint policy. If we got here due to not having free symbols, this is a known constant
197            # already. We could remove the discrepancy here, by having ConstantVariable be more permissive for
198            # constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and
199            # I'd like to keep it around for now.
200            props["size"] = tuple(
201                # the non is_symbolic case applies to the jagged layout
202                # NestedTensor case as singleton ints are not symbolic
203                [int(s) if is_symbolic(s) else s for s in value.size()]
204            )
205            props["stride"] = tuple(value.stride())
206            if torch._C._functorch.is_batchedtensor(value):
207                # Batched tensors does not support contiguity patterns, so
208                # we refrain from computing the `is_contiguous` property
209                props["is_contiguous"] = None
210            else:
211                props["is_contiguous"] = tuple(
212                    [
213                        x
214                        for x in torch._prims_common._memory_formats
215                        if value.is_contiguous(memory_format=x)
216                    ]
217                )
218        return props
219
220    def dynamic_getattr(self, tx: "InstructionTranslator", name):
221        fake_val = self.proxy.node.meta["example_value"]
222        # For getattrs on tensors without sources,
223        # we can do better than the default (creating a GetAttrVariable)
224        # if:
225        # (1) the tensor is a traceable tensor subclass
226        # (2) We are getattr'ing an inner tensor from that subclass
227        if not self.source and is_traceable_wrapper_subclass(fake_val):
228            fake_val = self.proxy.node.meta["example_value"]
229            attrs, ctx = fake_val.__tensor_flatten__()
230            proxy = getattr(self.as_proxy(), name)
231            example_value = getattr(fake_val, name)
232            if name in attrs:
233                # attrs returned from tensor_flatten are always tensors
234                assert isinstance(example_value, torch.Tensor)
235                from .builder import wrap_fx_proxy
236
237                return wrap_fx_proxy(tx=tx, proxy=proxy, example_value=example_value)
238            # any other attributes on the subclass (that are not methods)
239            # are assumed to be constant metadata.
240            elif not callable(example_value):
241                from .builder import SourcelessBuilder
242
243                return SourcelessBuilder.create(tx, example_value)
244
245        if not (self.source and self.source.subguards_allowed()):
246            raise NotImplementedError
247
248        # For local source, we associate the real value. We use this real value
249        # for implementing getattr fallthrough on the variable tracker base class.
250
251        # Note - this scope construction is mirrored in guards
252        # A subsequent PR will introduce a util.
253        scope = {"L": tx.output.local_scope, "G": tx.output.global_scope}
254        try:
255            # We raise in case we get a typerror bug w/ SuperSource.
256            # SuperSource has bugs in it atm, and can produce code like
257            # eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__,
258            # L['mod'].model.model.encoder.embed_positions)", scope)
259            # Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope.
260            _input_associated_real_value = eval(self.source.name(), scope)
261        except Exception as exc:
262            raise NotImplementedError from exc
263
264        if _input_associated_real_value is None:
265            raise NotImplementedError
266
267        if object_has_getattribute(_input_associated_real_value):
268            raise NotImplementedError
269
270        if get_custom_getattr(_input_associated_real_value):
271            raise NotImplementedError
272
273        real_value = getattr(_input_associated_real_value, name)
274        if callable(real_value):
275            # Callables have more nuanced handling, and we should let the existing system delegate here.
276            # Raising was past behavior and so should always be sound to fall back.
277            # Note - at a certain point we may want to handle
278            raise NotImplementedError
279
280        from ..guards import GuardBuilder
281        from .builder import VariableBuilder
282
283        attr_source = AttrSource(self.source, name)
284        install_guard(attr_source.make_guard(GuardBuilder.HASATTR))
285        return VariableBuilder(tx, attr_source)(real_value)
286
287    def method_attr_ndim(self, tx):
288        if self.ndim is not None:
289            return ConstantVariable.create(self.ndim)
290        else:
291            return self.call_method(tx, "dim", [], {})
292
293    def method_attr_dtype(self, tx):
294        if self.dtype is not None:
295            return ConstantVariable.create(self.dtype)
296
297    def method_attr_device(self, tx):
298        if self.device is not None:
299            return ConstantVariable.create(self.device)
300
301    def method_attr_layout(self, tx):
302        if self.layout is not None:
303            return ConstantVariable.create(self.layout)
304
305    def method_attr_is_cuda(self, tx):
306        if self.device is not None:
307            return ConstantVariable.create(self.device.type == "cuda")
308
309    def method_attr_shape(self, tx):
310        if self.size is not None:
311            sizes = [variables.ConstantVariable.create(x) for x in self.size]
312            return SizeVariable(sizes)
313        else:
314            return self.call_method(tx, "size", [], {})
315
316    def method_attr_requires_grad(self, tx):
317        if self.requires_grad is not None:
318            return ConstantVariable.create(self.requires_grad)
319
320    def method_attr_is_quantized(self, tx):
321        if self.is_quantized is not None:
322            return ConstantVariable.create(self.is_quantized)
323
324    def method_attr_is_sparse(self, tx):
325        if self.is_sparse is not None:
326            return ConstantVariable.create(self.is_sparse)
327
328    def method_attr_data(self, tx):
329        return variables.TorchInGraphFunctionVariable(
330            torch._C._autograd._get_data_attr
331        ).call_function(tx, [self], {})
332
333    def method_attr_grad_fn(self, tx):
334        if self.has_grad_fn:
335            unimplemented("TensorVariable has a grad_fn")
336        else:
337            return variables.ConstantVariable(None)
338
339    def method_attr__version(self, tx):
340        from ..tensor_version_op import _tensor_version
341
342        return variables.TorchInGraphFunctionVariable(_tensor_version).call_function(
343            tx, [self], {}
344        )
345
346    def call_hasattr(self, tx: "InstructionTranslator", name):
347        from . import GetAttrVariable
348        from .builtin import BuiltinVariable
349
350        try:
351            var = BuiltinVariable(getattr).call_function(
352                tx, [self, ConstantVariable(name)], {}
353            )
354            # in the event that TensorVariable returns NotImplemented
355            # BuiltinVariable.call_getattr returns GetAttrVariable
356            ret_val = not isinstance(var, GetAttrVariable)
357        except AttributeError:
358            ret_val = False
359
360        if self.source:
361            install_guard(
362                AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
363            )
364
365        return ConstantVariable(ret_val)
366
367    def var_getattr(self, tx: "InstructionTranslator", name):
368        from . import UserDefinedClassVariable
369
370        if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops():
371            unimplemented(f"Illegal getattr invocation {name} in strict mode")
372
373        if name == "__class__":
374            return UserDefinedClassVariable(self.python_type())
375
376        handler = getattr(self, f"method_attr_{name}", None)
377        result = handler(tx) if handler is not None else None
378
379        # Add a guard for type matching, these guards are checked before tensor guards
380        # In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
381        # <tensor> is later changed to another type
382        if (
383            result is not None
384            and self.source
385            and self.source.subguards_allowed()
386            and not (
387                name not in ("grad", "requires_grad") and result.is_python_constant()
388            )
389        ):
390            install_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
391            result.source = AttrSource(self.source, name)
392
393        # It's hard to get inplace view (metadata mutation) on graph input work properly across
394        # dynamo/aot/inductor, just fall back.
395        if self.source is not None and hasattr(torch.ops.aten, name):
396            fn = getattr(torch.ops.aten, name)
397            if (
398                hasattr(fn, "overloads")
399                and hasattr(fn, fn.overloads()[0])
400                and torch.Tag.inplace_view in getattr(fn, fn.overloads()[0]).tags
401            ):
402                # Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc.
403                return variables.misc.DelayGraphBreakVariable(
404                    source=AttrSource(self.source, name)
405                )
406
407        # For attributes (not methods) that were not caught in the special handling above,
408        # (e.g. tensor.real), we handle these generically, assuming that the output type is
409        # a tensor.
410        if result is None and name != "grad":
411
412            def try_generic_attr_handling():
413                from .builder import wrap_fx_proxy
414                from .misc import GetAttrVariable
415
416                try:
417                    static_attr = inspect.getattr_static(torch.Tensor, name)
418                except AttributeError:
419                    return None
420
421                # Make sure this is an attribute, not a method.
422                # type(torch.Tensor.H) should be "getset_descriptor"
423                # This is a because of CPython implementation, see THPVariableType:
424                # these attributes are implemented under tp_getset, which appear
425                # as `getset_descriptor`s, (compared to, say, methods which appear
426                # as `method_descriptor`s)
427                if type(static_attr) != types.GetSetDescriptorType:
428                    return None
429
430                proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name)
431                if self.source is not None:
432                    return wrap_fx_proxy(
433                        tx=tx, proxy=proxy, source=AttrSource(self.source, name)
434                    )
435                else:
436                    return wrap_fx_proxy(tx=tx, proxy=proxy)
437
438            result = try_generic_attr_handling()
439
440        if result is None:
441            result = self.dynamic_getattr(tx, name)
442
443        if result is None:
444            raise NotImplementedError
445        return result
446
447    def call_id(self, tx):
448        if not self.source:
449            unimplemented("call_id not supported for sourceless TensorVariable")
450
451        # For local source, we associate the real value. We use this real value
452        scope = {"L": tx.output.local_scope, "G": tx.output.global_scope}
453        try:
454            _input_associated_real_value = eval(self.source.name(), scope)
455        except Exception as exc:
456            unimplemented(f"error getting associated real value: {exc}")
457
458        if _input_associated_real_value is None:
459            unimplemented("call_id without associated real value")
460
461        install_guard(self.source.make_guard(GuardBuilder.ID_MATCH))
462        id_value = id(_input_associated_real_value)
463        return ConstantVariable.create(id_value)
464
465    def has_unpack_var_sequence(self, tx):
466        return self.ndim > 0
467
468    def unpack_var_sequence(self, tx: "InstructionTranslator", idxes=None):
469        from .builder import wrap_fx_proxy_cls
470
471        if self.size:
472            size_len = len(self.size)
473        else:
474            size_var = self.call_method(tx, "size", [], {})
475            assert isinstance(size_var, SizeVariable)
476            size_len = len(size_var.items)
477        # Ensure we don't unpack a scalar tensor.
478        assert size_len != 0, "Can't unpack scalar tensors."
479
480        if self.size:
481            length = self.size[0]
482        else:
483            dyn_length = self.call_method(tx, "size", [ConstantVariable.create(0)], {})
484            # SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values produced through
485            # symbolic_shapes, but that end up as int/sympy.Integer
486            assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable))
487            if isinstance(dyn_length, SymNodeVariable):
488                length = dyn_length.evaluate_expr(tx.output)
489            else:
490                length = dyn_length.value
491
492        if idxes is None:
493            idxes = range(length)
494        else:
495            assert (
496                len(idxes) == length
497            ), f"Can't unpack a tensor of {length} rows into a tuple of {len(idxes)} elements."
498        return [
499            wrap_fx_proxy_cls(target_cls=type(self), tx=tx, proxy=self.as_proxy()[i])
500            for i in idxes
501        ]
502
503    def _strict_mode_banned_ops(self):
504        return torch._dynamo.config._autograd_backward_strict_mode_banned_ops
505
506    def call_method(
507        self,
508        tx,
509        name,
510        args: "List[VariableTracker]",
511        kwargs: "Dict[str, VariableTracker]",
512    ) -> "VariableTracker":
513        if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops():
514            unimplemented(f"Illegal method invocation {name} in strict mode")
515
516        """
517        Dispatch to a method-specific handler defined below.  If the
518        handler returns None (or doesn't exist) we put the method call
519        in the graph.
520        """
521        try:
522            handler_method = getattr(self, f"method_{name}")
523        except AttributeError:
524            pass
525        else:
526            try:
527                result = handler_method(*args, **kwargs)
528                if result:
529                    return result
530            except TypeError as e:
531                unimplemented(f"unhandled args for {name}: {e}")
532
533        from .builder import wrap_fx_proxy
534
535        return wrap_fx_proxy(
536            tx,
537            tx.output.create_proxy(
538                "call_method",
539                name,
540                *proxy_args_kwargs([self, *args], kwargs),
541            ),
542        )
543
544    def method_size(self, *args, **kwargs):
545        return self._method_size_stride("size", *args, **kwargs)
546
547    def method_stride(self, *args, **kwargs):
548        return self._method_size_stride("stride", *args, **kwargs)
549
550    def _method_size_stride(self, name, dim=None):
551        dim = guard_if_dyn(dim)
552
553        def make_const_size_variable(x, **options):
554            return SizeVariable(
555                [ConstantVariable.create(y, **options) for y in x], **options
556            )
557
558        RetVariable = (
559            make_const_size_variable if name == "size" else ConstantVariable.create
560        )
561
562        # Technically, this should not be necessary, but I'm including it
563        # for enhanced BC, in case example_value is sometimes not set
564        # (it really should always be set though!)
565        if (r := getattr(self, name)) is not None:
566            if dim is None:
567                return RetVariable(r)
568            else:
569                return ConstantVariable.create(r[dim])
570
571        # It might still be constant!  Consult the fake tensor and see
572        if (fake := self.proxy.node.meta.get("example_value")) is not None:
573            if dim is None:
574                fake_r = getattr(fake, name)()
575                if not has_free_symbols(fake_r):
576                    # int conversion for safety, in case a SymInt refined
577                    # to constant
578                    return RetVariable(tuple(int(r) for r in fake_r))
579            else:
580                fake_r = getattr(fake, name)(dim)
581                if not has_free_symbols(fake_r):
582                    return ConstantVariable.create(int(fake_r))
583
584    def method_numel(self):
585        if self.size is not None:
586            return ConstantVariable.create(product(self.size))
587
588        # It might still be constant!  Consult the fake tensor and see
589        if (fake := self.proxy.node.meta.get("example_value")) is not None:
590            fake_r = fake.numel()
591            if not has_free_symbols(fake_r):
592                return ConstantVariable.create(int(fake_r))
593
594    method_nelement = method_numel
595
596    def method_dim(self):
597        if self.ndim is not None:
598            return ConstantVariable.create(self.ndim)
599
600    method_ndimension = method_dim
601
602    def method_is_floating_point(self):
603        if self.dtype is not None:
604            return ConstantVariable.create(self.dtype.is_floating_point)
605
606    def method_is_complex(self):
607        if self.dtype is not None:
608            return ConstantVariable.create(self.dtype.is_complex)
609
610    def method_is_contiguous(self, memory_format=None):
611        memory_format = (
612            memory_format.as_python_constant()
613            if memory_format is not None
614            else torch.contiguous_format
615        )
616        if self.is_contiguous is not None:
617            return ConstantVariable.create(memory_format in self.is_contiguous)
618        elif (fake := self.proxy.node.meta.get("example_value")) is not None:
619            return ConstantVariable.create(
620                fake.is_contiguous(memory_format=memory_format)
621            )
622
623    def method_type(self, dtype=None, non_blocking=False, **kwargs):
624        if (
625            dtype is None
626            and self.dtype is not None
627            and isinstance(self.device, torch.device)
628        ):
629            tensortype = next(
630                k for k, v in tensortype_to_dtype.items() if self.dtype in v
631            )
632            if self.device.type == "cuda":
633                return ConstantVariable.create(f"torch.cuda.{tensortype.__name__}")
634            else:
635                return ConstantVariable.create(f"torch.{tensortype.__name__}")
636        elif (
637            dtype is not None
638            and fqn(type(dtype.as_python_constant())) == "torch.tensortype"
639        ):
640            # torch.FloatTensor, etc. are all of type "torch.tensortype".
641            # torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type.
642            # So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args)
643            tensor_type = dtype.as_python_constant()
644            tensor_type_const = ConstantVariable.create(fqn(tensor_type))
645
646            from ..symbolic_convert import InstructionTranslator
647            from .builder import wrap_fx_proxy
648
649            tx = InstructionTranslator.current_tx()
650
651            if non_blocking:
652                kwargs = {"non_blocking": non_blocking, **kwargs}
653
654            return wrap_fx_proxy(
655                tx,
656                tx.output.create_proxy(
657                    "call_method",
658                    "type",
659                    *proxy_args_kwargs([self, tensor_type_const], kwargs),
660                ),
661            )
662
663    def method_as_subclass(self, cls):
664        if isinstance(cls, TensorSubclassVariable) and cls.source:
665            from ..symbolic_convert import InstructionTranslator
666            from .builder import VariableBuilder
667            from .torch_function import TensorWithTFOverrideVariable
668
669            tx = InstructionTranslator.current_tx()
670
671            # [Note: __torch_function__] coerce this tensor variable into a TensorWithTFOverrideVariable
672            # in eager, this is just a type change. This isn't sound if a __torch_function__ tensor subclass
673            # defines a constructor, but if only a __torch_function__ impl is defined, this is okay to call.
674            # It is up to the user whether this is correct behavior or not.
675            py_cls = cls.as_python_constant()
676            torch_fn = VariableBuilder(
677                tx,
678                AttrSource(AttrSource(cls.source, "__torch_function__"), "__func__"),
679            )(py_cls.__torch_function__.__func__)
680
681            return TensorWithTFOverrideVariable.from_tensor_var(
682                tx, self, py_cls, torch_fn
683            )
684
685    def method_get_device(self):
686        if isinstance(self.device, torch.device):
687            index = self.device.index if self.device.type != "cpu" else -1
688            return ConstantVariable.create(index)
689
690    def method_element_size(self):
691        return ConstantVariable.create(self.dtype.itemsize)
692
693    def method_numpy(self, *, force=False):
694        if not config.trace_numpy:
695            unimplemented("Tensor.numpy(). config.trace_numpy is False")
696        if not np:
697            unimplemented("Tensor.numpy(). NumPy is not available")
698        if self.layout != torch.strided:
699            raise TypeError(
700                f"can't convert {self.layout} layout tensor to numpy. Use Tensor.dense() first"
701            )
702        from ..symbolic_convert import InstructionTranslator
703
704        tx = InstructionTranslator.current_tx()
705
706        # We don't check that the tensor is on CPU when force is False, as this
707        # allows us to execute NumPy code on CUDA. Same for requires_grad=True
708        if force and force.as_python_constant():
709            # If the user set force=True we try to preserve the semantics (no gradients, move to CPU...)
710            t = self.call_method(tx, "detach", [], {})
711            proxy = tx.output.create_proxy("call_method", "cpu", (t.as_proxy(),), {})
712        else:
713            # Hacky way to create a view of self that will be marked as NumpyNdarrayVariable
714            proxy = tx.output.create_proxy(
715                "call_method", "view_as", *proxy_args_kwargs([self, self], {})
716            )
717        return NumpyNdarrayVariable.create(tx, proxy)
718
719    def method_tolist(self):
720        from ..symbolic_convert import InstructionTranslator
721        from .builder import SourcelessBuilder
722
723        tx = InstructionTranslator.current_tx()
724
725        def tolist(tensor, sub_proxy):
726            def wrap(i, sub_proxy):
727                # Sigh, we forgot to gate this, so this data dependent is on
728                # by default and is load bearing in CI
729                with unittest.mock.patch.object(
730                    tx.fake_mode, "allow_scalar_outputs", True
731                ):
732                    return SymNodeVariable.create(
733                        tx,
734                        sub_proxy.item(),
735                    )
736
737            if tensor.dtype not in [
738                torch.int8,
739                torch.int16,
740                torch.int32,
741                torch.int64,
742            ]:
743                unimplemented("Input tensor for tolist must be an integer tensor")
744
745            if tensor.dim() == 0:
746                return wrap(tensor, sub_proxy)
747
748            if tensor.dim() == 1:
749                return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)]
750
751            return [
752                tolist(sub_tensor, sub_proxy=sub_proxy[i])
753                for i, sub_tensor in enumerate(tensor)
754            ]
755
756        tensor = self.as_proxy().node.meta["example_value"]
757        out = tolist(tensor, self.as_proxy())
758        return SourcelessBuilder.create(tx, out)
759
760    def method_backward(self, *args, **kwargs):
761        unimplemented("Tensor.backward")
762
763    def method_data_ptr(self, *args, **kwargs):
764        unimplemented("Tensor.data_ptr")
765
766    def method_item(self, *args, **kwargs):
767        if not config.capture_scalar_outputs:
768            self._warn_capture_scalar_outputs()
769            unimplemented("Tensor.item")
770
771    @staticmethod
772    @functools.lru_cache(None)
773    def _warn_capture_scalar_outputs():
774        user_stack = torch._guards.TracingContext.extract_stack()
775        user_stack_formatted = "".join(traceback.format_list(user_stack))
776        log.warning(
777            textwrap.dedent(
778                """\
779                    Graph break from `Tensor.item()`, consider setting:
780                        torch._dynamo.config.capture_scalar_outputs = True
781                    or:
782                        env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
783                    to include these operations in the captured graph.
784
785                    Graph break: from user code at:
786                    %s
787                """
788            ),
789            user_stack_formatted,
790        )
791
792    def method___len__(self):
793        from ..symbolic_convert import InstructionTranslator
794
795        tx = InstructionTranslator.current_tx()
796        return self.call_method(tx, "size", [ConstantVariable.create(0)], {})
797
798    def method_addcmul_(self, tensor1, tensor2, *, value=None):
799        from ..symbolic_convert import InstructionTranslator
800
801        tx = InstructionTranslator.current_tx()
802        if value is not None:
803            from .. import polyfills
804            from .builder import SourcelessBuilder
805
806            return tx.inline_user_function_return(
807                SourcelessBuilder.create(tx, polyfills.addcmul_inplace),
808                [self, tensor1, tensor2, value],
809                {},
810            )
811
812    def method___setitem__(self, key, value):
813        def has_bool_key(v):
814            if isinstance(v, TensorVariable):
815                return v.dtype in (torch.bool, torch.int8)
816            elif isinstance(v, variables.TupleVariable):
817                return any(has_bool_key(item) for item in v.items)
818            else:
819                return False
820
821        if (
822            has_bool_key(key)
823            and isinstance(value, TensorVariable)
824            and value.requires_grad
825            and torch.is_grad_enabled()
826        ):
827            unimplemented(
828                "boolean masking setitem backwards, see https://github.com/pytorch/pytorch/issues/114123"
829            )
830        from ..symbolic_convert import InstructionTranslator
831
832        tx = InstructionTranslator.current_tx()
833        tx.output.create_proxy(
834            "call_function",
835            operator.setitem,
836            *proxy_args_kwargs([self, key, value], {}),
837        )
838        return ConstantVariable.create(None)
839
840    def method_resize_(self, *args, **kwargs):
841        unimplemented("Tensor.resize_")
842
843    def method_resize_as_(self, *args, **kwargs):
844        unimplemented("Tensor.resize_as_")
845
846    def method_sparse_resize_(self, *args, **kwargs):
847        unimplemented("Tensor.sparse_resize_")
848
849    def method_sparse_resize_and_clear_(self, *args, **kwargs):
850        unimplemented("Tensor.sparse_resize_and_clear_")
851
852    def method_set_(self, *args, **kwargs):
853        if len(args) > 1:
854            # torch.Tensor.set_() has several overloads.
855            # aten::set_.source_Tensor(Tensor) gets special handling
856            # in AOTAutograd and functionalization, because it is the most common
857            # overload and is used by FSDP.
858            # graph-breaking on aten::set_source_Tensor_storage_offset for now,
859            # unless we find that we need to make it work.
860            unimplemented("Tensor.set_.source_Tensor_storage_offset")
861
862    def method_add_(self, other, *, alpha=None):
863        if alpha is not None:
864            from ..symbolic_convert import InstructionTranslator
865
866            tx = InstructionTranslator.current_tx()
867            result = variables.TorchInGraphFunctionVariable(torch.mul).call_function(
868                tx, [other, alpha], {}
869            )
870            return self.call_method(tx, "add_", [result], {})
871
872    def method_addcdiv_(self, tensor1, tensor2, *, value=None):
873        from ..symbolic_convert import InstructionTranslator
874
875        tx = InstructionTranslator.current_tx()
876        if value is not None:
877            result = variables.TorchInGraphFunctionVariable(torch.div).call_function(
878                tx, [tensor1, tensor2], {}
879            )
880            result = variables.TorchInGraphFunctionVariable(torch.mul).call_function(
881                tx, [result, value], {}
882            )
883            return self.call_method(tx, "add_", [result], {})
884
885    def method___contains__(self, arg):
886        from ..symbolic_convert import InstructionTranslator
887
888        tx = InstructionTranslator.current_tx()
889
890        # Rewrite __contains__ here so that downstream passes can trace through
891        # without dealing with unbacked symbool. Roughly the code we translate is:
892        # def __contains__(self, x):
893        #     return (x == self).any().item()
894        result = variables.TorchInGraphFunctionVariable(torch.eq).call_function(
895            tx, [self, arg], {}
896        )
897        result = variables.TorchInGraphFunctionVariable(torch.any).call_function(
898            tx, [result], {}
899        )
900        return result.call_method(tx, "item", [], {})
901
902    def method_redistribute(self, *args, **kwargs):
903        from ..symbolic_convert import InstructionTranslator
904
905        tx = InstructionTranslator.current_tx()
906        # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
907        # and rewrite args to have only proxyable args, then insert call_function
908        args_as_value = [x.as_python_constant() for x in args]
909        kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()}
910
911        def redistribute_fn_with_prim_types(x):
912            return x.redistribute(*args_as_value, **kwargs_as_value)
913
914        # attach the same function name for better debugging
915        redistribute_fn_with_prim_types.__name__ = "prim_redistribute"
916
917        from .builder import wrap_fx_proxy
918
919        return wrap_fx_proxy(
920            tx=tx,
921            proxy=tx.output.create_proxy(
922                "call_function",
923                redistribute_fn_with_prim_types,
924                *proxy_args_kwargs([self], {}),
925            ),
926        )
927
928    def method_to_local(self, *args, **kwargs):
929        from ..symbolic_convert import InstructionTranslator
930
931        tx = InstructionTranslator.current_tx()
932        # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
933        # and rewrite args to have only proxyable args, then insert call_function
934        args_as_value = [x.as_python_constant() for x in args]
935        kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()}
936
937        def to_local_fn_with_prim_types(x):
938            return x.to_local(*args_as_value, **kwargs_as_value)
939
940        # attach the same function name for better debugging
941        to_local_fn_with_prim_types.__name__ = "prim_to_local"
942
943        from .builder import wrap_fx_proxy
944
945        return wrap_fx_proxy(
946            tx=tx,
947            proxy=tx.output.create_proxy(
948                "call_function",
949                to_local_fn_with_prim_types,
950                *proxy_args_kwargs([self], {}),
951            ),
952        )
953
954    def method_register_hook(self, *args, **kwargs):
955        return self._method_register_hook("register_hook", *args, **kwargs)
956
957    def method_register_post_accumulate_grad_hook(self, *args, **kwargs):
958        return self._method_register_hook(
959            "register_post_accumulate_grad_hook", *args, **kwargs
960        )
961
962    def _method_register_hook(self, name: str, hook: VariableTracker):
963        # Note - do not arbitrarily add hooks here - make sure they match the same contract
964        # see [On tensor.register_hook]
965        from ..symbolic_convert import InstructionTranslator
966
967        tx = InstructionTranslator.current_tx()
968
969        if not self.source:
970            if not compiled_autograd.compiled_autograd_enabled:
971                # TODO(voz):
972                # We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary
973                # python state.
974                # We *Must* be in compiled_autograd here because backward hooks can contain anything, and it is unsafe to run
975                # them in a compiled bwd without re-entering dynamo as compiled_autograd does.
976                #
977                # Discussion point 1 - Should we bypass this if nopython/fullgraph = True?
978                #   No. Because this was going to be a graph break anyway - this check does not
979                # introduce new graph breaks where there were none.
980                #
981                # Discussion point 2 - Should we defer this check to backwards?
982                #   No. Because compiled autograd is not yet ready for prime time. As such, if we defer, a user
983                # would have no recourse - their forward traces just fine, but will fail at backwards unless
984                # compiled_autograd is enabled. If compiled_autograd fails (there are a lot of failures today)
985                # then they have nothing they can do except disable compile.
986                unimplemented(
987                    "Compilation of intermediate hooks requires compiled autograd"
988                )
989
990            hook_name, bw_state_proxy = tx.output.add_backward_state_hook(hook)
991
992            def _register_hook_trampoline(tensor, bw_state):
993                register_hook = getattr(tensor, name)
994                register_hook(
995                    functools.partial(
996                        trace_wrapped,
997                        fn=call_hook_from_backward_state,
998                        bw_state=bw_state,
999                        hook_name=hook_name,
1000                    )
1001                )
1002                # TODO(jansel): returning None here is wrong, it should be
1003                # RemovableHandle, but we need some extra work to support
1004                # this properly.
1005                return None
1006
1007            from .builder import wrap_fx_proxy
1008
1009            return wrap_fx_proxy(
1010                tx,
1011                tx.output.create_proxy(
1012                    "call_function",
1013                    _register_hook_trampoline,
1014                    (self.as_proxy(), bw_state_proxy),
1015                    {},
1016                ),
1017            )
1018
1019        handle_variable = variables.RemovableHandleVariable(
1020            mutable_local=variables.base.MutableLocal(),
1021        )
1022        tx.output.side_effects.register_hook(self, hook, handle_variable, name)
1023        return handle_variable
1024
1025    def method_requires_grad_(self, requires_grad=True):
1026        if requires_grad is not True:
1027            requires_grad = requires_grad.as_python_constant()
1028
1029        if self.as_proxy().node.meta["example_value"].requires_grad != requires_grad:
1030            unimplemented("Tensor.requires_grad_")
1031        else:
1032            return self
1033
1034    def method_new(self, *args, **kwargs):
1035        # Convert x.new(torch.Size) into x.new_empty(torch.Size),
1036        # as Tensor.new acts differently with a Size input versus a tuple input.
1037        if (len(args) == 1 and isinstance(args[0], SizeVariable)) or (
1038            len(args) >= 1
1039            and all(
1040                isinstance(a, ConstantVariable) and a.python_type() == int for a in args
1041            )
1042        ):
1043            from ..symbolic_convert import InstructionTranslator
1044
1045            return self.call_method(
1046                InstructionTranslator.current_tx(), "new_empty", args, kwargs
1047            )
1048
1049    def method_untyped_storage(self):
1050        return UntypedStorageVariable(
1051            self, self.as_proxy().node.meta["example_value"].untyped_storage()
1052        )
1053
1054    def set_name_hint(self, name: str):
1055        if not self._is_name_set:
1056            self.proxy.node._rename(name)
1057            self._is_name_set = True
1058
1059
1060class SymNodeVariable(VariableTracker):
1061    """
1062    Represents a symbolic scalar, either int, float or bool.  This is most commonly used to
1063    handle symbolic size computation, e.g., tensor.size(0), but it is also used to
1064    handle logic like float_tensor.item() or unspecialized float inputs.
1065    """
1066
1067    _nonvar_fields = {
1068        "proxy",
1069        "sym_num",
1070        *VariableTracker._nonvar_fields,
1071    }
1072
1073    def debug_repr(self):
1074        return repr(self.sym_num)
1075
1076    @classmethod
1077    def create(cls, tx, proxy, sym_num=None, **options):
1078        if sym_num is None:
1079            sym_num = get_fake_value(proxy.node, tx)
1080        if "example_value" in proxy.node.meta:
1081            assert proxy.node.meta["example_value"] == sym_num
1082        set_example_value(proxy.node, sym_num)
1083
1084        if isinstance(sym_num, (sympy.Integer, int, bool)):
1085            sym_num = int(sym_num) if isinstance(sym_num, sympy.Integer) else sym_num
1086            return ConstantVariable.create(sym_num)
1087
1088        return SymNodeVariable(proxy, sym_num, **options)
1089
1090    def __init__(self, proxy, sym_num, **kwargs) -> None:
1091        super().__init__(**kwargs)
1092        self.proxy = proxy
1093        # TODO: Should we allow non SymTypes here?  Today it is allowed
1094        self.sym_num = sym_num
1095        self._tensor_var = None
1096
1097    def python_type(self):
1098        if isinstance(self.sym_num, SymTypes):
1099            return self.sym_num.node.pytype
1100        else:
1101            return type(self.sym_num)
1102
1103    def as_proxy(self):
1104        return self.proxy
1105
1106    def as_tensor(self, tx):
1107        if self._tensor_var is None:
1108            from .builder import SourcelessBuilder
1109
1110            self._tensor_var = SourcelessBuilder.create(
1111                tx, torch.scalar_tensor
1112            ).call_function(tx, [self], {})
1113        return self._tensor_var
1114
1115    def evaluate_expr(self, output_graph=None):
1116        try:
1117            return guard_scalar(self.sym_num)
1118        except GuardOnDataDependentSymNode as e:
1119            raise UserError(  # noqa: B904
1120                UserErrorType.ANTI_PATTERN,
1121                f"Consider annotating your code using torch._check*(). {str(e)}",
1122                case_name="constrain_as_size_example",
1123            )
1124
1125    def call_method(
1126        self,
1127        tx,
1128        name,
1129        args: "List[VariableTracker]",
1130        kwargs: "Dict[str, VariableTracker]",
1131    ) -> "VariableTracker":
1132        from .builder import wrap_fx_proxy
1133
1134        return wrap_fx_proxy(
1135            tx,
1136            tx.output.create_proxy(
1137                "call_method",
1138                name,
1139                *proxy_args_kwargs([self, *args], kwargs),
1140            ),
1141        )
1142
1143
1144class NumpyNdarrayVariable(TensorVariable):
1145    """
1146    Represents a np.ndarray, but backed by torch Tensor via torch._numpy.ndarray.
1147    Use this for Tensor.numpy() call.
1148    """
1149
1150    @staticmethod
1151    def create(tx: "InstructionTranslator", proxy, **options):
1152        from .builder import wrap_fx_proxy_cls
1153
1154        return wrap_fx_proxy_cls(
1155            target_cls=NumpyNdarrayVariable,
1156            tx=tx,
1157            proxy=proxy,
1158            **options,
1159        )
1160
1161    def var_getattr(self, tx: "InstructionTranslator", name):
1162        # NB: This INTENTIONALLY does not call super(), because there is
1163        # no intrinsic reason ndarray properties are related to Tensor
1164        # properties.  The inheritance here is for implementation sharing.
1165
1166        from ..utils import numpy_attr_wrapper
1167        from .builder import wrap_fx_proxy
1168
1169        result = None
1170
1171        example_value = self.as_proxy().node.meta["example_value"]
1172        example_ndarray = tnp.ndarray(example_value)
1173
1174        def insert_into_graph():
1175            return wrap_fx_proxy(
1176                tx,
1177                tx.output.create_proxy(
1178                    "call_function", numpy_attr_wrapper, (self.as_proxy(), name), {}
1179                ),
1180            )
1181
1182        if name in ["T", "real", "imag"]:
1183            proxy = tx.output.create_proxy(
1184                "call_function",
1185                numpy_attr_wrapper,
1186                (self.as_proxy(), name),
1187                {},
1188            )
1189            result = NumpyNdarrayVariable.create(tx, proxy)
1190
1191        # These are awkward to implement.  The standard playbook for torch._numpy
1192        # interop is to trace a call into the torch._numpy wrapper which works for
1193        # Tensor operations.  However, we don't want to do this for calls
1194        # that don't return Tensors, because in those cases we may not want
1195        # to trace the attribute access into the graph at all (it is sort
1196        # of harmless to do so, because AOTAutograd will eliminate them,
1197        # but it's best not to trace them in to begin with.)  But in any
1198        # case, tracing these into the graph is like trying to fit a square
1199        # peg into a round hole; best not to do it.  So instead we
1200        # painstakingly implement these by hand
1201        #
1202        # NB: only ALWAYS specialized attributes can go here; notably,
1203        # size/shape not allowed!
1204        elif name in ("ndim", "itemsize"):
1205            return ConstantVariable.create(getattr(example_ndarray, name))
1206        elif name in ("shape", "stride"):
1207            if not has_free_symbols(r := getattr(example_ndarray, name)):
1208                return ConstantVariable.create(tuple(int(r) for r in r))
1209            return insert_into_graph()
1210        elif name == "size":
1211            if not has_free_symbols(r := example_ndarray.size):
1212                return ConstantVariable.create(int(r))
1213            return insert_into_graph()
1214        elif name in ["base", "flags", "dtype"]:
1215            unimplemented(f"TODO: add support for ndarray.{name}")
1216        elif name in ["__version__"]:
1217            unimplemented("delegate np.__version__ to NumPy")
1218        if result is None:
1219            raise NotImplementedError
1220        return result
1221
1222    @staticmethod
1223    def patch_args(name, args, kwargs):
1224        if name == "clip":
1225            kwargs_rename = {"a_min": "min", "a_max": "max"}
1226            kwargs = {kwargs_rename.get(k, k): v for k, v in kwargs.items()}
1227        return args, kwargs
1228
1229    def call_method(
1230        self,
1231        tx,
1232        name,
1233        args: "List[VariableTracker]",
1234        kwargs: "Dict[str, VariableTracker]",
1235    ) -> "VariableTracker":
1236        from ..utils import numpy_method_wrapper
1237
1238        args, kwargs = self.patch_args(name, args, kwargs)
1239
1240        if name in ["__len__", "size", "tolist"]:
1241            # delegate back to TensorVariable
1242            return super().call_method(tx, name, args, kwargs)
1243        if name in ("tostring", "tobytes"):
1244            unimplemented(f"{name} is not modelled in torch._numpy")
1245        proxy = tx.output.create_proxy(
1246            "call_function",
1247            numpy_method_wrapper(name),
1248            *proxy_args_kwargs([self] + list(args), kwargs),
1249        )
1250        return NumpyNdarrayVariable.create(tx, proxy)
1251
1252    def python_type(self):
1253        return np.ndarray
1254
1255
1256class UnspecializedPythonVariable(TensorVariable):
1257    """
1258    This is a 1-element tensor represents unspecialized python float/int.
1259    """
1260
1261    _nonvar_fields = {
1262        "raw_value",
1263        "need_unwrap",
1264        *TensorVariable._nonvar_fields,
1265    }
1266
1267    def __init__(
1268        self, proxy: torch.fx.Proxy, *, raw_value=None, need_unwrap=True, **kwargs
1269    ) -> None:
1270        super().__init__(proxy, **kwargs)
1271        self.raw_value = raw_value
1272        self.need_unwrap = need_unwrap
1273
1274    @classmethod
1275    def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True):
1276        # Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance.
1277        return UnspecializedPythonVariable(
1278            **dict(tensor_variable.__dict__),
1279            raw_value=raw_value,
1280            need_unwrap=need_unwrap,
1281        )
1282
1283
1284class FakeItemVariable(TensorVariable):
1285    """An unspecialized python variable which prevents access to the underlying raw value.
1286    This is needed if item is called on a FakeTensor."""
1287
1288    _nonvar_fields = {
1289        "need_unwrap",
1290        *TensorVariable._nonvar_fields,
1291    }
1292
1293    def __init__(self, proxy: torch.fx.Proxy, **kwargs) -> None:
1294        need_unwrap = kwargs.pop("need_unwrap", False)
1295        super().__init__(proxy, **kwargs)
1296        self.need_unwrap = need_unwrap
1297
1298    @classmethod
1299    def from_tensor_variable(cls, tensor_variable):
1300        return FakeItemVariable(**dict(tensor_variable.__dict__))
1301
1302
1303class TensorSubclassVariable(VariableTracker):
1304    def __init__(self, value, *args, **kwargs) -> None:
1305        self.value = value
1306        super().__init__(*args, **kwargs)
1307
1308    def call_function(
1309        self,
1310        tx: "InstructionTranslator",
1311        args: List[VariableTracker],
1312        kwargs: Dict[str, VariableTracker],
1313    ) -> VariableTracker:
1314        if len(args) == 1 and isinstance(args[0], TensorVariable):
1315            from .builder import VariableBuilder
1316            from .torch_function import TensorWithTFOverrideVariable
1317
1318            torch_fn = VariableBuilder(
1319                tx, AttrSource(self.source, "__torch_function__")
1320            )(self.value.__torch_function__)
1321
1322            return TensorWithTFOverrideVariable.from_tensor_var(
1323                tx, args[0], self.value, torch_fn
1324            )
1325
1326        return super().call_function(tx, args, kwargs)
1327
1328    def as_python_constant(self):
1329        return self.value
1330
1331
1332class UntypedStorageVariable(VariableTracker):
1333    _nonvar_fields = {
1334        "example_value",
1335        *VariableTracker._nonvar_fields,
1336    }
1337
1338    def __init__(
1339        self,
1340        from_tensor: TensorVariable,
1341        example_value: torch.UntypedStorage,
1342        **kwargs,
1343    ) -> None:
1344        super().__init__(**kwargs),
1345        self.from_tensor = from_tensor
1346        # Example_value will always have device="meta"
1347        self.example_value = example_value
1348
1349    def call_method(
1350        self,
1351        tx,
1352        name,
1353        args: List[VariableTracker],
1354        kwargs: Dict[str, VariableTracker],
1355    ) -> VariableTracker:
1356        if name == "size":
1357            assert not args
1358            assert not kwargs
1359            result = self.example_value.size()
1360            if not has_free_symbols(result):
1361                # avoid creating a node in the graph
1362                return ConstantVariable.create(int(result))
1363            else:
1364                from ..external_utils import untyped_storage_size
1365                from .builder import wrap_fx_proxy
1366
1367                return wrap_fx_proxy(
1368                    tx,
1369                    tx.output.create_proxy(
1370                        "call_function",
1371                        untyped_storage_size,
1372                        (self.from_tensor.as_proxy(),),
1373                        {},
1374                    ),
1375                )
1376        if name == "resize_" and len(args) == 1:
1377            assert not kwargs
1378            tx.output.create_proxy(
1379                "call_function",
1380                torch.ops.inductor.resize_storage_bytes_,
1381                (self.from_tensor.as_proxy(), args[0].as_proxy()),
1382                {},
1383            )
1384            return self
1385
1386        return super().call_method(tx, name, args, kwargs)
1387
1388    def reconstruct(self, codegen):
1389        codegen(self.from_tensor)
1390        codegen.load_method("untyped_storage")
1391        codegen.call_method(0)
1392