xref: /aosp_15_r20/external/pytorch/test/dynamo/test_subclasses.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport functools
3*da0073e9SAndroid Build Coastguard Workerimport itertools
4*da0073e9SAndroid Build Coastguard Workerimport unittest
5*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
10*da0073e9SAndroid Build Coastguard Workerimport torch._functorch.config
11*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
12*da0073e9SAndroid Build Coastguard Workerimport torch.utils.checkpoint
13*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import normalize_gm
14*da0073e9SAndroid Build Coastguard Workerfrom torch._higher_order_ops.wrap import wrap
15*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.symbolic_shapes import (
16*da0073e9SAndroid Build Coastguard Worker    DimDynamic,
17*da0073e9SAndroid Build Coastguard Worker    ShapeEnv,
18*da0073e9SAndroid Build Coastguard Worker    StatelessSymbolicContext,
19*da0073e9SAndroid Build Coastguard Worker)
20*da0073e9SAndroid Build Coastguard Workerfrom torch.nested._internal.nested_tensor import (
21*da0073e9SAndroid Build Coastguard Worker    jagged_from_list,
22*da0073e9SAndroid Build Coastguard Worker    jagged_from_tensor_and_lengths,
23*da0073e9SAndroid Build Coastguard Worker    nested_view_from_values_offsets,
24*da0073e9SAndroid Build Coastguard Worker)
25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
26*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
27*da0073e9SAndroid Build Coastguard Worker    NestedTensorTestCase,
28*da0073e9SAndroid Build Coastguard Worker    parametrize,
29*da0073e9SAndroid Build Coastguard Worker    subtest,
30*da0073e9SAndroid Build Coastguard Worker)
31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.inductor_utils import HAS_CUDA
32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.two_tensor import TwoTensor
33*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import return_and_correct_aliasing
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Workerdef traceable_subclass(c):
37*da0073e9SAndroid Build Coastguard Worker    return torch._dynamo.config.patch("traceable_tensor_subclasses", {c})
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerdef _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
41*da0073e9SAndroid Build Coastguard Worker    actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2)
42*da0073e9SAndroid Build Coastguard Worker    self.assertEqual(actual_recompiles, expected_recompiles)
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Workerdef get_jagged_tensor(nested_size, offsets, requires_grad=True):
46*da0073e9SAndroid Build Coastguard Worker    # Makes a jagged tensor with N constituent tensors with size
47*da0073e9SAndroid Build Coastguard Worker    # as specified ((S0, S1, S2), D)
48*da0073e9SAndroid Build Coastguard Worker    D = nested_size[1]
49*da0073e9SAndroid Build Coastguard Worker    out = []
50*da0073e9SAndroid Build Coastguard Worker    for s in nested_size[0]:
51*da0073e9SAndroid Build Coastguard Worker        out.append(torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64))
52*da0073e9SAndroid Build Coastguard Worker    return jagged_from_list(out, offsets)
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Workerdef get_view_test_cases():
56*da0073e9SAndroid Build Coastguard Worker    # Test all cases with both an NT base and a dense base
57*da0073e9SAndroid Build Coastguard Worker    # Subclass -> Subclass
58*da0073e9SAndroid Build Coastguard Worker    # Dense -> Subclass
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    # NB: Don't close over loop variables, they will not get copied into the
61*da0073e9SAndroid Build Coastguard Worker    # closure
62*da0073e9SAndroid Build Coastguard Worker    #
63*da0073e9SAndroid Build Coastguard Worker    # NB: These return functions so we don't generate tensors during test
64*da0073e9SAndroid Build Coastguard Worker    # collection time
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    def mk_basic(base_is_nt):
67*da0073e9SAndroid Build Coastguard Worker        # There are three cases to consider here based on the logic in
68*da0073e9SAndroid Build Coastguard Worker        # meta_utils.py
69*da0073e9SAndroid Build Coastguard Worker        #
70*da0073e9SAndroid Build Coastguard Worker        # (1) basic case:
71*da0073e9SAndroid Build Coastguard Worker        # view is not a leaf and has the same requires grad as its basic case
72*da0073e9SAndroid Build Coastguard Worker        x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
73*da0073e9SAndroid Build Coastguard Worker        x = x.clone() if base_is_nt else x
74*da0073e9SAndroid Build Coastguard Worker        assert not x.is_leaf
75*da0073e9SAndroid Build Coastguard Worker        return x.unsqueeze(-1)
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    def mk_leaf(base_is_nt, requires_grad_1, requires_grad_2):
78*da0073e9SAndroid Build Coastguard Worker        x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=requires_grad_1)
79*da0073e9SAndroid Build Coastguard Worker        x = x.clone() if base_is_nt else x
80*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
81*da0073e9SAndroid Build Coastguard Worker            x_view = x.unsqueeze(-1)
82*da0073e9SAndroid Build Coastguard Worker            # The issue is this doesn't quite work
83*da0073e9SAndroid Build Coastguard Worker            x_view.requires_grad_(requires_grad_2)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        return x_view
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    def mk_obscure(base_is_nt):
88*da0073e9SAndroid Build Coastguard Worker        x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
89*da0073e9SAndroid Build Coastguard Worker        x = x.clone() if base_is_nt else x
90*da0073e9SAndroid Build Coastguard Worker        # intermediate leaf view
91*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
92*da0073e9SAndroid Build Coastguard Worker            x_view = x.unsqueeze(-1)
93*da0073e9SAndroid Build Coastguard Worker        x_view.requires_grad_(True)
94*da0073e9SAndroid Build Coastguard Worker        x_view_view = x_view.unsqueeze(-1)
95*da0073e9SAndroid Build Coastguard Worker        return x_view_view
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    for base_is_nt in [False, True]:
98*da0073e9SAndroid Build Coastguard Worker        prefix = f"base_is_nt_{base_is_nt}"
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        yield partial(mk_basic, base_is_nt), f"{prefix}_basic"
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker        # (2) leaf view case:
103*da0073e9SAndroid Build Coastguard Worker        # the view has to be a leaf (w/ requires_grad True or requires_grad False)
104*da0073e9SAndroid Build Coastguard Worker        # base w/ requires_grad True or requires_grad False
105*da0073e9SAndroid Build Coastguard Worker        for requires_grad_1, requires_grad_2 in itertools.product(
106*da0073e9SAndroid Build Coastguard Worker            [True, False], repeat=2
107*da0073e9SAndroid Build Coastguard Worker        ):
108*da0073e9SAndroid Build Coastguard Worker            yield partial(
109*da0073e9SAndroid Build Coastguard Worker                mk_leaf, base_is_nt, requires_grad_1, requires_grad_2
110*da0073e9SAndroid Build Coastguard Worker            ), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}"
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        # (3) obscure case:
113*da0073e9SAndroid Build Coastguard Worker        # view is not a leaf (implies requires_grad True)
114*da0073e9SAndroid Build Coastguard Worker        # base w/ requires_grad False)
115*da0073e9SAndroid Build Coastguard Worker        yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure"
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    # Subclass -> Dense
118*da0073e9SAndroid Build Coastguard Worker    yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[
119*da0073e9SAndroid Build Coastguard Worker        0
120*da0073e9SAndroid Build Coastguard Worker    ].clone(), "subclass_dense"
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    # Dense -> Subclass -> Dense -> Subclass
123*da0073e9SAndroid Build Coastguard Worker    def mk_dense_subclass_dense_subclass():
124*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(10, 5)
125*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 3, 6, 10])
126*da0073e9SAndroid Build Coastguard Worker        offsets2 = offsets.clone().detach()
127*da0073e9SAndroid Build Coastguard Worker        return nested_view_from_values_offsets(
128*da0073e9SAndroid Build Coastguard Worker            nested_view_from_values_offsets(values, offsets).values(), offsets
129*da0073e9SAndroid Build Coastguard Worker        )
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    yield mk_dense_subclass_dense_subclass, "dense_subclass_dense_subclass"
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    def mk_subclass_dense_subclass_dense():
134*da0073e9SAndroid Build Coastguard Worker        x = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
135*da0073e9SAndroid Build Coastguard Worker        offsets2 = x.offsets().clone().detach()
136*da0073e9SAndroid Build Coastguard Worker        nt_view = nested_view_from_values_offsets(x.values(), offsets2).values()
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    yield mk_subclass_dense_subclass_dense, "subclass_dense_subclass_dense"
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard WorkerVIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()}
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Workerrequires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Workercompile_full_eager = torch.compile(backend="eager", fullgraph=True)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Workerclass BaseTorchFunction(torch.Tensor):
150*da0073e9SAndroid Build Coastguard Worker    @classmethod
151*da0073e9SAndroid Build Coastguard Worker    def __torch_function__(cls, func, types, args=(), kwargs=None):
152*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
153*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
154*da0073e9SAndroid Build Coastguard Worker        return super().__torch_function__(func, types, args, kwargs)
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Workerclass MockSubclass(torch.Tensor):
158*da0073e9SAndroid Build Coastguard Worker    @classmethod
159*da0073e9SAndroid Build Coastguard Worker    def __torch_function__(cls, func, types, args=(), kwargs=None):
160*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
161*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
162*da0073e9SAndroid Build Coastguard Worker        return func(*args, **kwargs)
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Workerclass AttrSubclass(torch.Tensor):
166*da0073e9SAndroid Build Coastguard Worker    x: int = 10
167*da0073e9SAndroid Build Coastguard Worker    size: int = 10
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker    @classmethod
170*da0073e9SAndroid Build Coastguard Worker    def __torch_function__(cls, func, types, args=(), kwargs=None):
171*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
172*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker        return func(*args, **kwargs)
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Workerclass DummyNDim(torch.Tensor):
178*da0073e9SAndroid Build Coastguard Worker    @classmethod
179*da0073e9SAndroid Build Coastguard Worker    def __torch_function__(cls, func, types, args=(), kwargs=None):
180*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
181*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        if func == torch.Tensor.ndim.__get__:
184*da0073e9SAndroid Build Coastguard Worker            return 10
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker        return super().__torch_function__(func, types, args, kwargs)
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Workerclass WrapperSubclass:
190*da0073e9SAndroid Build Coastguard Worker    def __init__(self, tensor):
191*da0073e9SAndroid Build Coastguard Worker        self.tensor = tensor
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker    @classmethod
194*da0073e9SAndroid Build Coastguard Worker    def __torch_function__(cls, func, types, args=(), kwargs=None):
195*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
196*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker        args = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, args)
199*da0073e9SAndroid Build Coastguard Worker        kwargs = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, kwargs)
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker        return func(*args, **kwargs)
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Workerclass SigmoidToExpSubclass(torch.Tensor):
205*da0073e9SAndroid Build Coastguard Worker    @classmethod
206*da0073e9SAndroid Build Coastguard Worker    def __torch_function__(cls, func, types, args=(), kwargs=None):
207*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
208*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker        if func == torch.Tensor.sigmoid:
211*da0073e9SAndroid Build Coastguard Worker            return super().__torch_function__(torch.Tensor.exp, types, args, kwargs)
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker        return super().__torch_function__(func, types, args, kwargs)
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker# Wrapper subclass with two inner tensors: data and scale
217*da0073e9SAndroid Build Coastguard Worker# data has same shape as outer, and scale has single dim size
218*da0073e9SAndroid Build Coastguard Workerclass ScaledTensor(torch.Tensor):
219*da0073e9SAndroid Build Coastguard Worker    def __new__(
220*da0073e9SAndroid Build Coastguard Worker        cls,
221*da0073e9SAndroid Build Coastguard Worker        data: torch.Tensor,
222*da0073e9SAndroid Build Coastguard Worker        scale: torch.Tensor,
223*da0073e9SAndroid Build Coastguard Worker        *,
224*da0073e9SAndroid Build Coastguard Worker        constant: int = 0,
225*da0073e9SAndroid Build Coastguard Worker    ):
226*da0073e9SAndroid Build Coastguard Worker        return torch.Tensor._make_wrapper_subclass(
227*da0073e9SAndroid Build Coastguard Worker            cls,
228*da0073e9SAndroid Build Coastguard Worker            data.size(),
229*da0073e9SAndroid Build Coastguard Worker            strides=data.stride(),
230*da0073e9SAndroid Build Coastguard Worker            storage_offset=data.storage_offset(),
231*da0073e9SAndroid Build Coastguard Worker            dtype=data.dtype,
232*da0073e9SAndroid Build Coastguard Worker            layout=data.layout,
233*da0073e9SAndroid Build Coastguard Worker            requires_grad=data.requires_grad,
234*da0073e9SAndroid Build Coastguard Worker            device=data.device,
235*da0073e9SAndroid Build Coastguard Worker        )
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker    def __init__(self, data: torch.Tensor, scale: torch.Tensor, constant: int = 0):
238*da0073e9SAndroid Build Coastguard Worker        self._data = data
239*da0073e9SAndroid Build Coastguard Worker        self._scale = scale
240*da0073e9SAndroid Build Coastguard Worker        self._constant = constant
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker    def __tensor_flatten__(self):
243*da0073e9SAndroid Build Coastguard Worker        ctx = {"_constant": self._constant}
244*da0073e9SAndroid Build Coastguard Worker        return ["_data", "_scale"], ctx
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker    @staticmethod
247*da0073e9SAndroid Build Coastguard Worker    def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
248*da0073e9SAndroid Build Coastguard Worker        assert len(inner_tensors) == 2
249*da0073e9SAndroid Build Coastguard Worker        return ScaledTensor(
250*da0073e9SAndroid Build Coastguard Worker            inner_tensors["_data"],
251*da0073e9SAndroid Build Coastguard Worker            inner_tensors["_scale"],
252*da0073e9SAndroid Build Coastguard Worker            constant=metadata["_constant"],
253*da0073e9SAndroid Build Coastguard Worker        )
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker    @classmethod
256*da0073e9SAndroid Build Coastguard Worker    def __torch_dispatch__(cls, func, types, args, kwargs=None):
257*da0073e9SAndroid Build Coastguard Worker        scaled_tensor = args[0]
258*da0073e9SAndroid Build Coastguard Worker        out = func(scaled_tensor._data, *args[1:], **kwargs)
259*da0073e9SAndroid Build Coastguard Worker        return ScaledTensor(out, scaled_tensor._scale, constant=scaled_tensor._constant)
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
262*da0073e9SAndroid Build Coastguard Worker        return f"{self._data.__repr__()}\n{self._scale.__repr__()}"
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Workerclass OptionalScaledTensor(torch.Tensor):
266*da0073e9SAndroid Build Coastguard Worker    def __new__(
267*da0073e9SAndroid Build Coastguard Worker        cls,
268*da0073e9SAndroid Build Coastguard Worker        data,
269*da0073e9SAndroid Build Coastguard Worker        scale,
270*da0073e9SAndroid Build Coastguard Worker        *,
271*da0073e9SAndroid Build Coastguard Worker        constant: int = 0,
272*da0073e9SAndroid Build Coastguard Worker    ):
273*da0073e9SAndroid Build Coastguard Worker        return torch.Tensor._make_wrapper_subclass(
274*da0073e9SAndroid Build Coastguard Worker            cls,
275*da0073e9SAndroid Build Coastguard Worker            data.size(),
276*da0073e9SAndroid Build Coastguard Worker            strides=data.stride(),
277*da0073e9SAndroid Build Coastguard Worker            storage_offset=data.storage_offset(),
278*da0073e9SAndroid Build Coastguard Worker            dtype=data.dtype,
279*da0073e9SAndroid Build Coastguard Worker            layout=data.layout,
280*da0073e9SAndroid Build Coastguard Worker            requires_grad=data.requires_grad,
281*da0073e9SAndroid Build Coastguard Worker            device=data.device,
282*da0073e9SAndroid Build Coastguard Worker        )
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker    def __init__(self, data: torch.Tensor, scale, constant: int = 0):
285*da0073e9SAndroid Build Coastguard Worker        self._data = data
286*da0073e9SAndroid Build Coastguard Worker        self._scale = scale
287*da0073e9SAndroid Build Coastguard Worker        self._constant = constant
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker    def __tensor_flatten__(self):
290*da0073e9SAndroid Build Coastguard Worker        ctx = {"_constant": self._constant}
291*da0073e9SAndroid Build Coastguard Worker        if self._scale is not None:
292*da0073e9SAndroid Build Coastguard Worker            return ["_data", "_scale"], ctx
293*da0073e9SAndroid Build Coastguard Worker        else:
294*da0073e9SAndroid Build Coastguard Worker            return ["_data"], ctx
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker    @staticmethod
297*da0073e9SAndroid Build Coastguard Worker    def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
298*da0073e9SAndroid Build Coastguard Worker        return OptionalScaledTensor(
299*da0073e9SAndroid Build Coastguard Worker            inner_tensors["_data"],
300*da0073e9SAndroid Build Coastguard Worker            inner_tensors["_scale"] if "_scale" in inner_tensors else None,
301*da0073e9SAndroid Build Coastguard Worker            constant=metadata["_constant"],
302*da0073e9SAndroid Build Coastguard Worker        )
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker    @classmethod
305*da0073e9SAndroid Build Coastguard Worker    def __torch_dispatch__(cls, func, types, args, kwargs=None):
306*da0073e9SAndroid Build Coastguard Worker        scaled_tensor = args[0]
307*da0073e9SAndroid Build Coastguard Worker        out = func(scaled_tensor._data, *args[1:], **kwargs)
308*da0073e9SAndroid Build Coastguard Worker        if scaled_tensor._scale is not None:
309*da0073e9SAndroid Build Coastguard Worker            out = out * scaled_tensor._scale
310*da0073e9SAndroid Build Coastguard Worker        return OptionalScaledTensor(
311*da0073e9SAndroid Build Coastguard Worker            out, scaled_tensor._scale, constant=scaled_tensor._constant
312*da0073e9SAndroid Build Coastguard Worker        )
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
315*da0073e9SAndroid Build Coastguard Worker        return (
316*da0073e9SAndroid Build Coastguard Worker            f"OptionalScaledTensor({self._data.__repr__()}\n{self._scale.__repr__()})"
317*da0073e9SAndroid Build Coastguard Worker        )
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Workerclass CtxSubclassTensor(torch.Tensor):
321*da0073e9SAndroid Build Coastguard Worker    """
322*da0073e9SAndroid Build Coastguard Worker    Class used to verify guarding on the subclass metadata
323*da0073e9SAndroid Build Coastguard Worker    """
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker    @staticmethod
326*da0073e9SAndroid Build Coastguard Worker    def __new__(cls, a, constant):
327*da0073e9SAndroid Build Coastguard Worker        shape = a.shape
328*da0073e9SAndroid Build Coastguard Worker        kwargs = {}
329*da0073e9SAndroid Build Coastguard Worker        kwargs["strides"] = a.stride()
330*da0073e9SAndroid Build Coastguard Worker        kwargs["storage_offset"] = a.storage_offset()
331*da0073e9SAndroid Build Coastguard Worker        kwargs["device"] = a.device
332*da0073e9SAndroid Build Coastguard Worker        kwargs["layout"] = a.layout
333*da0073e9SAndroid Build Coastguard Worker        kwargs["requires_grad"] = a.requires_grad
334*da0073e9SAndroid Build Coastguard Worker        kwargs["dtype"] = a.dtype
335*da0073e9SAndroid Build Coastguard Worker        out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
336*da0073e9SAndroid Build Coastguard Worker        return out
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker    def __init__(self, a, constant):
339*da0073e9SAndroid Build Coastguard Worker        self.a = a
340*da0073e9SAndroid Build Coastguard Worker        self.constant = constant
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
343*da0073e9SAndroid Build Coastguard Worker        a_repr = repr(self.a)
344*da0073e9SAndroid Build Coastguard Worker        return f"CtxSubclassTensor({a_repr})"
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker    def __tensor_flatten__(self):
347*da0073e9SAndroid Build Coastguard Worker        return ["a"], (self.constant,)
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker    @staticmethod
350*da0073e9SAndroid Build Coastguard Worker    def __tensor_unflatten__(inner_tensors, meta, sizes, strides):
351*da0073e9SAndroid Build Coastguard Worker        constant = meta[0]
352*da0073e9SAndroid Build Coastguard Worker        a = inner_tensors["a"]
353*da0073e9SAndroid Build Coastguard Worker        return CtxSubclassTensor(a, constant)
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker    @classmethod
356*da0073e9SAndroid Build Coastguard Worker    def __torch_dispatch__(cls, func, types, args, kwargs):
357*da0073e9SAndroid Build Coastguard Worker        from torch.utils._python_dispatch import return_and_correct_aliasing
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
360*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
361*da0073e9SAndroid Build Coastguard Worker        biggest_constant = max(
362*da0073e9SAndroid Build Coastguard Worker            [
363*da0073e9SAndroid Build Coastguard Worker                x.constant
364*da0073e9SAndroid Build Coastguard Worker                for x in pytree.tree_flatten(args)[0]
365*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, CtxSubclassTensor)
366*da0073e9SAndroid Build Coastguard Worker            ]
367*da0073e9SAndroid Build Coastguard Worker        )
368*da0073e9SAndroid Build Coastguard Worker        args_a = pytree.tree_map(
369*da0073e9SAndroid Build Coastguard Worker            lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, args
370*da0073e9SAndroid Build Coastguard Worker        )
371*da0073e9SAndroid Build Coastguard Worker        kwargs_a = pytree.tree_map(
372*da0073e9SAndroid Build Coastguard Worker            lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, kwargs
373*da0073e9SAndroid Build Coastguard Worker        )
374*da0073e9SAndroid Build Coastguard Worker        out_a = func(*args_a, **kwargs_a)
375*da0073e9SAndroid Build Coastguard Worker        out = pytree.tree_map(
376*da0073e9SAndroid Build Coastguard Worker            lambda x: CtxSubclassTensor(x, biggest_constant)
377*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, torch.Tensor)
378*da0073e9SAndroid Build Coastguard Worker            else x,
379*da0073e9SAndroid Build Coastguard Worker            out_a,
380*da0073e9SAndroid Build Coastguard Worker        )
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker        if func == torch.ops.aten.mul.Tensor:
383*da0073e9SAndroid Build Coastguard Worker            out = out + out.constant
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker        return return_and_correct_aliasing(func, args, kwargs, out)
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Workerdef func(a):
389*da0073e9SAndroid Build Coastguard Worker    return a.sin()
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Workerclass EagerRecordGraphAndInputs:
393*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
394*da0073e9SAndroid Build Coastguard Worker        self.graphs = []
395*da0073e9SAndroid Build Coastguard Worker        self.example_inputs = []
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker    def __call__(self, gm: torch.fx.GraphModule, example_inputs):
398*da0073e9SAndroid Build Coastguard Worker        self.graphs.append(gm)
399*da0073e9SAndroid Build Coastguard Worker        self.example_inputs.append(example_inputs)
400*da0073e9SAndroid Build Coastguard Worker        return gm
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard WorkerGLOBAL_TEST_SUBCLASSES = {
404*da0073e9SAndroid Build Coastguard Worker    MockSubclass,
405*da0073e9SAndroid Build Coastguard Worker    DummyNDim,
406*da0073e9SAndroid Build Coastguard Worker    SigmoidToExpSubclass,
407*da0073e9SAndroid Build Coastguard Worker    BaseTorchFunction,
408*da0073e9SAndroid Build Coastguard Worker}
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker# Returns True if the function recompiles between inputs1 and inputs2 with the
412*da0073e9SAndroid Build Coastguard Worker# specified dynamic setting.
413*da0073e9SAndroid Build Coastguard Workerdef _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True):
414*da0073e9SAndroid Build Coastguard Worker    compile_count = [0]
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker    def counter(gm, example_inputs):
417*da0073e9SAndroid Build Coastguard Worker        compile_count[0] += 1
418*da0073e9SAndroid Build Coastguard Worker        return gm
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker    compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic)
421*da0073e9SAndroid Build Coastguard Worker    compiled_f(*inputs1)
422*da0073e9SAndroid Build Coastguard Worker    compiled_f(*inputs2)
423*da0073e9SAndroid Build Coastguard Worker    return compile_count[0] > 1
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Workerclass SubclassTests(torch._dynamo.test_case.TestCase):
427*da0073e9SAndroid Build Coastguard Worker    @classmethod
428*da0073e9SAndroid Build Coastguard Worker    def setUpClass(cls):
429*da0073e9SAndroid Build Coastguard Worker        super().setUpClass()
430*da0073e9SAndroid Build Coastguard Worker        cls._exit_stack.enter_context(
431*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.config.patch(
432*da0073e9SAndroid Build Coastguard Worker                "traceable_tensor_subclasses", GLOBAL_TEST_SUBCLASSES
433*da0073e9SAndroid Build Coastguard Worker            )
434*da0073e9SAndroid Build Coastguard Worker        )
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker    @classmethod
437*da0073e9SAndroid Build Coastguard Worker    def tearDownClass(cls):
438*da0073e9SAndroid Build Coastguard Worker        cls._exit_stack.close()
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Worker    def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
441*da0073e9SAndroid Build Coastguard Worker        _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles)
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker    def test_no_call_to_new(self):
444*da0073e9SAndroid Build Coastguard Worker        class BadNewTorchFunction(torch.Tensor):
445*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, *args, **kwargs):
446*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Oops!")
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker            @classmethod
449*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
450*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
451*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
452*da0073e9SAndroid Build Coastguard Worker                return super().__torch_function__(func, types, args, kwargs)
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch(
455*da0073e9SAndroid Build Coastguard Worker            "traceable_tensor_subclasses", {BadNewTorchFunction}
456*da0073e9SAndroid Build Coastguard Worker        ):
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker            @torch.compile(backend="eager", fullgraph=True)
459*da0073e9SAndroid Build Coastguard Worker            def fn(x):
460*da0073e9SAndroid Build Coastguard Worker                return torch.add(x, 1)
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker            input = torch.ones(2, 2).as_subclass(BadNewTorchFunction)
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker            res = fn(input)
465*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(res, BadNewTorchFunction)
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker    def test_no_torch_function_recompiles(self):
468*da0073e9SAndroid Build Coastguard Worker        class NJT:
469*da0073e9SAndroid Build Coastguard Worker            def __repr__(self):
470*da0073e9SAndroid Build Coastguard Worker                return f"NJT(shape={self.shape})"
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker            def __init__(self, values, offsets):
473*da0073e9SAndroid Build Coastguard Worker                self._values = values
474*da0073e9SAndroid Build Coastguard Worker                self._offsets = offsets
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker            def sin(self):
477*da0073e9SAndroid Build Coastguard Worker                return torch.sin(self)
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker            @classmethod
480*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
481*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
482*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
483*da0073e9SAndroid Build Coastguard Worker                if func == torch.sin:
484*da0073e9SAndroid Build Coastguard Worker                    self = args[0]
485*da0073e9SAndroid Build Coastguard Worker                    return NJT(func(self._values), self._offsets)
486*da0073e9SAndroid Build Coastguard Worker                raise AssertionError("should not get here")
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker        values1 = torch.randn(10, 3, 4, requires_grad=True)
489*da0073e9SAndroid Build Coastguard Worker        values2 = torch.randn(10, 3, 4, requires_grad=True)
490*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 3, 10])
491*da0073e9SAndroid Build Coastguard Worker        njt1 = NJT(values1, offsets)
492*da0073e9SAndroid Build Coastguard Worker        njt2 = NJT(values2, offsets)
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
495*da0073e9SAndroid Build Coastguard Worker        def f(x):
496*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
499*da0073e9SAndroid Build Coastguard Worker            f(njt1)
500*da0073e9SAndroid Build Coastguard Worker            f(njt2)
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker    def test_base_torch_function_tracing(self):
503*da0073e9SAndroid Build Coastguard Worker        def fn(x):
504*da0073e9SAndroid Build Coastguard Worker            return torch.add(x, 1)
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(2, 2).as_subclass(BaseTorchFunction)
507*da0073e9SAndroid Build Coastguard Worker        out = fn(input)
508*da0073e9SAndroid Build Coastguard Worker        out_opt = compile_full_eager(fn)(input)
509*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(out, BaseTorchFunction)
510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, out_opt)
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_state_graph_break(self):
513*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
514*da0073e9SAndroid Build Coastguard Worker        def fn(x):
515*da0073e9SAndroid Build Coastguard Worker            with torch._C.DisableTorchFunctionSubclass():
516*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
517*da0073e9SAndroid Build Coastguard Worker                return torch._C._is_torch_function_enabled(), torch.add(x, 1.0)
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(2, 2)
520*da0073e9SAndroid Build Coastguard Worker        res, _ = fn(input)
521*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(res)
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_state_nested(self):
524*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
525*da0073e9SAndroid Build Coastguard Worker        def fn(x):
526*da0073e9SAndroid Build Coastguard Worker            with torch._C.DisableTorchFunctionSubclass():
527*da0073e9SAndroid Build Coastguard Worker                with torch._C.DisableTorchFunctionSubclass():
528*da0073e9SAndroid Build Coastguard Worker                    x = x + 1
529*da0073e9SAndroid Build Coastguard Worker                # Should reset to the outer state (disabled) after exiting ctx manager
530*da0073e9SAndroid Build Coastguard Worker                return torch._C._is_torch_function_enabled(), torch.add(x, 1.0)
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(2, 2)
533*da0073e9SAndroid Build Coastguard Worker        res, _ = fn(input)
534*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(res)
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_state_tracing(self):
537*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
538*da0073e9SAndroid Build Coastguard Worker        def fn(x):
539*da0073e9SAndroid Build Coastguard Worker            with torch._C.DisableTorchFunctionSubclass():
540*da0073e9SAndroid Build Coastguard Worker                torch.add(x, 1.0)
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(2, 2)
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker        res = fn(input)
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_state_guards(self):
547*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True)
550*da0073e9SAndroid Build Coastguard Worker        def fn(x):
551*da0073e9SAndroid Build Coastguard Worker            torch.add(x, 1.0)
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(2, 2)
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker        with torch._C.DisableTorchFunctionSubclass():
556*da0073e9SAndroid Build Coastguard Worker            res = fn(input)
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker        res = fn(input)
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker    def test_return_subclass(self):
563*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
564*da0073e9SAndroid Build Coastguard Worker        def fn(x):
565*da0073e9SAndroid Build Coastguard Worker            return MockSubclass(torch.add(x, 1.0))
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(2, 2)
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker        res = fn(input)
570*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(res, MockSubclass)
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker    def test_return_as_subclass(self):
573*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
574*da0073e9SAndroid Build Coastguard Worker        def fn(x):
575*da0073e9SAndroid Build Coastguard Worker            return torch.add(x, 1.0).as_subclass(MockSubclass)
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(2, 2)
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker        res = fn(input)
580*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(res, MockSubclass)
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker    def test_return_local_subclass(self):
583*da0073e9SAndroid Build Coastguard Worker        class LocalSubclass(torch.Tensor):
584*da0073e9SAndroid Build Coastguard Worker            @classmethod
585*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
586*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
587*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
588*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker            @torch.compile(backend="eager", fullgraph=True)
593*da0073e9SAndroid Build Coastguard Worker            def fn(x):
594*da0073e9SAndroid Build Coastguard Worker                return LocalSubclass(torch.add(x, 1.0))
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker            input = torch.ones(2, 2)
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker            res = fn(input)
599*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(res, LocalSubclass)
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_list_args(self):
602*da0073e9SAndroid Build Coastguard Worker        HANDLED_FUNCTIONS = {}
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker        class MyClass:
605*da0073e9SAndroid Build Coastguard Worker            def __init__(self, foo):
606*da0073e9SAndroid Build Coastguard Worker                self.foo = foo
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker            @classmethod
609*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(
610*da0073e9SAndroid Build Coastguard Worker                cls,
611*da0073e9SAndroid Build Coastguard Worker                func,
612*da0073e9SAndroid Build Coastguard Worker                types,
613*da0073e9SAndroid Build Coastguard Worker                args=(),
614*da0073e9SAndroid Build Coastguard Worker                kwargs=None,
615*da0073e9SAndroid Build Coastguard Worker            ):
616*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
617*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
618*da0073e9SAndroid Build Coastguard Worker                if func not in HANDLED_FUNCTIONS or not all(  # noqa: C419
619*da0073e9SAndroid Build Coastguard Worker                    [  # noqa: C419
620*da0073e9SAndroid Build Coastguard Worker                        issubclass(t, (torch.Tensor, MyClass)) for t in types
621*da0073e9SAndroid Build Coastguard Worker                    ]
622*da0073e9SAndroid Build Coastguard Worker                ):
623*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
624*da0073e9SAndroid Build Coastguard Worker                return HANDLED_FUNCTIONS[func](*args, **kwargs)
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker        def _stack(input, dim=0, *, out=None):
627*da0073e9SAndroid Build Coastguard Worker            return MyClass(sum([x.foo for x in input]))
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker        HANDLED_FUNCTIONS[torch.stack] = _stack
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
632*da0073e9SAndroid Build Coastguard Worker        def fn(v0, v1):
633*da0073e9SAndroid Build Coastguard Worker            return torch.stack([v0, v1])
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker        ret = fn(MyClass(1), MyClass(1))
636*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ret.foo, 2)
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker    @parametrize(
639*da0073e9SAndroid Build Coastguard Worker        "comparison",
640*da0073e9SAndroid Build Coastguard Worker        [
641*da0073e9SAndroid Build Coastguard Worker            subtest(isinstance, "isinstance"),
642*da0073e9SAndroid Build Coastguard Worker            subtest(lambda instance, type_: type(instance) == type_, "equality"),
643*da0073e9SAndroid Build Coastguard Worker            subtest(lambda instance, type_: type(instance) is type_, "identity"),
644*da0073e9SAndroid Build Coastguard Worker        ],
645*da0073e9SAndroid Build Coastguard Worker    )
646*da0073e9SAndroid Build Coastguard Worker    @parametrize(
647*da0073e9SAndroid Build Coastguard Worker        "input_type",
648*da0073e9SAndroid Build Coastguard Worker        [
649*da0073e9SAndroid Build Coastguard Worker            subtest(torch.Tensor, "tensor"),
650*da0073e9SAndroid Build Coastguard Worker            subtest(DummyNDim, "subclass"),
651*da0073e9SAndroid Build Coastguard Worker        ],
652*da0073e9SAndroid Build Coastguard Worker    )
653*da0073e9SAndroid Build Coastguard Worker    def test_type_check(self, comparison, input_type):
654*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch("traceable_tensor_subclasses", {DummyNDim}):
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker            def fn(x):
657*da0073e9SAndroid Build Coastguard Worker                if comparison(x, DummyNDim):
658*da0073e9SAndroid Build Coastguard Worker                    return torch.ones(1, 1)
659*da0073e9SAndroid Build Coastguard Worker                else:
660*da0073e9SAndroid Build Coastguard Worker                    return torch.zeros(2, 2)
661*da0073e9SAndroid Build Coastguard Worker
662*da0073e9SAndroid Build Coastguard Worker            input = torch.ones(2, 2).as_subclass(input_type)
663*da0073e9SAndroid Build Coastguard Worker            exp_res = fn(input)
664*da0073e9SAndroid Build Coastguard Worker            act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input)
665*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(exp_res, act_res)
666*da0073e9SAndroid Build Coastguard Worker
667*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_call_on_method(self):
668*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2)
669*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(2, 2)
670*da0073e9SAndroid Build Coastguard Worker        z = torch.ones(2, 2)
671*da0073e9SAndroid Build Coastguard Worker        wrapped = x.as_subclass(SigmoidToExpSubclass)
672*da0073e9SAndroid Build Coastguard Worker        wrapped2 = y.as_subclass(SigmoidToExpSubclass)
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Worker        def fn(w):
675*da0073e9SAndroid Build Coastguard Worker            return w.sigmoid()
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker        fn_opt = compile_full_eager(fn)
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker        res_exp = fn(wrapped)
680*da0073e9SAndroid Build Coastguard Worker        res_act = fn_opt(wrapped2)
681*da0073e9SAndroid Build Coastguard Worker        res_exp2 = z.exp()
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_exp, res_act)
684*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_exp, res_exp2)
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Worker    def test_user_overidden_method_unsupported(self):
687*da0073e9SAndroid Build Coastguard Worker        class LocalSubclass(torch.Tensor):
688*da0073e9SAndroid Build Coastguard Worker            @classmethod
689*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
690*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
691*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
692*da0073e9SAndroid Build Coastguard Worker                return super().__torch_function__(func, types, args, kwargs)
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker            def sigmoid(self):
695*da0073e9SAndroid Build Coastguard Worker                return None
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
698*da0073e9SAndroid Build Coastguard Worker        def fn(x):
699*da0073e9SAndroid Build Coastguard Worker            x.sigmoid()
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker        msg = (
702*da0073e9SAndroid Build Coastguard Worker            "Accessing overridden method/attribute sigmoid on a tensor"
703*da0073e9SAndroid Build Coastguard Worker            " subclass with a __torch_function__ override is not supported"
704*da0073e9SAndroid Build Coastguard Worker        )
705*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch(
706*da0073e9SAndroid Build Coastguard Worker            "traceable_tensor_subclasses", {LocalSubclass}
707*da0073e9SAndroid Build Coastguard Worker        ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
708*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(2, 2).as_subclass(LocalSubclass)
709*da0073e9SAndroid Build Coastguard Worker            fn(x)
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker    def test_user_overidden_attr_unsupported(self):
712*da0073e9SAndroid Build Coastguard Worker        class LocalSubclass(torch.Tensor):
713*da0073e9SAndroid Build Coastguard Worker            @classmethod
714*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
715*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
716*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
717*da0073e9SAndroid Build Coastguard Worker                return super().__torch_function__(func, types, args, kwargs)
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker            ndim = 10
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
722*da0073e9SAndroid Build Coastguard Worker        def fn(x):
723*da0073e9SAndroid Build Coastguard Worker            return x.ndim
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker        msg = (
726*da0073e9SAndroid Build Coastguard Worker            "Accessing overridden method/attribute ndim on a tensor"
727*da0073e9SAndroid Build Coastguard Worker            " subclass with a __torch_function__ override is not supported"
728*da0073e9SAndroid Build Coastguard Worker        )
729*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch(
730*da0073e9SAndroid Build Coastguard Worker            "traceable_tensor_subclasses", {LocalSubclass}
731*da0073e9SAndroid Build Coastguard Worker        ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
732*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(2, 2).as_subclass(LocalSubclass)
733*da0073e9SAndroid Build Coastguard Worker            fn(x)
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker    def test_user_overidden_property_unsupported(self):
736*da0073e9SAndroid Build Coastguard Worker        class LocalSubclass(torch.Tensor):
737*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
738*da0073e9SAndroid Build Coastguard Worker                self._ndim = 10
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker            @classmethod
741*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
742*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
743*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
744*da0073e9SAndroid Build Coastguard Worker                return super().__torch_function__(func, types, args, kwargs)
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard Worker            @property
747*da0073e9SAndroid Build Coastguard Worker            def ndim(self):
748*da0073e9SAndroid Build Coastguard Worker                return self._ndim
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker            @ndim.setter
751*da0073e9SAndroid Build Coastguard Worker            def ndim(self, value):
752*da0073e9SAndroid Build Coastguard Worker                self._ndim = value
753*da0073e9SAndroid Build Coastguard Worker
754*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
755*da0073e9SAndroid Build Coastguard Worker        def fn(x):
756*da0073e9SAndroid Build Coastguard Worker            return x.ndim
757*da0073e9SAndroid Build Coastguard Worker
758*da0073e9SAndroid Build Coastguard Worker        msg = (
759*da0073e9SAndroid Build Coastguard Worker            "Accessing overridden method/attribute ndim on a tensor"
760*da0073e9SAndroid Build Coastguard Worker            " subclass with a __torch_function__ override is not supported"
761*da0073e9SAndroid Build Coastguard Worker        )
762*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch(
763*da0073e9SAndroid Build Coastguard Worker            "traceable_tensor_subclasses", {LocalSubclass}
764*da0073e9SAndroid Build Coastguard Worker        ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
765*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(2, 2).as_subclass(LocalSubclass)
766*da0073e9SAndroid Build Coastguard Worker            fn(x)
767*da0073e9SAndroid Build Coastguard Worker
768*da0073e9SAndroid Build Coastguard Worker    def test_overridden_method_guarding(self):
769*da0073e9SAndroid Build Coastguard Worker        class LocalSubclass(torch.Tensor):
770*da0073e9SAndroid Build Coastguard Worker            @classmethod
771*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
772*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
773*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
774*da0073e9SAndroid Build Coastguard Worker                return super().__torch_function__(func, types, args, kwargs)
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
777*da0073e9SAndroid Build Coastguard Worker        def fn(x):
778*da0073e9SAndroid Build Coastguard Worker            return x.sigmoid()
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch(
781*da0073e9SAndroid Build Coastguard Worker            error_on_recompile=True, traceable_tensor_subclasses={LocalSubclass}
782*da0073e9SAndroid Build Coastguard Worker        ):
783*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(2, 2).as_subclass(LocalSubclass)
784*da0073e9SAndroid Build Coastguard Worker            fn(x)
785*da0073e9SAndroid Build Coastguard Worker            fn(x)
786*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(2, 2).as_subclass(LocalSubclass)
787*da0073e9SAndroid Build Coastguard Worker            fn(x)
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch(
790*da0073e9SAndroid Build Coastguard Worker            traceable_tensor_subclasses={LocalSubclass}
791*da0073e9SAndroid Build Coastguard Worker        ), self.assertRaisesRegex(
792*da0073e9SAndroid Build Coastguard Worker            TypeError,
793*da0073e9SAndroid Build Coastguard Worker            "'bool' object is not callable",
794*da0073e9SAndroid Build Coastguard Worker        ):
795*da0073e9SAndroid Build Coastguard Worker            LocalSubclass.sigmoid = False
796*da0073e9SAndroid Build Coastguard Worker            fn(x)
797*da0073e9SAndroid Build Coastguard Worker
798*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_call_on_attr(self):
799*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2)
800*da0073e9SAndroid Build Coastguard Worker        wrapped = x.as_subclass(DummyNDim)
801*da0073e9SAndroid Build Coastguard Worker
802*da0073e9SAndroid Build Coastguard Worker        def fn(w):
803*da0073e9SAndroid Build Coastguard Worker            return w.ndim + torch.ones(2)
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker        fn_opt = compile_full_eager(fn)
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker        res_exp = fn(wrapped)
808*da0073e9SAndroid Build Coastguard Worker        res_act = fn_opt(wrapped)
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_exp, res_act)
811*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_exp, torch.ones(2) + 10)
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_wrapper_class(self):
814*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2)
815*da0073e9SAndroid Build Coastguard Worker        wrapped = WrapperSubclass(x)
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker        def fn(w):
818*da0073e9SAndroid Build Coastguard Worker            return torch.add(w, 1.0)
819*da0073e9SAndroid Build Coastguard Worker
820*da0073e9SAndroid Build Coastguard Worker        fn_opt = compile_full_eager(fn)
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker        res_exp = fn(wrapped)
823*da0073e9SAndroid Build Coastguard Worker        res_act = fn_opt(wrapped)
824*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_exp, res_act)
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_wrapper_class_with_kwargs(self):
827*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2)
828*da0073e9SAndroid Build Coastguard Worker        wrapped = WrapperSubclass(x)
829*da0073e9SAndroid Build Coastguard Worker
830*da0073e9SAndroid Build Coastguard Worker        def fn(w):
831*da0073e9SAndroid Build Coastguard Worker            return torch.add(w, 1.0, alpha=2.0)
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker        fn_opt = compile_full_eager(fn)
834*da0073e9SAndroid Build Coastguard Worker
835*da0073e9SAndroid Build Coastguard Worker        res_exp = fn(wrapped)
836*da0073e9SAndroid Build Coastguard Worker        res_act = fn_opt(wrapped)
837*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_exp, res_act)
838*da0073e9SAndroid Build Coastguard Worker
839*da0073e9SAndroid Build Coastguard Worker    def test_tensor_subclass_custom_attr(self):
840*da0073e9SAndroid Build Coastguard Worker        class AttrSubclass(torch.Tensor):
841*da0073e9SAndroid Build Coastguard Worker            x: int = 10
842*da0073e9SAndroid Build Coastguard Worker
843*da0073e9SAndroid Build Coastguard Worker            @classmethod
844*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
845*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
846*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Worker                return super().__torch_function__(func, types, args, kwargs)
849*da0073e9SAndroid Build Coastguard Worker
850*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
851*da0073e9SAndroid Build Coastguard Worker        def fn(x):
852*da0073e9SAndroid Build Coastguard Worker            return x.x + torch.ones(2, 2)
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker        with traceable_subclass(AttrSubclass):
855*da0073e9SAndroid Build Coastguard Worker            input = torch.ones(2, 2).as_subclass(AttrSubclass)
856*da0073e9SAndroid Build Coastguard Worker            fn_opt = compile_full_eager(fn)
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Worker            res_exp = fn(input)
859*da0073e9SAndroid Build Coastguard Worker            res_act = fn_opt(input)
860*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_exp, res_act)
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker    def test_compile_with_fake_tensor_dynamic_dim(self):
863*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([3, 4])
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Worker        def f(x):
866*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker        def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count):
869*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.reset()
870*da0073e9SAndroid Build Coastguard Worker            cnt = torch._dynamo.testing.CompileCounter()
871*da0073e9SAndroid Build Coastguard Worker
872*da0073e9SAndroid Build Coastguard Worker            opt_f = torch.compile(f, backend=cnt, fullgraph=True)
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Worker            x1 = torch.rand_like(x)
875*da0073e9SAndroid Build Coastguard Worker            f(x)
876*da0073e9SAndroid Build Coastguard Worker            f(torch.randn([4, 3]))
877*da0073e9SAndroid Build Coastguard Worker            shape_env = ShapeEnv()
878*da0073e9SAndroid Build Coastguard Worker            with torch._subclasses.fake_tensor.FakeTensorMode(
879*da0073e9SAndroid Build Coastguard Worker                shape_env=shape_env
880*da0073e9SAndroid Build Coastguard Worker            ) as fake_mode:
881*da0073e9SAndroid Build Coastguard Worker                x_fake = fake_mode.from_tensor(
882*da0073e9SAndroid Build Coastguard Worker                    x,
883*da0073e9SAndroid Build Coastguard Worker                    symbolic_context=StatelessSymbolicContext(
884*da0073e9SAndroid Build Coastguard Worker                        dynamic_sizes=[dim_dynamic for i in range(x.dim())]
885*da0073e9SAndroid Build Coastguard Worker                    ),
886*da0073e9SAndroid Build Coastguard Worker                )
887*da0073e9SAndroid Build Coastguard Worker                x1_fake = fake_mode.from_tensor(
888*da0073e9SAndroid Build Coastguard Worker                    x1,
889*da0073e9SAndroid Build Coastguard Worker                    symbolic_context=StatelessSymbolicContext(
890*da0073e9SAndroid Build Coastguard Worker                        dynamic_sizes=[dim_dynamic for i in range(x.dim())]
891*da0073e9SAndroid Build Coastguard Worker                    ),
892*da0073e9SAndroid Build Coastguard Worker                )
893*da0073e9SAndroid Build Coastguard Worker                opt_f(x_fake)
894*da0073e9SAndroid Build Coastguard Worker                opt_f(x1_fake)
895*da0073e9SAndroid Build Coastguard Worker
896*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.frame_count, exp_frame_count)
897*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.op_count, exp_op_count)
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Worker        test_dynamic_dim(f, x, DimDynamic.DYNAMIC, 1, 1)
900*da0073e9SAndroid Build Coastguard Worker        test_dynamic_dim(f, x, DimDynamic.DUCK, 1, 1)
901*da0073e9SAndroid Build Coastguard Worker        test_dynamic_dim(f, x, DimDynamic.STATIC, 1, 1)
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Worker    def test_compile_with_fake_tensor_automatic_dynamic(self):
904*da0073e9SAndroid Build Coastguard Worker        def f(x):
905*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker        def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count):
908*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.reset()
909*da0073e9SAndroid Build Coastguard Worker            cnt = torch._dynamo.testing.CompileCounter()
910*da0073e9SAndroid Build Coastguard Worker            opt_f = torch.compile(f, backend=cnt, fullgraph=True)
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker            shape_env = ShapeEnv()
913*da0073e9SAndroid Build Coastguard Worker            with torch._subclasses.fake_tensor.FakeTensorMode(
914*da0073e9SAndroid Build Coastguard Worker                shape_env=shape_env
915*da0073e9SAndroid Build Coastguard Worker            ) as fake_mode:
916*da0073e9SAndroid Build Coastguard Worker                for inp in inps:
917*da0073e9SAndroid Build Coastguard Worker                    fake_inp = fake_mode.from_tensor(
918*da0073e9SAndroid Build Coastguard Worker                        inp,
919*da0073e9SAndroid Build Coastguard Worker                        symbolic_context=StatelessSymbolicContext(
920*da0073e9SAndroid Build Coastguard Worker                            [dim_dynamic for i in range(x.dim())]
921*da0073e9SAndroid Build Coastguard Worker                        ),
922*da0073e9SAndroid Build Coastguard Worker                    )
923*da0073e9SAndroid Build Coastguard Worker                    opt_f(fake_inp)
924*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.frame_count, exp_frame_count)
925*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.op_count, exp_op_count)
926*da0073e9SAndroid Build Coastguard Worker
927*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([3, 4])
928*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([4, 5])
929*da0073e9SAndroid Build Coastguard Worker        z = torch.randn([5, 6])
930*da0073e9SAndroid Build Coastguard Worker        a = torch.randn([3, 5])
931*da0073e9SAndroid Build Coastguard Worker        b = torch.randn([4, 4])
932*da0073e9SAndroid Build Coastguard Worker        # When inputs' DimDynamic is DYNAMIC or DUCK, the inputs
933*da0073e9SAndroid Build Coastguard Worker        # to opt_f will be tensors with SymInt sizes. Dynamo will treat input
934*da0073e9SAndroid Build Coastguard Worker        # as dynamic automatically and will only compile once
935*da0073e9SAndroid Build Coastguard Worker        for dim_dynamic in [DimDynamic.DYNAMIC, DimDynamic.DUCK]:
936*da0073e9SAndroid Build Coastguard Worker            test_automatic_dynamic(f, [x, y, z], dim_dynamic, 1, 1)
937*da0073e9SAndroid Build Coastguard Worker            test_automatic_dynamic(f, [x, a, z], dim_dynamic, 1, 1)
938*da0073e9SAndroid Build Coastguard Worker            test_automatic_dynamic(f, [x, b, z], dim_dynamic, 1, 1)
939*da0073e9SAndroid Build Coastguard Worker
940*da0073e9SAndroid Build Coastguard Worker        for dim_dynamic in [DimDynamic.STATIC]:
941*da0073e9SAndroid Build Coastguard Worker            # Recompile once, first with dim 0 and 1 become Dynamic
942*da0073e9SAndroid Build Coastguard Worker            test_automatic_dynamic(f, [x, y, z], dim_dynamic, 2, 2)
943*da0073e9SAndroid Build Coastguard Worker            # Recompile 2 times, first with dim 1 become Dynamic, second with dim 0 becomes Dynamic.
944*da0073e9SAndroid Build Coastguard Worker            test_automatic_dynamic(f, [x, a, z], dim_dynamic, 3, 3)
945*da0073e9SAndroid Build Coastguard Worker            # Recompile 2 times, first with dim 0 become Dynamic, second with dim 1 becomes Dynamic.
946*da0073e9SAndroid Build Coastguard Worker            test_automatic_dynamic(f, [x, b, z], dim_dynamic, 3, 3)
947*da0073e9SAndroid Build Coastguard Worker
948*da0073e9SAndroid Build Coastguard Worker    def test_compile_with_functionalization(self):
949*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([3, 4])
950*da0073e9SAndroid Build Coastguard Worker        x_clone = x.clone()
951*da0073e9SAndroid Build Coastguard Worker        x_clone2 = x.clone()
952*da0073e9SAndroid Build Coastguard Worker        backend = EagerRecordGraphAndInputs()
953*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True)
956*da0073e9SAndroid Build Coastguard Worker        def f(x):
957*da0073e9SAndroid Build Coastguard Worker            return x.add_(1.0) + torch.nn.functional.relu_(x)
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Worker        f_out = f(x)
960*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
961*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 3)
962*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(backend.graphs), 1)
963*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(backend.example_inputs), 1)
964*da0073e9SAndroid Build Coastguard Worker
965*da0073e9SAndroid Build Coastguard Worker        actual = normalize_gm(backend.graphs[0].print_readable(print_output=False))
966*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
967*da0073e9SAndroid Build Coastguard Worker            actual,
968*da0073e9SAndroid Build Coastguard Worker            """\
969*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
970*da0073e9SAndroid Build Coastguard Worker    def forward(self, L_x_: "f32[3, 4]"):
971*da0073e9SAndroid Build Coastguard Worker        l_x_ = L_x_
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Worker        add_: "f32[3, 4]" = l_x_.add_(1.0)
974*da0073e9SAndroid Build Coastguard Worker        relu_: "f32[3, 4]" = torch.relu_(l_x_);  l_x_ = None
975*da0073e9SAndroid Build Coastguard Worker        add: "f32[3, 4]" = add_ + relu_;  add_ = relu_ = None
976*da0073e9SAndroid Build Coastguard Worker        return (add,)
977*da0073e9SAndroid Build Coastguard Worker""",
978*da0073e9SAndroid Build Coastguard Worker        )
979*da0073e9SAndroid Build Coastguard Worker
980*da0073e9SAndroid Build Coastguard Worker        ff = torch.func.functionalize(f)
981*da0073e9SAndroid Build Coastguard Worker        ff_out = ff(x_clone)
982*da0073e9SAndroid Build Coastguard Worker
983*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
984*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 6)
985*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(backend.graphs), 2)
986*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(backend.example_inputs), 2)
987*da0073e9SAndroid Build Coastguard Worker        actual = normalize_gm(backend.graphs[1].print_readable(print_output=False))
988*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
989*da0073e9SAndroid Build Coastguard Worker            actual,
990*da0073e9SAndroid Build Coastguard Worker            """\
991*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
992*da0073e9SAndroid Build Coastguard Worker    def forward(self, L_x_: "f32[3, 4]"):
993*da0073e9SAndroid Build Coastguard Worker        l_x_ = L_x_
994*da0073e9SAndroid Build Coastguard Worker
995*da0073e9SAndroid Build Coastguard Worker        add_: "f32[3, 4]" = l_x_.add_(1.0)
996*da0073e9SAndroid Build Coastguard Worker        relu_: "f32[3, 4]" = torch.relu_(l_x_);  l_x_ = None
997*da0073e9SAndroid Build Coastguard Worker        add: "f32[3, 4]" = add_ + relu_;  add_ = relu_ = None
998*da0073e9SAndroid Build Coastguard Worker        return (add,)
999*da0073e9SAndroid Build Coastguard Worker""",
1000*da0073e9SAndroid Build Coastguard Worker        )
1001*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
1002*da0073e9SAndroid Build Coastguard Worker
1003*da0073e9SAndroid Build Coastguard Worker        # Cannot re-use the version from AOTAutograd, since that uses python functional tensors.
1004*da0073e9SAndroid Build Coastguard Worker        def to_fun(x):
1005*da0073e9SAndroid Build Coastguard Worker            x_functional = torch._to_functional_tensor(x)
1006*da0073e9SAndroid Build Coastguard Worker            torch._mirror_autograd_meta_to(x, x_functional)
1007*da0073e9SAndroid Build Coastguard Worker            return x_functional
1008*da0073e9SAndroid Build Coastguard Worker
1009*da0073e9SAndroid Build Coastguard Worker        def aot_f_wrapper(func):
1010*da0073e9SAndroid Build Coastguard Worker            @functools.wraps(func)
1011*da0073e9SAndroid Build Coastguard Worker            def wrapper(*args, **kwargs):
1012*da0073e9SAndroid Build Coastguard Worker                torch._enable_functionalization(reapply_views=False)
1013*da0073e9SAndroid Build Coastguard Worker                try:
1014*da0073e9SAndroid Build Coastguard Worker                    func_args = pytree.tree_map(to_fun, args)
1015*da0073e9SAndroid Build Coastguard Worker                    func_kwargs = pytree.tree_map(to_fun, kwargs)
1016*da0073e9SAndroid Build Coastguard Worker                    return func(*func_args, **func_kwargs)
1017*da0073e9SAndroid Build Coastguard Worker                finally:
1018*da0073e9SAndroid Build Coastguard Worker                    torch._disable_functionalization()
1019*da0073e9SAndroid Build Coastguard Worker
1020*da0073e9SAndroid Build Coastguard Worker            return wrapper
1021*da0073e9SAndroid Build Coastguard Worker
1022*da0073e9SAndroid Build Coastguard Worker        aot_ff = aot_f_wrapper(f)
1023*da0073e9SAndroid Build Coastguard Worker        aot_ff_out = aot_ff(x_clone2)
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 3)
1026*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 9)
1027*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(backend.graphs), 3)
1028*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(backend.example_inputs), 3)
1029*da0073e9SAndroid Build Coastguard Worker        actual = normalize_gm(backend.graphs[2].print_readable(print_output=False))
1030*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1031*da0073e9SAndroid Build Coastguard Worker            actual,
1032*da0073e9SAndroid Build Coastguard Worker            """\
1033*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
1034*da0073e9SAndroid Build Coastguard Worker    def forward(self, L_x_: "f32[3, 4]"):
1035*da0073e9SAndroid Build Coastguard Worker        l_x_ = L_x_
1036*da0073e9SAndroid Build Coastguard Worker
1037*da0073e9SAndroid Build Coastguard Worker        add_: "f32[3, 4]" = l_x_.add_(1.0)
1038*da0073e9SAndroid Build Coastguard Worker        relu_: "f32[3, 4]" = torch.relu_(l_x_);  l_x_ = None
1039*da0073e9SAndroid Build Coastguard Worker        add: "f32[3, 4]" = add_ + relu_;  add_ = relu_ = None
1040*da0073e9SAndroid Build Coastguard Worker        return (add,)
1041*da0073e9SAndroid Build Coastguard Worker""",
1042*da0073e9SAndroid Build Coastguard Worker        )
1043*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
1044*da0073e9SAndroid Build Coastguard Worker
1045*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f_out, ff_out)
1046*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f_out, aot_ff_out)
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker        try:
1049*da0073e9SAndroid Build Coastguard Worker            torch._enable_functionalization(reapply_views=False)
1050*da0073e9SAndroid Build Coastguard Worker            xf = pytree.tree_map(to_fun, x)
1051*da0073e9SAndroid Build Coastguard Worker            x_view = xf.t()
1052*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"):
1053*da0073e9SAndroid Build Coastguard Worker                f(x_view)
1054*da0073e9SAndroid Build Coastguard Worker        finally:
1055*da0073e9SAndroid Build Coastguard Worker            torch._disable_functionalization()
1056*da0073e9SAndroid Build Coastguard Worker
1057*da0073e9SAndroid Build Coastguard Worker    def test_compile_higher_order_with_functionalization(self):
1058*da0073e9SAndroid Build Coastguard Worker        backend = EagerRecordGraphAndInputs()
1059*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True)
1062*da0073e9SAndroid Build Coastguard Worker        def f(x):
1063*da0073e9SAndroid Build Coastguard Worker            return wrap(lambda x: x.add_(1.0), x)
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker        def check_count_and_graph(
1066*da0073e9SAndroid Build Coastguard Worker            exp_frame_count, exp_op_count, exp_n_graph, exp_graph
1067*da0073e9SAndroid Build Coastguard Worker        ):
1068*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.frame_count, exp_frame_count)
1069*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.op_count, exp_op_count)
1070*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(backend.graphs), exp_n_graph)
1071*da0073e9SAndroid Build Coastguard Worker            actual = normalize_gm(
1072*da0073e9SAndroid Build Coastguard Worker                backend.graphs[exp_n_graph - 1].print_readable(print_output=False)
1073*da0073e9SAndroid Build Coastguard Worker            )
1074*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(actual, exp_graph, skip=1)
1075*da0073e9SAndroid Build Coastguard Worker
1076*da0073e9SAndroid Build Coastguard Worker        t = torch.randn([3, 4])
1077*da0073e9SAndroid Build Coastguard Worker        t_clone = t.clone()
1078*da0073e9SAndroid Build Coastguard Worker        t_clone2 = t.clone()
1079*da0073e9SAndroid Build Coastguard Worker        f(t)
1080*da0073e9SAndroid Build Coastguard Worker
1081*da0073e9SAndroid Build Coastguard Worker        check_count_and_graph(
1082*da0073e9SAndroid Build Coastguard Worker            1,
1083*da0073e9SAndroid Build Coastguard Worker            2,
1084*da0073e9SAndroid Build Coastguard Worker            1,
1085*da0073e9SAndroid Build Coastguard Worker            """\
1086*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
1087*da0073e9SAndroid Build Coastguard Worker    def forward(self, L_x_: "f32[3, 4]"):
1088*da0073e9SAndroid Build Coastguard Worker        l_x_ = L_x_
1089*da0073e9SAndroid Build Coastguard Worker
1090*da0073e9SAndroid Build Coastguard Worker        wrap_body_0 = self.wrap_body_0
1091*da0073e9SAndroid Build Coastguard Worker        wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_);  wrap_body_0 = l_x_ = None
1092*da0073e9SAndroid Build Coastguard Worker        getitem: "f32[3, 4]" = wrap[0];  wrap = None
1093*da0073e9SAndroid Build Coastguard Worker        return (getitem,)
1094*da0073e9SAndroid Build Coastguard Worker
1095*da0073e9SAndroid Build Coastguard Worker    class wrap_body_0(torch.nn.Module):
1096*da0073e9SAndroid Build Coastguard Worker        def forward(self, l_x_: "f32[3, 4]"):
1097*da0073e9SAndroid Build Coastguard Worker            add_: "f32[3, 4]" = l_x_.add_(1.0);  l_x_ = None
1098*da0073e9SAndroid Build Coastguard Worker            return (add_,)
1099*da0073e9SAndroid Build Coastguard Worker""",
1100*da0073e9SAndroid Build Coastguard Worker        )
1101*da0073e9SAndroid Build Coastguard Worker
1102*da0073e9SAndroid Build Coastguard Worker        ff = torch.func.functionalize(f)
1103*da0073e9SAndroid Build Coastguard Worker        ff_out = ff(t_clone)
1104*da0073e9SAndroid Build Coastguard Worker        # frame count and op count are incremented due to re-compilation
1105*da0073e9SAndroid Build Coastguard Worker        check_count_and_graph(
1106*da0073e9SAndroid Build Coastguard Worker            2,
1107*da0073e9SAndroid Build Coastguard Worker            4,
1108*da0073e9SAndroid Build Coastguard Worker            2,
1109*da0073e9SAndroid Build Coastguard Worker            """\
1110*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
1111*da0073e9SAndroid Build Coastguard Worker    def forward(self, L_x_: "f32[3, 4]"):
1112*da0073e9SAndroid Build Coastguard Worker        l_x_ = L_x_
1113*da0073e9SAndroid Build Coastguard Worker
1114*da0073e9SAndroid Build Coastguard Worker        wrap_body_0 = self.wrap_body_0
1115*da0073e9SAndroid Build Coastguard Worker        wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_);  wrap_body_0 = l_x_ = None
1116*da0073e9SAndroid Build Coastguard Worker        getitem: "f32[3, 4]" = wrap[0];  wrap = None
1117*da0073e9SAndroid Build Coastguard Worker        return (getitem,)
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Worker    class wrap_body_0(torch.nn.Module):
1120*da0073e9SAndroid Build Coastguard Worker        def forward(self, l_x_: "f32[3, 4]"):
1121*da0073e9SAndroid Build Coastguard Worker            add_: "f32[3, 4]" = l_x_.add_(1.0);  l_x_ = None
1122*da0073e9SAndroid Build Coastguard Worker            return (add_,)
1123*da0073e9SAndroid Build Coastguard Worker""",
1124*da0073e9SAndroid Build Coastguard Worker        )
1125*da0073e9SAndroid Build Coastguard Worker
1126*da0073e9SAndroid Build Coastguard Worker        try:
1127*da0073e9SAndroid Build Coastguard Worker            x = torch._to_functional_tensor(t_clone2)
1128*da0073e9SAndroid Build Coastguard Worker            torch._mirror_autograd_meta_to(t_clone2, x)
1129*da0073e9SAndroid Build Coastguard Worker            torch._enable_functionalization(reapply_views=False)
1130*da0073e9SAndroid Build Coastguard Worker            aot_f_out = f(x)
1131*da0073e9SAndroid Build Coastguard Worker        finally:
1132*da0073e9SAndroid Build Coastguard Worker            torch._disable_functionalization()
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker        # frame count and op count are incremented due to re-compilation
1135*da0073e9SAndroid Build Coastguard Worker        check_count_and_graph(
1136*da0073e9SAndroid Build Coastguard Worker            3,
1137*da0073e9SAndroid Build Coastguard Worker            6,
1138*da0073e9SAndroid Build Coastguard Worker            3,
1139*da0073e9SAndroid Build Coastguard Worker            """\
1140*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
1141*da0073e9SAndroid Build Coastguard Worker    def forward(self, L_x_: "f32[3, 4]"):
1142*da0073e9SAndroid Build Coastguard Worker        l_x_ = L_x_
1143*da0073e9SAndroid Build Coastguard Worker
1144*da0073e9SAndroid Build Coastguard Worker        wrap_body_0 = self.wrap_body_0
1145*da0073e9SAndroid Build Coastguard Worker        wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_);  wrap_body_0 = l_x_ = None
1146*da0073e9SAndroid Build Coastguard Worker        getitem: "f32[3, 4]" = wrap[0];  wrap = None
1147*da0073e9SAndroid Build Coastguard Worker        return (getitem,)
1148*da0073e9SAndroid Build Coastguard Worker
1149*da0073e9SAndroid Build Coastguard Worker    class wrap_body_0(torch.nn.Module):
1150*da0073e9SAndroid Build Coastguard Worker        def forward(self, l_x_: "f32[3, 4]"):
1151*da0073e9SAndroid Build Coastguard Worker            add_: "f32[3, 4]" = l_x_.add_(1.0);  l_x_ = None
1152*da0073e9SAndroid Build Coastguard Worker            return (add_,)
1153*da0073e9SAndroid Build Coastguard Worker""",
1154*da0073e9SAndroid Build Coastguard Worker        )
1155*da0073e9SAndroid Build Coastguard Worker
1156*da0073e9SAndroid Build Coastguard Worker    def test_has_torch_function(self):
1157*da0073e9SAndroid Build Coastguard Worker        class MyTensor:
1158*da0073e9SAndroid Build Coastguard Worker            @classmethod
1159*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
1160*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1161*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1162*da0073e9SAndroid Build Coastguard Worker
1163*da0073e9SAndroid Build Coastguard Worker                if func is torch.max:
1164*da0073e9SAndroid Build Coastguard Worker                    return torch.tensor(123)
1165*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1166*da0073e9SAndroid Build Coastguard Worker
1167*da0073e9SAndroid Build Coastguard Worker        class LocalSubclass(torch.Tensor):
1168*da0073e9SAndroid Build Coastguard Worker            @classmethod
1169*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
1170*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1171*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1172*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1173*da0073e9SAndroid Build Coastguard Worker
1174*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1175*da0073e9SAndroid Build Coastguard Worker            return torch.overrides.has_torch_function_unary(
1176*da0073e9SAndroid Build Coastguard Worker                x
1177*da0073e9SAndroid Build Coastguard Worker            ), torch.overrides.has_torch_function_variadic(x)
1178*da0073e9SAndroid Build Coastguard Worker
1179*da0073e9SAndroid Build Coastguard Worker        for test_class in [MyTensor, LocalSubclass]:
1180*da0073e9SAndroid Build Coastguard Worker            x = test_class()
1181*da0073e9SAndroid Build Coastguard Worker            ref0 = fn(x)
1182*da0073e9SAndroid Build Coastguard Worker            ref1 = fn(4)
1183*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch._dynamo.optimize("eager")(fn)
1184*da0073e9SAndroid Build Coastguard Worker            res0 = opt_fn(x)
1185*da0073e9SAndroid Build Coastguard Worker            res1 = opt_fn(4)
1186*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref0, res0)
1187*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref1, res1)
1188*da0073e9SAndroid Build Coastguard Worker
1189*da0073e9SAndroid Build Coastguard Worker    def test_wrapper_subclass_guards_on_inner_tensor(self):
1190*da0073e9SAndroid Build Coastguard Worker        # Holds an inner tensor, that has a distinct shape from the outer wrapper tensor.
1191*da0073e9SAndroid Build Coastguard Worker        # Also adds additional guards on the inner tensor's sizes.
1192*da0073e9SAndroid Build Coastguard Worker        # When the first input to an op has x.shape[0] > 5, we insert an extra add node.
1193*da0073e9SAndroid Build Coastguard Worker        class DoubleSizeMaybeAddGeThreeTensor(torch.Tensor):
1194*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1195*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, inner):
1196*da0073e9SAndroid Build Coastguard Worker                # Double the outer-most dimension
1197*da0073e9SAndroid Build Coastguard Worker                outer_shape = (inner.shape[0] * 2,) + inner.shape[1:]
1198*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_wrapper_subclass(
1199*da0073e9SAndroid Build Coastguard Worker                    # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
1200*da0073e9SAndroid Build Coastguard Worker                    # Calling the overload that has kwargs causes us to go down the first overload path,
1201*da0073e9SAndroid Build Coastguard Worker                    # which will **always** specialize sizes.
1202*da0073e9SAndroid Build Coastguard Worker                    # We should probably eventually fix this so that the first overload can just handle dynamic shapes.
1203*da0073e9SAndroid Build Coastguard Worker                    cls,
1204*da0073e9SAndroid Build Coastguard Worker                    outer_shape,
1205*da0073e9SAndroid Build Coastguard Worker                    inner.stride(),
1206*da0073e9SAndroid Build Coastguard Worker                    None,
1207*da0073e9SAndroid Build Coastguard Worker                    None,
1208*da0073e9SAndroid Build Coastguard Worker                    inner.dtype,
1209*da0073e9SAndroid Build Coastguard Worker                    inner.layout,
1210*da0073e9SAndroid Build Coastguard Worker                    inner.device,
1211*da0073e9SAndroid Build Coastguard Worker                    False,
1212*da0073e9SAndroid Build Coastguard Worker                    inner.requires_grad,
1213*da0073e9SAndroid Build Coastguard Worker                )
1214*da0073e9SAndroid Build Coastguard Worker
1215*da0073e9SAndroid Build Coastguard Worker            def __init__(self, inner):
1216*da0073e9SAndroid Build Coastguard Worker                self.inner_elem = inner
1217*da0073e9SAndroid Build Coastguard Worker
1218*da0073e9SAndroid Build Coastguard Worker            def __tensor_flatten__(self):
1219*da0073e9SAndroid Build Coastguard Worker                return ["inner_elem"], None
1220*da0073e9SAndroid Build Coastguard Worker
1221*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1222*da0073e9SAndroid Build Coastguard Worker            def __tensor_unflatten__(inner_tensors, _, outer_size, outer_stride):
1223*da0073e9SAndroid Build Coastguard Worker                return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"])
1224*da0073e9SAndroid Build Coastguard Worker
1225*da0073e9SAndroid Build Coastguard Worker            def __repr__(self):
1226*da0073e9SAndroid Build Coastguard Worker                return f"DoubleSizeMayberAddGeThreeTensor({repr(self.inner_elem)})"
1227*da0073e9SAndroid Build Coastguard Worker
1228*da0073e9SAndroid Build Coastguard Worker            @classmethod
1229*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1230*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1231*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1232*da0073e9SAndroid Build Coastguard Worker
1233*da0073e9SAndroid Build Coastguard Worker                args_inner = torch.utils._pytree.tree_map_only(
1234*da0073e9SAndroid Build Coastguard Worker                    DoubleSizeMaybeAddGeThreeTensor, lambda x: x.inner_elem, args
1235*da0073e9SAndroid Build Coastguard Worker                )
1236*da0073e9SAndroid Build Coastguard Worker                out_inner = func(*args_inner, **kwargs)
1237*da0073e9SAndroid Build Coastguard Worker
1238*da0073e9SAndroid Build Coastguard Worker                # Add guards on the  inner tensor's sizes
1239*da0073e9SAndroid Build Coastguard Worker                if args_inner[0].shape[0] > 3:
1240*da0073e9SAndroid Build Coastguard Worker                    out_inner += 2
1241*da0073e9SAndroid Build Coastguard Worker
1242*da0073e9SAndroid Build Coastguard Worker                return DoubleSizeMaybeAddGeThreeTensor(out_inner)
1243*da0073e9SAndroid Build Coastguard Worker
1244*da0073e9SAndroid Build Coastguard Worker        curr_var_to_val = None
1245*da0073e9SAndroid Build Coastguard Worker        curr_var_to_sources = None
1246*da0073e9SAndroid Build Coastguard Worker        guards = None
1247*da0073e9SAndroid Build Coastguard Worker
1248*da0073e9SAndroid Build Coastguard Worker        def backend(gm, args):
1249*da0073e9SAndroid Build Coastguard Worker            context = torch._guards.TracingContext.get()
1250*da0073e9SAndroid Build Coastguard Worker
1251*da0073e9SAndroid Build Coastguard Worker            # Grab info on sources and guards from the shapeenv
1252*da0073e9SAndroid Build Coastguard Worker            nonlocal curr_var_to_val
1253*da0073e9SAndroid Build Coastguard Worker            nonlocal curr_var_to_sources
1254*da0073e9SAndroid Build Coastguard Worker            nonlocal guards
1255*da0073e9SAndroid Build Coastguard Worker
1256*da0073e9SAndroid Build Coastguard Worker            guards = [str(g.expr) for g in context.fake_mode.shape_env.guards]
1257*da0073e9SAndroid Build Coastguard Worker            curr_var_to_val = {
1258*da0073e9SAndroid Build Coastguard Worker                str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items()
1259*da0073e9SAndroid Build Coastguard Worker            }
1260*da0073e9SAndroid Build Coastguard Worker            curr_var_to_sources = {
1261*da0073e9SAndroid Build Coastguard Worker                str(k): v[0].name()
1262*da0073e9SAndroid Build Coastguard Worker                for k, v in context.fake_mode.shape_env.var_to_sources.items()
1263*da0073e9SAndroid Build Coastguard Worker            }
1264*da0073e9SAndroid Build Coastguard Worker            return gm
1265*da0073e9SAndroid Build Coastguard Worker
1266*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=backend)
1267*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1268*da0073e9SAndroid Build Coastguard Worker            if x.shape[0] < 13:
1269*da0073e9SAndroid Build Coastguard Worker                return torch.mul(x, x)
1270*da0073e9SAndroid Build Coastguard Worker            else:
1271*da0073e9SAndroid Build Coastguard Worker                return torch.div(x, x)
1272*da0073e9SAndroid Build Coastguard Worker
1273*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(4, 4)
1274*da0073e9SAndroid Build Coastguard Worker
1275*da0073e9SAndroid Build Coastguard Worker        x = DoubleSizeMaybeAddGeThreeTensor(inp)
1276*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(x, 0)
1277*da0073e9SAndroid Build Coastguard Worker        res = fn(x)
1278*da0073e9SAndroid Build Coastguard Worker        # During fakeifying, we end up allocating a separate symint
1279*da0073e9SAndroid Build Coastguard Worker        # for the outer and inner tensor (in this test, s0 is unused).
1280*da0073e9SAndroid Build Coastguard Worker        expected_var_to_val = {
1281*da0073e9SAndroid Build Coastguard Worker            "s0": 8,
1282*da0073e9SAndroid Build Coastguard Worker            "s1": 4,
1283*da0073e9SAndroid Build Coastguard Worker        }
1284*da0073e9SAndroid Build Coastguard Worker        expected_var_to_sources = {
1285*da0073e9SAndroid Build Coastguard Worker            "s0": "L['x'].size()[0]",
1286*da0073e9SAndroid Build Coastguard Worker            "s1": "L['x'].inner_elem.size()[0]",
1287*da0073e9SAndroid Build Coastguard Worker        }
1288*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(curr_var_to_val, expected_var_to_val)
1289*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(curr_var_to_sources, expected_var_to_sources)
1290*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1291*da0073e9SAndroid Build Coastguard Worker            "\n".join(guards),
1292*da0073e9SAndroid Build Coastguard Worker            """\
1293*da0073e9SAndroid Build Coastguard WorkerEq(2*s1, s0)
1294*da0073e9SAndroid Build Coastguard Worker2*s1 < 13
1295*da0073e9SAndroid Build Coastguard Workers1 > 3""",
1296*da0073e9SAndroid Build Coastguard Worker        )
1297*da0073e9SAndroid Build Coastguard Worker
1298*da0073e9SAndroid Build Coastguard Worker    def test_wrapper_subclass_with_same_sized_inner_tensor(self):
1299*da0073e9SAndroid Build Coastguard Worker        # shouldn't recompile for different sizes when dynamic=True
1300*da0073e9SAndroid Build Coastguard Worker        sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
1301*da0073e9SAndroid Build Coastguard Worker        sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7))
1302*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=True))
1303*da0073e9SAndroid Build Coastguard Worker
1304*da0073e9SAndroid Build Coastguard Worker        # should recompile for different data size when dynamic=False
1305*da0073e9SAndroid Build Coastguard Worker        sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
1306*da0073e9SAndroid Build Coastguard Worker        sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
1307*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
1308*da0073e9SAndroid Build Coastguard Worker
1309*da0073e9SAndroid Build Coastguard Worker        # avoid recompile using manual mark_dynamic() for different data size
1310*da0073e9SAndroid Build Coastguard Worker        sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
1311*da0073e9SAndroid Build Coastguard Worker        # NB: mark_dynamic() on outer tensor should translate to inner tensors of the same size
1312*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(sub1, 0)
1313*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(sub1, 1)
1314*da0073e9SAndroid Build Coastguard Worker        sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
1315*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
1316*da0073e9SAndroid Build Coastguard Worker
1317*da0073e9SAndroid Build Coastguard Worker    def test_wrapper_subclass_with_differently_sized_inner_tensor(self):
1318*da0073e9SAndroid Build Coastguard Worker        # should recompile for different scale size when dynamic=False
1319*da0073e9SAndroid Build Coastguard Worker        sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
1320*da0073e9SAndroid Build Coastguard Worker        sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
1321*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker        # still recompiles using manual mark_dynamic() on outer for different scale size
1324*da0073e9SAndroid Build Coastguard Worker        sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
1325*da0073e9SAndroid Build Coastguard Worker        # NB: mark_dynamic() on outer tensor doesn't translate to inner tensors of different size
1326*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(sub1, 0)
1327*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(sub1, 1)
1328*da0073e9SAndroid Build Coastguard Worker        sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
1329*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
1330*da0073e9SAndroid Build Coastguard Worker
1331*da0073e9SAndroid Build Coastguard Worker    def test_recompiles_with_optional_inner_tensor(self):
1332*da0073e9SAndroid Build Coastguard Worker        def f(x):
1333*da0073e9SAndroid Build Coastguard Worker            return x + 1
1334*da0073e9SAndroid Build Coastguard Worker
1335*da0073e9SAndroid Build Coastguard Worker        # sub1 does not have the optional tensor specified while sub2 does
1336*da0073e9SAndroid Build Coastguard Worker        sub1 = OptionalScaledTensor(torch.randn(2, 4), None)
1337*da0073e9SAndroid Build Coastguard Worker        sub2 = OptionalScaledTensor(torch.randn(2, 4), torch.randn(2, 4))
1338*da0073e9SAndroid Build Coastguard Worker
1339*da0073e9SAndroid Build Coastguard Worker        # sanity check; don't recompile for same input
1340*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(_recompiles_for_inputs(f, (sub1,), (sub1,), dynamic=True))
1341*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(_recompiles_for_inputs(f, (sub2,), (sub2,), dynamic=True))
1342*da0073e9SAndroid Build Coastguard Worker
1343*da0073e9SAndroid Build Coastguard Worker        # these should recompile; optional tensor changes between specified and unspecified
1344*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(_recompiles_for_inputs(f, (sub1,), (sub2,), dynamic=True))
1345*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(_recompiles_for_inputs(f, (sub2,), (sub1,), dynamic=True))
1346*da0073e9SAndroid Build Coastguard Worker
1347*da0073e9SAndroid Build Coastguard Worker        f_compiled = torch.compile(f, backend="aot_eager")
1348*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(sub1)._data, f_compiled(sub1)._data)
1349*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(sub2)._data, f_compiled(sub2)._data)
1350*da0073e9SAndroid Build Coastguard Worker
1351*da0073e9SAndroid Build Coastguard Worker    def test_torch_dispatch_subclass_guard_recompile(self):
1352*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2)
1353*da0073e9SAndroid Build Coastguard Worker        x_two = TwoTensor(x.clone(), x.clone())
1354*da0073e9SAndroid Build Coastguard Worker
1355*da0073e9SAndroid Build Coastguard Worker        def fn(w):
1356*da0073e9SAndroid Build Coastguard Worker            return torch.add(w, 1.0)
1357*da0073e9SAndroid Build Coastguard Worker
1358*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(backend="eager")(fn)
1359*da0073e9SAndroid Build Coastguard Worker
1360*da0073e9SAndroid Build Coastguard Worker        ref = fn(x_two)
1361*da0073e9SAndroid Build Coastguard Worker        res = fn_opt(x_two)
1362*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1363*da0073e9SAndroid Build Coastguard Worker
1364*da0073e9SAndroid Build Coastguard Worker        # ensure no recompilation on same input type
1365*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
1366*da0073e9SAndroid Build Coastguard Worker            fn_opt(TwoTensor(x + 1, x + 2))
1367*da0073e9SAndroid Build Coastguard Worker
1368*da0073e9SAndroid Build Coastguard Worker        # recompile!
1369*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
1370*da0073e9SAndroid Build Coastguard Worker        res = fn_opt(x)
1371*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1372*da0073e9SAndroid Build Coastguard Worker
1373*da0073e9SAndroid Build Coastguard Worker    def test_tensor_subclass_ctx_guards(self):
1374*da0073e9SAndroid Build Coastguard Worker        x = CtxSubclassTensor(torch.ones(2), 3)
1375*da0073e9SAndroid Build Coastguard Worker        x2 = CtxSubclassTensor(torch.ones(2), 3)
1376*da0073e9SAndroid Build Coastguard Worker        x3 = CtxSubclassTensor(torch.ones(2), 4)
1377*da0073e9SAndroid Build Coastguard Worker        _check_recompiles(self, lambda x: x * x, (x,), (x2,), False)
1378*da0073e9SAndroid Build Coastguard Worker        _check_recompiles(self, lambda x: x * x, (x,), (x3,), True)
1379*da0073e9SAndroid Build Coastguard Worker
1380*da0073e9SAndroid Build Coastguard Worker    def test_tensor_subclass_ctx_recursive_guards(self):
1381*da0073e9SAndroid Build Coastguard Worker        x0 = torch.ones(2, 2)
1382*da0073e9SAndroid Build Coastguard Worker        x1 = CtxSubclassTensor(x0.clone(), 2)
1383*da0073e9SAndroid Build Coastguard Worker        x2 = CtxSubclassTensor(x0.clone(), 3)
1384*da0073e9SAndroid Build Coastguard Worker        tt0 = TwoTensor(x0.clone(), x1)
1385*da0073e9SAndroid Build Coastguard Worker        tt1 = TwoTensor(x0.clone(), x2)
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker        _check_recompiles(self, lambda x: x * x, (tt0,), (tt1,), True)
1388*da0073e9SAndroid Build Coastguard Worker
1389*da0073e9SAndroid Build Coastguard Worker    def test_tensor_subclass_ctx_custom_guards_override(self):
1390*da0073e9SAndroid Build Coastguard Worker        class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor):
1391*da0073e9SAndroid Build Coastguard Worker            @classmethod
1392*da0073e9SAndroid Build Coastguard Worker            def __metadata_guard__(cls, orig_data, other):
1393*da0073e9SAndroid Build Coastguard Worker                return orig_data[0] <= other[0]
1394*da0073e9SAndroid Build Coastguard Worker
1395*da0073e9SAndroid Build Coastguard Worker        x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 2)
1396*da0073e9SAndroid Build Coastguard Worker        x2 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3)
1397*da0073e9SAndroid Build Coastguard Worker        x3 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 1)
1398*da0073e9SAndroid Build Coastguard Worker        _check_recompiles(self, lambda x: x * x, (x,), (x2,), False)
1399*da0073e9SAndroid Build Coastguard Worker        _check_recompiles(self, lambda x: x * x, (x,), (x3,), True)
1400*da0073e9SAndroid Build Coastguard Worker
1401*da0073e9SAndroid Build Coastguard Worker    def test_tensor_subclass_ctx_custom_guards_error_arg_num(self):
1402*da0073e9SAndroid Build Coastguard Worker        import torch._dynamo.exc
1403*da0073e9SAndroid Build Coastguard Worker
1404*da0073e9SAndroid Build Coastguard Worker        class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor):
1405*da0073e9SAndroid Build Coastguard Worker            @classmethod
1406*da0073e9SAndroid Build Coastguard Worker            def __metadata_guard__(cls, y):
1407*da0073e9SAndroid Build Coastguard Worker                # Shouldn't reach here
1408*da0073e9SAndroid Build Coastguard Worker                return False
1409*da0073e9SAndroid Build Coastguard Worker
1410*da0073e9SAndroid Build Coastguard Worker        x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3)
1411*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1412*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.exc.InternalTorchDynamoError,
1413*da0073e9SAndroid Build Coastguard Worker            "Tensor subclass method __metadata_guard__ must take exactly two subclass metadata arguments",
1414*da0073e9SAndroid Build Coastguard Worker            lambda: torch.compile(lambda x: x * x)(x),
1415*da0073e9SAndroid Build Coastguard Worker        )
1416*da0073e9SAndroid Build Coastguard Worker
1417*da0073e9SAndroid Build Coastguard Worker    def test_tensor_subclass_ctx_custom_guards_error_not_classmethod(self):
1418*da0073e9SAndroid Build Coastguard Worker        import torch._dynamo.exc
1419*da0073e9SAndroid Build Coastguard Worker
1420*da0073e9SAndroid Build Coastguard Worker        class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor):
1421*da0073e9SAndroid Build Coastguard Worker            def __metadata_guard__(self, x, y):
1422*da0073e9SAndroid Build Coastguard Worker                return False
1423*da0073e9SAndroid Build Coastguard Worker
1424*da0073e9SAndroid Build Coastguard Worker        x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3)
1425*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1426*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.exc.InternalTorchDynamoError,
1427*da0073e9SAndroid Build Coastguard Worker            "Tensor subclass method __metadata_guard__ must be a classmethod",
1428*da0073e9SAndroid Build Coastguard Worker            lambda: torch.compile(lambda x: x * x)(x),
1429*da0073e9SAndroid Build Coastguard Worker        )
1430*da0073e9SAndroid Build Coastguard Worker
1431*da0073e9SAndroid Build Coastguard Worker    def test_subclass_constructor_proxying(self):
1432*da0073e9SAndroid Build Coastguard Worker        import dataclasses
1433*da0073e9SAndroid Build Coastguard Worker        from collections import namedtuple
1434*da0073e9SAndroid Build Coastguard Worker        from typing import Any
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker        @dataclasses.dataclass(frozen=True)
1437*da0073e9SAndroid Build Coastguard Worker        class SubclassTensorArgs:
1438*da0073e9SAndroid Build Coastguard Worker            original_shape: torch.Size
1439*da0073e9SAndroid Build Coastguard Worker            device: torch.device
1440*da0073e9SAndroid Build Coastguard Worker            inner_meta: Any
1441*da0073e9SAndroid Build Coastguard Worker
1442*da0073e9SAndroid Build Coastguard Worker        SubclassTensorArgs2 = namedtuple(
1443*da0073e9SAndroid Build Coastguard Worker            "SubclassTensorArgs2",
1444*da0073e9SAndroid Build Coastguard Worker            [
1445*da0073e9SAndroid Build Coastguard Worker                "original_shape",
1446*da0073e9SAndroid Build Coastguard Worker                "device",
1447*da0073e9SAndroid Build Coastguard Worker                "inner_meta",
1448*da0073e9SAndroid Build Coastguard Worker            ],
1449*da0073e9SAndroid Build Coastguard Worker        )
1450*da0073e9SAndroid Build Coastguard Worker
1451*da0073e9SAndroid Build Coastguard Worker        class SubclassTensor(torch.Tensor):
1452*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1453*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, a, meta):
1454*da0073e9SAndroid Build Coastguard Worker                shape = a.shape
1455*da0073e9SAndroid Build Coastguard Worker                kwargs = {}
1456*da0073e9SAndroid Build Coastguard Worker                kwargs["strides"] = a.stride()
1457*da0073e9SAndroid Build Coastguard Worker                kwargs["storage_offset"] = a.storage_offset()
1458*da0073e9SAndroid Build Coastguard Worker                kwargs["device"] = a.device
1459*da0073e9SAndroid Build Coastguard Worker                kwargs["layout"] = a.layout
1460*da0073e9SAndroid Build Coastguard Worker                kwargs["requires_grad"] = a.requires_grad
1461*da0073e9SAndroid Build Coastguard Worker                kwargs["dtype"] = a.dtype
1462*da0073e9SAndroid Build Coastguard Worker                out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
1463*da0073e9SAndroid Build Coastguard Worker                return out
1464*da0073e9SAndroid Build Coastguard Worker
1465*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a, meta):
1466*da0073e9SAndroid Build Coastguard Worker                self.a = a
1467*da0073e9SAndroid Build Coastguard Worker                self.meta = meta
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Worker            def __repr__(self):
1470*da0073e9SAndroid Build Coastguard Worker                a_repr = repr(self.a)
1471*da0073e9SAndroid Build Coastguard Worker                return f"SubclassTensor({a_repr})"
1472*da0073e9SAndroid Build Coastguard Worker
1473*da0073e9SAndroid Build Coastguard Worker            def __tensor_flatten__(self):
1474*da0073e9SAndroid Build Coastguard Worker                return ["a"], self.meta
1475*da0073e9SAndroid Build Coastguard Worker
1476*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1477*da0073e9SAndroid Build Coastguard Worker            def __tensor_unflatten__(inner_tensors, meta, _, __):
1478*da0073e9SAndroid Build Coastguard Worker                a = inner_tensors["a"]
1479*da0073e9SAndroid Build Coastguard Worker                return SubclassTensor(a, meta)
1480*da0073e9SAndroid Build Coastguard Worker
1481*da0073e9SAndroid Build Coastguard Worker            @classmethod
1482*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args, kwargs):
1483*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1484*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1485*da0073e9SAndroid Build Coastguard Worker                args_a = pytree.tree_map(
1486*da0073e9SAndroid Build Coastguard Worker                    lambda x: x.a if isinstance(x, SubclassTensor) else x, args
1487*da0073e9SAndroid Build Coastguard Worker                )
1488*da0073e9SAndroid Build Coastguard Worker                kwargs_a = pytree.tree_map(
1489*da0073e9SAndroid Build Coastguard Worker                    lambda x: x.a if isinstance(x, SubclassTensor) else x, kwargs
1490*da0073e9SAndroid Build Coastguard Worker                )
1491*da0073e9SAndroid Build Coastguard Worker                out_a = func(*args_a, **kwargs_a)
1492*da0073e9SAndroid Build Coastguard Worker                out = pytree.tree_map(
1493*da0073e9SAndroid Build Coastguard Worker                    lambda x: SubclassTensor(
1494*da0073e9SAndroid Build Coastguard Worker                        x, SubclassTensorArgs2(x.shape, x.device, None)
1495*da0073e9SAndroid Build Coastguard Worker                    )
1496*da0073e9SAndroid Build Coastguard Worker                    if isinstance(x, torch.Tensor)
1497*da0073e9SAndroid Build Coastguard Worker                    else x,
1498*da0073e9SAndroid Build Coastguard Worker                    out_a,
1499*da0073e9SAndroid Build Coastguard Worker                )
1500*da0073e9SAndroid Build Coastguard Worker                return return_and_correct_aliasing(func, args, kwargs, out)
1501*da0073e9SAndroid Build Coastguard Worker
1502*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True)
1503*da0073e9SAndroid Build Coastguard Worker        def f1(x):
1504*da0073e9SAndroid Build Coastguard Worker            meta = SubclassTensorArgs(
1505*da0073e9SAndroid Build Coastguard Worker                x.shape, x.device, SubclassTensorArgs(x.shape, x.device, None)
1506*da0073e9SAndroid Build Coastguard Worker            )
1507*da0073e9SAndroid Build Coastguard Worker            out = SubclassTensor(x, meta)
1508*da0073e9SAndroid Build Coastguard Worker            return out * out
1509*da0073e9SAndroid Build Coastguard Worker
1510*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3)
1511*da0073e9SAndroid Build Coastguard Worker        f1(x)
1512*da0073e9SAndroid Build Coastguard Worker
1513*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True)
1514*da0073e9SAndroid Build Coastguard Worker        def f1(x):
1515*da0073e9SAndroid Build Coastguard Worker            meta = SubclassTensorArgs2(
1516*da0073e9SAndroid Build Coastguard Worker                x.shape, x.device, SubclassTensorArgs2(x.shape, x.device, None)
1517*da0073e9SAndroid Build Coastguard Worker            )
1518*da0073e9SAndroid Build Coastguard Worker            out = SubclassTensor(x, meta)
1519*da0073e9SAndroid Build Coastguard Worker            return out * out
1520*da0073e9SAndroid Build Coastguard Worker
1521*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3)
1522*da0073e9SAndroid Build Coastguard Worker        f1(x)
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_subclass_survives_into_aot_autograd(self):
1525*da0073e9SAndroid Build Coastguard Worker        # If you have a tensor subclass that relies on dispatch into the same op
1526*da0073e9SAndroid Build Coastguard Worker        # without unwrapping and calling torch._C.DisableTorchFunctionSubclass(),
1527*da0073e9SAndroid Build Coastguard Worker        # the torch function-ness will survive into AOTAutograd. Today, NestedTensor
1528*da0073e9SAndroid Build Coastguard Worker        # actually relies on this behavior! Because that torch function logic
1529*da0073e9SAndroid Build Coastguard Worker        # runs during AOTAutograd, this test tests that there is no logic below
1530*da0073e9SAndroid Build Coastguard Worker        # that relies torch function that gets unexpectedly disabled after we
1531*da0073e9SAndroid Build Coastguard Worker        # redispatch from the subclass's torch function.
1532*da0073e9SAndroid Build Coastguard Worker        class SubTensor(torch.Tensor):
1533*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1534*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, t):
1535*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_wrapper_subclass(
1536*da0073e9SAndroid Build Coastguard Worker                    cls,
1537*da0073e9SAndroid Build Coastguard Worker                    t.shape,
1538*da0073e9SAndroid Build Coastguard Worker                    t.stride(),
1539*da0073e9SAndroid Build Coastguard Worker                    t.storage_offset(),
1540*da0073e9SAndroid Build Coastguard Worker                    torch.contiguous_format,
1541*da0073e9SAndroid Build Coastguard Worker                    t.dtype,
1542*da0073e9SAndroid Build Coastguard Worker                    torch.strided,
1543*da0073e9SAndroid Build Coastguard Worker                    t.device,
1544*da0073e9SAndroid Build Coastguard Worker                    False,
1545*da0073e9SAndroid Build Coastguard Worker                    t.requires_grad,
1546*da0073e9SAndroid Build Coastguard Worker                    "sizes",
1547*da0073e9SAndroid Build Coastguard Worker                    False,
1548*da0073e9SAndroid Build Coastguard Worker                    False,
1549*da0073e9SAndroid Build Coastguard Worker                    None,
1550*da0073e9SAndroid Build Coastguard Worker                )
1551*da0073e9SAndroid Build Coastguard Worker
1552*da0073e9SAndroid Build Coastguard Worker            def __init__(self, t):
1553*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1554*da0073e9SAndroid Build Coastguard Worker                self._t = t
1555*da0073e9SAndroid Build Coastguard Worker
1556*da0073e9SAndroid Build Coastguard Worker            def __tensor_flatten__(self):
1557*da0073e9SAndroid Build Coastguard Worker                return ["_t"], {}
1558*da0073e9SAndroid Build Coastguard Worker
1559*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1560*da0073e9SAndroid Build Coastguard Worker            def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
1561*da0073e9SAndroid Build Coastguard Worker                t = inner_tensors["_t"]
1562*da0073e9SAndroid Build Coastguard Worker                return SubTensor(t)
1563*da0073e9SAndroid Build Coastguard Worker
1564*da0073e9SAndroid Build Coastguard Worker            def __repr__(self):
1565*da0073e9SAndroid Build Coastguard Worker                return f"SubTensor({self._t})"
1566*da0073e9SAndroid Build Coastguard Worker
1567*da0073e9SAndroid Build Coastguard Worker            @classmethod
1568*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
1569*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1570*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1571*da0073e9SAndroid Build Coastguard Worker
1572*da0073e9SAndroid Build Coastguard Worker                with torch._C.DisableTorchFunctionSubclass():
1573*da0073e9SAndroid Build Coastguard Worker                    return func(*args, **kwargs)
1574*da0073e9SAndroid Build Coastguard Worker
1575*da0073e9SAndroid Build Coastguard Worker            @classmethod
1576*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1577*da0073e9SAndroid Build Coastguard Worker                kwargs = {} if kwargs is None else kwargs
1578*da0073e9SAndroid Build Coastguard Worker                new_args = pytree.tree_map_only(SubTensor, lambda s: s._t, args)
1579*da0073e9SAndroid Build Coastguard Worker                output = func(*new_args, **kwargs)
1580*da0073e9SAndroid Build Coastguard Worker                output = pytree.tree_map_only(
1581*da0073e9SAndroid Build Coastguard Worker                    torch.Tensor, lambda t: SubTensor(t), output
1582*da0073e9SAndroid Build Coastguard Worker                )
1583*da0073e9SAndroid Build Coastguard Worker                return output
1584*da0073e9SAndroid Build Coastguard Worker
1585*da0073e9SAndroid Build Coastguard Worker        @torch.compile(dynamic=True)
1586*da0073e9SAndroid Build Coastguard Worker        def f(x):
1587*da0073e9SAndroid Build Coastguard Worker            return x.unflatten(-1, [2, 5])
1588*da0073e9SAndroid Build Coastguard Worker
1589*da0073e9SAndroid Build Coastguard Worker        s = SubTensor(torch.randn(3, 10))
1590*da0073e9SAndroid Build Coastguard Worker        f(s)
1591*da0073e9SAndroid Build Coastguard Worker
1592*da0073e9SAndroid Build Coastguard Worker    # Guard validation upsets the guard
1593*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/129936
1594*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
1595*da0073e9SAndroid Build Coastguard Worker    def test_recompile_with_symbool_inputs(self):
1596*da0073e9SAndroid Build Coastguard Worker        def f(pred: bool):
1597*da0073e9SAndroid Build Coastguard Worker            if pred:
1598*da0073e9SAndroid Build Coastguard Worker                return torch.ones([3, 4])
1599*da0073e9SAndroid Build Coastguard Worker            else:
1600*da0073e9SAndroid Build Coastguard Worker                return torch.ones([4, 3])
1601*da0073e9SAndroid Build Coastguard Worker
1602*da0073e9SAndroid Build Coastguard Worker        def test_recompilation(
1603*da0073e9SAndroid Build Coastguard Worker            f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards
1604*da0073e9SAndroid Build Coastguard Worker        ):
1605*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.reset()
1606*da0073e9SAndroid Build Coastguard Worker            shape_env = ShapeEnv()
1607*da0073e9SAndroid Build Coastguard Worker            backend = torch._dynamo.testing.EagerAndRecordGraphs()
1608*da0073e9SAndroid Build Coastguard Worker            cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
1609*da0073e9SAndroid Build Coastguard Worker            f_cond = torch.compile(f, backend=cnt, fullgraph=True)
1610*da0073e9SAndroid Build Coastguard Worker            with torch._subclasses.fake_tensor.FakeTensorMode(
1611*da0073e9SAndroid Build Coastguard Worker                shape_env=shape_env
1612*da0073e9SAndroid Build Coastguard Worker            ) as fake_mode:
1613*da0073e9SAndroid Build Coastguard Worker                fake_inp = fake_mode.from_tensor(
1614*da0073e9SAndroid Build Coastguard Worker                    x,
1615*da0073e9SAndroid Build Coastguard Worker                    symbolic_context=StatelessSymbolicContext(
1616*da0073e9SAndroid Build Coastguard Worker                        dynamic_sizes=[DimDynamic.DYNAMIC for i in range(x.dim())]
1617*da0073e9SAndroid Build Coastguard Worker                    ),
1618*da0073e9SAndroid Build Coastguard Worker                )
1619*da0073e9SAndroid Build Coastguard Worker                for i, size in enumerate(sizes):
1620*da0073e9SAndroid Build Coastguard Worker                    pred = fake_inp.size(0) == size
1621*da0073e9SAndroid Build Coastguard Worker                    f_cond(pred)
1622*da0073e9SAndroid Build Coastguard Worker                    actual = normalize_gm(
1623*da0073e9SAndroid Build Coastguard Worker                        backend.graphs[exp_frame_count[i] - 1].print_readable(
1624*da0073e9SAndroid Build Coastguard Worker                            print_output=False
1625*da0073e9SAndroid Build Coastguard Worker                        )
1626*da0073e9SAndroid Build Coastguard Worker                    )
1627*da0073e9SAndroid Build Coastguard Worker                    actual_guard_str = [str(guard.expr) for guard in shape_env.guards]
1628*da0073e9SAndroid Build Coastguard Worker                    self.assertExpectedInline(actual, exp_graphs[i])
1629*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(cnt.frame_count, exp_frame_count[i])
1630*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(actual_guard_str, exp_shape_env_guards[i])
1631*da0073e9SAndroid Build Coastguard Worker
1632*da0073e9SAndroid Build Coastguard Worker        true_graph = """\
1633*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
1634*da0073e9SAndroid Build Coastguard Worker    def forward(self):
1635*da0073e9SAndroid Build Coastguard Worker        ones: "f32[3, 4]" = torch.ones([3, 4])
1636*da0073e9SAndroid Build Coastguard Worker        return (ones,)
1637*da0073e9SAndroid Build Coastguard Worker"""
1638*da0073e9SAndroid Build Coastguard Worker        false_graph = """\
1639*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
1640*da0073e9SAndroid Build Coastguard Worker    def forward(self):
1641*da0073e9SAndroid Build Coastguard Worker        ones: "f32[4, 3]" = torch.ones([4, 3])
1642*da0073e9SAndroid Build Coastguard Worker        return (ones,)
1643*da0073e9SAndroid Build Coastguard Worker"""
1644*da0073e9SAndroid Build Coastguard Worker        test_recompilation(
1645*da0073e9SAndroid Build Coastguard Worker            f,
1646*da0073e9SAndroid Build Coastguard Worker            torch.randn([3, 4]),
1647*da0073e9SAndroid Build Coastguard Worker            [3, 3, 4, 5],
1648*da0073e9SAndroid Build Coastguard Worker            exp_graphs=[true_graph, true_graph, false_graph, false_graph],
1649*da0073e9SAndroid Build Coastguard Worker            exp_frame_count=[1, 1, 2, 2],
1650*da0073e9SAndroid Build Coastguard Worker            exp_shape_env_guards=[
1651*da0073e9SAndroid Build Coastguard Worker                [],
1652*da0073e9SAndroid Build Coastguard Worker                # s0 is specialized and guarded in outter shape_env when dynamo checks the guards
1653*da0073e9SAndroid Build Coastguard Worker                ["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"],
1654*da0073e9SAndroid Build Coastguard Worker                [
1655*da0073e9SAndroid Build Coastguard Worker                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
1656*da0073e9SAndroid Build Coastguard Worker                    "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
1657*da0073e9SAndroid Build Coastguard Worker                ],
1658*da0073e9SAndroid Build Coastguard Worker                [
1659*da0073e9SAndroid Build Coastguard Worker                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
1660*da0073e9SAndroid Build Coastguard Worker                    "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
1661*da0073e9SAndroid Build Coastguard Worker                    "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
1662*da0073e9SAndroid Build Coastguard Worker                ],
1663*da0073e9SAndroid Build Coastguard Worker            ],
1664*da0073e9SAndroid Build Coastguard Worker        )
1665*da0073e9SAndroid Build Coastguard Worker
1666*da0073e9SAndroid Build Coastguard Worker        test_recompilation(
1667*da0073e9SAndroid Build Coastguard Worker            f,
1668*da0073e9SAndroid Build Coastguard Worker            torch.randn([3, 4]),
1669*da0073e9SAndroid Build Coastguard Worker            [4, 5, 3, 3],
1670*da0073e9SAndroid Build Coastguard Worker            exp_graphs=[false_graph, false_graph, true_graph, true_graph],
1671*da0073e9SAndroid Build Coastguard Worker            exp_frame_count=[1, 1, 2, 2],
1672*da0073e9SAndroid Build Coastguard Worker            exp_shape_env_guards=[
1673*da0073e9SAndroid Build Coastguard Worker                [],
1674*da0073e9SAndroid Build Coastguard Worker                # s0 is specialized and guarded in outter shape_env when dynamo checks the guards
1675*da0073e9SAndroid Build Coastguard Worker                ["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"],
1676*da0073e9SAndroid Build Coastguard Worker                [
1677*da0073e9SAndroid Build Coastguard Worker                    "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
1678*da0073e9SAndroid Build Coastguard Worker                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
1679*da0073e9SAndroid Build Coastguard Worker                ],
1680*da0073e9SAndroid Build Coastguard Worker                [
1681*da0073e9SAndroid Build Coastguard Worker                    "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
1682*da0073e9SAndroid Build Coastguard Worker                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
1683*da0073e9SAndroid Build Coastguard Worker                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
1684*da0073e9SAndroid Build Coastguard Worker                ],
1685*da0073e9SAndroid Build Coastguard Worker            ],
1686*da0073e9SAndroid Build Coastguard Worker        )
1687*da0073e9SAndroid Build Coastguard Worker
1688*da0073e9SAndroid Build Coastguard Worker    def test_wrapper_subclass_dynamo_attribute_access_on_intermediate(self):
1689*da0073e9SAndroid Build Coastguard Worker        def f(x_subclass):
1690*da0073e9SAndroid Build Coastguard Worker            tmp_subclass = torch.add(x, 1)
1691*da0073e9SAndroid Build Coastguard Worker            return torch.mul(tmp_subclass._scale, tmp_subclass._constant)
1692*da0073e9SAndroid Build Coastguard Worker
1693*da0073e9SAndroid Build Coastguard Worker        x = ScaledTensor(torch.randn(2, 4), torch.randn(3), constant=2)
1694*da0073e9SAndroid Build Coastguard Worker        out_ref = f(x)
1695*da0073e9SAndroid Build Coastguard Worker        out_test = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
1696*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
1697*da0073e9SAndroid Build Coastguard Worker
1698*da0073e9SAndroid Build Coastguard Worker    def test_support_bases(self):
1699*da0073e9SAndroid Build Coastguard Worker        import abc
1700*da0073e9SAndroid Build Coastguard Worker
1701*da0073e9SAndroid Build Coastguard Worker        import torch.fx._symbolic_trace
1702*da0073e9SAndroid Build Coastguard Worker
1703*da0073e9SAndroid Build Coastguard Worker        class Meta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta):
1704*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, name, bases, dct):
1705*da0073e9SAndroid Build Coastguard Worker                x = super().__new__(cls, name, bases, dct)
1706*da0073e9SAndroid Build Coastguard Worker                x.attr = 100
1707*da0073e9SAndroid Build Coastguard Worker                return x
1708*da0073e9SAndroid Build Coastguard Worker
1709*da0073e9SAndroid Build Coastguard Worker        class Multistreamable(abc.ABC):  # noqa: B024
1710*da0073e9SAndroid Build Coastguard Worker            pass
1711*da0073e9SAndroid Build Coastguard Worker
1712*da0073e9SAndroid Build Coastguard Worker        class Foo(Multistreamable, metaclass=Meta):
1713*da0073e9SAndroid Build Coastguard Worker            pass
1714*da0073e9SAndroid Build Coastguard Worker
1715*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
1716*da0073e9SAndroid Build Coastguard Worker        def f(x):
1717*da0073e9SAndroid Build Coastguard Worker            typ = type(Foo())
1718*da0073e9SAndroid Build Coastguard Worker            typ.__bases__
1719*da0073e9SAndroid Build Coastguard Worker            return typ.__bases__
1720*da0073e9SAndroid Build Coastguard Worker
1721*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(torch.randn(1)), (Multistreamable,))
1722*da0073e9SAndroid Build Coastguard Worker
1723*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
1724*da0073e9SAndroid Build Coastguard Worker        def g(x):
1725*da0073e9SAndroid Build Coastguard Worker            typ = type(Foo())
1726*da0073e9SAndroid Build Coastguard Worker            typ.__base__
1727*da0073e9SAndroid Build Coastguard Worker            return typ.__base__
1728*da0073e9SAndroid Build Coastguard Worker
1729*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g(torch.randn(1)), Multistreamable)
1730*da0073e9SAndroid Build Coastguard Worker
1731*da0073e9SAndroid Build Coastguard Worker    @parametrize("dynamic", [False, True])
1732*da0073e9SAndroid Build Coastguard Worker    def test_subclass_views(self, dynamic):
1733*da0073e9SAndroid Build Coastguard Worker        def _get_views(t):  # returns (view: Tensor, expects_raises_false)
1734*da0073e9SAndroid Build Coastguard Worker            # Note that any closed-over SymInts will be symbolicized during fake-ification.
1735*da0073e9SAndroid Build Coastguard Worker            yield t.narrow(dim=-1, start=3, length=8), False
1736*da0073e9SAndroid Build Coastguard Worker            yield t.split(5, -1)[2], False
1737*da0073e9SAndroid Build Coastguard Worker            yield t.split_with_sizes([9, 6], -1)[1], False
1738*da0073e9SAndroid Build Coastguard Worker            yield t.unsqueeze(-1).expand(4, 15, 10), False
1739*da0073e9SAndroid Build Coastguard Worker            yield t.select(-1, 6), False
1740*da0073e9SAndroid Build Coastguard Worker            # https://github.com/pytorch/pytorch/issues/128649
1741*da0073e9SAndroid Build Coastguard Worker            yield t[2:3, 5:9], dynamic
1742*da0073e9SAndroid Build Coastguard Worker            yield t.view(-1, 15), False
1743*da0073e9SAndroid Build Coastguard Worker
1744*da0073e9SAndroid Build Coastguard Worker        def f(x):
1745*da0073e9SAndroid Build Coastguard Worker            return x * 2
1746*da0073e9SAndroid Build Coastguard Worker
1747*da0073e9SAndroid Build Coastguard Worker        compiled_f = torch.compile(
1748*da0073e9SAndroid Build Coastguard Worker            f, backend="aot_eager", fullgraph=True, dynamic=dynamic
1749*da0073e9SAndroid Build Coastguard Worker        )
1750*da0073e9SAndroid Build Coastguard Worker
1751*da0073e9SAndroid Build Coastguard Worker        # Take a view of a subclass to pass as input.
1752*da0073e9SAndroid Build Coastguard Worker        t = TwoTensor(torch.randn(4, 15), torch.randn(4, 15))
1753*da0073e9SAndroid Build Coastguard Worker        for view, expects_raises in _get_views(t):
1754*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.reset()
1755*da0073e9SAndroid Build Coastguard Worker            out_ref = f(view)
1756*da0073e9SAndroid Build Coastguard Worker            if expects_raises:
1757*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(AssertionError):
1758*da0073e9SAndroid Build Coastguard Worker                    out_test = compiled_f(view)
1759*da0073e9SAndroid Build Coastguard Worker            else:
1760*da0073e9SAndroid Build Coastguard Worker                out_test = compiled_f(view)
1761*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out_ref, out_test)
1762*da0073e9SAndroid Build Coastguard Worker
1763*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
1764*da0073e9SAndroid Build Coastguard Worker    def test_mark_static_with_subclass_desugaring(self):
1765*da0073e9SAndroid Build Coastguard Worker        from typing import Any, Callable, Dict, List, Optional
1766*da0073e9SAndroid Build Coastguard Worker
1767*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.decorators import mark_static_address
1768*da0073e9SAndroid Build Coastguard Worker        from torch._inductor.compile_fx import compile_fx
1769*da0073e9SAndroid Build Coastguard Worker        from torch._inductor.cudagraph_utils import BoxedDeviceIndex
1770*da0073e9SAndroid Build Coastguard Worker        from torch._inductor.utils import BoxedBool
1771*da0073e9SAndroid Build Coastguard Worker
1772*da0073e9SAndroid Build Coastguard Worker        x_inner = torch.ones(4)
1773*da0073e9SAndroid Build Coastguard Worker        x = TwoTensor(x_inner, x_inner)
1774*da0073e9SAndroid Build Coastguard Worker        mark_static_address(x, guard=False)
1775*da0073e9SAndroid Build Coastguard Worker
1776*da0073e9SAndroid Build Coastguard Worker        def inner_compile(
1777*da0073e9SAndroid Build Coastguard Worker            gm: torch.fx.GraphModule,
1778*da0073e9SAndroid Build Coastguard Worker            example_inputs: List[torch.Tensor],
1779*da0073e9SAndroid Build Coastguard Worker            cudagraphs: Optional[BoxedBool] = None,
1780*da0073e9SAndroid Build Coastguard Worker            static_input_idxs: Optional[List[int]] = None,
1781*da0073e9SAndroid Build Coastguard Worker            is_backward: bool = False,
1782*da0073e9SAndroid Build Coastguard Worker            graph_id: Optional[int] = None,
1783*da0073e9SAndroid Build Coastguard Worker            cpp_wrapper: bool = False,
1784*da0073e9SAndroid Build Coastguard Worker            aot_mode: bool = False,
1785*da0073e9SAndroid Build Coastguard Worker            is_inference: bool = False,
1786*da0073e9SAndroid Build Coastguard Worker            boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
1787*da0073e9SAndroid Build Coastguard Worker            user_visible_outputs: Optional[Dict[str, None]] = None,
1788*da0073e9SAndroid Build Coastguard Worker            layout_opt: Optional[bool] = None,
1789*da0073e9SAndroid Build Coastguard Worker            extern_node_serializer: Optional[Callable[[List[Any]], Any]] = None,
1790*da0073e9SAndroid Build Coastguard Worker        ):
1791*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(static_input_idxs, [1, 2])
1792*da0073e9SAndroid Build Coastguard Worker            return gm
1793*da0073e9SAndroid Build Coastguard Worker
1794*da0073e9SAndroid Build Coastguard Worker        compiler = functools.partial(compile_fx, inner_compile=inner_compile)
1795*da0073e9SAndroid Build Coastguard Worker
1796*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=compiler)
1797*da0073e9SAndroid Build Coastguard Worker        def fn(t0, t1, t2):
1798*da0073e9SAndroid Build Coastguard Worker            return t0 + t1 + t2 + 2
1799*da0073e9SAndroid Build Coastguard Worker
1800*da0073e9SAndroid Build Coastguard Worker        fn(torch.ones(4), x, torch.ones(4))
1801*da0073e9SAndroid Build Coastguard Worker
1802*da0073e9SAndroid Build Coastguard Worker
1803*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(SubclassTests)
1804*da0073e9SAndroid Build Coastguard Worker
1805*da0073e9SAndroid Build Coastguard Worker
1806*da0073e9SAndroid Build Coastguard Workerclass TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
1807*da0073e9SAndroid Build Coastguard Worker    def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True):
1808*da0073e9SAndroid Build Coastguard Worker        return get_jagged_tensor(nested_size, offsets, requires_grad)
1809*da0073e9SAndroid Build Coastguard Worker
1810*da0073e9SAndroid Build Coastguard Worker    def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True):
1811*da0073e9SAndroid Build Coastguard Worker        # Makes a jagged tensor with N constituent tensors with size
1812*da0073e9SAndroid Build Coastguard Worker        # as specified ((S0, S1, S2), D)
1813*da0073e9SAndroid Build Coastguard Worker        max_dim = (starts + lengths).max()
1814*da0073e9SAndroid Build Coastguard Worker        values_tensor = torch.randn(
1815*da0073e9SAndroid Build Coastguard Worker            starts.shape[0],
1816*da0073e9SAndroid Build Coastguard Worker            max_dim.item(),
1817*da0073e9SAndroid Build Coastguard Worker            inner_dim,
1818*da0073e9SAndroid Build Coastguard Worker            requires_grad=requires_grad,
1819*da0073e9SAndroid Build Coastguard Worker            dtype=torch.float64,
1820*da0073e9SAndroid Build Coastguard Worker        )
1821*da0073e9SAndroid Build Coastguard Worker        return jagged_from_tensor_and_lengths(values_tensor, starts, lengths)
1822*da0073e9SAndroid Build Coastguard Worker
1823*da0073e9SAndroid Build Coastguard Worker    def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
1824*da0073e9SAndroid Build Coastguard Worker        _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles)
1825*da0073e9SAndroid Build Coastguard Worker
1826*da0073e9SAndroid Build Coastguard Worker    def test_unary_does_not_recompile(self):
1827*da0073e9SAndroid Build Coastguard Worker        nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
1828*da0073e9SAndroid Build Coastguard Worker        nt2, _ = self._get_jagged_tensor(((3, 4, 5, 6), 4), None)
1829*da0073e9SAndroid Build Coastguard Worker        self._check_recompiles(lambda nt1: nt1.sin(), (nt1,), (nt2,), False)
1830*da0073e9SAndroid Build Coastguard Worker
1831*da0073e9SAndroid Build Coastguard Worker    def test_binary_does_not_recompile(self):
1832*da0073e9SAndroid Build Coastguard Worker        def binary(nt1, nt2):
1833*da0073e9SAndroid Build Coastguard Worker            if nt1.shape == nt2.shape:
1834*da0073e9SAndroid Build Coastguard Worker                return nt1 + nt2
1835*da0073e9SAndroid Build Coastguard Worker            else:
1836*da0073e9SAndroid Build Coastguard Worker                return nt1.sin()
1837*da0073e9SAndroid Build Coastguard Worker
1838*da0073e9SAndroid Build Coastguard Worker        # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
1839*da0073e9SAndroid Build Coastguard Worker        # This causes a recompile later on when it realizes the batch and last dim
1840*da0073e9SAndroid Build Coastguard Worker        # should not always be equal. To avoid that, we use (3, j0, 5) here.
1841*da0073e9SAndroid Build Coastguard Worker        nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
1842*da0073e9SAndroid Build Coastguard Worker        nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets)
1843*da0073e9SAndroid Build Coastguard Worker        nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None)
1844*da0073e9SAndroid Build Coastguard Worker        nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets)
1845*da0073e9SAndroid Build Coastguard Worker        self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False)
1846*da0073e9SAndroid Build Coastguard Worker
1847*da0073e9SAndroid Build Coastguard Worker    def test_binary_recompiles(self):
1848*da0073e9SAndroid Build Coastguard Worker        def binary(nt1, nt2):
1849*da0073e9SAndroid Build Coastguard Worker            if nt1.shape == nt2.shape:
1850*da0073e9SAndroid Build Coastguard Worker                return nt1 + nt2
1851*da0073e9SAndroid Build Coastguard Worker            else:
1852*da0073e9SAndroid Build Coastguard Worker                return nt1.sin()
1853*da0073e9SAndroid Build Coastguard Worker
1854*da0073e9SAndroid Build Coastguard Worker        # Binary recompiles because singleton ints no longer match
1855*da0073e9SAndroid Build Coastguard Worker        nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
1856*da0073e9SAndroid Build Coastguard Worker        nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets)
1857*da0073e9SAndroid Build Coastguard Worker        nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
1858*da0073e9SAndroid Build Coastguard Worker        self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
1859*da0073e9SAndroid Build Coastguard Worker
1860*da0073e9SAndroid Build Coastguard Worker    def _validate_compile(self, fn, arg_fn):
1861*da0073e9SAndroid Build Coastguard Worker        def _gen_grad_outputs(out_val):
1862*da0073e9SAndroid Build Coastguard Worker            if isinstance(out_val, (list, tuple)):
1863*da0073e9SAndroid Build Coastguard Worker                return tuple(torch.ones_like(c) for c in out_val)
1864*da0073e9SAndroid Build Coastguard Worker            else:
1865*da0073e9SAndroid Build Coastguard Worker                return (torch.ones_like(out_val),)
1866*da0073e9SAndroid Build Coastguard Worker
1867*da0073e9SAndroid Build Coastguard Worker        with self.branch_nested_state():
1868*da0073e9SAndroid Build Coastguard Worker            from torch.nested._internal.nested_tensor import _tensor_symint_registry
1869*da0073e9SAndroid Build Coastguard Worker
1870*da0073e9SAndroid Build Coastguard Worker            # Validate that compilation does not modify eager state
1871*da0073e9SAndroid Build Coastguard Worker            registry_before = list(_tensor_symint_registry.items())
1872*da0073e9SAndroid Build Coastguard Worker            count_before = torch.nested._internal.nested_tensor._tensor_id_counter
1873*da0073e9SAndroid Build Coastguard Worker
1874*da0073e9SAndroid Build Coastguard Worker            guards_exported = []
1875*da0073e9SAndroid Build Coastguard Worker            guards_failed = []
1876*da0073e9SAndroid Build Coastguard Worker
1877*da0073e9SAndroid Build Coastguard Worker            def append_guard_export(guards):
1878*da0073e9SAndroid Build Coastguard Worker                for g in guards:
1879*da0073e9SAndroid Build Coastguard Worker                    if g.code_list is not None:
1880*da0073e9SAndroid Build Coastguard Worker                        guards_exported.append(g.code_list[0])
1881*da0073e9SAndroid Build Coastguard Worker
1882*da0073e9SAndroid Build Coastguard Worker            def append_guard_fail(guards):
1883*da0073e9SAndroid Build Coastguard Worker                guards_failed.extend(guards)
1884*da0073e9SAndroid Build Coastguard Worker
1885*da0073e9SAndroid Build Coastguard Worker            compiled = torch._dynamo.optimize(
1886*da0073e9SAndroid Build Coastguard Worker                nopython=True,
1887*da0073e9SAndroid Build Coastguard Worker                backend="aot_eager",
1888*da0073e9SAndroid Build Coastguard Worker                guard_export_fn=append_guard_export,
1889*da0073e9SAndroid Build Coastguard Worker                guard_fail_fn=append_guard_fail,
1890*da0073e9SAndroid Build Coastguard Worker            )(fn)
1891*da0073e9SAndroid Build Coastguard Worker            registry_after = list(_tensor_symint_registry.items())
1892*da0073e9SAndroid Build Coastguard Worker            count_after = torch.nested._internal.nested_tensor._tensor_id_counter
1893*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(registry_before, registry_after)
1894*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(count_before, count_after)
1895*da0073e9SAndroid Build Coastguard Worker
1896*da0073e9SAndroid Build Coastguard Worker            args = arg_fn()
1897*da0073e9SAndroid Build Coastguard Worker            compile_out = compiled(*args)
1898*da0073e9SAndroid Build Coastguard Worker            compile_grads = []
1899*da0073e9SAndroid Build Coastguard Worker            g_args = [arg for arg in args if arg.requires_grad]
1900*da0073e9SAndroid Build Coastguard Worker            if len(g_args) > 0:
1901*da0073e9SAndroid Build Coastguard Worker                compile_grad_outputs = _gen_grad_outputs(compile_out)
1902*da0073e9SAndroid Build Coastguard Worker                compile_grads = torch.autograd.grad(
1903*da0073e9SAndroid Build Coastguard Worker                    compile_out, inputs=g_args, grad_outputs=compile_grad_outputs
1904*da0073e9SAndroid Build Coastguard Worker                )
1905*da0073e9SAndroid Build Coastguard Worker
1906*da0073e9SAndroid Build Coastguard Worker        with self.branch_nested_state():
1907*da0073e9SAndroid Build Coastguard Worker            args = arg_fn()
1908*da0073e9SAndroid Build Coastguard Worker            ref_out = fn(*args)
1909*da0073e9SAndroid Build Coastguard Worker            ref_grads = []
1910*da0073e9SAndroid Build Coastguard Worker            g_args = [arg for arg in args if arg.requires_grad]
1911*da0073e9SAndroid Build Coastguard Worker            if len(g_args) > 0:
1912*da0073e9SAndroid Build Coastguard Worker                ref_grad_outputs = _gen_grad_outputs(ref_out)
1913*da0073e9SAndroid Build Coastguard Worker                ref_grads = torch.autograd.grad(
1914*da0073e9SAndroid Build Coastguard Worker                    ref_out, inputs=g_args, grad_outputs=ref_grad_outputs
1915*da0073e9SAndroid Build Coastguard Worker                )
1916*da0073e9SAndroid Build Coastguard Worker
1917*da0073e9SAndroid Build Coastguard Worker        # Validate correctness forward
1918*da0073e9SAndroid Build Coastguard Worker        if isinstance(compile_out, (list, tuple)):
1919*da0073e9SAndroid Build Coastguard Worker            # TODO: Fix assertEqual() to support NJTs so this isn't necessary
1920*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(compile_out), len(ref_out))
1921*da0073e9SAndroid Build Coastguard Worker            for c, r in zip(compile_out, ref_out):
1922*da0073e9SAndroid Build Coastguard Worker                self.assertEqualIgnoringNestedInts(c, r)
1923*da0073e9SAndroid Build Coastguard Worker        else:
1924*da0073e9SAndroid Build Coastguard Worker            self.assertEqualIgnoringNestedInts(compile_out, ref_out)
1925*da0073e9SAndroid Build Coastguard Worker
1926*da0073e9SAndroid Build Coastguard Worker        # Validate correctness backward
1927*da0073e9SAndroid Build Coastguard Worker        for compile_grad, ref_grad in zip(compile_grads, ref_grads):
1928*da0073e9SAndroid Build Coastguard Worker            self.assertEqualIgnoringNestedInts(compile_grad, ref_grad)
1929*da0073e9SAndroid Build Coastguard Worker
1930*da0073e9SAndroid Build Coastguard Worker        return guards_exported, guards_failed
1931*da0073e9SAndroid Build Coastguard Worker
1932*da0073e9SAndroid Build Coastguard Worker    # Note: [What kind of guards are involved in nested tensor compilation]
1933*da0073e9SAndroid Build Coastguard Worker    #
1934*da0073e9SAndroid Build Coastguard Worker    # Until we implement UnionFind, dynamic shapes guards are not involved.
1935*da0073e9SAndroid Build Coastguard Worker    # we rely only on dynamo's tensor aliasing guards.
1936*da0073e9SAndroid Build Coastguard Worker    #
1937*da0073e9SAndroid Build Coastguard Worker    # This is possible because dynamo able to generate tensor aliasing guards
1938*da0073e9SAndroid Build Coastguard Worker    # not only for the outer tensor, but also for the inner tensor.
1939*da0073e9SAndroid Build Coastguard Worker    #
1940*da0073e9SAndroid Build Coastguard Worker    # The case where dynamic shapes guards would eventually come into play is
1941*da0073e9SAndroid Build Coastguard Worker    # when my inputs are (1) two non-aliased tensors, but (2) declared as
1942*da0073e9SAndroid Build Coastguard Worker    # equal using a "trust me assert equal" API.
1943*da0073e9SAndroid Build Coastguard Worker
1944*da0073e9SAndroid Build Coastguard Worker    # Note: [Compiling nested tensor global state]
1945*da0073e9SAndroid Build Coastguard Worker    #
1946*da0073e9SAndroid Build Coastguard Worker    # Today there are two pieces of global eager state that NJTs deals with:
1947*da0073e9SAndroid Build Coastguard Worker    # - tensor_id_counter: a global counter that assigns unique ids to tensors
1948*da0073e9SAndroid Build Coastguard Worker    # - tensor_symint_registry: maps tensor to nested int
1949*da0073e9SAndroid Build Coastguard Worker    #   - this is used in eager only (we should get rid of this because it is
1950*da0073e9SAndroid Build Coastguard Worker    #     not necessary to cache nested int in eager)
1951*da0073e9SAndroid Build Coastguard Worker    #   - during tracing, we DO need to cache nested int, but we do so on
1952*da0073e9SAndroid Build Coastguard Worker    #     the FakeTensor.
1953*da0073e9SAndroid Build Coastguard Worker    #
1954*da0073e9SAndroid Build Coastguard Worker    # Ideally we would like to satisfy the following:
1955*da0073e9SAndroid Build Coastguard Worker    # - (1) The eager state is not mutated during tracing
1956*da0073e9SAndroid Build Coastguard Worker    # - (2) Running the compiled function should mutate the eager state in the
1957*da0073e9SAndroid Build Coastguard Worker    #       same way that running the eager function would
1958*da0073e9SAndroid Build Coastguard Worker    #       (a) The global counter should be incremented
1959*da0073e9SAndroid Build Coastguard Worker    #       (b) The registry is updated in the same way
1960*da0073e9SAndroid Build Coastguard Worker    #
1961*da0073e9SAndroid Build Coastguard Worker    # Today we can satisfy (1) and (2a) but cannot satisfy (2b)
1962*da0073e9SAndroid Build Coastguard Worker    #
1963*da0073e9SAndroid Build Coastguard Worker    # Today, (1) is satisfied because we maintain a separate counter during
1964*da0073e9SAndroid Build Coastguard Worker    # tracing, and cache nested int on FakeTensor instead of relying on
1965*da0073e9SAndroid Build Coastguard Worker    # tensor_symint_registry.
1966*da0073e9SAndroid Build Coastguard Worker    #
1967*da0073e9SAndroid Build Coastguard Worker    # (2) is cannot be completely satisfied because we trace away the
1968*da0073e9SAndroid Build Coastguard Worker    # side-effectful operations (which we can fix this by wrapping the
1969*da0073e9SAndroid Build Coastguard Worker    # side-effectful operations in a custom op, and threading through effect
1970*da0073e9SAndroid Build Coastguard Worker    # tokens.) The current plan is to do that in the UnionFind impl.
1971*da0073e9SAndroid Build Coastguard Worker    #
1972*da0073e9SAndroid Build Coastguard Worker    # Interestingly, despite this, the state is mutated in a way that is somewhat
1973*da0073e9SAndroid Build Coastguard Worker    # close to what we want, e.g. if I construct a nested tensor using an
1974*da0073e9SAndroid Build Coastguard Worker    # offsets in the compiled region and return it, AOTAutograd runtime wrapper
1975*da0073e9SAndroid Build Coastguard Worker    # must rewrap the inner->inner graph outputs back into subclass. This
1976*da0073e9SAndroid Build Coastguard Worker    # triggers the eager logic to run, updating the counter and registry.
1977*da0073e9SAndroid Build Coastguard Worker    #
1978*da0073e9SAndroid Build Coastguard Worker    # Notably however, compile differs in two ways from eager:
1979*da0073e9SAndroid Build Coastguard Worker    # (1) The order in which the offsets are assigned ids is differnet
1980*da0073e9SAndroid Build Coastguard Worker    #     the registry would be set in the order the offsets are returned
1981*da0073e9SAndroid Build Coastguard Worker    #     which is not necessarily the same order as they were constructed.
1982*da0073e9SAndroid Build Coastguard Worker    # (2) If a NestedTensor is not returned, then the AOTAutograd wrapping
1983*da0073e9SAndroid Build Coastguard Worker    #     logic will not be triggered.
1984*da0073e9SAndroid Build Coastguard Worker    #
1985*da0073e9SAndroid Build Coastguard Worker    # I claim that correctness is not affected by these differences today.
1986*da0073e9SAndroid Build Coastguard Worker    # e.g. there is never the case where two distinct offsets silently share
1987*da0073e9SAndroid Build Coastguard Worker    # the same id.
1988*da0073e9SAndroid Build Coastguard Worker    #
1989*da0073e9SAndroid Build Coastguard Worker    # (1) is clearly not a problem, and (2) should only be a problem if
1990*da0073e9SAndroid Build Coastguard Worker    # the nested int is returned on its own, without the corresponding NJT
1991*da0073e9SAndroid Build Coastguard Worker    # being returned. This is not a problem in the current implementation
1992*da0073e9SAndroid Build Coastguard Worker    # because returning only a shape is not supported!
1993*da0073e9SAndroid Build Coastguard Worker
1994*da0073e9SAndroid Build Coastguard Worker    # Note: [Creating symbolic nested int]
1995*da0073e9SAndroid Build Coastguard Worker    #
1996*da0073e9SAndroid Build Coastguard Worker    # We must create a symbolic nested int when we construct a nested tensor
1997*da0073e9SAndroid Build Coastguard Worker    # from a tensor. There are two main cases:
1998*da0073e9SAndroid Build Coastguard Worker    #
1999*da0073e9SAndroid Build Coastguard Worker    # 1. The offsets has NOT been used to construct a NJT
2000*da0073e9SAndroid Build Coastguard Worker    #    - Create a new plain nested int with current val of fake nt id counter
2001*da0073e9SAndroid Build Coastguard Worker    #    - Increment the fake nt id counter
2002*da0073e9SAndroid Build Coastguard Worker    #    - Create a new symint with plain nested int as hint
2003*da0073e9SAndroid Build Coastguard Worker    # 2. The offsets HAS been used to construct a NJT
2004*da0073e9SAndroid Build Coastguard Worker    #    - Create a new symint with plain nested int as hint
2005*da0073e9SAndroid Build Coastguard Worker    #
2006*da0073e9SAndroid Build Coastguard Worker    # More details on case 2:
2007*da0073e9SAndroid Build Coastguard Worker    # - During fakification of the offsets, we check the eager registry, and
2008*da0073e9SAndroid Build Coastguard Worker    #   if the tensor HAS been used to construct a NJT,
2009*da0073e9SAndroid Build Coastguard Worker    #   we create a symint, with the existing nested int as hint, and cache
2010*da0073e9SAndroid Build Coastguard Worker    #   it on to the FakeTensor.
2011*da0073e9SAndroid Build Coastguard Worker    #
2012*da0073e9SAndroid Build Coastguard Worker    # [ Always use ephemeral source ]
2013*da0073e9SAndroid Build Coastguard Worker    #
2014*da0073e9SAndroid Build Coastguard Worker    # We create the new symint ALWAYS with ephemeral source whether that is
2015*da0073e9SAndroid Build Coastguard Worker    # in case (1) or (2) even though we could've had a proper source for case (2).
2016*da0073e9SAndroid Build Coastguard Worker    # Using a proper source would enable a few more (edge) cases, but since
2017*da0073e9SAndroid Build Coastguard Worker    # we plan to handle things more holistically in the future anyway, we don't
2018*da0073e9SAndroid Build Coastguard Worker    # bother doing so today.
2019*da0073e9SAndroid Build Coastguard Worker    #
2020*da0073e9SAndroid Build Coastguard Worker    # Using an ephemeral source has some consequences. But we are happy if
2021*da0073e9SAndroid Build Coastguard Worker    # - We do not silently miss recompiles, e.g. we guard when necessary.
2022*da0073e9SAndroid Build Coastguard Worker    #   We know that this is true, because dynamo guards alone are already
2023*da0073e9SAndroid Build Coastguard Worker    #   sufficient.
2024*da0073e9SAndroid Build Coastguard Worker    # - We are not producing errors for the cases we care about
2025*da0073e9SAndroid Build Coastguard Worker    #
2026*da0073e9SAndroid Build Coastguard Worker    # The main case we care about is when we guard that two shapes are equal.
2027*da0073e9SAndroid Build Coastguard Worker    # In this case, the replacements logic would simplify away the ephemeral
2028*da0073e9SAndroid Build Coastguard Worker    # symbol, and there is no error produced.
2029*da0073e9SAndroid Build Coastguard Worker    # The unsupported case is when we guard that two shapes are not equal, in
2030*da0073e9SAndroid Build Coastguard Worker    # which, we will try and fail to generate a guard.
2031*da0073e9SAndroid Build Coastguard Worker
2032*da0073e9SAndroid Build Coastguard Worker    #
2033*da0073e9SAndroid Build Coastguard Worker    # Case 1: in-graph construction where the offsets are passed as inputs
2034*da0073e9SAndroid Build Coastguard Worker    #
2035*da0073e9SAndroid Build Coastguard Worker    def test_in_graph_construction_from_input(self):
2036*da0073e9SAndroid Build Coastguard Worker        # The offsets is passed as an input
2037*da0073e9SAndroid Build Coastguard Worker        def fn(values, offsets):
2038*da0073e9SAndroid Build Coastguard Worker            return torch.nested.nested_tensor_from_jagged(values * 2, offsets) * 2
2039*da0073e9SAndroid Build Coastguard Worker
2040*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(10, 5, requires_grad=True)
2041*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2042*da0073e9SAndroid Build Coastguard Worker        self._validate_compile(fn, arg_fn=lambda: (values, offsets))
2043*da0073e9SAndroid Build Coastguard Worker
2044*da0073e9SAndroid Build Coastguard Worker        # Do not specialize on the offsets
2045*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
2046*da0073e9SAndroid Build Coastguard Worker            different_offsets = torch.tensor([0, 1, 5, 10], dtype=torch.int64)
2047*da0073e9SAndroid Build Coastguard Worker            self._validate_compile(fn, arg_fn=lambda: (values, different_offsets))
2048*da0073e9SAndroid Build Coastguard Worker
2049*da0073e9SAndroid Build Coastguard Worker    def test_in_graph_construction_from_input_2(self):
2050*da0073e9SAndroid Build Coastguard Worker        # Construct two NJTs, both are passed as inputs
2051*da0073e9SAndroid Build Coastguard Worker        def fn(values, offsets1, offsets2):
2052*da0073e9SAndroid Build Coastguard Worker            nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets1)
2053*da0073e9SAndroid Build Coastguard Worker            nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2)
2054*da0073e9SAndroid Build Coastguard Worker            return nt2, nt1
2055*da0073e9SAndroid Build Coastguard Worker
2056*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(10, 5, requires_grad=True)
2057*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2058*da0073e9SAndroid Build Coastguard Worker        offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64)
2059*da0073e9SAndroid Build Coastguard Worker        # 1. Offsets are different
2060*da0073e9SAndroid Build Coastguard Worker        guards_exported, guards_failed = self._validate_compile(
2061*da0073e9SAndroid Build Coastguard Worker            fn, arg_fn=lambda: (values, offsets, offsets2)
2062*da0073e9SAndroid Build Coastguard Worker        )
2063*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(guards_failed), 0)
2064*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("L['offsets1'] is L['offsets2']", guards_exported)
2065*da0073e9SAndroid Build Coastguard Worker
2066*da0073e9SAndroid Build Coastguard Worker        # TODO
2067*da0073e9SAndroid Build Coastguard Worker        # 2. Offsets are the same
2068*da0073e9SAndroid Build Coastguard Worker        new_guards_exported, _ = self._validate_compile(
2069*da0073e9SAndroid Build Coastguard Worker            fn, arg_fn=lambda: (values, offsets, offsets)
2070*da0073e9SAndroid Build Coastguard Worker        )
2071*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(any("Duplicate tensors found" in g for g in guards_failed))
2072*da0073e9SAndroid Build Coastguard Worker        self.assertIn("L['offsets1'] is L['offsets2']", new_guards_exported)
2073*da0073e9SAndroid Build Coastguard Worker
2074*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
2075*da0073e9SAndroid Build Coastguard Worker            offsets3 = offsets.clone()
2076*da0073e9SAndroid Build Coastguard Worker            self._validate_compile(fn, arg_fn=lambda: (values, offsets3, offsets3))
2077*da0073e9SAndroid Build Coastguard Worker
2078*da0073e9SAndroid Build Coastguard Worker        # Do a binary op
2079*da0073e9SAndroid Build Coastguard Worker        def fn(values, offsets, offsets2):
2080*da0073e9SAndroid Build Coastguard Worker            nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets)
2081*da0073e9SAndroid Build Coastguard Worker            nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2)
2082*da0073e9SAndroid Build Coastguard Worker            return nt1 * nt2
2083*da0073e9SAndroid Build Coastguard Worker
2084*da0073e9SAndroid Build Coastguard Worker        self._validate_compile(fn, arg_fn=lambda: (values, offsets, offsets))
2085*da0073e9SAndroid Build Coastguard Worker
2086*da0073e9SAndroid Build Coastguard Worker    def test_in_graph_construction_from_input_4(self):
2087*da0073e9SAndroid Build Coastguard Worker        # The offsets is taken from an NJT input
2088*da0073e9SAndroid Build Coastguard Worker        def fn(nt, other_values):
2089*da0073e9SAndroid Build Coastguard Worker            nt2 = torch.nested.nested_tensor_from_jagged(other_values, nt.offsets())
2090*da0073e9SAndroid Build Coastguard Worker            return nt + nt2
2091*da0073e9SAndroid Build Coastguard Worker
2092*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(9, 5, requires_grad=True)
2093*da0073e9SAndroid Build Coastguard Worker        other_values = torch.randn(9, 5, requires_grad=True)
2094*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64)
2095*da0073e9SAndroid Build Coastguard Worker
2096*da0073e9SAndroid Build Coastguard Worker        def arg_fn(values=values, other_values=other_values, offsets=offsets):
2097*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
2098*da0073e9SAndroid Build Coastguard Worker            return nt, other_values
2099*da0073e9SAndroid Build Coastguard Worker
2100*da0073e9SAndroid Build Coastguard Worker        self._validate_compile(fn, arg_fn=arg_fn)
2101*da0073e9SAndroid Build Coastguard Worker
2102*da0073e9SAndroid Build Coastguard Worker        # Do not specialize on the offsets
2103*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
2104*da0073e9SAndroid Build Coastguard Worker            different_offsets = offsets.clone()
2105*da0073e9SAndroid Build Coastguard Worker
2106*da0073e9SAndroid Build Coastguard Worker            def arg_fn(
2107*da0073e9SAndroid Build Coastguard Worker                values=values, other_values=other_values, offsets=different_offsets
2108*da0073e9SAndroid Build Coastguard Worker            ):
2109*da0073e9SAndroid Build Coastguard Worker                nt = torch.nested.nested_tensor_from_jagged(values, different_offsets)
2110*da0073e9SAndroid Build Coastguard Worker                return nt, other_values
2111*da0073e9SAndroid Build Coastguard Worker
2112*da0073e9SAndroid Build Coastguard Worker            self._validate_compile(fn, arg_fn=arg_fn)
2113*da0073e9SAndroid Build Coastguard Worker
2114*da0073e9SAndroid Build Coastguard Worker    def test_in_graph_construction_from_input_5(self):
2115*da0073e9SAndroid Build Coastguard Worker        # Construct from lengths instead of offsets
2116*da0073e9SAndroid Build Coastguard Worker        def fn(values, lengths):
2117*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths)
2118*da0073e9SAndroid Build Coastguard Worker            return nt.sin()
2119*da0073e9SAndroid Build Coastguard Worker
2120*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(9, 5, requires_grad=True)
2121*da0073e9SAndroid Build Coastguard Worker        lengths = torch.tensor([2, 4, 3])
2122*da0073e9SAndroid Build Coastguard Worker        self._validate_compile(fn, arg_fn=lambda: (values, lengths))
2123*da0073e9SAndroid Build Coastguard Worker
2124*da0073e9SAndroid Build Coastguard Worker    #
2125*da0073e9SAndroid Build Coastguard Worker    # Case 2: in-graph construction where offsets are graph intermediates
2126*da0073e9SAndroid Build Coastguard Worker    #
2127*da0073e9SAndroid Build Coastguard Worker    def test_in_graph_construction_from_intermediate(self):
2128*da0073e9SAndroid Build Coastguard Worker        # offsets is an intermediate computed from lengths
2129*da0073e9SAndroid Build Coastguard Worker        def fn(values, lengths):
2130*da0073e9SAndroid Build Coastguard Worker            offsets = torch.cat([lengths.new_zeros(1), lengths.cumsum(0)])
2131*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
2132*da0073e9SAndroid Build Coastguard Worker            nt2 = torch.nested.nested_tensor_from_jagged(values, offsets)
2133*da0073e9SAndroid Build Coastguard Worker            return (nt * nt2).sin()
2134*da0073e9SAndroid Build Coastguard Worker
2135*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(9, 5, requires_grad=True)
2136*da0073e9SAndroid Build Coastguard Worker        lengths = torch.tensor([2, 4, 3])
2137*da0073e9SAndroid Build Coastguard Worker        self._validate_compile(fn, arg_fn=lambda: (values, lengths))
2138*da0073e9SAndroid Build Coastguard Worker
2139*da0073e9SAndroid Build Coastguard Worker        # Do not specialize on the lengths
2140*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
2141*da0073e9SAndroid Build Coastguard Worker            different_lengths = lengths.clone()
2142*da0073e9SAndroid Build Coastguard Worker            self._validate_compile(fn, arg_fn=lambda: (values, different_lengths))
2143*da0073e9SAndroid Build Coastguard Worker
2144*da0073e9SAndroid Build Coastguard Worker    def test_in_graph_construction_from_intermediate_2(self):
2145*da0073e9SAndroid Build Coastguard Worker        def fn(values, offsets):
2146*da0073e9SAndroid Build Coastguard Worker            return torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone())
2147*da0073e9SAndroid Build Coastguard Worker
2148*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(10, 5, requires_grad=True)
2149*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2150*da0073e9SAndroid Build Coastguard Worker        self._validate_compile(fn, arg_fn=lambda: (values, offsets))
2151*da0073e9SAndroid Build Coastguard Worker
2152*da0073e9SAndroid Build Coastguard Worker    def test_in_graph_construction_from_intermediate_3(self):
2153*da0073e9SAndroid Build Coastguard Worker        # Note that due to CSE, clone is not necessarily called twice!
2154*da0073e9SAndroid Build Coastguard Worker        def fn(values, offsets):
2155*da0073e9SAndroid Build Coastguard Worker            nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone())
2156*da0073e9SAndroid Build Coastguard Worker            nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets.clone())
2157*da0073e9SAndroid Build Coastguard Worker            return nt2, nt1
2158*da0073e9SAndroid Build Coastguard Worker
2159*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(10, 5, requires_grad=True)
2160*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2161*da0073e9SAndroid Build Coastguard Worker        self._validate_compile(fn, arg_fn=lambda: (values, offsets))
2162*da0073e9SAndroid Build Coastguard Worker
2163*da0073e9SAndroid Build Coastguard Worker    def test_in_graph_construction_from_intermediate_4(self):
2164*da0073e9SAndroid Build Coastguard Worker        # Shared intermediate (should be same as case #1)
2165*da0073e9SAndroid Build Coastguard Worker        def fn(values):
2166*da0073e9SAndroid Build Coastguard Worker            offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2167*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
2168*da0073e9SAndroid Build Coastguard Worker            values2 = torch.ones_like(values)
2169*da0073e9SAndroid Build Coastguard Worker            nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets)
2170*da0073e9SAndroid Build Coastguard Worker            return nt * nt2
2171*da0073e9SAndroid Build Coastguard Worker
2172*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(10, 5).requires_grad_(True)
2173*da0073e9SAndroid Build Coastguard Worker        self._validate_compile(fn, arg_fn=lambda: (values,))
2174*da0073e9SAndroid Build Coastguard Worker
2175*da0073e9SAndroid Build Coastguard Worker    # AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>',
2176*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
2177*da0073e9SAndroid Build Coastguard Worker    def test_in_graph_construction_from_intermediate_5(self):
2178*da0073e9SAndroid Build Coastguard Worker        # non-shared intermediate
2179*da0073e9SAndroid Build Coastguard Worker        def fn(values):
2180*da0073e9SAndroid Build Coastguard Worker            offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2181*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
2182*da0073e9SAndroid Build Coastguard Worker            values2 = torch.ones_like(values)
2183*da0073e9SAndroid Build Coastguard Worker            nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets.clone())
2184*da0073e9SAndroid Build Coastguard Worker            if nt2.shape[1] != nt.shape[1]:
2185*da0073e9SAndroid Build Coastguard Worker                return nt * 2
2186*da0073e9SAndroid Build Coastguard Worker            else:
2187*da0073e9SAndroid Build Coastguard Worker                return nt * 3
2188*da0073e9SAndroid Build Coastguard Worker
2189*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(10, 5).requires_grad_(True)
2190*da0073e9SAndroid Build Coastguard Worker        self._validate_compile(fn, arg_fn=lambda: (values,))
2191*da0073e9SAndroid Build Coastguard Worker
2192*da0073e9SAndroid Build Coastguard Worker    #
2193*da0073e9SAndroid Build Coastguard Worker    # Case 3: in-graph construction where offsets are both direct graph inputs
2194*da0073e9SAndroid Build Coastguard Worker    #         and passed in as part of an NJT's offsets.
2195*da0073e9SAndroid Build Coastguard Worker    #
2196*da0073e9SAndroid Build Coastguard Worker    def test_in_graph_construction_mixed(self):
2197*da0073e9SAndroid Build Coastguard Worker        def fn(nt, values, offsets):
2198*da0073e9SAndroid Build Coastguard Worker            nt2 = torch.nested.nested_tensor_from_jagged(values, offsets)
2199*da0073e9SAndroid Build Coastguard Worker            return nt * nt2
2200*da0073e9SAndroid Build Coastguard Worker
2201*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(10, 5, requires_grad=True)
2202*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2203*da0073e9SAndroid Build Coastguard Worker
2204*da0073e9SAndroid Build Coastguard Worker        def arg_fn(values=values, offsets=offsets):
2205*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
2206*da0073e9SAndroid Build Coastguard Worker            return nt, values, offsets
2207*da0073e9SAndroid Build Coastguard Worker
2208*da0073e9SAndroid Build Coastguard Worker        self._validate_compile(fn, arg_fn)
2209*da0073e9SAndroid Build Coastguard Worker
2210*da0073e9SAndroid Build Coastguard Worker    # See Note: [Creating symbolic nested int]
2211*da0073e9SAndroid Build Coastguard Worker    # AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>',
2212*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
2213*da0073e9SAndroid Build Coastguard Worker    def test_in_graph_construction_mixed_2(self):
2214*da0073e9SAndroid Build Coastguard Worker        def fn(nt, values, offsets, nt2):
2215*da0073e9SAndroid Build Coastguard Worker            # Intermediate offsets has ephemeral source
2216*da0073e9SAndroid Build Coastguard Worker            intermediate_nt = torch.nested.nested_tensor_from_jagged(
2217*da0073e9SAndroid Build Coastguard Worker                values, offsets.clone()
2218*da0073e9SAndroid Build Coastguard Worker            )
2219*da0073e9SAndroid Build Coastguard Worker            # This creates a dynamic shapes neq guard
2220*da0073e9SAndroid Build Coastguard Worker            if nt2.shape[1] != intermediate_nt.shape[1]:
2221*da0073e9SAndroid Build Coastguard Worker                # We should always go here.
2222*da0073e9SAndroid Build Coastguard Worker                nt = nt * 2
2223*da0073e9SAndroid Build Coastguard Worker            return nt
2224*da0073e9SAndroid Build Coastguard Worker
2225*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(10, 5, requires_grad=True)
2226*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
2227*da0073e9SAndroid Build Coastguard Worker        offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64)
2228*da0073e9SAndroid Build Coastguard Worker
2229*da0073e9SAndroid Build Coastguard Worker        def arg_fn(values=values, offsets=offsets, offsets2=offsets2):
2230*da0073e9SAndroid Build Coastguard Worker            # Values is shared, but it shouldn't matter
2231*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
2232*da0073e9SAndroid Build Coastguard Worker            nt2 = torch.nested.nested_tensor_from_jagged(values, offsets2)
2233*da0073e9SAndroid Build Coastguard Worker            return nt, values, offsets, nt2
2234*da0073e9SAndroid Build Coastguard Worker
2235*da0073e9SAndroid Build Coastguard Worker        self._validate_compile(fn, arg_fn)
2236*da0073e9SAndroid Build Coastguard Worker
2237*da0073e9SAndroid Build Coastguard Worker    def test_in_graph_construction_mixed_3(self):
2238*da0073e9SAndroid Build Coastguard Worker        # More involved mixed case
2239*da0073e9SAndroid Build Coastguard Worker        def fn(nt, values, offsets):
2240*da0073e9SAndroid Build Coastguard Worker            nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets)
2241*da0073e9SAndroid Build Coastguard Worker            nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets)
2242*da0073e9SAndroid Build Coastguard Worker            return nt1 + nt2 + nt
2243*da0073e9SAndroid Build Coastguard Worker
2244*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(9, 5, requires_grad=True)
2245*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64)
2246*da0073e9SAndroid Build Coastguard Worker
2247*da0073e9SAndroid Build Coastguard Worker        def arg_fn(values=values, offsets=offsets):
2248*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
2249*da0073e9SAndroid Build Coastguard Worker            return nt, values, offsets
2250*da0073e9SAndroid Build Coastguard Worker
2251*da0073e9SAndroid Build Coastguard Worker        self._validate_compile(fn, arg_fn)
2252*da0073e9SAndroid Build Coastguard Worker
2253*da0073e9SAndroid Build Coastguard Worker    def test_return_shape(self):
2254*da0073e9SAndroid Build Coastguard Worker        nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
2255*da0073e9SAndroid Build Coastguard Worker
2256*da0073e9SAndroid Build Coastguard Worker        def fn(nt):
2257*da0073e9SAndroid Build Coastguard Worker            return (nt * 2).shape
2258*da0073e9SAndroid Build Coastguard Worker
2259*da0073e9SAndroid Build Coastguard Worker        compiled = torch.compile(fn, fullgraph=True, backend="aot_eager")
2260*da0073e9SAndroid Build Coastguard Worker        compiled(nt)
2261*da0073e9SAndroid Build Coastguard Worker
2262*da0073e9SAndroid Build Coastguard Worker    def test_inference_tensor(self):
2263*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
2264*da0073e9SAndroid Build Coastguard Worker            nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
2265*da0073e9SAndroid Build Coastguard Worker
2266*da0073e9SAndroid Build Coastguard Worker        def fn(n):
2267*da0073e9SAndroid Build Coastguard Worker            return n * 2
2268*da0073e9SAndroid Build Coastguard Worker
2269*da0073e9SAndroid Build Coastguard Worker        torch.compile(fn, backend="eager")(nt)
2270*da0073e9SAndroid Build Coastguard Worker
2271*da0073e9SAndroid Build Coastguard Worker    # TODO: cannot parametrize this test class with device for some reason
2272*da0073e9SAndroid Build Coastguard Worker    def _test_autograd(self, backend):
2273*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
2274*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
2275*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64)
2276*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
2277*da0073e9SAndroid Build Coastguard Worker        # TODO: Switch to public API when it exists
2278*da0073e9SAndroid Build Coastguard Worker        nt2, _ = jagged_from_list([a, b, c], nt.offsets())
2279*da0073e9SAndroid Build Coastguard Worker
2280*da0073e9SAndroid Build Coastguard Worker        def fn1(nt1, nt2):
2281*da0073e9SAndroid Build Coastguard Worker            return (nt1 + nt2).sin().cos()
2282*da0073e9SAndroid Build Coastguard Worker
2283*da0073e9SAndroid Build Coastguard Worker        compiled_f = torch.compile(fn1, fullgraph=True, backend=backend, dynamic=True)
2284*da0073e9SAndroid Build Coastguard Worker        out = compiled_f(nt, nt2)
2285*da0073e9SAndroid Build Coastguard Worker        out_buffer = out.values()
2286*da0073e9SAndroid Build Coastguard Worker        ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c))
2287*da0073e9SAndroid Build Coastguard Worker
2288*da0073e9SAndroid Build Coastguard Worker        out_ref = fn1(nt, nt2)
2289*da0073e9SAndroid Build Coastguard Worker        out_buffer_ref = out_ref.values()
2290*da0073e9SAndroid Build Coastguard Worker        ga_ref, gb_ref, gc_ref = torch.autograd.grad(out_buffer_ref.sum(), (a, b, c))
2291*da0073e9SAndroid Build Coastguard Worker
2292*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ga, ga_ref))
2293*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(gb, gb_ref))
2294*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(gc, gc_ref))
2295*da0073e9SAndroid Build Coastguard Worker
2296*da0073e9SAndroid Build Coastguard Worker    def test_basic_autograd(self):
2297*da0073e9SAndroid Build Coastguard Worker        self._test_autograd("aot_eager")
2298*da0073e9SAndroid Build Coastguard Worker
2299*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
2300*da0073e9SAndroid Build Coastguard Worker    def test_basic_autograd_inductor(self):
2301*da0073e9SAndroid Build Coastguard Worker        self._test_autograd("inductor")
2302*da0073e9SAndroid Build Coastguard Worker
2303*da0073e9SAndroid Build Coastguard Worker    def test_subclass_with_mutation_in_graph(self):
2304*da0073e9SAndroid Build Coastguard Worker        # In this graph, we have an in-graph mutation, i.e. a mutation that is allowed
2305*da0073e9SAndroid Build Coastguard Worker        # to remain in the graph. Normally this is allowed, but it's not allowed if
2306*da0073e9SAndroid Build Coastguard Worker        # the graph handles subclasses at all.
2307*da0073e9SAndroid Build Coastguard Worker        # Whether the mutation is allowed or not allowed in the graph alters the number
2308*da0073e9SAndroid Build Coastguard Worker        # of outputs from the forward graph. Previously, a bug in this handling meant
2309*da0073e9SAndroid Build Coastguard Worker        # that sometimes the expected number and actual number of outputs from the
2310*da0073e9SAndroid Build Coastguard Worker        # joint graph did not match, causing assertion failures.
2311*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
2312*da0073e9SAndroid Build Coastguard Worker            z = x.sin()
2313*da0073e9SAndroid Build Coastguard Worker            y.sin_()
2314*da0073e9SAndroid Build Coastguard Worker            return z.cos(), y.cos()
2315*da0073e9SAndroid Build Coastguard Worker
2316*da0073e9SAndroid Build Coastguard Worker        fn_c = torch.compile(fn, backend="inductor")
2317*da0073e9SAndroid Build Coastguard Worker
2318*da0073e9SAndroid Build Coastguard Worker        values = [torch.rand((i, 8), requires_grad=True) for i in range(1, 6)]
2319*da0073e9SAndroid Build Coastguard Worker        values_copy = [x.detach().clone().requires_grad_(True) for x in values]
2320*da0073e9SAndroid Build Coastguard Worker
2321*da0073e9SAndroid Build Coastguard Worker        nt, offsets = jagged_from_list(values, None)
2322*da0073e9SAndroid Build Coastguard Worker        nt_copy, offsets = jagged_from_list(values_copy, offsets)
2323*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((4, 8))
2324*da0073e9SAndroid Build Coastguard Worker        y_copy = y.clone()
2325*da0073e9SAndroid Build Coastguard Worker
2326*da0073e9SAndroid Build Coastguard Worker        ret = fn_c(nt, y)[0]
2327*da0073e9SAndroid Build Coastguard Worker        ref = fn(nt_copy, y_copy)[0]
2328*da0073e9SAndroid Build Coastguard Worker
2329*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ret.values(), ref.values())
2330*da0073e9SAndroid Build Coastguard Worker
2331*da0073e9SAndroid Build Coastguard Worker        ret.values().sum().backward()
2332*da0073e9SAndroid Build Coastguard Worker        ref.values().sum().backward()
2333*da0073e9SAndroid Build Coastguard Worker        for ref_v, res_v in zip(values_copy, values):
2334*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref_v.grad, res_v.grad)
2335*da0073e9SAndroid Build Coastguard Worker
2336*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch({"capture_scalar_outputs": True})
2337*da0073e9SAndroid Build Coastguard Worker    def test_unbind(self):
2338*da0073e9SAndroid Build Coastguard Worker        # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
2339*da0073e9SAndroid Build Coastguard Worker        # This causes a recompile later on when it realizes the batch and last dim
2340*da0073e9SAndroid Build Coastguard Worker        # should not always be equal. To avoid that, we use (3, j0, 5) here.
2341*da0073e9SAndroid Build Coastguard Worker        nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
2342*da0073e9SAndroid Build Coastguard Worker        nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None)
2343*da0073e9SAndroid Build Coastguard Worker        nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None)
2344*da0073e9SAndroid Build Coastguard Worker
2345*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2346*da0073e9SAndroid Build Coastguard Worker            return x.unbind()
2347*da0073e9SAndroid Build Coastguard Worker
2348*da0073e9SAndroid Build Coastguard Worker        compiled_f = torch.compile(fn, fullgraph=True, backend="eager", dynamic=True)
2349*da0073e9SAndroid Build Coastguard Worker        out = compiled_f(nt)
2350*da0073e9SAndroid Build Coastguard Worker
2351*da0073e9SAndroid Build Coastguard Worker        out_ref = fn(nt)
2352*da0073e9SAndroid Build Coastguard Worker
2353*da0073e9SAndroid Build Coastguard Worker        # correctness
2354*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(out), len(out_ref))
2355*da0073e9SAndroid Build Coastguard Worker        for x, x_ref in zip(out, out_ref):
2356*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(x, x_ref))
2357*da0073e9SAndroid Build Coastguard Worker
2358*da0073e9SAndroid Build Coastguard Worker        # We specialize on the length of offsets, e.g. (1) we recompile if the
2359*da0073e9SAndroid Build Coastguard Worker        # length of the offsets is different. (2) we don't recompile if the
2360*da0073e9SAndroid Build Coastguard Worker        # length of the offsets is the same, even if the size of the constituent
2361*da0073e9SAndroid Build Coastguard Worker        # tensors are different.
2362*da0073e9SAndroid Build Coastguard Worker        self._check_recompiles(fn, (nt,), (nt2,), False)
2363*da0073e9SAndroid Build Coastguard Worker        self._check_recompiles(fn, (nt,), (nt3,), True)
2364*da0073e9SAndroid Build Coastguard Worker
2365*da0073e9SAndroid Build Coastguard Worker    def test_inline_nested_tensor_from_jagged(self):
2366*da0073e9SAndroid Build Coastguard Worker        nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
2367*da0073e9SAndroid Build Coastguard Worker
2368*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2369*da0073e9SAndroid Build Coastguard Worker            return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets())
2370*da0073e9SAndroid Build Coastguard Worker
2371*da0073e9SAndroid Build Coastguard Worker        torch.compile(fn, fullgraph=True, backend="aot_eager")(nt)
2372*da0073e9SAndroid Build Coastguard Worker
2373*da0073e9SAndroid Build Coastguard Worker    # The test here: nn.Parameters that are secretly subclasses
2374*da0073e9SAndroid Build Coastguard Worker    # have a metaclass that overrides __isinstance__,
2375*da0073e9SAndroid Build Coastguard Worker    # that dynamo needs to respect when it inlines the if statement.
2376*da0073e9SAndroid Build Coastguard Worker    def test_param_subclass_isinstance_input(self):
2377*da0073e9SAndroid Build Coastguard Worker        x_inner = torch.randn(16, 16, requires_grad=True)
2378*da0073e9SAndroid Build Coastguard Worker        x = torch.nn.Parameter(TwoTensor(x_inner, x_inner))
2379*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.Linear(16, 16)
2380*da0073e9SAndroid Build Coastguard Worker        m.weight = x
2381*da0073e9SAndroid Build Coastguard Worker
2382*da0073e9SAndroid Build Coastguard Worker        def fn():
2383*da0073e9SAndroid Build Coastguard Worker            if isinstance(m.weight, torch.nn.Parameter):
2384*da0073e9SAndroid Build Coastguard Worker                return m.weight + 1
2385*da0073e9SAndroid Build Coastguard Worker            else:
2386*da0073e9SAndroid Build Coastguard Worker                return m.weight + 2
2387*da0073e9SAndroid Build Coastguard Worker
2388*da0073e9SAndroid Build Coastguard Worker        out_ref = fn()
2389*da0073e9SAndroid Build Coastguard Worker        out_test = torch.compile(fn, backend="aot_eager")()
2390*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
2391*da0073e9SAndroid Build Coastguard Worker
2392*da0073e9SAndroid Build Coastguard Worker    def _input_view_test(self, nt_view_name):
2393*da0073e9SAndroid Build Coastguard Worker        nt_view = VIEW_TEST_CASES[nt_view_name]()
2394*da0073e9SAndroid Build Coastguard Worker
2395*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2396*da0073e9SAndroid Build Coastguard Worker            return x.sin()
2397*da0073e9SAndroid Build Coastguard Worker
2398*da0073e9SAndroid Build Coastguard Worker        out_ref = fn(nt_view)
2399*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
2400*da0073e9SAndroid Build Coastguard Worker        compile_fn = torch.compile(
2401*da0073e9SAndroid Build Coastguard Worker            fn, fullgraph=True, backend="aot_eager", dynamic=True
2402*da0073e9SAndroid Build Coastguard Worker        )
2403*da0073e9SAndroid Build Coastguard Worker        out = compile_fn(nt_view)
2404*da0073e9SAndroid Build Coastguard Worker
2405*da0073e9SAndroid Build Coastguard Worker        # Check metadata and values are correct
2406*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out.size() == out_ref.size())
2407*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out.stride() == out_ref.stride())
2408*da0073e9SAndroid Build Coastguard Worker        if out.is_nested:
2409*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(out.values(), out_ref.values()))
2410*da0073e9SAndroid Build Coastguard Worker        else:
2411*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(out, out_ref))
2412*da0073e9SAndroid Build Coastguard Worker
2413*da0073e9SAndroid Build Coastguard Worker        # Check that no upper/lower bound guards are incurred
2414*da0073e9SAndroid Build Coastguard Worker        def backend(gm, args):
2415*da0073e9SAndroid Build Coastguard Worker            context = torch._guards.TracingContext.get()
2416*da0073e9SAndroid Build Coastguard Worker            guards = [str(g.expr) for g in context.fake_mode.shape_env.guards]
2417*da0073e9SAndroid Build Coastguard Worker
2418*da0073e9SAndroid Build Coastguard Worker            # varies based on the type of view
2419*da0073e9SAndroid Build Coastguard Worker            guard_str = "\n".join(guards)
2420*da0073e9SAndroid Build Coastguard Worker            if nt_view_name == "subclass_dense":
2421*da0073e9SAndroid Build Coastguard Worker                self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""")
2422*da0073e9SAndroid Build Coastguard Worker            elif nt_view_name == "dense_subclass_dense_subclass":
2423*da0073e9SAndroid Build Coastguard Worker                self.assertExpectedInline(
2424*da0073e9SAndroid Build Coastguard Worker                    guard_str,
2425*da0073e9SAndroid Build Coastguard Worker                    """\
2426*da0073e9SAndroid Build Coastguard WorkerEq(s5 - 1, s2)
2427*da0073e9SAndroid Build Coastguard WorkerEq(s12 - 1, s7)
2428*da0073e9SAndroid Build Coastguard WorkerEq(s11, s9)""",
2429*da0073e9SAndroid Build Coastguard Worker                )
2430*da0073e9SAndroid Build Coastguard Worker            elif nt_view_name.startswith("base_is_nt_True"):
2431*da0073e9SAndroid Build Coastguard Worker                self.assertExpectedInline(
2432*da0073e9SAndroid Build Coastguard Worker                    guard_str,
2433*da0073e9SAndroid Build Coastguard Worker                    """Eq(s3 - 1, s0)""",
2434*da0073e9SAndroid Build Coastguard Worker                )
2435*da0073e9SAndroid Build Coastguard Worker            else:
2436*da0073e9SAndroid Build Coastguard Worker                self.assertExpectedInline(
2437*da0073e9SAndroid Build Coastguard Worker                    guard_str,
2438*da0073e9SAndroid Build Coastguard Worker                    """\
2439*da0073e9SAndroid Build Coastguard WorkerEq(s4 - 1, s1)
2440*da0073e9SAndroid Build Coastguard WorkerEq(s13 - 1, s8)
2441*da0073e9SAndroid Build Coastguard WorkerEq(s12, s10)""",
2442*da0073e9SAndroid Build Coastguard Worker                )
2443*da0073e9SAndroid Build Coastguard Worker            return gm
2444*da0073e9SAndroid Build Coastguard Worker
2445*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
2446*da0073e9SAndroid Build Coastguard Worker        compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True)
2447*da0073e9SAndroid Build Coastguard Worker        out = compile_fn(nt_view)
2448*da0073e9SAndroid Build Coastguard Worker
2449*da0073e9SAndroid Build Coastguard Worker    @parametrize(
2450*da0073e9SAndroid Build Coastguard Worker        "nt_view_name",
2451*da0073e9SAndroid Build Coastguard Worker        [k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"],
2452*da0073e9SAndroid Build Coastguard Worker    )
2453*da0073e9SAndroid Build Coastguard Worker    def test_inputs_to_compiled_fn_are_views(self, nt_view_name):
2454*da0073e9SAndroid Build Coastguard Worker        self._input_view_test(nt_view_name)
2455*da0073e9SAndroid Build Coastguard Worker
2456*da0073e9SAndroid Build Coastguard Worker    def test_subclass_gives_static_shapes_when_dynamic_false(self):
2457*da0073e9SAndroid Build Coastguard Worker        def check_graph(gm, *args):
2458*da0073e9SAndroid Build Coastguard Worker            first_node_example_val = next(iter(gm.graph.nodes)).meta["example_value"]
2459*da0073e9SAndroid Build Coastguard Worker            # We compiled with dynamic=False, expect no SymInt sizes on our placeholders
2460*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
2461*da0073e9SAndroid Build Coastguard Worker                all(isinstance(x, int) for x in first_node_example_val.shape)
2462*da0073e9SAndroid Build Coastguard Worker            )
2463*da0073e9SAndroid Build Coastguard Worker            return gm
2464*da0073e9SAndroid Build Coastguard Worker
2465*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=check_graph, dynamic=False)
2466*da0073e9SAndroid Build Coastguard Worker        def f(x):
2467*da0073e9SAndroid Build Coastguard Worker            return x + 1
2468*da0073e9SAndroid Build Coastguard Worker
2469*da0073e9SAndroid Build Coastguard Worker        x_inner = torch.ones(4)
2470*da0073e9SAndroid Build Coastguard Worker        x = TwoTensor(x_inner, x_inner)
2471*da0073e9SAndroid Build Coastguard Worker        x_view = x.view(2, 2)
2472*da0073e9SAndroid Build Coastguard Worker        out = f(x_view)
2473*da0073e9SAndroid Build Coastguard Worker
2474*da0073e9SAndroid Build Coastguard Worker    # NJT1 -> Dense -> NJT2 -> Dense view
2475*da0073e9SAndroid Build Coastguard Worker    # During view replay, the Dense -> NJT2 part will construct an intermediate,
2476*da0073e9SAndroid Build Coastguard Worker    # symbolically-sized NJT that is immediately deconstructed to return the final dense
2477*da0073e9SAndroid Build Coastguard Worker    # view. To construct this intermediate properly, we need the associated nested int
2478*da0073e9SAndroid Build Coastguard Worker    # to be symbolic. This view is expected to fail compilation until symbolic nested ints
2479*da0073e9SAndroid Build Coastguard Worker    # are cached onto fake offsets to solve this problem.
2480*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
2481*da0073e9SAndroid Build Coastguard Worker    def test_subclass_dense_subclass_dense_view(self):
2482*da0073e9SAndroid Build Coastguard Worker        self._input_view_test("subclass_dense_subclass_dense")
2483*da0073e9SAndroid Build Coastguard Worker
2484*da0073e9SAndroid Build Coastguard Worker
2485*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestNestedTensor)
2486*da0073e9SAndroid Build Coastguard Worker
2487*da0073e9SAndroid Build Coastguard Worker
2488*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
2489*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
2490*da0073e9SAndroid Build Coastguard Worker
2491*da0073e9SAndroid Build Coastguard Worker    run_tests()
2492