xref: /aosp_15_r20/external/pytorch/test/dynamo/test_subclasses.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import functools
3import itertools
4import unittest
5from functools import partial
6
7import torch
8import torch._dynamo.test_case
9import torch._dynamo.testing
10import torch._functorch.config
11import torch.utils._pytree as pytree
12import torch.utils.checkpoint
13from torch._dynamo.testing import normalize_gm
14from torch._higher_order_ops.wrap import wrap
15from torch.fx.experimental.symbolic_shapes import (
16    DimDynamic,
17    ShapeEnv,
18    StatelessSymbolicContext,
19)
20from torch.nested._internal.nested_tensor import (
21    jagged_from_list,
22    jagged_from_tensor_and_lengths,
23    nested_view_from_values_offsets,
24)
25from torch.testing._internal.common_utils import (
26    instantiate_parametrized_tests,
27    NestedTensorTestCase,
28    parametrize,
29    subtest,
30)
31from torch.testing._internal.inductor_utils import HAS_CUDA
32from torch.testing._internal.two_tensor import TwoTensor
33from torch.utils._python_dispatch import return_and_correct_aliasing
34
35
36def traceable_subclass(c):
37    return torch._dynamo.config.patch("traceable_tensor_subclasses", {c})
38
39
40def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
41    actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2)
42    self.assertEqual(actual_recompiles, expected_recompiles)
43
44
45def get_jagged_tensor(nested_size, offsets, requires_grad=True):
46    # Makes a jagged tensor with N constituent tensors with size
47    # as specified ((S0, S1, S2), D)
48    D = nested_size[1]
49    out = []
50    for s in nested_size[0]:
51        out.append(torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64))
52    return jagged_from_list(out, offsets)
53
54
55def get_view_test_cases():
56    # Test all cases with both an NT base and a dense base
57    # Subclass -> Subclass
58    # Dense -> Subclass
59
60    # NB: Don't close over loop variables, they will not get copied into the
61    # closure
62    #
63    # NB: These return functions so we don't generate tensors during test
64    # collection time
65
66    def mk_basic(base_is_nt):
67        # There are three cases to consider here based on the logic in
68        # meta_utils.py
69        #
70        # (1) basic case:
71        # view is not a leaf and has the same requires grad as its basic case
72        x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
73        x = x.clone() if base_is_nt else x
74        assert not x.is_leaf
75        return x.unsqueeze(-1)
76
77    def mk_leaf(base_is_nt, requires_grad_1, requires_grad_2):
78        x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=requires_grad_1)
79        x = x.clone() if base_is_nt else x
80        with torch.no_grad():
81            x_view = x.unsqueeze(-1)
82            # The issue is this doesn't quite work
83            x_view.requires_grad_(requires_grad_2)
84
85        return x_view
86
87    def mk_obscure(base_is_nt):
88        x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
89        x = x.clone() if base_is_nt else x
90        # intermediate leaf view
91        with torch.no_grad():
92            x_view = x.unsqueeze(-1)
93        x_view.requires_grad_(True)
94        x_view_view = x_view.unsqueeze(-1)
95        return x_view_view
96
97    for base_is_nt in [False, True]:
98        prefix = f"base_is_nt_{base_is_nt}"
99
100        yield partial(mk_basic, base_is_nt), f"{prefix}_basic"
101
102        # (2) leaf view case:
103        # the view has to be a leaf (w/ requires_grad True or requires_grad False)
104        # base w/ requires_grad True or requires_grad False
105        for requires_grad_1, requires_grad_2 in itertools.product(
106            [True, False], repeat=2
107        ):
108            yield partial(
109                mk_leaf, base_is_nt, requires_grad_1, requires_grad_2
110            ), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}"
111
112        # (3) obscure case:
113        # view is not a leaf (implies requires_grad True)
114        # base w/ requires_grad False)
115        yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure"
116
117    # Subclass -> Dense
118    yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[
119        0
120    ].clone(), "subclass_dense"
121
122    # Dense -> Subclass -> Dense -> Subclass
123    def mk_dense_subclass_dense_subclass():
124        values = torch.randn(10, 5)
125        offsets = torch.tensor([0, 3, 6, 10])
126        offsets2 = offsets.clone().detach()
127        return nested_view_from_values_offsets(
128            nested_view_from_values_offsets(values, offsets).values(), offsets
129        )
130
131    yield mk_dense_subclass_dense_subclass, "dense_subclass_dense_subclass"
132
133    def mk_subclass_dense_subclass_dense():
134        x = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
135        offsets2 = x.offsets().clone().detach()
136        nt_view = nested_view_from_values_offsets(x.values(), offsets2).values()
137
138    yield mk_subclass_dense_subclass_dense, "subclass_dense_subclass_dense"
139
140
141VIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()}
142
143
144requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
145
146compile_full_eager = torch.compile(backend="eager", fullgraph=True)
147
148
149class BaseTorchFunction(torch.Tensor):
150    @classmethod
151    def __torch_function__(cls, func, types, args=(), kwargs=None):
152        if kwargs is None:
153            kwargs = {}
154        return super().__torch_function__(func, types, args, kwargs)
155
156
157class MockSubclass(torch.Tensor):
158    @classmethod
159    def __torch_function__(cls, func, types, args=(), kwargs=None):
160        if kwargs is None:
161            kwargs = {}
162        return func(*args, **kwargs)
163
164
165class AttrSubclass(torch.Tensor):
166    x: int = 10
167    size: int = 10
168
169    @classmethod
170    def __torch_function__(cls, func, types, args=(), kwargs=None):
171        if kwargs is None:
172            kwargs = {}
173
174        return func(*args, **kwargs)
175
176
177class DummyNDim(torch.Tensor):
178    @classmethod
179    def __torch_function__(cls, func, types, args=(), kwargs=None):
180        if kwargs is None:
181            kwargs = {}
182
183        if func == torch.Tensor.ndim.__get__:
184            return 10
185
186        return super().__torch_function__(func, types, args, kwargs)
187
188
189class WrapperSubclass:
190    def __init__(self, tensor):
191        self.tensor = tensor
192
193    @classmethod
194    def __torch_function__(cls, func, types, args=(), kwargs=None):
195        if kwargs is None:
196            kwargs = {}
197
198        args = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, args)
199        kwargs = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, kwargs)
200
201        return func(*args, **kwargs)
202
203
204class SigmoidToExpSubclass(torch.Tensor):
205    @classmethod
206    def __torch_function__(cls, func, types, args=(), kwargs=None):
207        if kwargs is None:
208            kwargs = {}
209
210        if func == torch.Tensor.sigmoid:
211            return super().__torch_function__(torch.Tensor.exp, types, args, kwargs)
212
213        return super().__torch_function__(func, types, args, kwargs)
214
215
216# Wrapper subclass with two inner tensors: data and scale
217# data has same shape as outer, and scale has single dim size
218class ScaledTensor(torch.Tensor):
219    def __new__(
220        cls,
221        data: torch.Tensor,
222        scale: torch.Tensor,
223        *,
224        constant: int = 0,
225    ):
226        return torch.Tensor._make_wrapper_subclass(
227            cls,
228            data.size(),
229            strides=data.stride(),
230            storage_offset=data.storage_offset(),
231            dtype=data.dtype,
232            layout=data.layout,
233            requires_grad=data.requires_grad,
234            device=data.device,
235        )
236
237    def __init__(self, data: torch.Tensor, scale: torch.Tensor, constant: int = 0):
238        self._data = data
239        self._scale = scale
240        self._constant = constant
241
242    def __tensor_flatten__(self):
243        ctx = {"_constant": self._constant}
244        return ["_data", "_scale"], ctx
245
246    @staticmethod
247    def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
248        assert len(inner_tensors) == 2
249        return ScaledTensor(
250            inner_tensors["_data"],
251            inner_tensors["_scale"],
252            constant=metadata["_constant"],
253        )
254
255    @classmethod
256    def __torch_dispatch__(cls, func, types, args, kwargs=None):
257        scaled_tensor = args[0]
258        out = func(scaled_tensor._data, *args[1:], **kwargs)
259        return ScaledTensor(out, scaled_tensor._scale, constant=scaled_tensor._constant)
260
261    def __repr__(self):
262        return f"{self._data.__repr__()}\n{self._scale.__repr__()}"
263
264
265class OptionalScaledTensor(torch.Tensor):
266    def __new__(
267        cls,
268        data,
269        scale,
270        *,
271        constant: int = 0,
272    ):
273        return torch.Tensor._make_wrapper_subclass(
274            cls,
275            data.size(),
276            strides=data.stride(),
277            storage_offset=data.storage_offset(),
278            dtype=data.dtype,
279            layout=data.layout,
280            requires_grad=data.requires_grad,
281            device=data.device,
282        )
283
284    def __init__(self, data: torch.Tensor, scale, constant: int = 0):
285        self._data = data
286        self._scale = scale
287        self._constant = constant
288
289    def __tensor_flatten__(self):
290        ctx = {"_constant": self._constant}
291        if self._scale is not None:
292            return ["_data", "_scale"], ctx
293        else:
294            return ["_data"], ctx
295
296    @staticmethod
297    def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
298        return OptionalScaledTensor(
299            inner_tensors["_data"],
300            inner_tensors["_scale"] if "_scale" in inner_tensors else None,
301            constant=metadata["_constant"],
302        )
303
304    @classmethod
305    def __torch_dispatch__(cls, func, types, args, kwargs=None):
306        scaled_tensor = args[0]
307        out = func(scaled_tensor._data, *args[1:], **kwargs)
308        if scaled_tensor._scale is not None:
309            out = out * scaled_tensor._scale
310        return OptionalScaledTensor(
311            out, scaled_tensor._scale, constant=scaled_tensor._constant
312        )
313
314    def __repr__(self):
315        return (
316            f"OptionalScaledTensor({self._data.__repr__()}\n{self._scale.__repr__()})"
317        )
318
319
320class CtxSubclassTensor(torch.Tensor):
321    """
322    Class used to verify guarding on the subclass metadata
323    """
324
325    @staticmethod
326    def __new__(cls, a, constant):
327        shape = a.shape
328        kwargs = {}
329        kwargs["strides"] = a.stride()
330        kwargs["storage_offset"] = a.storage_offset()
331        kwargs["device"] = a.device
332        kwargs["layout"] = a.layout
333        kwargs["requires_grad"] = a.requires_grad
334        kwargs["dtype"] = a.dtype
335        out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
336        return out
337
338    def __init__(self, a, constant):
339        self.a = a
340        self.constant = constant
341
342    def __repr__(self):
343        a_repr = repr(self.a)
344        return f"CtxSubclassTensor({a_repr})"
345
346    def __tensor_flatten__(self):
347        return ["a"], (self.constant,)
348
349    @staticmethod
350    def __tensor_unflatten__(inner_tensors, meta, sizes, strides):
351        constant = meta[0]
352        a = inner_tensors["a"]
353        return CtxSubclassTensor(a, constant)
354
355    @classmethod
356    def __torch_dispatch__(cls, func, types, args, kwargs):
357        from torch.utils._python_dispatch import return_and_correct_aliasing
358
359        if kwargs is None:
360            kwargs = {}
361        biggest_constant = max(
362            [
363                x.constant
364                for x in pytree.tree_flatten(args)[0]
365                if isinstance(x, CtxSubclassTensor)
366            ]
367        )
368        args_a = pytree.tree_map(
369            lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, args
370        )
371        kwargs_a = pytree.tree_map(
372            lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, kwargs
373        )
374        out_a = func(*args_a, **kwargs_a)
375        out = pytree.tree_map(
376            lambda x: CtxSubclassTensor(x, biggest_constant)
377            if isinstance(x, torch.Tensor)
378            else x,
379            out_a,
380        )
381
382        if func == torch.ops.aten.mul.Tensor:
383            out = out + out.constant
384
385        return return_and_correct_aliasing(func, args, kwargs, out)
386
387
388def func(a):
389    return a.sin()
390
391
392class EagerRecordGraphAndInputs:
393    def __init__(self) -> None:
394        self.graphs = []
395        self.example_inputs = []
396
397    def __call__(self, gm: torch.fx.GraphModule, example_inputs):
398        self.graphs.append(gm)
399        self.example_inputs.append(example_inputs)
400        return gm
401
402
403GLOBAL_TEST_SUBCLASSES = {
404    MockSubclass,
405    DummyNDim,
406    SigmoidToExpSubclass,
407    BaseTorchFunction,
408}
409
410
411# Returns True if the function recompiles between inputs1 and inputs2 with the
412# specified dynamic setting.
413def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True):
414    compile_count = [0]
415
416    def counter(gm, example_inputs):
417        compile_count[0] += 1
418        return gm
419
420    compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic)
421    compiled_f(*inputs1)
422    compiled_f(*inputs2)
423    return compile_count[0] > 1
424
425
426class SubclassTests(torch._dynamo.test_case.TestCase):
427    @classmethod
428    def setUpClass(cls):
429        super().setUpClass()
430        cls._exit_stack.enter_context(
431            torch._dynamo.config.patch(
432                "traceable_tensor_subclasses", GLOBAL_TEST_SUBCLASSES
433            )
434        )
435
436    @classmethod
437    def tearDownClass(cls):
438        cls._exit_stack.close()
439
440    def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
441        _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles)
442
443    def test_no_call_to_new(self):
444        class BadNewTorchFunction(torch.Tensor):
445            def __new__(cls, *args, **kwargs):
446                raise RuntimeError("Oops!")
447
448            @classmethod
449            def __torch_function__(cls, func, types, args=(), kwargs=None):
450                if kwargs is None:
451                    kwargs = {}
452                return super().__torch_function__(func, types, args, kwargs)
453
454        with torch._dynamo.config.patch(
455            "traceable_tensor_subclasses", {BadNewTorchFunction}
456        ):
457
458            @torch.compile(backend="eager", fullgraph=True)
459            def fn(x):
460                return torch.add(x, 1)
461
462            input = torch.ones(2, 2).as_subclass(BadNewTorchFunction)
463
464            res = fn(input)
465            self.assertIsInstance(res, BadNewTorchFunction)
466
467    def test_no_torch_function_recompiles(self):
468        class NJT:
469            def __repr__(self):
470                return f"NJT(shape={self.shape})"
471
472            def __init__(self, values, offsets):
473                self._values = values
474                self._offsets = offsets
475
476            def sin(self):
477                return torch.sin(self)
478
479            @classmethod
480            def __torch_function__(cls, func, types, args=(), kwargs=None):
481                if kwargs is None:
482                    kwargs = {}
483                if func == torch.sin:
484                    self = args[0]
485                    return NJT(func(self._values), self._offsets)
486                raise AssertionError("should not get here")
487
488        values1 = torch.randn(10, 3, 4, requires_grad=True)
489        values2 = torch.randn(10, 3, 4, requires_grad=True)
490        offsets = torch.tensor([0, 3, 10])
491        njt1 = NJT(values1, offsets)
492        njt2 = NJT(values2, offsets)
493
494        @torch.compile(backend="eager", fullgraph=True)
495        def f(x):
496            return torch.sin(x)
497
498        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
499            f(njt1)
500            f(njt2)
501
502    def test_base_torch_function_tracing(self):
503        def fn(x):
504            return torch.add(x, 1)
505
506        input = torch.ones(2, 2).as_subclass(BaseTorchFunction)
507        out = fn(input)
508        out_opt = compile_full_eager(fn)(input)
509        self.assertIsInstance(out, BaseTorchFunction)
510        self.assertEqual(out, out_opt)
511
512    def test_torch_function_state_graph_break(self):
513        @torch.compile(backend="eager")
514        def fn(x):
515            with torch._C.DisableTorchFunctionSubclass():
516                torch._dynamo.graph_break()
517                return torch._C._is_torch_function_enabled(), torch.add(x, 1.0)
518
519        input = torch.ones(2, 2)
520        res, _ = fn(input)
521        self.assertFalse(res)
522
523    def test_torch_function_state_nested(self):
524        @torch.compile(backend="eager")
525        def fn(x):
526            with torch._C.DisableTorchFunctionSubclass():
527                with torch._C.DisableTorchFunctionSubclass():
528                    x = x + 1
529                # Should reset to the outer state (disabled) after exiting ctx manager
530                return torch._C._is_torch_function_enabled(), torch.add(x, 1.0)
531
532        input = torch.ones(2, 2)
533        res, _ = fn(input)
534        self.assertFalse(res)
535
536    def test_torch_function_state_tracing(self):
537        @torch.compile(backend="eager", fullgraph=True)
538        def fn(x):
539            with torch._C.DisableTorchFunctionSubclass():
540                torch.add(x, 1.0)
541
542        input = torch.ones(2, 2)
543
544        res = fn(input)
545
546    def test_torch_function_state_guards(self):
547        cnt = torch._dynamo.testing.CompileCounter()
548
549        @torch.compile(backend=cnt, fullgraph=True)
550        def fn(x):
551            torch.add(x, 1.0)
552
553        input = torch.ones(2, 2)
554
555        with torch._C.DisableTorchFunctionSubclass():
556            res = fn(input)
557
558        res = fn(input)
559
560        self.assertEqual(cnt.frame_count, 2)
561
562    def test_return_subclass(self):
563        @torch.compile(backend="eager", fullgraph=True)
564        def fn(x):
565            return MockSubclass(torch.add(x, 1.0))
566
567        input = torch.ones(2, 2)
568
569        res = fn(input)
570        self.assertIsInstance(res, MockSubclass)
571
572    def test_return_as_subclass(self):
573        @torch.compile(backend="eager", fullgraph=True)
574        def fn(x):
575            return torch.add(x, 1.0).as_subclass(MockSubclass)
576
577        input = torch.ones(2, 2)
578
579        res = fn(input)
580        self.assertIsInstance(res, MockSubclass)
581
582    def test_return_local_subclass(self):
583        class LocalSubclass(torch.Tensor):
584            @classmethod
585            def __torch_function__(cls, func, types, args=(), kwargs=None):
586                if kwargs is None:
587                    kwargs = {}
588                return func(*args, **kwargs)
589
590        with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
591
592            @torch.compile(backend="eager", fullgraph=True)
593            def fn(x):
594                return LocalSubclass(torch.add(x, 1.0))
595
596            input = torch.ones(2, 2)
597
598            res = fn(input)
599            self.assertIsInstance(res, LocalSubclass)
600
601    def test_torch_function_list_args(self):
602        HANDLED_FUNCTIONS = {}
603
604        class MyClass:
605            def __init__(self, foo):
606                self.foo = foo
607
608            @classmethod
609            def __torch_function__(
610                cls,
611                func,
612                types,
613                args=(),
614                kwargs=None,
615            ):
616                if kwargs is None:
617                    kwargs = {}
618                if func not in HANDLED_FUNCTIONS or not all(  # noqa: C419
619                    [  # noqa: C419
620                        issubclass(t, (torch.Tensor, MyClass)) for t in types
621                    ]
622                ):
623                    return NotImplemented
624                return HANDLED_FUNCTIONS[func](*args, **kwargs)
625
626        def _stack(input, dim=0, *, out=None):
627            return MyClass(sum([x.foo for x in input]))
628
629        HANDLED_FUNCTIONS[torch.stack] = _stack
630
631        @torch.compile(backend="eager", fullgraph=True)
632        def fn(v0, v1):
633            return torch.stack([v0, v1])
634
635        ret = fn(MyClass(1), MyClass(1))
636        self.assertEqual(ret.foo, 2)
637
638    @parametrize(
639        "comparison",
640        [
641            subtest(isinstance, "isinstance"),
642            subtest(lambda instance, type_: type(instance) == type_, "equality"),
643            subtest(lambda instance, type_: type(instance) is type_, "identity"),
644        ],
645    )
646    @parametrize(
647        "input_type",
648        [
649            subtest(torch.Tensor, "tensor"),
650            subtest(DummyNDim, "subclass"),
651        ],
652    )
653    def test_type_check(self, comparison, input_type):
654        with torch._dynamo.config.patch("traceable_tensor_subclasses", {DummyNDim}):
655
656            def fn(x):
657                if comparison(x, DummyNDim):
658                    return torch.ones(1, 1)
659                else:
660                    return torch.zeros(2, 2)
661
662            input = torch.ones(2, 2).as_subclass(input_type)
663            exp_res = fn(input)
664            act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input)
665            self.assertEqual(exp_res, act_res)
666
667    def test_torch_function_call_on_method(self):
668        x = torch.ones(2, 2)
669        y = torch.ones(2, 2)
670        z = torch.ones(2, 2)
671        wrapped = x.as_subclass(SigmoidToExpSubclass)
672        wrapped2 = y.as_subclass(SigmoidToExpSubclass)
673
674        def fn(w):
675            return w.sigmoid()
676
677        fn_opt = compile_full_eager(fn)
678
679        res_exp = fn(wrapped)
680        res_act = fn_opt(wrapped2)
681        res_exp2 = z.exp()
682
683        self.assertEqual(res_exp, res_act)
684        self.assertEqual(res_exp, res_exp2)
685
686    def test_user_overidden_method_unsupported(self):
687        class LocalSubclass(torch.Tensor):
688            @classmethod
689            def __torch_function__(cls, func, types, args=(), kwargs=None):
690                if kwargs is None:
691                    kwargs = {}
692                return super().__torch_function__(func, types, args, kwargs)
693
694            def sigmoid(self):
695                return None
696
697        @torch.compile(backend="eager", fullgraph=True)
698        def fn(x):
699            x.sigmoid()
700
701        msg = (
702            "Accessing overridden method/attribute sigmoid on a tensor"
703            " subclass with a __torch_function__ override is not supported"
704        )
705        with torch._dynamo.config.patch(
706            "traceable_tensor_subclasses", {LocalSubclass}
707        ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
708            x = torch.ones(2, 2).as_subclass(LocalSubclass)
709            fn(x)
710
711    def test_user_overidden_attr_unsupported(self):
712        class LocalSubclass(torch.Tensor):
713            @classmethod
714            def __torch_function__(cls, func, types, args=(), kwargs=None):
715                if kwargs is None:
716                    kwargs = {}
717                return super().__torch_function__(func, types, args, kwargs)
718
719            ndim = 10
720
721        @torch.compile(backend="eager", fullgraph=True)
722        def fn(x):
723            return x.ndim
724
725        msg = (
726            "Accessing overridden method/attribute ndim on a tensor"
727            " subclass with a __torch_function__ override is not supported"
728        )
729        with torch._dynamo.config.patch(
730            "traceable_tensor_subclasses", {LocalSubclass}
731        ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
732            x = torch.ones(2, 2).as_subclass(LocalSubclass)
733            fn(x)
734
735    def test_user_overidden_property_unsupported(self):
736        class LocalSubclass(torch.Tensor):
737            def __init__(self) -> None:
738                self._ndim = 10
739
740            @classmethod
741            def __torch_function__(cls, func, types, args=(), kwargs=None):
742                if kwargs is None:
743                    kwargs = {}
744                return super().__torch_function__(func, types, args, kwargs)
745
746            @property
747            def ndim(self):
748                return self._ndim
749
750            @ndim.setter
751            def ndim(self, value):
752                self._ndim = value
753
754        @torch.compile(backend="eager", fullgraph=True)
755        def fn(x):
756            return x.ndim
757
758        msg = (
759            "Accessing overridden method/attribute ndim on a tensor"
760            " subclass with a __torch_function__ override is not supported"
761        )
762        with torch._dynamo.config.patch(
763            "traceable_tensor_subclasses", {LocalSubclass}
764        ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
765            x = torch.ones(2, 2).as_subclass(LocalSubclass)
766            fn(x)
767
768    def test_overridden_method_guarding(self):
769        class LocalSubclass(torch.Tensor):
770            @classmethod
771            def __torch_function__(cls, func, types, args=(), kwargs=None):
772                if kwargs is None:
773                    kwargs = {}
774                return super().__torch_function__(func, types, args, kwargs)
775
776        @torch.compile(backend="eager")
777        def fn(x):
778            return x.sigmoid()
779
780        with torch._dynamo.config.patch(
781            error_on_recompile=True, traceable_tensor_subclasses={LocalSubclass}
782        ):
783            x = torch.ones(2, 2).as_subclass(LocalSubclass)
784            fn(x)
785            fn(x)
786            x = torch.ones(2, 2).as_subclass(LocalSubclass)
787            fn(x)
788
789        with torch._dynamo.config.patch(
790            traceable_tensor_subclasses={LocalSubclass}
791        ), self.assertRaisesRegex(
792            TypeError,
793            "'bool' object is not callable",
794        ):
795            LocalSubclass.sigmoid = False
796            fn(x)
797
798    def test_torch_function_call_on_attr(self):
799        x = torch.ones(2, 2)
800        wrapped = x.as_subclass(DummyNDim)
801
802        def fn(w):
803            return w.ndim + torch.ones(2)
804
805        fn_opt = compile_full_eager(fn)
806
807        res_exp = fn(wrapped)
808        res_act = fn_opt(wrapped)
809
810        self.assertEqual(res_exp, res_act)
811        self.assertEqual(res_exp, torch.ones(2) + 10)
812
813    def test_torch_function_wrapper_class(self):
814        x = torch.ones(2, 2)
815        wrapped = WrapperSubclass(x)
816
817        def fn(w):
818            return torch.add(w, 1.0)
819
820        fn_opt = compile_full_eager(fn)
821
822        res_exp = fn(wrapped)
823        res_act = fn_opt(wrapped)
824        self.assertEqual(res_exp, res_act)
825
826    def test_torch_function_wrapper_class_with_kwargs(self):
827        x = torch.ones(2, 2)
828        wrapped = WrapperSubclass(x)
829
830        def fn(w):
831            return torch.add(w, 1.0, alpha=2.0)
832
833        fn_opt = compile_full_eager(fn)
834
835        res_exp = fn(wrapped)
836        res_act = fn_opt(wrapped)
837        self.assertEqual(res_exp, res_act)
838
839    def test_tensor_subclass_custom_attr(self):
840        class AttrSubclass(torch.Tensor):
841            x: int = 10
842
843            @classmethod
844            def __torch_function__(cls, func, types, args=(), kwargs=None):
845                if kwargs is None:
846                    kwargs = {}
847
848                return super().__torch_function__(func, types, args, kwargs)
849
850        @torch.compile(backend="eager", fullgraph=True)
851        def fn(x):
852            return x.x + torch.ones(2, 2)
853
854        with traceable_subclass(AttrSubclass):
855            input = torch.ones(2, 2).as_subclass(AttrSubclass)
856            fn_opt = compile_full_eager(fn)
857
858            res_exp = fn(input)
859            res_act = fn_opt(input)
860            self.assertEqual(res_exp, res_act)
861
862    def test_compile_with_fake_tensor_dynamic_dim(self):
863        x = torch.randn([3, 4])
864
865        def f(x):
866            return torch.sin(x)
867
868        def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count):
869            torch._dynamo.reset()
870            cnt = torch._dynamo.testing.CompileCounter()
871
872            opt_f = torch.compile(f, backend=cnt, fullgraph=True)
873
874            x1 = torch.rand_like(x)
875            f(x)
876            f(torch.randn([4, 3]))
877            shape_env = ShapeEnv()
878            with torch._subclasses.fake_tensor.FakeTensorMode(
879                shape_env=shape_env
880            ) as fake_mode:
881                x_fake = fake_mode.from_tensor(
882                    x,
883                    symbolic_context=StatelessSymbolicContext(
884                        dynamic_sizes=[dim_dynamic for i in range(x.dim())]
885                    ),
886                )
887                x1_fake = fake_mode.from_tensor(
888                    x1,
889                    symbolic_context=StatelessSymbolicContext(
890                        dynamic_sizes=[dim_dynamic for i in range(x.dim())]
891                    ),
892                )
893                opt_f(x_fake)
894                opt_f(x1_fake)
895
896            self.assertEqual(cnt.frame_count, exp_frame_count)
897            self.assertEqual(cnt.op_count, exp_op_count)
898
899        test_dynamic_dim(f, x, DimDynamic.DYNAMIC, 1, 1)
900        test_dynamic_dim(f, x, DimDynamic.DUCK, 1, 1)
901        test_dynamic_dim(f, x, DimDynamic.STATIC, 1, 1)
902
903    def test_compile_with_fake_tensor_automatic_dynamic(self):
904        def f(x):
905            return torch.sin(x)
906
907        def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count):
908            torch._dynamo.reset()
909            cnt = torch._dynamo.testing.CompileCounter()
910            opt_f = torch.compile(f, backend=cnt, fullgraph=True)
911
912            shape_env = ShapeEnv()
913            with torch._subclasses.fake_tensor.FakeTensorMode(
914                shape_env=shape_env
915            ) as fake_mode:
916                for inp in inps:
917                    fake_inp = fake_mode.from_tensor(
918                        inp,
919                        symbolic_context=StatelessSymbolicContext(
920                            [dim_dynamic for i in range(x.dim())]
921                        ),
922                    )
923                    opt_f(fake_inp)
924            self.assertEqual(cnt.frame_count, exp_frame_count)
925            self.assertEqual(cnt.op_count, exp_op_count)
926
927        x = torch.randn([3, 4])
928        y = torch.randn([4, 5])
929        z = torch.randn([5, 6])
930        a = torch.randn([3, 5])
931        b = torch.randn([4, 4])
932        # When inputs' DimDynamic is DYNAMIC or DUCK, the inputs
933        # to opt_f will be tensors with SymInt sizes. Dynamo will treat input
934        # as dynamic automatically and will only compile once
935        for dim_dynamic in [DimDynamic.DYNAMIC, DimDynamic.DUCK]:
936            test_automatic_dynamic(f, [x, y, z], dim_dynamic, 1, 1)
937            test_automatic_dynamic(f, [x, a, z], dim_dynamic, 1, 1)
938            test_automatic_dynamic(f, [x, b, z], dim_dynamic, 1, 1)
939
940        for dim_dynamic in [DimDynamic.STATIC]:
941            # Recompile once, first with dim 0 and 1 become Dynamic
942            test_automatic_dynamic(f, [x, y, z], dim_dynamic, 2, 2)
943            # Recompile 2 times, first with dim 1 become Dynamic, second with dim 0 becomes Dynamic.
944            test_automatic_dynamic(f, [x, a, z], dim_dynamic, 3, 3)
945            # Recompile 2 times, first with dim 0 become Dynamic, second with dim 1 becomes Dynamic.
946            test_automatic_dynamic(f, [x, b, z], dim_dynamic, 3, 3)
947
948    def test_compile_with_functionalization(self):
949        x = torch.randn([3, 4])
950        x_clone = x.clone()
951        x_clone2 = x.clone()
952        backend = EagerRecordGraphAndInputs()
953        cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
954
955        @torch.compile(backend=cnt, fullgraph=True)
956        def f(x):
957            return x.add_(1.0) + torch.nn.functional.relu_(x)
958
959        f_out = f(x)
960        self.assertEqual(cnt.frame_count, 1)
961        self.assertEqual(cnt.op_count, 3)
962        self.assertEqual(len(backend.graphs), 1)
963        self.assertEqual(len(backend.example_inputs), 1)
964
965        actual = normalize_gm(backend.graphs[0].print_readable(print_output=False))
966        self.assertExpectedInline(
967            actual,
968            """\
969class GraphModule(torch.nn.Module):
970    def forward(self, L_x_: "f32[3, 4]"):
971        l_x_ = L_x_
972
973        add_: "f32[3, 4]" = l_x_.add_(1.0)
974        relu_: "f32[3, 4]" = torch.relu_(l_x_);  l_x_ = None
975        add: "f32[3, 4]" = add_ + relu_;  add_ = relu_ = None
976        return (add,)
977""",
978        )
979
980        ff = torch.func.functionalize(f)
981        ff_out = ff(x_clone)
982
983        self.assertEqual(cnt.frame_count, 2)
984        self.assertEqual(cnt.op_count, 6)
985        self.assertEqual(len(backend.graphs), 2)
986        self.assertEqual(len(backend.example_inputs), 2)
987        actual = normalize_gm(backend.graphs[1].print_readable(print_output=False))
988        self.assertExpectedInline(
989            actual,
990            """\
991class GraphModule(torch.nn.Module):
992    def forward(self, L_x_: "f32[3, 4]"):
993        l_x_ = L_x_
994
995        add_: "f32[3, 4]" = l_x_.add_(1.0)
996        relu_: "f32[3, 4]" = torch.relu_(l_x_);  l_x_ = None
997        add: "f32[3, 4]" = add_ + relu_;  add_ = relu_ = None
998        return (add,)
999""",
1000        )
1001        self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
1002
1003        # Cannot re-use the version from AOTAutograd, since that uses python functional tensors.
1004        def to_fun(x):
1005            x_functional = torch._to_functional_tensor(x)
1006            torch._mirror_autograd_meta_to(x, x_functional)
1007            return x_functional
1008
1009        def aot_f_wrapper(func):
1010            @functools.wraps(func)
1011            def wrapper(*args, **kwargs):
1012                torch._enable_functionalization(reapply_views=False)
1013                try:
1014                    func_args = pytree.tree_map(to_fun, args)
1015                    func_kwargs = pytree.tree_map(to_fun, kwargs)
1016                    return func(*func_args, **func_kwargs)
1017                finally:
1018                    torch._disable_functionalization()
1019
1020            return wrapper
1021
1022        aot_ff = aot_f_wrapper(f)
1023        aot_ff_out = aot_ff(x_clone2)
1024
1025        self.assertEqual(cnt.frame_count, 3)
1026        self.assertEqual(cnt.op_count, 9)
1027        self.assertEqual(len(backend.graphs), 3)
1028        self.assertEqual(len(backend.example_inputs), 3)
1029        actual = normalize_gm(backend.graphs[2].print_readable(print_output=False))
1030        self.assertExpectedInline(
1031            actual,
1032            """\
1033class GraphModule(torch.nn.Module):
1034    def forward(self, L_x_: "f32[3, 4]"):
1035        l_x_ = L_x_
1036
1037        add_: "f32[3, 4]" = l_x_.add_(1.0)
1038        relu_: "f32[3, 4]" = torch.relu_(l_x_);  l_x_ = None
1039        add: "f32[3, 4]" = add_ + relu_;  add_ = relu_ = None
1040        return (add,)
1041""",
1042        )
1043        self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
1044
1045        self.assertEqual(f_out, ff_out)
1046        self.assertEqual(f_out, aot_ff_out)
1047
1048        try:
1049            torch._enable_functionalization(reapply_views=False)
1050            xf = pytree.tree_map(to_fun, x)
1051            x_view = xf.t()
1052            with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"):
1053                f(x_view)
1054        finally:
1055            torch._disable_functionalization()
1056
1057    def test_compile_higher_order_with_functionalization(self):
1058        backend = EagerRecordGraphAndInputs()
1059        cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
1060
1061        @torch.compile(backend=cnt, fullgraph=True)
1062        def f(x):
1063            return wrap(lambda x: x.add_(1.0), x)
1064
1065        def check_count_and_graph(
1066            exp_frame_count, exp_op_count, exp_n_graph, exp_graph
1067        ):
1068            self.assertEqual(cnt.frame_count, exp_frame_count)
1069            self.assertEqual(cnt.op_count, exp_op_count)
1070            self.assertEqual(len(backend.graphs), exp_n_graph)
1071            actual = normalize_gm(
1072                backend.graphs[exp_n_graph - 1].print_readable(print_output=False)
1073            )
1074            self.assertExpectedInline(actual, exp_graph, skip=1)
1075
1076        t = torch.randn([3, 4])
1077        t_clone = t.clone()
1078        t_clone2 = t.clone()
1079        f(t)
1080
1081        check_count_and_graph(
1082            1,
1083            2,
1084            1,
1085            """\
1086class GraphModule(torch.nn.Module):
1087    def forward(self, L_x_: "f32[3, 4]"):
1088        l_x_ = L_x_
1089
1090        wrap_body_0 = self.wrap_body_0
1091        wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_);  wrap_body_0 = l_x_ = None
1092        getitem: "f32[3, 4]" = wrap[0];  wrap = None
1093        return (getitem,)
1094
1095    class wrap_body_0(torch.nn.Module):
1096        def forward(self, l_x_: "f32[3, 4]"):
1097            add_: "f32[3, 4]" = l_x_.add_(1.0);  l_x_ = None
1098            return (add_,)
1099""",
1100        )
1101
1102        ff = torch.func.functionalize(f)
1103        ff_out = ff(t_clone)
1104        # frame count and op count are incremented due to re-compilation
1105        check_count_and_graph(
1106            2,
1107            4,
1108            2,
1109            """\
1110class GraphModule(torch.nn.Module):
1111    def forward(self, L_x_: "f32[3, 4]"):
1112        l_x_ = L_x_
1113
1114        wrap_body_0 = self.wrap_body_0
1115        wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_);  wrap_body_0 = l_x_ = None
1116        getitem: "f32[3, 4]" = wrap[0];  wrap = None
1117        return (getitem,)
1118
1119    class wrap_body_0(torch.nn.Module):
1120        def forward(self, l_x_: "f32[3, 4]"):
1121            add_: "f32[3, 4]" = l_x_.add_(1.0);  l_x_ = None
1122            return (add_,)
1123""",
1124        )
1125
1126        try:
1127            x = torch._to_functional_tensor(t_clone2)
1128            torch._mirror_autograd_meta_to(t_clone2, x)
1129            torch._enable_functionalization(reapply_views=False)
1130            aot_f_out = f(x)
1131        finally:
1132            torch._disable_functionalization()
1133
1134        # frame count and op count are incremented due to re-compilation
1135        check_count_and_graph(
1136            3,
1137            6,
1138            3,
1139            """\
1140class GraphModule(torch.nn.Module):
1141    def forward(self, L_x_: "f32[3, 4]"):
1142        l_x_ = L_x_
1143
1144        wrap_body_0 = self.wrap_body_0
1145        wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_);  wrap_body_0 = l_x_ = None
1146        getitem: "f32[3, 4]" = wrap[0];  wrap = None
1147        return (getitem,)
1148
1149    class wrap_body_0(torch.nn.Module):
1150        def forward(self, l_x_: "f32[3, 4]"):
1151            add_: "f32[3, 4]" = l_x_.add_(1.0);  l_x_ = None
1152            return (add_,)
1153""",
1154        )
1155
1156    def test_has_torch_function(self):
1157        class MyTensor:
1158            @classmethod
1159            def __torch_function__(cls, func, types, args=(), kwargs=None):
1160                if kwargs is None:
1161                    kwargs = {}
1162
1163                if func is torch.max:
1164                    return torch.tensor(123)
1165                return func(*args, **kwargs)
1166
1167        class LocalSubclass(torch.Tensor):
1168            @classmethod
1169            def __torch_function__(cls, func, types, args=(), kwargs=None):
1170                if kwargs is None:
1171                    kwargs = {}
1172                return func(*args, **kwargs)
1173
1174        def fn(x):
1175            return torch.overrides.has_torch_function_unary(
1176                x
1177            ), torch.overrides.has_torch_function_variadic(x)
1178
1179        for test_class in [MyTensor, LocalSubclass]:
1180            x = test_class()
1181            ref0 = fn(x)
1182            ref1 = fn(4)
1183            opt_fn = torch._dynamo.optimize("eager")(fn)
1184            res0 = opt_fn(x)
1185            res1 = opt_fn(4)
1186            self.assertEqual(ref0, res0)
1187            self.assertEqual(ref1, res1)
1188
1189    def test_wrapper_subclass_guards_on_inner_tensor(self):
1190        # Holds an inner tensor, that has a distinct shape from the outer wrapper tensor.
1191        # Also adds additional guards on the inner tensor's sizes.
1192        # When the first input to an op has x.shape[0] > 5, we insert an extra add node.
1193        class DoubleSizeMaybeAddGeThreeTensor(torch.Tensor):
1194            @staticmethod
1195            def __new__(cls, inner):
1196                # Double the outer-most dimension
1197                outer_shape = (inner.shape[0] * 2,) + inner.shape[1:]
1198                return torch.Tensor._make_wrapper_subclass(
1199                    # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
1200                    # Calling the overload that has kwargs causes us to go down the first overload path,
1201                    # which will **always** specialize sizes.
1202                    # We should probably eventually fix this so that the first overload can just handle dynamic shapes.
1203                    cls,
1204                    outer_shape,
1205                    inner.stride(),
1206                    None,
1207                    None,
1208                    inner.dtype,
1209                    inner.layout,
1210                    inner.device,
1211                    False,
1212                    inner.requires_grad,
1213                )
1214
1215            def __init__(self, inner):
1216                self.inner_elem = inner
1217
1218            def __tensor_flatten__(self):
1219                return ["inner_elem"], None
1220
1221            @staticmethod
1222            def __tensor_unflatten__(inner_tensors, _, outer_size, outer_stride):
1223                return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"])
1224
1225            def __repr__(self):
1226                return f"DoubleSizeMayberAddGeThreeTensor({repr(self.inner_elem)})"
1227
1228            @classmethod
1229            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1230                if kwargs is None:
1231                    kwargs = {}
1232
1233                args_inner = torch.utils._pytree.tree_map_only(
1234                    DoubleSizeMaybeAddGeThreeTensor, lambda x: x.inner_elem, args
1235                )
1236                out_inner = func(*args_inner, **kwargs)
1237
1238                # Add guards on the  inner tensor's sizes
1239                if args_inner[0].shape[0] > 3:
1240                    out_inner += 2
1241
1242                return DoubleSizeMaybeAddGeThreeTensor(out_inner)
1243
1244        curr_var_to_val = None
1245        curr_var_to_sources = None
1246        guards = None
1247
1248        def backend(gm, args):
1249            context = torch._guards.TracingContext.get()
1250
1251            # Grab info on sources and guards from the shapeenv
1252            nonlocal curr_var_to_val
1253            nonlocal curr_var_to_sources
1254            nonlocal guards
1255
1256            guards = [str(g.expr) for g in context.fake_mode.shape_env.guards]
1257            curr_var_to_val = {
1258                str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items()
1259            }
1260            curr_var_to_sources = {
1261                str(k): v[0].name()
1262                for k, v in context.fake_mode.shape_env.var_to_sources.items()
1263            }
1264            return gm
1265
1266        @torch.compile(backend=backend)
1267        def fn(x):
1268            if x.shape[0] < 13:
1269                return torch.mul(x, x)
1270            else:
1271                return torch.div(x, x)
1272
1273        inp = torch.ones(4, 4)
1274
1275        x = DoubleSizeMaybeAddGeThreeTensor(inp)
1276        torch._dynamo.mark_dynamic(x, 0)
1277        res = fn(x)
1278        # During fakeifying, we end up allocating a separate symint
1279        # for the outer and inner tensor (in this test, s0 is unused).
1280        expected_var_to_val = {
1281            "s0": 8,
1282            "s1": 4,
1283        }
1284        expected_var_to_sources = {
1285            "s0": "L['x'].size()[0]",
1286            "s1": "L['x'].inner_elem.size()[0]",
1287        }
1288        self.assertEqual(curr_var_to_val, expected_var_to_val)
1289        self.assertEqual(curr_var_to_sources, expected_var_to_sources)
1290        self.assertExpectedInline(
1291            "\n".join(guards),
1292            """\
1293Eq(2*s1, s0)
12942*s1 < 13
1295s1 > 3""",
1296        )
1297
1298    def test_wrapper_subclass_with_same_sized_inner_tensor(self):
1299        # shouldn't recompile for different sizes when dynamic=True
1300        sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
1301        sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7))
1302        self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=True))
1303
1304        # should recompile for different data size when dynamic=False
1305        sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
1306        sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
1307        self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
1308
1309        # avoid recompile using manual mark_dynamic() for different data size
1310        sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
1311        # NB: mark_dynamic() on outer tensor should translate to inner tensors of the same size
1312        torch._dynamo.mark_dynamic(sub1, 0)
1313        torch._dynamo.mark_dynamic(sub1, 1)
1314        sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
1315        self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
1316
1317    def test_wrapper_subclass_with_differently_sized_inner_tensor(self):
1318        # should recompile for different scale size when dynamic=False
1319        sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
1320        sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
1321        self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
1322
1323        # still recompiles using manual mark_dynamic() on outer for different scale size
1324        sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
1325        # NB: mark_dynamic() on outer tensor doesn't translate to inner tensors of different size
1326        torch._dynamo.mark_dynamic(sub1, 0)
1327        torch._dynamo.mark_dynamic(sub1, 1)
1328        sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
1329        self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
1330
1331    def test_recompiles_with_optional_inner_tensor(self):
1332        def f(x):
1333            return x + 1
1334
1335        # sub1 does not have the optional tensor specified while sub2 does
1336        sub1 = OptionalScaledTensor(torch.randn(2, 4), None)
1337        sub2 = OptionalScaledTensor(torch.randn(2, 4), torch.randn(2, 4))
1338
1339        # sanity check; don't recompile for same input
1340        self.assertFalse(_recompiles_for_inputs(f, (sub1,), (sub1,), dynamic=True))
1341        self.assertFalse(_recompiles_for_inputs(f, (sub2,), (sub2,), dynamic=True))
1342
1343        # these should recompile; optional tensor changes between specified and unspecified
1344        self.assertTrue(_recompiles_for_inputs(f, (sub1,), (sub2,), dynamic=True))
1345        self.assertTrue(_recompiles_for_inputs(f, (sub2,), (sub1,), dynamic=True))
1346
1347        f_compiled = torch.compile(f, backend="aot_eager")
1348        self.assertEqual(f(sub1)._data, f_compiled(sub1)._data)
1349        self.assertEqual(f(sub2)._data, f_compiled(sub2)._data)
1350
1351    def test_torch_dispatch_subclass_guard_recompile(self):
1352        x = torch.ones(2, 2)
1353        x_two = TwoTensor(x.clone(), x.clone())
1354
1355        def fn(w):
1356            return torch.add(w, 1.0)
1357
1358        fn_opt = torch.compile(backend="eager")(fn)
1359
1360        ref = fn(x_two)
1361        res = fn_opt(x_two)
1362        self.assertEqual(ref, res)
1363
1364        # ensure no recompilation on same input type
1365        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
1366            fn_opt(TwoTensor(x + 1, x + 2))
1367
1368        # recompile!
1369        ref = fn(x)
1370        res = fn_opt(x)
1371        self.assertEqual(ref, res)
1372
1373    def test_tensor_subclass_ctx_guards(self):
1374        x = CtxSubclassTensor(torch.ones(2), 3)
1375        x2 = CtxSubclassTensor(torch.ones(2), 3)
1376        x3 = CtxSubclassTensor(torch.ones(2), 4)
1377        _check_recompiles(self, lambda x: x * x, (x,), (x2,), False)
1378        _check_recompiles(self, lambda x: x * x, (x,), (x3,), True)
1379
1380    def test_tensor_subclass_ctx_recursive_guards(self):
1381        x0 = torch.ones(2, 2)
1382        x1 = CtxSubclassTensor(x0.clone(), 2)
1383        x2 = CtxSubclassTensor(x0.clone(), 3)
1384        tt0 = TwoTensor(x0.clone(), x1)
1385        tt1 = TwoTensor(x0.clone(), x2)
1386
1387        _check_recompiles(self, lambda x: x * x, (tt0,), (tt1,), True)
1388
1389    def test_tensor_subclass_ctx_custom_guards_override(self):
1390        class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor):
1391            @classmethod
1392            def __metadata_guard__(cls, orig_data, other):
1393                return orig_data[0] <= other[0]
1394
1395        x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 2)
1396        x2 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3)
1397        x3 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 1)
1398        _check_recompiles(self, lambda x: x * x, (x,), (x2,), False)
1399        _check_recompiles(self, lambda x: x * x, (x,), (x3,), True)
1400
1401    def test_tensor_subclass_ctx_custom_guards_error_arg_num(self):
1402        import torch._dynamo.exc
1403
1404        class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor):
1405            @classmethod
1406            def __metadata_guard__(cls, y):
1407                # Shouldn't reach here
1408                return False
1409
1410        x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3)
1411        self.assertRaisesRegex(
1412            torch._dynamo.exc.InternalTorchDynamoError,
1413            "Tensor subclass method __metadata_guard__ must take exactly two subclass metadata arguments",
1414            lambda: torch.compile(lambda x: x * x)(x),
1415        )
1416
1417    def test_tensor_subclass_ctx_custom_guards_error_not_classmethod(self):
1418        import torch._dynamo.exc
1419
1420        class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor):
1421            def __metadata_guard__(self, x, y):
1422                return False
1423
1424        x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3)
1425        self.assertRaisesRegex(
1426            torch._dynamo.exc.InternalTorchDynamoError,
1427            "Tensor subclass method __metadata_guard__ must be a classmethod",
1428            lambda: torch.compile(lambda x: x * x)(x),
1429        )
1430
1431    def test_subclass_constructor_proxying(self):
1432        import dataclasses
1433        from collections import namedtuple
1434        from typing import Any
1435
1436        @dataclasses.dataclass(frozen=True)
1437        class SubclassTensorArgs:
1438            original_shape: torch.Size
1439            device: torch.device
1440            inner_meta: Any
1441
1442        SubclassTensorArgs2 = namedtuple(
1443            "SubclassTensorArgs2",
1444            [
1445                "original_shape",
1446                "device",
1447                "inner_meta",
1448            ],
1449        )
1450
1451        class SubclassTensor(torch.Tensor):
1452            @staticmethod
1453            def __new__(cls, a, meta):
1454                shape = a.shape
1455                kwargs = {}
1456                kwargs["strides"] = a.stride()
1457                kwargs["storage_offset"] = a.storage_offset()
1458                kwargs["device"] = a.device
1459                kwargs["layout"] = a.layout
1460                kwargs["requires_grad"] = a.requires_grad
1461                kwargs["dtype"] = a.dtype
1462                out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
1463                return out
1464
1465            def __init__(self, a, meta):
1466                self.a = a
1467                self.meta = meta
1468
1469            def __repr__(self):
1470                a_repr = repr(self.a)
1471                return f"SubclassTensor({a_repr})"
1472
1473            def __tensor_flatten__(self):
1474                return ["a"], self.meta
1475
1476            @staticmethod
1477            def __tensor_unflatten__(inner_tensors, meta, _, __):
1478                a = inner_tensors["a"]
1479                return SubclassTensor(a, meta)
1480
1481            @classmethod
1482            def __torch_dispatch__(cls, func, types, args, kwargs):
1483                if kwargs is None:
1484                    kwargs = {}
1485                args_a = pytree.tree_map(
1486                    lambda x: x.a if isinstance(x, SubclassTensor) else x, args
1487                )
1488                kwargs_a = pytree.tree_map(
1489                    lambda x: x.a if isinstance(x, SubclassTensor) else x, kwargs
1490                )
1491                out_a = func(*args_a, **kwargs_a)
1492                out = pytree.tree_map(
1493                    lambda x: SubclassTensor(
1494                        x, SubclassTensorArgs2(x.shape, x.device, None)
1495                    )
1496                    if isinstance(x, torch.Tensor)
1497                    else x,
1498                    out_a,
1499                )
1500                return return_and_correct_aliasing(func, args, kwargs, out)
1501
1502        @torch.compile(fullgraph=True)
1503        def f1(x):
1504            meta = SubclassTensorArgs(
1505                x.shape, x.device, SubclassTensorArgs(x.shape, x.device, None)
1506            )
1507            out = SubclassTensor(x, meta)
1508            return out * out
1509
1510        x = torch.randn(3, 3)
1511        f1(x)
1512
1513        @torch.compile(fullgraph=True)
1514        def f1(x):
1515            meta = SubclassTensorArgs2(
1516                x.shape, x.device, SubclassTensorArgs2(x.shape, x.device, None)
1517            )
1518            out = SubclassTensor(x, meta)
1519            return out * out
1520
1521        x = torch.randn(3, 3)
1522        f1(x)
1523
1524    def test_torch_function_subclass_survives_into_aot_autograd(self):
1525        # If you have a tensor subclass that relies on dispatch into the same op
1526        # without unwrapping and calling torch._C.DisableTorchFunctionSubclass(),
1527        # the torch function-ness will survive into AOTAutograd. Today, NestedTensor
1528        # actually relies on this behavior! Because that torch function logic
1529        # runs during AOTAutograd, this test tests that there is no logic below
1530        # that relies torch function that gets unexpectedly disabled after we
1531        # redispatch from the subclass's torch function.
1532        class SubTensor(torch.Tensor):
1533            @staticmethod
1534            def __new__(cls, t):
1535                return torch.Tensor._make_wrapper_subclass(
1536                    cls,
1537                    t.shape,
1538                    t.stride(),
1539                    t.storage_offset(),
1540                    torch.contiguous_format,
1541                    t.dtype,
1542                    torch.strided,
1543                    t.device,
1544                    False,
1545                    t.requires_grad,
1546                    "sizes",
1547                    False,
1548                    False,
1549                    None,
1550                )
1551
1552            def __init__(self, t):
1553                super().__init__()
1554                self._t = t
1555
1556            def __tensor_flatten__(self):
1557                return ["_t"], {}
1558
1559            @staticmethod
1560            def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
1561                t = inner_tensors["_t"]
1562                return SubTensor(t)
1563
1564            def __repr__(self):
1565                return f"SubTensor({self._t})"
1566
1567            @classmethod
1568            def __torch_function__(cls, func, types, args=(), kwargs=None):
1569                if kwargs is None:
1570                    kwargs = {}
1571
1572                with torch._C.DisableTorchFunctionSubclass():
1573                    return func(*args, **kwargs)
1574
1575            @classmethod
1576            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1577                kwargs = {} if kwargs is None else kwargs
1578                new_args = pytree.tree_map_only(SubTensor, lambda s: s._t, args)
1579                output = func(*new_args, **kwargs)
1580                output = pytree.tree_map_only(
1581                    torch.Tensor, lambda t: SubTensor(t), output
1582                )
1583                return output
1584
1585        @torch.compile(dynamic=True)
1586        def f(x):
1587            return x.unflatten(-1, [2, 5])
1588
1589        s = SubTensor(torch.randn(3, 10))
1590        f(s)
1591
1592    # Guard validation upsets the guard
1593    # https://github.com/pytorch/pytorch/issues/129936
1594    @unittest.expectedFailure
1595    def test_recompile_with_symbool_inputs(self):
1596        def f(pred: bool):
1597            if pred:
1598                return torch.ones([3, 4])
1599            else:
1600                return torch.ones([4, 3])
1601
1602        def test_recompilation(
1603            f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards
1604        ):
1605            torch._dynamo.reset()
1606            shape_env = ShapeEnv()
1607            backend = torch._dynamo.testing.EagerAndRecordGraphs()
1608            cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
1609            f_cond = torch.compile(f, backend=cnt, fullgraph=True)
1610            with torch._subclasses.fake_tensor.FakeTensorMode(
1611                shape_env=shape_env
1612            ) as fake_mode:
1613                fake_inp = fake_mode.from_tensor(
1614                    x,
1615                    symbolic_context=StatelessSymbolicContext(
1616                        dynamic_sizes=[DimDynamic.DYNAMIC for i in range(x.dim())]
1617                    ),
1618                )
1619                for i, size in enumerate(sizes):
1620                    pred = fake_inp.size(0) == size
1621                    f_cond(pred)
1622                    actual = normalize_gm(
1623                        backend.graphs[exp_frame_count[i] - 1].print_readable(
1624                            print_output=False
1625                        )
1626                    )
1627                    actual_guard_str = [str(guard.expr) for guard in shape_env.guards]
1628                    self.assertExpectedInline(actual, exp_graphs[i])
1629                    self.assertEqual(cnt.frame_count, exp_frame_count[i])
1630                    self.assertEqual(actual_guard_str, exp_shape_env_guards[i])
1631
1632        true_graph = """\
1633class GraphModule(torch.nn.Module):
1634    def forward(self):
1635        ones: "f32[3, 4]" = torch.ones([3, 4])
1636        return (ones,)
1637"""
1638        false_graph = """\
1639class GraphModule(torch.nn.Module):
1640    def forward(self):
1641        ones: "f32[4, 3]" = torch.ones([4, 3])
1642        return (ones,)
1643"""
1644        test_recompilation(
1645            f,
1646            torch.randn([3, 4]),
1647            [3, 3, 4, 5],
1648            exp_graphs=[true_graph, true_graph, false_graph, false_graph],
1649            exp_frame_count=[1, 1, 2, 2],
1650            exp_shape_env_guards=[
1651                [],
1652                # s0 is specialized and guarded in outter shape_env when dynamo checks the guards
1653                ["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"],
1654                [
1655                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
1656                    "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
1657                ],
1658                [
1659                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
1660                    "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
1661                    "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
1662                ],
1663            ],
1664        )
1665
1666        test_recompilation(
1667            f,
1668            torch.randn([3, 4]),
1669            [4, 5, 3, 3],
1670            exp_graphs=[false_graph, false_graph, true_graph, true_graph],
1671            exp_frame_count=[1, 1, 2, 2],
1672            exp_shape_env_guards=[
1673                [],
1674                # s0 is specialized and guarded in outter shape_env when dynamo checks the guards
1675                ["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"],
1676                [
1677                    "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
1678                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
1679                ],
1680                [
1681                    "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
1682                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
1683                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
1684                ],
1685            ],
1686        )
1687
1688    def test_wrapper_subclass_dynamo_attribute_access_on_intermediate(self):
1689        def f(x_subclass):
1690            tmp_subclass = torch.add(x, 1)
1691            return torch.mul(tmp_subclass._scale, tmp_subclass._constant)
1692
1693        x = ScaledTensor(torch.randn(2, 4), torch.randn(3), constant=2)
1694        out_ref = f(x)
1695        out_test = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
1696        self.assertEqual(out_ref, out_test)
1697
1698    def test_support_bases(self):
1699        import abc
1700
1701        import torch.fx._symbolic_trace
1702
1703        class Meta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta):
1704            def __new__(cls, name, bases, dct):
1705                x = super().__new__(cls, name, bases, dct)
1706                x.attr = 100
1707                return x
1708
1709        class Multistreamable(abc.ABC):  # noqa: B024
1710            pass
1711
1712        class Foo(Multistreamable, metaclass=Meta):
1713            pass
1714
1715        @torch.compile(backend="eager", fullgraph=True)
1716        def f(x):
1717            typ = type(Foo())
1718            typ.__bases__
1719            return typ.__bases__
1720
1721        self.assertEqual(f(torch.randn(1)), (Multistreamable,))
1722
1723        @torch.compile(backend="eager", fullgraph=True)
1724        def g(x):
1725            typ = type(Foo())
1726            typ.__base__
1727            return typ.__base__
1728
1729        self.assertEqual(g(torch.randn(1)), Multistreamable)
1730
1731    @parametrize("dynamic", [False, True])
1732    def test_subclass_views(self, dynamic):
1733        def _get_views(t):  # returns (view: Tensor, expects_raises_false)
1734            # Note that any closed-over SymInts will be symbolicized during fake-ification.
1735            yield t.narrow(dim=-1, start=3, length=8), False
1736            yield t.split(5, -1)[2], False
1737            yield t.split_with_sizes([9, 6], -1)[1], False
1738            yield t.unsqueeze(-1).expand(4, 15, 10), False
1739            yield t.select(-1, 6), False
1740            # https://github.com/pytorch/pytorch/issues/128649
1741            yield t[2:3, 5:9], dynamic
1742            yield t.view(-1, 15), False
1743
1744        def f(x):
1745            return x * 2
1746
1747        compiled_f = torch.compile(
1748            f, backend="aot_eager", fullgraph=True, dynamic=dynamic
1749        )
1750
1751        # Take a view of a subclass to pass as input.
1752        t = TwoTensor(torch.randn(4, 15), torch.randn(4, 15))
1753        for view, expects_raises in _get_views(t):
1754            torch._dynamo.reset()
1755            out_ref = f(view)
1756            if expects_raises:
1757                with self.assertRaises(AssertionError):
1758                    out_test = compiled_f(view)
1759            else:
1760                out_test = compiled_f(view)
1761                self.assertEqual(out_ref, out_test)
1762
1763    @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
1764    def test_mark_static_with_subclass_desugaring(self):
1765        from typing import Any, Callable, Dict, List, Optional
1766
1767        from torch._dynamo.decorators import mark_static_address
1768        from torch._inductor.compile_fx import compile_fx
1769        from torch._inductor.cudagraph_utils import BoxedDeviceIndex
1770        from torch._inductor.utils import BoxedBool
1771
1772        x_inner = torch.ones(4)
1773        x = TwoTensor(x_inner, x_inner)
1774        mark_static_address(x, guard=False)
1775
1776        def inner_compile(
1777            gm: torch.fx.GraphModule,
1778            example_inputs: List[torch.Tensor],
1779            cudagraphs: Optional[BoxedBool] = None,
1780            static_input_idxs: Optional[List[int]] = None,
1781            is_backward: bool = False,
1782            graph_id: Optional[int] = None,
1783            cpp_wrapper: bool = False,
1784            aot_mode: bool = False,
1785            is_inference: bool = False,
1786            boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
1787            user_visible_outputs: Optional[Dict[str, None]] = None,
1788            layout_opt: Optional[bool] = None,
1789            extern_node_serializer: Optional[Callable[[List[Any]], Any]] = None,
1790        ):
1791            self.assertEqual(static_input_idxs, [1, 2])
1792            return gm
1793
1794        compiler = functools.partial(compile_fx, inner_compile=inner_compile)
1795
1796        @torch.compile(backend=compiler)
1797        def fn(t0, t1, t2):
1798            return t0 + t1 + t2 + 2
1799
1800        fn(torch.ones(4), x, torch.ones(4))
1801
1802
1803instantiate_parametrized_tests(SubclassTests)
1804
1805
1806class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
1807    def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True):
1808        return get_jagged_tensor(nested_size, offsets, requires_grad)
1809
1810    def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True):
1811        # Makes a jagged tensor with N constituent tensors with size
1812        # as specified ((S0, S1, S2), D)
1813        max_dim = (starts + lengths).max()
1814        values_tensor = torch.randn(
1815            starts.shape[0],
1816            max_dim.item(),
1817            inner_dim,
1818            requires_grad=requires_grad,
1819            dtype=torch.float64,
1820        )
1821        return jagged_from_tensor_and_lengths(values_tensor, starts, lengths)
1822
1823    def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
1824        _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles)
1825
1826    def test_unary_does_not_recompile(self):
1827        nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
1828        nt2, _ = self._get_jagged_tensor(((3, 4, 5, 6), 4), None)
1829        self._check_recompiles(lambda nt1: nt1.sin(), (nt1,), (nt2,), False)
1830
1831    def test_binary_does_not_recompile(self):
1832        def binary(nt1, nt2):
1833            if nt1.shape == nt2.shape:
1834                return nt1 + nt2
1835            else:
1836                return nt1.sin()
1837
1838        # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
1839        # This causes a recompile later on when it realizes the batch and last dim
1840        # should not always be equal. To avoid that, we use (3, j0, 5) here.
1841        nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
1842        nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets)
1843        nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None)
1844        nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets)
1845        self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False)
1846
1847    def test_binary_recompiles(self):
1848        def binary(nt1, nt2):
1849            if nt1.shape == nt2.shape:
1850                return nt1 + nt2
1851            else:
1852                return nt1.sin()
1853
1854        # Binary recompiles because singleton ints no longer match
1855        nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
1856        nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets)
1857        nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
1858        self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
1859
1860    def _validate_compile(self, fn, arg_fn):
1861        def _gen_grad_outputs(out_val):
1862            if isinstance(out_val, (list, tuple)):
1863                return tuple(torch.ones_like(c) for c in out_val)
1864            else:
1865                return (torch.ones_like(out_val),)
1866
1867        with self.branch_nested_state():
1868            from torch.nested._internal.nested_tensor import _tensor_symint_registry
1869
1870            # Validate that compilation does not modify eager state
1871            registry_before = list(_tensor_symint_registry.items())
1872            count_before = torch.nested._internal.nested_tensor._tensor_id_counter
1873
1874            guards_exported = []
1875            guards_failed = []
1876
1877            def append_guard_export(guards):
1878                for g in guards:
1879                    if g.code_list is not None:
1880                        guards_exported.append(g.code_list[0])
1881
1882            def append_guard_fail(guards):
1883                guards_failed.extend(guards)
1884
1885            compiled = torch._dynamo.optimize(
1886                nopython=True,
1887                backend="aot_eager",
1888                guard_export_fn=append_guard_export,
1889                guard_fail_fn=append_guard_fail,
1890            )(fn)
1891            registry_after = list(_tensor_symint_registry.items())
1892            count_after = torch.nested._internal.nested_tensor._tensor_id_counter
1893            self.assertEqual(registry_before, registry_after)
1894            self.assertEqual(count_before, count_after)
1895
1896            args = arg_fn()
1897            compile_out = compiled(*args)
1898            compile_grads = []
1899            g_args = [arg for arg in args if arg.requires_grad]
1900            if len(g_args) > 0:
1901                compile_grad_outputs = _gen_grad_outputs(compile_out)
1902                compile_grads = torch.autograd.grad(
1903                    compile_out, inputs=g_args, grad_outputs=compile_grad_outputs
1904                )
1905
1906        with self.branch_nested_state():
1907            args = arg_fn()
1908            ref_out = fn(*args)
1909            ref_grads = []
1910            g_args = [arg for arg in args if arg.requires_grad]
1911            if len(g_args) > 0:
1912                ref_grad_outputs = _gen_grad_outputs(ref_out)
1913                ref_grads = torch.autograd.grad(
1914                    ref_out, inputs=g_args, grad_outputs=ref_grad_outputs
1915                )
1916
1917        # Validate correctness forward
1918        if isinstance(compile_out, (list, tuple)):
1919            # TODO: Fix assertEqual() to support NJTs so this isn't necessary
1920            self.assertEqual(len(compile_out), len(ref_out))
1921            for c, r in zip(compile_out, ref_out):
1922                self.assertEqualIgnoringNestedInts(c, r)
1923        else:
1924            self.assertEqualIgnoringNestedInts(compile_out, ref_out)
1925
1926        # Validate correctness backward
1927        for compile_grad, ref_grad in zip(compile_grads, ref_grads):
1928            self.assertEqualIgnoringNestedInts(compile_grad, ref_grad)
1929
1930        return guards_exported, guards_failed
1931
1932    # Note: [What kind of guards are involved in nested tensor compilation]
1933    #
1934    # Until we implement UnionFind, dynamic shapes guards are not involved.
1935    # we rely only on dynamo's tensor aliasing guards.
1936    #
1937    # This is possible because dynamo able to generate tensor aliasing guards
1938    # not only for the outer tensor, but also for the inner tensor.
1939    #
1940    # The case where dynamic shapes guards would eventually come into play is
1941    # when my inputs are (1) two non-aliased tensors, but (2) declared as
1942    # equal using a "trust me assert equal" API.
1943
1944    # Note: [Compiling nested tensor global state]
1945    #
1946    # Today there are two pieces of global eager state that NJTs deals with:
1947    # - tensor_id_counter: a global counter that assigns unique ids to tensors
1948    # - tensor_symint_registry: maps tensor to nested int
1949    #   - this is used in eager only (we should get rid of this because it is
1950    #     not necessary to cache nested int in eager)
1951    #   - during tracing, we DO need to cache nested int, but we do so on
1952    #     the FakeTensor.
1953    #
1954    # Ideally we would like to satisfy the following:
1955    # - (1) The eager state is not mutated during tracing
1956    # - (2) Running the compiled function should mutate the eager state in the
1957    #       same way that running the eager function would
1958    #       (a) The global counter should be incremented
1959    #       (b) The registry is updated in the same way
1960    #
1961    # Today we can satisfy (1) and (2a) but cannot satisfy (2b)
1962    #
1963    # Today, (1) is satisfied because we maintain a separate counter during
1964    # tracing, and cache nested int on FakeTensor instead of relying on
1965    # tensor_symint_registry.
1966    #
1967    # (2) is cannot be completely satisfied because we trace away the
1968    # side-effectful operations (which we can fix this by wrapping the
1969    # side-effectful operations in a custom op, and threading through effect
1970    # tokens.) The current plan is to do that in the UnionFind impl.
1971    #
1972    # Interestingly, despite this, the state is mutated in a way that is somewhat
1973    # close to what we want, e.g. if I construct a nested tensor using an
1974    # offsets in the compiled region and return it, AOTAutograd runtime wrapper
1975    # must rewrap the inner->inner graph outputs back into subclass. This
1976    # triggers the eager logic to run, updating the counter and registry.
1977    #
1978    # Notably however, compile differs in two ways from eager:
1979    # (1) The order in which the offsets are assigned ids is differnet
1980    #     the registry would be set in the order the offsets are returned
1981    #     which is not necessarily the same order as they were constructed.
1982    # (2) If a NestedTensor is not returned, then the AOTAutograd wrapping
1983    #     logic will not be triggered.
1984    #
1985    # I claim that correctness is not affected by these differences today.
1986    # e.g. there is never the case where two distinct offsets silently share
1987    # the same id.
1988    #
1989    # (1) is clearly not a problem, and (2) should only be a problem if
1990    # the nested int is returned on its own, without the corresponding NJT
1991    # being returned. This is not a problem in the current implementation
1992    # because returning only a shape is not supported!
1993
1994    # Note: [Creating symbolic nested int]
1995    #
1996    # We must create a symbolic nested int when we construct a nested tensor
1997    # from a tensor. There are two main cases:
1998    #
1999    # 1. The offsets has NOT been used to construct a NJT
2000    #    - Create a new plain nested int with current val of fake nt id counter
2001    #    - Increment the fake nt id counter
2002    #    - Create a new symint with plain nested int as hint
2003    # 2. The offsets HAS been used to construct a NJT
2004    #    - Create a new symint with plain nested int as hint
2005    #
2006    # More details on case 2:
2007    # - During fakification of the offsets, we check the eager registry, and
2008    #   if the tensor HAS been used to construct a NJT,
2009    #   we create a symint, with the existing nested int as hint, and cache
2010    #   it on to the FakeTensor.
2011    #
2012    # [ Always use ephemeral source ]
2013    #
2014    # We create the new symint ALWAYS with ephemeral source whether that is
2015    # in case (1) or (2) even though we could've had a proper source for case (2).
2016    # Using a proper source would enable a few more (edge) cases, but since
2017    # we plan to handle things more holistically in the future anyway, we don't
2018    # bother doing so today.
2019    #
2020    # Using an ephemeral source has some consequences. But we are happy if
2021    # - We do not silently miss recompiles, e.g. we guard when necessary.
2022    #   We know that this is true, because dynamo guards alone are already
2023    #   sufficient.
2024    # - We are not producing errors for the cases we care about
2025    #
2026    # The main case we care about is when we guard that two shapes are equal.
2027    # In this case, the replacements logic would simplify away the ephemeral
2028    # symbol, and there is no error produced.
2029    # The unsupported case is when we guard that two shapes are not equal, in
2030    # which, we will try and fail to generate a guard.
2031
2032    #
2033    # Case 1: in-graph construction where the offsets are passed as inputs
2034    #
2035    def test_in_graph_construction_from_input(self):
2036        # The offsets is passed as an input
2037        def fn(values, offsets):
2038            return torch.nested.nested_tensor_from_jagged(values * 2, offsets) * 2
2039
2040        values = torch.randn(10, 5, requires_grad=True)
2041        offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2042        self._validate_compile(fn, arg_fn=lambda: (values, offsets))
2043
2044        # Do not specialize on the offsets
2045        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
2046            different_offsets = torch.tensor([0, 1, 5, 10], dtype=torch.int64)
2047            self._validate_compile(fn, arg_fn=lambda: (values, different_offsets))
2048
2049    def test_in_graph_construction_from_input_2(self):
2050        # Construct two NJTs, both are passed as inputs
2051        def fn(values, offsets1, offsets2):
2052            nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets1)
2053            nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2)
2054            return nt2, nt1
2055
2056        values = torch.randn(10, 5, requires_grad=True)
2057        offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2058        offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64)
2059        # 1. Offsets are different
2060        guards_exported, guards_failed = self._validate_compile(
2061            fn, arg_fn=lambda: (values, offsets, offsets2)
2062        )
2063        self.assertEqual(len(guards_failed), 0)
2064        self.assertNotIn("L['offsets1'] is L['offsets2']", guards_exported)
2065
2066        # TODO
2067        # 2. Offsets are the same
2068        new_guards_exported, _ = self._validate_compile(
2069            fn, arg_fn=lambda: (values, offsets, offsets)
2070        )
2071        self.assertTrue(any("Duplicate tensors found" in g for g in guards_failed))
2072        self.assertIn("L['offsets1'] is L['offsets2']", new_guards_exported)
2073
2074        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
2075            offsets3 = offsets.clone()
2076            self._validate_compile(fn, arg_fn=lambda: (values, offsets3, offsets3))
2077
2078        # Do a binary op
2079        def fn(values, offsets, offsets2):
2080            nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets)
2081            nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2)
2082            return nt1 * nt2
2083
2084        self._validate_compile(fn, arg_fn=lambda: (values, offsets, offsets))
2085
2086    def test_in_graph_construction_from_input_4(self):
2087        # The offsets is taken from an NJT input
2088        def fn(nt, other_values):
2089            nt2 = torch.nested.nested_tensor_from_jagged(other_values, nt.offsets())
2090            return nt + nt2
2091
2092        values = torch.randn(9, 5, requires_grad=True)
2093        other_values = torch.randn(9, 5, requires_grad=True)
2094        offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64)
2095
2096        def arg_fn(values=values, other_values=other_values, offsets=offsets):
2097            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
2098            return nt, other_values
2099
2100        self._validate_compile(fn, arg_fn=arg_fn)
2101
2102        # Do not specialize on the offsets
2103        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
2104            different_offsets = offsets.clone()
2105
2106            def arg_fn(
2107                values=values, other_values=other_values, offsets=different_offsets
2108            ):
2109                nt = torch.nested.nested_tensor_from_jagged(values, different_offsets)
2110                return nt, other_values
2111
2112            self._validate_compile(fn, arg_fn=arg_fn)
2113
2114    def test_in_graph_construction_from_input_5(self):
2115        # Construct from lengths instead of offsets
2116        def fn(values, lengths):
2117            nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths)
2118            return nt.sin()
2119
2120        values = torch.randn(9, 5, requires_grad=True)
2121        lengths = torch.tensor([2, 4, 3])
2122        self._validate_compile(fn, arg_fn=lambda: (values, lengths))
2123
2124    #
2125    # Case 2: in-graph construction where offsets are graph intermediates
2126    #
2127    def test_in_graph_construction_from_intermediate(self):
2128        # offsets is an intermediate computed from lengths
2129        def fn(values, lengths):
2130            offsets = torch.cat([lengths.new_zeros(1), lengths.cumsum(0)])
2131            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
2132            nt2 = torch.nested.nested_tensor_from_jagged(values, offsets)
2133            return (nt * nt2).sin()
2134
2135        values = torch.randn(9, 5, requires_grad=True)
2136        lengths = torch.tensor([2, 4, 3])
2137        self._validate_compile(fn, arg_fn=lambda: (values, lengths))
2138
2139        # Do not specialize on the lengths
2140        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
2141            different_lengths = lengths.clone()
2142            self._validate_compile(fn, arg_fn=lambda: (values, different_lengths))
2143
2144    def test_in_graph_construction_from_intermediate_2(self):
2145        def fn(values, offsets):
2146            return torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone())
2147
2148        values = torch.randn(10, 5, requires_grad=True)
2149        offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2150        self._validate_compile(fn, arg_fn=lambda: (values, offsets))
2151
2152    def test_in_graph_construction_from_intermediate_3(self):
2153        # Note that due to CSE, clone is not necessarily called twice!
2154        def fn(values, offsets):
2155            nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone())
2156            nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets.clone())
2157            return nt2, nt1
2158
2159        values = torch.randn(10, 5, requires_grad=True)
2160        offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2161        self._validate_compile(fn, arg_fn=lambda: (values, offsets))
2162
2163    def test_in_graph_construction_from_intermediate_4(self):
2164        # Shared intermediate (should be same as case #1)
2165        def fn(values):
2166            offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2167            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
2168            values2 = torch.ones_like(values)
2169            nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets)
2170            return nt * nt2
2171
2172        values = torch.randn(10, 5).requires_grad_(True)
2173        self._validate_compile(fn, arg_fn=lambda: (values,))
2174
2175    # AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>',
2176    @unittest.expectedFailure
2177    def test_in_graph_construction_from_intermediate_5(self):
2178        # non-shared intermediate
2179        def fn(values):
2180            offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2181            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
2182            values2 = torch.ones_like(values)
2183            nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets.clone())
2184            if nt2.shape[1] != nt.shape[1]:
2185                return nt * 2
2186            else:
2187                return nt * 3
2188
2189        values = torch.randn(10, 5).requires_grad_(True)
2190        self._validate_compile(fn, arg_fn=lambda: (values,))
2191
2192    #
2193    # Case 3: in-graph construction where offsets are both direct graph inputs
2194    #         and passed in as part of an NJT's offsets.
2195    #
2196    def test_in_graph_construction_mixed(self):
2197        def fn(nt, values, offsets):
2198            nt2 = torch.nested.nested_tensor_from_jagged(values, offsets)
2199            return nt * nt2
2200
2201        values = torch.randn(10, 5, requires_grad=True)
2202        offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2203
2204        def arg_fn(values=values, offsets=offsets):
2205            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
2206            return nt, values, offsets
2207
2208        self._validate_compile(fn, arg_fn)
2209
2210    # See Note: [Creating symbolic nested int]
2211    # AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>',
2212    @unittest.expectedFailure
2213    def test_in_graph_construction_mixed_2(self):
2214        def fn(nt, values, offsets, nt2):
2215            # Intermediate offsets has ephemeral source
2216            intermediate_nt = torch.nested.nested_tensor_from_jagged(
2217                values, offsets.clone()
2218            )
2219            # This creates a dynamic shapes neq guard
2220            if nt2.shape[1] != intermediate_nt.shape[1]:
2221                # We should always go here.
2222                nt = nt * 2
2223            return nt
2224
2225        values = torch.randn(10, 5, requires_grad=True)
2226        offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2227        offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64)
2228
2229        def arg_fn(values=values, offsets=offsets, offsets2=offsets2):
2230            # Values is shared, but it shouldn't matter
2231            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
2232            nt2 = torch.nested.nested_tensor_from_jagged(values, offsets2)
2233            return nt, values, offsets, nt2
2234
2235        self._validate_compile(fn, arg_fn)
2236
2237    def test_in_graph_construction_mixed_3(self):
2238        # More involved mixed case
2239        def fn(nt, values, offsets):
2240            nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets)
2241            nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets)
2242            return nt1 + nt2 + nt
2243
2244        values = torch.randn(9, 5, requires_grad=True)
2245        offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64)
2246
2247        def arg_fn(values=values, offsets=offsets):
2248            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
2249            return nt, values, offsets
2250
2251        self._validate_compile(fn, arg_fn)
2252
2253    def test_return_shape(self):
2254        nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
2255
2256        def fn(nt):
2257            return (nt * 2).shape
2258
2259        compiled = torch.compile(fn, fullgraph=True, backend="aot_eager")
2260        compiled(nt)
2261
2262    def test_inference_tensor(self):
2263        with torch.inference_mode():
2264            nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
2265
2266        def fn(n):
2267            return n * 2
2268
2269        torch.compile(fn, backend="eager")(nt)
2270
2271    # TODO: cannot parametrize this test class with device for some reason
2272    def _test_autograd(self, backend):
2273        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
2274        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
2275        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64)
2276        nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
2277        # TODO: Switch to public API when it exists
2278        nt2, _ = jagged_from_list([a, b, c], nt.offsets())
2279
2280        def fn1(nt1, nt2):
2281            return (nt1 + nt2).sin().cos()
2282
2283        compiled_f = torch.compile(fn1, fullgraph=True, backend=backend, dynamic=True)
2284        out = compiled_f(nt, nt2)
2285        out_buffer = out.values()
2286        ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c))
2287
2288        out_ref = fn1(nt, nt2)
2289        out_buffer_ref = out_ref.values()
2290        ga_ref, gb_ref, gc_ref = torch.autograd.grad(out_buffer_ref.sum(), (a, b, c))
2291
2292        self.assertTrue(torch.allclose(ga, ga_ref))
2293        self.assertTrue(torch.allclose(gb, gb_ref))
2294        self.assertTrue(torch.allclose(gc, gc_ref))
2295
2296    def test_basic_autograd(self):
2297        self._test_autograd("aot_eager")
2298
2299    @requires_cuda
2300    def test_basic_autograd_inductor(self):
2301        self._test_autograd("inductor")
2302
2303    def test_subclass_with_mutation_in_graph(self):
2304        # In this graph, we have an in-graph mutation, i.e. a mutation that is allowed
2305        # to remain in the graph. Normally this is allowed, but it's not allowed if
2306        # the graph handles subclasses at all.
2307        # Whether the mutation is allowed or not allowed in the graph alters the number
2308        # of outputs from the forward graph. Previously, a bug in this handling meant
2309        # that sometimes the expected number and actual number of outputs from the
2310        # joint graph did not match, causing assertion failures.
2311        def fn(x, y):
2312            z = x.sin()
2313            y.sin_()
2314            return z.cos(), y.cos()
2315
2316        fn_c = torch.compile(fn, backend="inductor")
2317
2318        values = [torch.rand((i, 8), requires_grad=True) for i in range(1, 6)]
2319        values_copy = [x.detach().clone().requires_grad_(True) for x in values]
2320
2321        nt, offsets = jagged_from_list(values, None)
2322        nt_copy, offsets = jagged_from_list(values_copy, offsets)
2323        y = torch.rand((4, 8))
2324        y_copy = y.clone()
2325
2326        ret = fn_c(nt, y)[0]
2327        ref = fn(nt_copy, y_copy)[0]
2328
2329        self.assertEqual(ret.values(), ref.values())
2330
2331        ret.values().sum().backward()
2332        ref.values().sum().backward()
2333        for ref_v, res_v in zip(values_copy, values):
2334            self.assertEqual(ref_v.grad, res_v.grad)
2335
2336    @torch._dynamo.config.patch({"capture_scalar_outputs": True})
2337    def test_unbind(self):
2338        # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
2339        # This causes a recompile later on when it realizes the batch and last dim
2340        # should not always be equal. To avoid that, we use (3, j0, 5) here.
2341        nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
2342        nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None)
2343        nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None)
2344
2345        def fn(x):
2346            return x.unbind()
2347
2348        compiled_f = torch.compile(fn, fullgraph=True, backend="eager", dynamic=True)
2349        out = compiled_f(nt)
2350
2351        out_ref = fn(nt)
2352
2353        # correctness
2354        self.assertEqual(len(out), len(out_ref))
2355        for x, x_ref in zip(out, out_ref):
2356            self.assertTrue(torch.allclose(x, x_ref))
2357
2358        # We specialize on the length of offsets, e.g. (1) we recompile if the
2359        # length of the offsets is different. (2) we don't recompile if the
2360        # length of the offsets is the same, even if the size of the constituent
2361        # tensors are different.
2362        self._check_recompiles(fn, (nt,), (nt2,), False)
2363        self._check_recompiles(fn, (nt,), (nt3,), True)
2364
2365    def test_inline_nested_tensor_from_jagged(self):
2366        nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
2367
2368        def fn(x):
2369            return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets())
2370
2371        torch.compile(fn, fullgraph=True, backend="aot_eager")(nt)
2372
2373    # The test here: nn.Parameters that are secretly subclasses
2374    # have a metaclass that overrides __isinstance__,
2375    # that dynamo needs to respect when it inlines the if statement.
2376    def test_param_subclass_isinstance_input(self):
2377        x_inner = torch.randn(16, 16, requires_grad=True)
2378        x = torch.nn.Parameter(TwoTensor(x_inner, x_inner))
2379        m = torch.nn.Linear(16, 16)
2380        m.weight = x
2381
2382        def fn():
2383            if isinstance(m.weight, torch.nn.Parameter):
2384                return m.weight + 1
2385            else:
2386                return m.weight + 2
2387
2388        out_ref = fn()
2389        out_test = torch.compile(fn, backend="aot_eager")()
2390        self.assertEqual(out_ref, out_test)
2391
2392    def _input_view_test(self, nt_view_name):
2393        nt_view = VIEW_TEST_CASES[nt_view_name]()
2394
2395        def fn(x):
2396            return x.sin()
2397
2398        out_ref = fn(nt_view)
2399        torch._dynamo.reset()
2400        compile_fn = torch.compile(
2401            fn, fullgraph=True, backend="aot_eager", dynamic=True
2402        )
2403        out = compile_fn(nt_view)
2404
2405        # Check metadata and values are correct
2406        self.assertTrue(out.size() == out_ref.size())
2407        self.assertTrue(out.stride() == out_ref.stride())
2408        if out.is_nested:
2409            self.assertTrue(torch.allclose(out.values(), out_ref.values()))
2410        else:
2411            self.assertTrue(torch.allclose(out, out_ref))
2412
2413        # Check that no upper/lower bound guards are incurred
2414        def backend(gm, args):
2415            context = torch._guards.TracingContext.get()
2416            guards = [str(g.expr) for g in context.fake_mode.shape_env.guards]
2417
2418            # varies based on the type of view
2419            guard_str = "\n".join(guards)
2420            if nt_view_name == "subclass_dense":
2421                self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""")
2422            elif nt_view_name == "dense_subclass_dense_subclass":
2423                self.assertExpectedInline(
2424                    guard_str,
2425                    """\
2426Eq(s5 - 1, s2)
2427Eq(s12 - 1, s7)
2428Eq(s11, s9)""",
2429                )
2430            elif nt_view_name.startswith("base_is_nt_True"):
2431                self.assertExpectedInline(
2432                    guard_str,
2433                    """Eq(s3 - 1, s0)""",
2434                )
2435            else:
2436                self.assertExpectedInline(
2437                    guard_str,
2438                    """\
2439Eq(s4 - 1, s1)
2440Eq(s13 - 1, s8)
2441Eq(s12, s10)""",
2442                )
2443            return gm
2444
2445        torch._dynamo.reset()
2446        compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True)
2447        out = compile_fn(nt_view)
2448
2449    @parametrize(
2450        "nt_view_name",
2451        [k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"],
2452    )
2453    def test_inputs_to_compiled_fn_are_views(self, nt_view_name):
2454        self._input_view_test(nt_view_name)
2455
2456    def test_subclass_gives_static_shapes_when_dynamic_false(self):
2457        def check_graph(gm, *args):
2458            first_node_example_val = next(iter(gm.graph.nodes)).meta["example_value"]
2459            # We compiled with dynamic=False, expect no SymInt sizes on our placeholders
2460            self.assertTrue(
2461                all(isinstance(x, int) for x in first_node_example_val.shape)
2462            )
2463            return gm
2464
2465        @torch.compile(backend=check_graph, dynamic=False)
2466        def f(x):
2467            return x + 1
2468
2469        x_inner = torch.ones(4)
2470        x = TwoTensor(x_inner, x_inner)
2471        x_view = x.view(2, 2)
2472        out = f(x_view)
2473
2474    # NJT1 -> Dense -> NJT2 -> Dense view
2475    # During view replay, the Dense -> NJT2 part will construct an intermediate,
2476    # symbolically-sized NJT that is immediately deconstructed to return the final dense
2477    # view. To construct this intermediate properly, we need the associated nested int
2478    # to be symbolic. This view is expected to fail compilation until symbolic nested ints
2479    # are cached onto fake offsets to solve this problem.
2480    @unittest.expectedFailure
2481    def test_subclass_dense_subclass_dense_view(self):
2482        self._input_view_test("subclass_dense_subclass_dense")
2483
2484
2485instantiate_parametrized_tests(TestNestedTensor)
2486
2487
2488if __name__ == "__main__":
2489    from torch._dynamo.test_case import run_tests
2490
2491    run_tests()
2492