# Owner(s): ["module: dynamo"] import functools import itertools import unittest from functools import partial import torch import torch._dynamo.test_case import torch._dynamo.testing import torch._functorch.config import torch.utils._pytree as pytree import torch.utils.checkpoint from torch._dynamo.testing import normalize_gm from torch._higher_order_ops.wrap import wrap from torch.fx.experimental.symbolic_shapes import ( DimDynamic, ShapeEnv, StatelessSymbolicContext, ) from torch.nested._internal.nested_tensor import ( jagged_from_list, jagged_from_tensor_and_lengths, nested_view_from_values_offsets, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, NestedTensorTestCase, parametrize, subtest, ) from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.two_tensor import TwoTensor from torch.utils._python_dispatch import return_and_correct_aliasing def traceable_subclass(c): return torch._dynamo.config.patch("traceable_tensor_subclasses", {c}) def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2) self.assertEqual(actual_recompiles, expected_recompiles) def get_jagged_tensor(nested_size, offsets, requires_grad=True): # Makes a jagged tensor with N constituent tensors with size # as specified ((S0, S1, S2), D) D = nested_size[1] out = [] for s in nested_size[0]: out.append(torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64)) return jagged_from_list(out, offsets) def get_view_test_cases(): # Test all cases with both an NT base and a dense base # Subclass -> Subclass # Dense -> Subclass # NB: Don't close over loop variables, they will not get copied into the # closure # # NB: These return functions so we don't generate tensors during test # collection time def mk_basic(base_is_nt): # There are three cases to consider here based on the logic in # meta_utils.py # # (1) basic case: # view is not a leaf and has the same requires grad as its basic case x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True) x = x.clone() if base_is_nt else x assert not x.is_leaf return x.unsqueeze(-1) def mk_leaf(base_is_nt, requires_grad_1, requires_grad_2): x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=requires_grad_1) x = x.clone() if base_is_nt else x with torch.no_grad(): x_view = x.unsqueeze(-1) # The issue is this doesn't quite work x_view.requires_grad_(requires_grad_2) return x_view def mk_obscure(base_is_nt): x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False) x = x.clone() if base_is_nt else x # intermediate leaf view with torch.no_grad(): x_view = x.unsqueeze(-1) x_view.requires_grad_(True) x_view_view = x_view.unsqueeze(-1) return x_view_view for base_is_nt in [False, True]: prefix = f"base_is_nt_{base_is_nt}" yield partial(mk_basic, base_is_nt), f"{prefix}_basic" # (2) leaf view case: # the view has to be a leaf (w/ requires_grad True or requires_grad False) # base w/ requires_grad True or requires_grad False for requires_grad_1, requires_grad_2 in itertools.product( [True, False], repeat=2 ): yield partial( mk_leaf, base_is_nt, requires_grad_1, requires_grad_2 ), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}" # (3) obscure case: # view is not a leaf (implies requires_grad True) # base w/ requires_grad False) yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure" # Subclass -> Dense yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[ 0 ].clone(), "subclass_dense" # Dense -> Subclass -> Dense -> Subclass def mk_dense_subclass_dense_subclass(): values = torch.randn(10, 5) offsets = torch.tensor([0, 3, 6, 10]) offsets2 = offsets.clone().detach() return nested_view_from_values_offsets( nested_view_from_values_offsets(values, offsets).values(), offsets ) yield mk_dense_subclass_dense_subclass, "dense_subclass_dense_subclass" def mk_subclass_dense_subclass_dense(): x = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone() offsets2 = x.offsets().clone().detach() nt_view = nested_view_from_values_offsets(x.values(), offsets2).values() yield mk_subclass_dense_subclass_dense, "subclass_dense_subclass_dense" VIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()} requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") compile_full_eager = torch.compile(backend="eager", fullgraph=True) class BaseTorchFunction(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) class MockSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return func(*args, **kwargs) class AttrSubclass(torch.Tensor): x: int = 10 size: int = 10 @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return func(*args, **kwargs) class DummyNDim(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if func == torch.Tensor.ndim.__get__: return 10 return super().__torch_function__(func, types, args, kwargs) class WrapperSubclass: def __init__(self, tensor): self.tensor = tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} args = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, args) kwargs = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, kwargs) return func(*args, **kwargs) class SigmoidToExpSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if func == torch.Tensor.sigmoid: return super().__torch_function__(torch.Tensor.exp, types, args, kwargs) return super().__torch_function__(func, types, args, kwargs) # Wrapper subclass with two inner tensors: data and scale # data has same shape as outer, and scale has single dim size class ScaledTensor(torch.Tensor): def __new__( cls, data: torch.Tensor, scale: torch.Tensor, *, constant: int = 0, ): return torch.Tensor._make_wrapper_subclass( cls, data.size(), strides=data.stride(), storage_offset=data.storage_offset(), dtype=data.dtype, layout=data.layout, requires_grad=data.requires_grad, device=data.device, ) def __init__(self, data: torch.Tensor, scale: torch.Tensor, constant: int = 0): self._data = data self._scale = scale self._constant = constant def __tensor_flatten__(self): ctx = {"_constant": self._constant} return ["_data", "_scale"], ctx @staticmethod def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): assert len(inner_tensors) == 2 return ScaledTensor( inner_tensors["_data"], inner_tensors["_scale"], constant=metadata["_constant"], ) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): scaled_tensor = args[0] out = func(scaled_tensor._data, *args[1:], **kwargs) return ScaledTensor(out, scaled_tensor._scale, constant=scaled_tensor._constant) def __repr__(self): return f"{self._data.__repr__()}\n{self._scale.__repr__()}" class OptionalScaledTensor(torch.Tensor): def __new__( cls, data, scale, *, constant: int = 0, ): return torch.Tensor._make_wrapper_subclass( cls, data.size(), strides=data.stride(), storage_offset=data.storage_offset(), dtype=data.dtype, layout=data.layout, requires_grad=data.requires_grad, device=data.device, ) def __init__(self, data: torch.Tensor, scale, constant: int = 0): self._data = data self._scale = scale self._constant = constant def __tensor_flatten__(self): ctx = {"_constant": self._constant} if self._scale is not None: return ["_data", "_scale"], ctx else: return ["_data"], ctx @staticmethod def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): return OptionalScaledTensor( inner_tensors["_data"], inner_tensors["_scale"] if "_scale" in inner_tensors else None, constant=metadata["_constant"], ) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): scaled_tensor = args[0] out = func(scaled_tensor._data, *args[1:], **kwargs) if scaled_tensor._scale is not None: out = out * scaled_tensor._scale return OptionalScaledTensor( out, scaled_tensor._scale, constant=scaled_tensor._constant ) def __repr__(self): return ( f"OptionalScaledTensor({self._data.__repr__()}\n{self._scale.__repr__()})" ) class CtxSubclassTensor(torch.Tensor): """ Class used to verify guarding on the subclass metadata """ @staticmethod def __new__(cls, a, constant): shape = a.shape kwargs = {} kwargs["strides"] = a.stride() kwargs["storage_offset"] = a.storage_offset() kwargs["device"] = a.device kwargs["layout"] = a.layout kwargs["requires_grad"] = a.requires_grad kwargs["dtype"] = a.dtype out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) return out def __init__(self, a, constant): self.a = a self.constant = constant def __repr__(self): a_repr = repr(self.a) return f"CtxSubclassTensor({a_repr})" def __tensor_flatten__(self): return ["a"], (self.constant,) @staticmethod def __tensor_unflatten__(inner_tensors, meta, sizes, strides): constant = meta[0] a = inner_tensors["a"] return CtxSubclassTensor(a, constant) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): from torch.utils._python_dispatch import return_and_correct_aliasing if kwargs is None: kwargs = {} biggest_constant = max( [ x.constant for x in pytree.tree_flatten(args)[0] if isinstance(x, CtxSubclassTensor) ] ) args_a = pytree.tree_map( lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, args ) kwargs_a = pytree.tree_map( lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, kwargs ) out_a = func(*args_a, **kwargs_a) out = pytree.tree_map( lambda x: CtxSubclassTensor(x, biggest_constant) if isinstance(x, torch.Tensor) else x, out_a, ) if func == torch.ops.aten.mul.Tensor: out = out + out.constant return return_and_correct_aliasing(func, args, kwargs, out) def func(a): return a.sin() class EagerRecordGraphAndInputs: def __init__(self) -> None: self.graphs = [] self.example_inputs = [] def __call__(self, gm: torch.fx.GraphModule, example_inputs): self.graphs.append(gm) self.example_inputs.append(example_inputs) return gm GLOBAL_TEST_SUBCLASSES = { MockSubclass, DummyNDim, SigmoidToExpSubclass, BaseTorchFunction, } # Returns True if the function recompiles between inputs1 and inputs2 with the # specified dynamic setting. def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True): compile_count = [0] def counter(gm, example_inputs): compile_count[0] += 1 return gm compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic) compiled_f(*inputs1) compiled_f(*inputs2) return compile_count[0] > 1 class SubclassTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): super().setUpClass() cls._exit_stack.enter_context( torch._dynamo.config.patch( "traceable_tensor_subclasses", GLOBAL_TEST_SUBCLASSES ) ) @classmethod def tearDownClass(cls): cls._exit_stack.close() def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles) def test_no_call_to_new(self): class BadNewTorchFunction(torch.Tensor): def __new__(cls, *args, **kwargs): raise RuntimeError("Oops!") @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) with torch._dynamo.config.patch( "traceable_tensor_subclasses", {BadNewTorchFunction} ): @torch.compile(backend="eager", fullgraph=True) def fn(x): return torch.add(x, 1) input = torch.ones(2, 2).as_subclass(BadNewTorchFunction) res = fn(input) self.assertIsInstance(res, BadNewTorchFunction) def test_no_torch_function_recompiles(self): class NJT: def __repr__(self): return f"NJT(shape={self.shape})" def __init__(self, values, offsets): self._values = values self._offsets = offsets def sin(self): return torch.sin(self) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if func == torch.sin: self = args[0] return NJT(func(self._values), self._offsets) raise AssertionError("should not get here") values1 = torch.randn(10, 3, 4, requires_grad=True) values2 = torch.randn(10, 3, 4, requires_grad=True) offsets = torch.tensor([0, 3, 10]) njt1 = NJT(values1, offsets) njt2 = NJT(values2, offsets) @torch.compile(backend="eager", fullgraph=True) def f(x): return torch.sin(x) with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): f(njt1) f(njt2) def test_base_torch_function_tracing(self): def fn(x): return torch.add(x, 1) input = torch.ones(2, 2).as_subclass(BaseTorchFunction) out = fn(input) out_opt = compile_full_eager(fn)(input) self.assertIsInstance(out, BaseTorchFunction) self.assertEqual(out, out_opt) def test_torch_function_state_graph_break(self): @torch.compile(backend="eager") def fn(x): with torch._C.DisableTorchFunctionSubclass(): torch._dynamo.graph_break() return torch._C._is_torch_function_enabled(), torch.add(x, 1.0) input = torch.ones(2, 2) res, _ = fn(input) self.assertFalse(res) def test_torch_function_state_nested(self): @torch.compile(backend="eager") def fn(x): with torch._C.DisableTorchFunctionSubclass(): with torch._C.DisableTorchFunctionSubclass(): x = x + 1 # Should reset to the outer state (disabled) after exiting ctx manager return torch._C._is_torch_function_enabled(), torch.add(x, 1.0) input = torch.ones(2, 2) res, _ = fn(input) self.assertFalse(res) def test_torch_function_state_tracing(self): @torch.compile(backend="eager", fullgraph=True) def fn(x): with torch._C.DisableTorchFunctionSubclass(): torch.add(x, 1.0) input = torch.ones(2, 2) res = fn(input) def test_torch_function_state_guards(self): cnt = torch._dynamo.testing.CompileCounter() @torch.compile(backend=cnt, fullgraph=True) def fn(x): torch.add(x, 1.0) input = torch.ones(2, 2) with torch._C.DisableTorchFunctionSubclass(): res = fn(input) res = fn(input) self.assertEqual(cnt.frame_count, 2) def test_return_subclass(self): @torch.compile(backend="eager", fullgraph=True) def fn(x): return MockSubclass(torch.add(x, 1.0)) input = torch.ones(2, 2) res = fn(input) self.assertIsInstance(res, MockSubclass) def test_return_as_subclass(self): @torch.compile(backend="eager", fullgraph=True) def fn(x): return torch.add(x, 1.0).as_subclass(MockSubclass) input = torch.ones(2, 2) res = fn(input) self.assertIsInstance(res, MockSubclass) def test_return_local_subclass(self): class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return func(*args, **kwargs) with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): @torch.compile(backend="eager", fullgraph=True) def fn(x): return LocalSubclass(torch.add(x, 1.0)) input = torch.ones(2, 2) res = fn(input) self.assertIsInstance(res, LocalSubclass) def test_torch_function_list_args(self): HANDLED_FUNCTIONS = {} class MyClass: def __init__(self, foo): self.foo = foo @classmethod def __torch_function__( cls, func, types, args=(), kwargs=None, ): if kwargs is None: kwargs = {} if func not in HANDLED_FUNCTIONS or not all( # noqa: C419 [ # noqa: C419 issubclass(t, (torch.Tensor, MyClass)) for t in types ] ): return NotImplemented return HANDLED_FUNCTIONS[func](*args, **kwargs) def _stack(input, dim=0, *, out=None): return MyClass(sum([x.foo for x in input])) HANDLED_FUNCTIONS[torch.stack] = _stack @torch.compile(backend="eager", fullgraph=True) def fn(v0, v1): return torch.stack([v0, v1]) ret = fn(MyClass(1), MyClass(1)) self.assertEqual(ret.foo, 2) @parametrize( "comparison", [ subtest(isinstance, "isinstance"), subtest(lambda instance, type_: type(instance) == type_, "equality"), subtest(lambda instance, type_: type(instance) is type_, "identity"), ], ) @parametrize( "input_type", [ subtest(torch.Tensor, "tensor"), subtest(DummyNDim, "subclass"), ], ) def test_type_check(self, comparison, input_type): with torch._dynamo.config.patch("traceable_tensor_subclasses", {DummyNDim}): def fn(x): if comparison(x, DummyNDim): return torch.ones(1, 1) else: return torch.zeros(2, 2) input = torch.ones(2, 2).as_subclass(input_type) exp_res = fn(input) act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input) self.assertEqual(exp_res, act_res) def test_torch_function_call_on_method(self): x = torch.ones(2, 2) y = torch.ones(2, 2) z = torch.ones(2, 2) wrapped = x.as_subclass(SigmoidToExpSubclass) wrapped2 = y.as_subclass(SigmoidToExpSubclass) def fn(w): return w.sigmoid() fn_opt = compile_full_eager(fn) res_exp = fn(wrapped) res_act = fn_opt(wrapped2) res_exp2 = z.exp() self.assertEqual(res_exp, res_act) self.assertEqual(res_exp, res_exp2) def test_user_overidden_method_unsupported(self): class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) def sigmoid(self): return None @torch.compile(backend="eager", fullgraph=True) def fn(x): x.sigmoid() msg = ( "Accessing overridden method/attribute sigmoid on a tensor" " subclass with a __torch_function__ override is not supported" ) with torch._dynamo.config.patch( "traceable_tensor_subclasses", {LocalSubclass} ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) def test_user_overidden_attr_unsupported(self): class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) ndim = 10 @torch.compile(backend="eager", fullgraph=True) def fn(x): return x.ndim msg = ( "Accessing overridden method/attribute ndim on a tensor" " subclass with a __torch_function__ override is not supported" ) with torch._dynamo.config.patch( "traceable_tensor_subclasses", {LocalSubclass} ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) def test_user_overidden_property_unsupported(self): class LocalSubclass(torch.Tensor): def __init__(self) -> None: self._ndim = 10 @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) @property def ndim(self): return self._ndim @ndim.setter def ndim(self, value): self._ndim = value @torch.compile(backend="eager", fullgraph=True) def fn(x): return x.ndim msg = ( "Accessing overridden method/attribute ndim on a tensor" " subclass with a __torch_function__ override is not supported" ) with torch._dynamo.config.patch( "traceable_tensor_subclasses", {LocalSubclass} ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) def test_overridden_method_guarding(self): class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) @torch.compile(backend="eager") def fn(x): return x.sigmoid() with torch._dynamo.config.patch( error_on_recompile=True, traceable_tensor_subclasses={LocalSubclass} ): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) fn(x) x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) with torch._dynamo.config.patch( traceable_tensor_subclasses={LocalSubclass} ), self.assertRaisesRegex( TypeError, "'bool' object is not callable", ): LocalSubclass.sigmoid = False fn(x) def test_torch_function_call_on_attr(self): x = torch.ones(2, 2) wrapped = x.as_subclass(DummyNDim) def fn(w): return w.ndim + torch.ones(2) fn_opt = compile_full_eager(fn) res_exp = fn(wrapped) res_act = fn_opt(wrapped) self.assertEqual(res_exp, res_act) self.assertEqual(res_exp, torch.ones(2) + 10) def test_torch_function_wrapper_class(self): x = torch.ones(2, 2) wrapped = WrapperSubclass(x) def fn(w): return torch.add(w, 1.0) fn_opt = compile_full_eager(fn) res_exp = fn(wrapped) res_act = fn_opt(wrapped) self.assertEqual(res_exp, res_act) def test_torch_function_wrapper_class_with_kwargs(self): x = torch.ones(2, 2) wrapped = WrapperSubclass(x) def fn(w): return torch.add(w, 1.0, alpha=2.0) fn_opt = compile_full_eager(fn) res_exp = fn(wrapped) res_act = fn_opt(wrapped) self.assertEqual(res_exp, res_act) def test_tensor_subclass_custom_attr(self): class AttrSubclass(torch.Tensor): x: int = 10 @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) @torch.compile(backend="eager", fullgraph=True) def fn(x): return x.x + torch.ones(2, 2) with traceable_subclass(AttrSubclass): input = torch.ones(2, 2).as_subclass(AttrSubclass) fn_opt = compile_full_eager(fn) res_exp = fn(input) res_act = fn_opt(input) self.assertEqual(res_exp, res_act) def test_compile_with_fake_tensor_dynamic_dim(self): x = torch.randn([3, 4]) def f(x): return torch.sin(x) def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count): torch._dynamo.reset() cnt = torch._dynamo.testing.CompileCounter() opt_f = torch.compile(f, backend=cnt, fullgraph=True) x1 = torch.rand_like(x) f(x) f(torch.randn([4, 3])) shape_env = ShapeEnv() with torch._subclasses.fake_tensor.FakeTensorMode( shape_env=shape_env ) as fake_mode: x_fake = fake_mode.from_tensor( x, symbolic_context=StatelessSymbolicContext( dynamic_sizes=[dim_dynamic for i in range(x.dim())] ), ) x1_fake = fake_mode.from_tensor( x1, symbolic_context=StatelessSymbolicContext( dynamic_sizes=[dim_dynamic for i in range(x.dim())] ), ) opt_f(x_fake) opt_f(x1_fake) self.assertEqual(cnt.frame_count, exp_frame_count) self.assertEqual(cnt.op_count, exp_op_count) test_dynamic_dim(f, x, DimDynamic.DYNAMIC, 1, 1) test_dynamic_dim(f, x, DimDynamic.DUCK, 1, 1) test_dynamic_dim(f, x, DimDynamic.STATIC, 1, 1) def test_compile_with_fake_tensor_automatic_dynamic(self): def f(x): return torch.sin(x) def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count): torch._dynamo.reset() cnt = torch._dynamo.testing.CompileCounter() opt_f = torch.compile(f, backend=cnt, fullgraph=True) shape_env = ShapeEnv() with torch._subclasses.fake_tensor.FakeTensorMode( shape_env=shape_env ) as fake_mode: for inp in inps: fake_inp = fake_mode.from_tensor( inp, symbolic_context=StatelessSymbolicContext( [dim_dynamic for i in range(x.dim())] ), ) opt_f(fake_inp) self.assertEqual(cnt.frame_count, exp_frame_count) self.assertEqual(cnt.op_count, exp_op_count) x = torch.randn([3, 4]) y = torch.randn([4, 5]) z = torch.randn([5, 6]) a = torch.randn([3, 5]) b = torch.randn([4, 4]) # When inputs' DimDynamic is DYNAMIC or DUCK, the inputs # to opt_f will be tensors with SymInt sizes. Dynamo will treat input # as dynamic automatically and will only compile once for dim_dynamic in [DimDynamic.DYNAMIC, DimDynamic.DUCK]: test_automatic_dynamic(f, [x, y, z], dim_dynamic, 1, 1) test_automatic_dynamic(f, [x, a, z], dim_dynamic, 1, 1) test_automatic_dynamic(f, [x, b, z], dim_dynamic, 1, 1) for dim_dynamic in [DimDynamic.STATIC]: # Recompile once, first with dim 0 and 1 become Dynamic test_automatic_dynamic(f, [x, y, z], dim_dynamic, 2, 2) # Recompile 2 times, first with dim 1 become Dynamic, second with dim 0 becomes Dynamic. test_automatic_dynamic(f, [x, a, z], dim_dynamic, 3, 3) # Recompile 2 times, first with dim 0 become Dynamic, second with dim 1 becomes Dynamic. test_automatic_dynamic(f, [x, b, z], dim_dynamic, 3, 3) def test_compile_with_functionalization(self): x = torch.randn([3, 4]) x_clone = x.clone() x_clone2 = x.clone() backend = EagerRecordGraphAndInputs() cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) @torch.compile(backend=cnt, fullgraph=True) def f(x): return x.add_(1.0) + torch.nn.functional.relu_(x) f_out = f(x) self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 3) self.assertEqual(len(backend.graphs), 1) self.assertEqual(len(backend.example_inputs), 1) actual = normalize_gm(backend.graphs[0].print_readable(print_output=False)) self.assertExpectedInline( actual, """\ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 4]"): l_x_ = L_x_ add_: "f32[3, 4]" = l_x_.add_(1.0) relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None return (add,) """, ) ff = torch.func.functionalize(f) ff_out = ff(x_clone) self.assertEqual(cnt.frame_count, 2) self.assertEqual(cnt.op_count, 6) self.assertEqual(len(backend.graphs), 2) self.assertEqual(len(backend.example_inputs), 2) actual = normalize_gm(backend.graphs[1].print_readable(print_output=False)) self.assertExpectedInline( actual, """\ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 4]"): l_x_ = L_x_ add_: "f32[3, 4]" = l_x_.add_(1.0) relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None return (add,) """, ) self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) # Cannot re-use the version from AOTAutograd, since that uses python functional tensors. def to_fun(x): x_functional = torch._to_functional_tensor(x) torch._mirror_autograd_meta_to(x, x_functional) return x_functional def aot_f_wrapper(func): @functools.wraps(func) def wrapper(*args, **kwargs): torch._enable_functionalization(reapply_views=False) try: func_args = pytree.tree_map(to_fun, args) func_kwargs = pytree.tree_map(to_fun, kwargs) return func(*func_args, **func_kwargs) finally: torch._disable_functionalization() return wrapper aot_ff = aot_f_wrapper(f) aot_ff_out = aot_ff(x_clone2) self.assertEqual(cnt.frame_count, 3) self.assertEqual(cnt.op_count, 9) self.assertEqual(len(backend.graphs), 3) self.assertEqual(len(backend.example_inputs), 3) actual = normalize_gm(backend.graphs[2].print_readable(print_output=False)) self.assertExpectedInline( actual, """\ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 4]"): l_x_ = L_x_ add_: "f32[3, 4]" = l_x_.add_(1.0) relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None return (add,) """, ) self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) self.assertEqual(f_out, ff_out) self.assertEqual(f_out, aot_ff_out) try: torch._enable_functionalization(reapply_views=False) xf = pytree.tree_map(to_fun, x) x_view = xf.t() with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"): f(x_view) finally: torch._disable_functionalization() def test_compile_higher_order_with_functionalization(self): backend = EagerRecordGraphAndInputs() cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) @torch.compile(backend=cnt, fullgraph=True) def f(x): return wrap(lambda x: x.add_(1.0), x) def check_count_and_graph( exp_frame_count, exp_op_count, exp_n_graph, exp_graph ): self.assertEqual(cnt.frame_count, exp_frame_count) self.assertEqual(cnt.op_count, exp_op_count) self.assertEqual(len(backend.graphs), exp_n_graph) actual = normalize_gm( backend.graphs[exp_n_graph - 1].print_readable(print_output=False) ) self.assertExpectedInline(actual, exp_graph, skip=1) t = torch.randn([3, 4]) t_clone = t.clone() t_clone2 = t.clone() f(t) check_count_and_graph( 1, 2, 1, """\ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 4]"): l_x_ = L_x_ wrap_body_0 = self.wrap_body_0 wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None getitem: "f32[3, 4]" = wrap[0]; wrap = None return (getitem,) class wrap_body_0(torch.nn.Module): def forward(self, l_x_: "f32[3, 4]"): add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None return (add_,) """, ) ff = torch.func.functionalize(f) ff_out = ff(t_clone) # frame count and op count are incremented due to re-compilation check_count_and_graph( 2, 4, 2, """\ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 4]"): l_x_ = L_x_ wrap_body_0 = self.wrap_body_0 wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None getitem: "f32[3, 4]" = wrap[0]; wrap = None return (getitem,) class wrap_body_0(torch.nn.Module): def forward(self, l_x_: "f32[3, 4]"): add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None return (add_,) """, ) try: x = torch._to_functional_tensor(t_clone2) torch._mirror_autograd_meta_to(t_clone2, x) torch._enable_functionalization(reapply_views=False) aot_f_out = f(x) finally: torch._disable_functionalization() # frame count and op count are incremented due to re-compilation check_count_and_graph( 3, 6, 3, """\ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 4]"): l_x_ = L_x_ wrap_body_0 = self.wrap_body_0 wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None getitem: "f32[3, 4]" = wrap[0]; wrap = None return (getitem,) class wrap_body_0(torch.nn.Module): def forward(self, l_x_: "f32[3, 4]"): add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None return (add_,) """, ) def test_has_torch_function(self): class MyTensor: @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if func is torch.max: return torch.tensor(123) return func(*args, **kwargs) class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return func(*args, **kwargs) def fn(x): return torch.overrides.has_torch_function_unary( x ), torch.overrides.has_torch_function_variadic(x) for test_class in [MyTensor, LocalSubclass]: x = test_class() ref0 = fn(x) ref1 = fn(4) opt_fn = torch._dynamo.optimize("eager")(fn) res0 = opt_fn(x) res1 = opt_fn(4) self.assertEqual(ref0, res0) self.assertEqual(ref1, res1) def test_wrapper_subclass_guards_on_inner_tensor(self): # Holds an inner tensor, that has a distinct shape from the outer wrapper tensor. # Also adds additional guards on the inner tensor's sizes. # When the first input to an op has x.shape[0] > 5, we insert an extra add node. class DoubleSizeMaybeAddGeThreeTensor(torch.Tensor): @staticmethod def __new__(cls, inner): # Double the outer-most dimension outer_shape = (inner.shape[0] * 2,) + inner.shape[1:] return torch.Tensor._make_wrapper_subclass( # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. # Calling the overload that has kwargs causes us to go down the first overload path, # which will **always** specialize sizes. # We should probably eventually fix this so that the first overload can just handle dynamic shapes. cls, outer_shape, inner.stride(), None, None, inner.dtype, inner.layout, inner.device, False, inner.requires_grad, ) def __init__(self, inner): self.inner_elem = inner def __tensor_flatten__(self): return ["inner_elem"], None @staticmethod def __tensor_unflatten__(inner_tensors, _, outer_size, outer_stride): return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"]) def __repr__(self): return f"DoubleSizeMayberAddGeThreeTensor({repr(self.inner_elem)})" @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} args_inner = torch.utils._pytree.tree_map_only( DoubleSizeMaybeAddGeThreeTensor, lambda x: x.inner_elem, args ) out_inner = func(*args_inner, **kwargs) # Add guards on the inner tensor's sizes if args_inner[0].shape[0] > 3: out_inner += 2 return DoubleSizeMaybeAddGeThreeTensor(out_inner) curr_var_to_val = None curr_var_to_sources = None guards = None def backend(gm, args): context = torch._guards.TracingContext.get() # Grab info on sources and guards from the shapeenv nonlocal curr_var_to_val nonlocal curr_var_to_sources nonlocal guards guards = [str(g.expr) for g in context.fake_mode.shape_env.guards] curr_var_to_val = { str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items() } curr_var_to_sources = { str(k): v[0].name() for k, v in context.fake_mode.shape_env.var_to_sources.items() } return gm @torch.compile(backend=backend) def fn(x): if x.shape[0] < 13: return torch.mul(x, x) else: return torch.div(x, x) inp = torch.ones(4, 4) x = DoubleSizeMaybeAddGeThreeTensor(inp) torch._dynamo.mark_dynamic(x, 0) res = fn(x) # During fakeifying, we end up allocating a separate symint # for the outer and inner tensor (in this test, s0 is unused). expected_var_to_val = { "s0": 8, "s1": 4, } expected_var_to_sources = { "s0": "L['x'].size()[0]", "s1": "L['x'].inner_elem.size()[0]", } self.assertEqual(curr_var_to_val, expected_var_to_val) self.assertEqual(curr_var_to_sources, expected_var_to_sources) self.assertExpectedInline( "\n".join(guards), """\ Eq(2*s1, s0) 2*s1 < 13 s1 > 3""", ) def test_wrapper_subclass_with_same_sized_inner_tensor(self): # shouldn't recompile for different sizes when dynamic=True sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7)) self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=True)) # should recompile for different data size when dynamic=False sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6)) self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) # avoid recompile using manual mark_dynamic() for different data size sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) # NB: mark_dynamic() on outer tensor should translate to inner tensors of the same size torch._dynamo.mark_dynamic(sub1, 0) torch._dynamo.mark_dynamic(sub1, 1) sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6)) self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) def test_wrapper_subclass_with_differently_sized_inner_tensor(self): # should recompile for different scale size when dynamic=False sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3)) sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5)) self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) # still recompiles using manual mark_dynamic() on outer for different scale size sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3)) # NB: mark_dynamic() on outer tensor doesn't translate to inner tensors of different size torch._dynamo.mark_dynamic(sub1, 0) torch._dynamo.mark_dynamic(sub1, 1) sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5)) self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) def test_recompiles_with_optional_inner_tensor(self): def f(x): return x + 1 # sub1 does not have the optional tensor specified while sub2 does sub1 = OptionalScaledTensor(torch.randn(2, 4), None) sub2 = OptionalScaledTensor(torch.randn(2, 4), torch.randn(2, 4)) # sanity check; don't recompile for same input self.assertFalse(_recompiles_for_inputs(f, (sub1,), (sub1,), dynamic=True)) self.assertFalse(_recompiles_for_inputs(f, (sub2,), (sub2,), dynamic=True)) # these should recompile; optional tensor changes between specified and unspecified self.assertTrue(_recompiles_for_inputs(f, (sub1,), (sub2,), dynamic=True)) self.assertTrue(_recompiles_for_inputs(f, (sub2,), (sub1,), dynamic=True)) f_compiled = torch.compile(f, backend="aot_eager") self.assertEqual(f(sub1)._data, f_compiled(sub1)._data) self.assertEqual(f(sub2)._data, f_compiled(sub2)._data) def test_torch_dispatch_subclass_guard_recompile(self): x = torch.ones(2, 2) x_two = TwoTensor(x.clone(), x.clone()) def fn(w): return torch.add(w, 1.0) fn_opt = torch.compile(backend="eager")(fn) ref = fn(x_two) res = fn_opt(x_two) self.assertEqual(ref, res) # ensure no recompilation on same input type with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): fn_opt(TwoTensor(x + 1, x + 2)) # recompile! ref = fn(x) res = fn_opt(x) self.assertEqual(ref, res) def test_tensor_subclass_ctx_guards(self): x = CtxSubclassTensor(torch.ones(2), 3) x2 = CtxSubclassTensor(torch.ones(2), 3) x3 = CtxSubclassTensor(torch.ones(2), 4) _check_recompiles(self, lambda x: x * x, (x,), (x2,), False) _check_recompiles(self, lambda x: x * x, (x,), (x3,), True) def test_tensor_subclass_ctx_recursive_guards(self): x0 = torch.ones(2, 2) x1 = CtxSubclassTensor(x0.clone(), 2) x2 = CtxSubclassTensor(x0.clone(), 3) tt0 = TwoTensor(x0.clone(), x1) tt1 = TwoTensor(x0.clone(), x2) _check_recompiles(self, lambda x: x * x, (tt0,), (tt1,), True) def test_tensor_subclass_ctx_custom_guards_override(self): class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor): @classmethod def __metadata_guard__(cls, orig_data, other): return orig_data[0] <= other[0] x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 2) x2 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3) x3 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 1) _check_recompiles(self, lambda x: x * x, (x,), (x2,), False) _check_recompiles(self, lambda x: x * x, (x,), (x3,), True) def test_tensor_subclass_ctx_custom_guards_error_arg_num(self): import torch._dynamo.exc class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor): @classmethod def __metadata_guard__(cls, y): # Shouldn't reach here return False x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3) self.assertRaisesRegex( torch._dynamo.exc.InternalTorchDynamoError, "Tensor subclass method __metadata_guard__ must take exactly two subclass metadata arguments", lambda: torch.compile(lambda x: x * x)(x), ) def test_tensor_subclass_ctx_custom_guards_error_not_classmethod(self): import torch._dynamo.exc class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor): def __metadata_guard__(self, x, y): return False x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3) self.assertRaisesRegex( torch._dynamo.exc.InternalTorchDynamoError, "Tensor subclass method __metadata_guard__ must be a classmethod", lambda: torch.compile(lambda x: x * x)(x), ) def test_subclass_constructor_proxying(self): import dataclasses from collections import namedtuple from typing import Any @dataclasses.dataclass(frozen=True) class SubclassTensorArgs: original_shape: torch.Size device: torch.device inner_meta: Any SubclassTensorArgs2 = namedtuple( "SubclassTensorArgs2", [ "original_shape", "device", "inner_meta", ], ) class SubclassTensor(torch.Tensor): @staticmethod def __new__(cls, a, meta): shape = a.shape kwargs = {} kwargs["strides"] = a.stride() kwargs["storage_offset"] = a.storage_offset() kwargs["device"] = a.device kwargs["layout"] = a.layout kwargs["requires_grad"] = a.requires_grad kwargs["dtype"] = a.dtype out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) return out def __init__(self, a, meta): self.a = a self.meta = meta def __repr__(self): a_repr = repr(self.a) return f"SubclassTensor({a_repr})" def __tensor_flatten__(self): return ["a"], self.meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, _, __): a = inner_tensors["a"] return SubclassTensor(a, meta) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): if kwargs is None: kwargs = {} args_a = pytree.tree_map( lambda x: x.a if isinstance(x, SubclassTensor) else x, args ) kwargs_a = pytree.tree_map( lambda x: x.a if isinstance(x, SubclassTensor) else x, kwargs ) out_a = func(*args_a, **kwargs_a) out = pytree.tree_map( lambda x: SubclassTensor( x, SubclassTensorArgs2(x.shape, x.device, None) ) if isinstance(x, torch.Tensor) else x, out_a, ) return return_and_correct_aliasing(func, args, kwargs, out) @torch.compile(fullgraph=True) def f1(x): meta = SubclassTensorArgs( x.shape, x.device, SubclassTensorArgs(x.shape, x.device, None) ) out = SubclassTensor(x, meta) return out * out x = torch.randn(3, 3) f1(x) @torch.compile(fullgraph=True) def f1(x): meta = SubclassTensorArgs2( x.shape, x.device, SubclassTensorArgs2(x.shape, x.device, None) ) out = SubclassTensor(x, meta) return out * out x = torch.randn(3, 3) f1(x) def test_torch_function_subclass_survives_into_aot_autograd(self): # If you have a tensor subclass that relies on dispatch into the same op # without unwrapping and calling torch._C.DisableTorchFunctionSubclass(), # the torch function-ness will survive into AOTAutograd. Today, NestedTensor # actually relies on this behavior! Because that torch function logic # runs during AOTAutograd, this test tests that there is no logic below # that relies torch function that gets unexpectedly disabled after we # redispatch from the subclass's torch function. class SubTensor(torch.Tensor): @staticmethod def __new__(cls, t): return torch.Tensor._make_wrapper_subclass( cls, t.shape, t.stride(), t.storage_offset(), torch.contiguous_format, t.dtype, torch.strided, t.device, False, t.requires_grad, "sizes", False, False, None, ) def __init__(self, t): super().__init__() self._t = t def __tensor_flatten__(self): return ["_t"], {} @staticmethod def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): t = inner_tensors["_t"] return SubTensor(t) def __repr__(self): return f"SubTensor({self._t})" @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs new_args = pytree.tree_map_only(SubTensor, lambda s: s._t, args) output = func(*new_args, **kwargs) output = pytree.tree_map_only( torch.Tensor, lambda t: SubTensor(t), output ) return output @torch.compile(dynamic=True) def f(x): return x.unflatten(-1, [2, 5]) s = SubTensor(torch.randn(3, 10)) f(s) # Guard validation upsets the guard # https://github.com/pytorch/pytorch/issues/129936 @unittest.expectedFailure def test_recompile_with_symbool_inputs(self): def f(pred: bool): if pred: return torch.ones([3, 4]) else: return torch.ones([4, 3]) def test_recompilation( f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards ): torch._dynamo.reset() shape_env = ShapeEnv() backend = torch._dynamo.testing.EagerAndRecordGraphs() cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) f_cond = torch.compile(f, backend=cnt, fullgraph=True) with torch._subclasses.fake_tensor.FakeTensorMode( shape_env=shape_env ) as fake_mode: fake_inp = fake_mode.from_tensor( x, symbolic_context=StatelessSymbolicContext( dynamic_sizes=[DimDynamic.DYNAMIC for i in range(x.dim())] ), ) for i, size in enumerate(sizes): pred = fake_inp.size(0) == size f_cond(pred) actual = normalize_gm( backend.graphs[exp_frame_count[i] - 1].print_readable( print_output=False ) ) actual_guard_str = [str(guard.expr) for guard in shape_env.guards] self.assertExpectedInline(actual, exp_graphs[i]) self.assertEqual(cnt.frame_count, exp_frame_count[i]) self.assertEqual(actual_guard_str, exp_shape_env_guards[i]) true_graph = """\ class GraphModule(torch.nn.Module): def forward(self): ones: "f32[3, 4]" = torch.ones([3, 4]) return (ones,) """ false_graph = """\ class GraphModule(torch.nn.Module): def forward(self): ones: "f32[4, 3]" = torch.ones([4, 3]) return (ones,) """ test_recompilation( f, torch.randn([3, 4]), [3, 3, 4, 5], exp_graphs=[true_graph, true_graph, false_graph, false_graph], exp_frame_count=[1, 1, 2, 2], exp_shape_env_guards=[ [], # s0 is specialized and guarded in outter shape_env when dynamo checks the guards ["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"], [ "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)", ], [ "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)", "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", ], ], ) test_recompilation( f, torch.randn([3, 4]), [4, 5, 3, 3], exp_graphs=[false_graph, false_graph, true_graph, true_graph], exp_frame_count=[1, 1, 2, 2], exp_shape_env_guards=[ [], # s0 is specialized and guarded in outter shape_env when dynamo checks the guards ["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"], [ "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", ], [ "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", ], ], ) def test_wrapper_subclass_dynamo_attribute_access_on_intermediate(self): def f(x_subclass): tmp_subclass = torch.add(x, 1) return torch.mul(tmp_subclass._scale, tmp_subclass._constant) x = ScaledTensor(torch.randn(2, 4), torch.randn(3), constant=2) out_ref = f(x) out_test = torch.compile(f, backend="aot_eager", fullgraph=True)(x) self.assertEqual(out_ref, out_test) def test_support_bases(self): import abc import torch.fx._symbolic_trace class Meta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta): def __new__(cls, name, bases, dct): x = super().__new__(cls, name, bases, dct) x.attr = 100 return x class Multistreamable(abc.ABC): # noqa: B024 pass class Foo(Multistreamable, metaclass=Meta): pass @torch.compile(backend="eager", fullgraph=True) def f(x): typ = type(Foo()) typ.__bases__ return typ.__bases__ self.assertEqual(f(torch.randn(1)), (Multistreamable,)) @torch.compile(backend="eager", fullgraph=True) def g(x): typ = type(Foo()) typ.__base__ return typ.__base__ self.assertEqual(g(torch.randn(1)), Multistreamable) @parametrize("dynamic", [False, True]) def test_subclass_views(self, dynamic): def _get_views(t): # returns (view: Tensor, expects_raises_false) # Note that any closed-over SymInts will be symbolicized during fake-ification. yield t.narrow(dim=-1, start=3, length=8), False yield t.split(5, -1)[2], False yield t.split_with_sizes([9, 6], -1)[1], False yield t.unsqueeze(-1).expand(4, 15, 10), False yield t.select(-1, 6), False # https://github.com/pytorch/pytorch/issues/128649 yield t[2:3, 5:9], dynamic yield t.view(-1, 15), False def f(x): return x * 2 compiled_f = torch.compile( f, backend="aot_eager", fullgraph=True, dynamic=dynamic ) # Take a view of a subclass to pass as input. t = TwoTensor(torch.randn(4, 15), torch.randn(4, 15)) for view, expects_raises in _get_views(t): torch._dynamo.reset() out_ref = f(view) if expects_raises: with self.assertRaises(AssertionError): out_test = compiled_f(view) else: out_test = compiled_f(view) self.assertEqual(out_ref, out_test) @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_mark_static_with_subclass_desugaring(self): from typing import Any, Callable, Dict, List, Optional from torch._dynamo.decorators import mark_static_address from torch._inductor.compile_fx import compile_fx from torch._inductor.cudagraph_utils import BoxedDeviceIndex from torch._inductor.utils import BoxedBool x_inner = torch.ones(4) x = TwoTensor(x_inner, x_inner) mark_static_address(x, guard=False) def inner_compile( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], cudagraphs: Optional[BoxedBool] = None, static_input_idxs: Optional[List[int]] = None, is_backward: bool = False, graph_id: Optional[int] = None, cpp_wrapper: bool = False, aot_mode: bool = False, is_inference: bool = False, boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, user_visible_outputs: Optional[Dict[str, None]] = None, layout_opt: Optional[bool] = None, extern_node_serializer: Optional[Callable[[List[Any]], Any]] = None, ): self.assertEqual(static_input_idxs, [1, 2]) return gm compiler = functools.partial(compile_fx, inner_compile=inner_compile) @torch.compile(backend=compiler) def fn(t0, t1, t2): return t0 + t1 + t2 + 2 fn(torch.ones(4), x, torch.ones(4)) instantiate_parametrized_tests(SubclassTests) class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase): def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True): return get_jagged_tensor(nested_size, offsets, requires_grad) def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True): # Makes a jagged tensor with N constituent tensors with size # as specified ((S0, S1, S2), D) max_dim = (starts + lengths).max() values_tensor = torch.randn( starts.shape[0], max_dim.item(), inner_dim, requires_grad=requires_grad, dtype=torch.float64, ) return jagged_from_tensor_and_lengths(values_tensor, starts, lengths) def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles) def test_unary_does_not_recompile(self): nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None) nt2, _ = self._get_jagged_tensor(((3, 4, 5, 6), 4), None) self._check_recompiles(lambda nt1: nt1.sin(), (nt1,), (nt2,), False) def test_binary_does_not_recompile(self): def binary(nt1, nt2): if nt1.shape == nt2.shape: return nt1 + nt2 else: return nt1.sin() # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). # This causes a recompile later on when it realizes the batch and last dim # should not always be equal. To avoid that, we use (3, j0, 5) here. nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets) nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None) nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets) self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False) def test_binary_recompiles(self): def binary(nt1, nt2): if nt1.shape == nt2.shape: return nt1 + nt2 else: return nt1.sin() # Binary recompiles because singleton ints no longer match nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets) nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True) def _validate_compile(self, fn, arg_fn): def _gen_grad_outputs(out_val): if isinstance(out_val, (list, tuple)): return tuple(torch.ones_like(c) for c in out_val) else: return (torch.ones_like(out_val),) with self.branch_nested_state(): from torch.nested._internal.nested_tensor import _tensor_symint_registry # Validate that compilation does not modify eager state registry_before = list(_tensor_symint_registry.items()) count_before = torch.nested._internal.nested_tensor._tensor_id_counter guards_exported = [] guards_failed = [] def append_guard_export(guards): for g in guards: if g.code_list is not None: guards_exported.append(g.code_list[0]) def append_guard_fail(guards): guards_failed.extend(guards) compiled = torch._dynamo.optimize( nopython=True, backend="aot_eager", guard_export_fn=append_guard_export, guard_fail_fn=append_guard_fail, )(fn) registry_after = list(_tensor_symint_registry.items()) count_after = torch.nested._internal.nested_tensor._tensor_id_counter self.assertEqual(registry_before, registry_after) self.assertEqual(count_before, count_after) args = arg_fn() compile_out = compiled(*args) compile_grads = [] g_args = [arg for arg in args if arg.requires_grad] if len(g_args) > 0: compile_grad_outputs = _gen_grad_outputs(compile_out) compile_grads = torch.autograd.grad( compile_out, inputs=g_args, grad_outputs=compile_grad_outputs ) with self.branch_nested_state(): args = arg_fn() ref_out = fn(*args) ref_grads = [] g_args = [arg for arg in args if arg.requires_grad] if len(g_args) > 0: ref_grad_outputs = _gen_grad_outputs(ref_out) ref_grads = torch.autograd.grad( ref_out, inputs=g_args, grad_outputs=ref_grad_outputs ) # Validate correctness forward if isinstance(compile_out, (list, tuple)): # TODO: Fix assertEqual() to support NJTs so this isn't necessary self.assertEqual(len(compile_out), len(ref_out)) for c, r in zip(compile_out, ref_out): self.assertEqualIgnoringNestedInts(c, r) else: self.assertEqualIgnoringNestedInts(compile_out, ref_out) # Validate correctness backward for compile_grad, ref_grad in zip(compile_grads, ref_grads): self.assertEqualIgnoringNestedInts(compile_grad, ref_grad) return guards_exported, guards_failed # Note: [What kind of guards are involved in nested tensor compilation] # # Until we implement UnionFind, dynamic shapes guards are not involved. # we rely only on dynamo's tensor aliasing guards. # # This is possible because dynamo able to generate tensor aliasing guards # not only for the outer tensor, but also for the inner tensor. # # The case where dynamic shapes guards would eventually come into play is # when my inputs are (1) two non-aliased tensors, but (2) declared as # equal using a "trust me assert equal" API. # Note: [Compiling nested tensor global state] # # Today there are two pieces of global eager state that NJTs deals with: # - tensor_id_counter: a global counter that assigns unique ids to tensors # - tensor_symint_registry: maps tensor to nested int # - this is used in eager only (we should get rid of this because it is # not necessary to cache nested int in eager) # - during tracing, we DO need to cache nested int, but we do so on # the FakeTensor. # # Ideally we would like to satisfy the following: # - (1) The eager state is not mutated during tracing # - (2) Running the compiled function should mutate the eager state in the # same way that running the eager function would # (a) The global counter should be incremented # (b) The registry is updated in the same way # # Today we can satisfy (1) and (2a) but cannot satisfy (2b) # # Today, (1) is satisfied because we maintain a separate counter during # tracing, and cache nested int on FakeTensor instead of relying on # tensor_symint_registry. # # (2) is cannot be completely satisfied because we trace away the # side-effectful operations (which we can fix this by wrapping the # side-effectful operations in a custom op, and threading through effect # tokens.) The current plan is to do that in the UnionFind impl. # # Interestingly, despite this, the state is mutated in a way that is somewhat # close to what we want, e.g. if I construct a nested tensor using an # offsets in the compiled region and return it, AOTAutograd runtime wrapper # must rewrap the inner->inner graph outputs back into subclass. This # triggers the eager logic to run, updating the counter and registry. # # Notably however, compile differs in two ways from eager: # (1) The order in which the offsets are assigned ids is differnet # the registry would be set in the order the offsets are returned # which is not necessarily the same order as they were constructed. # (2) If a NestedTensor is not returned, then the AOTAutograd wrapping # logic will not be triggered. # # I claim that correctness is not affected by these differences today. # e.g. there is never the case where two distinct offsets silently share # the same id. # # (1) is clearly not a problem, and (2) should only be a problem if # the nested int is returned on its own, without the corresponding NJT # being returned. This is not a problem in the current implementation # because returning only a shape is not supported! # Note: [Creating symbolic nested int] # # We must create a symbolic nested int when we construct a nested tensor # from a tensor. There are two main cases: # # 1. The offsets has NOT been used to construct a NJT # - Create a new plain nested int with current val of fake nt id counter # - Increment the fake nt id counter # - Create a new symint with plain nested int as hint # 2. The offsets HAS been used to construct a NJT # - Create a new symint with plain nested int as hint # # More details on case 2: # - During fakification of the offsets, we check the eager registry, and # if the tensor HAS been used to construct a NJT, # we create a symint, with the existing nested int as hint, and cache # it on to the FakeTensor. # # [ Always use ephemeral source ] # # We create the new symint ALWAYS with ephemeral source whether that is # in case (1) or (2) even though we could've had a proper source for case (2). # Using a proper source would enable a few more (edge) cases, but since # we plan to handle things more holistically in the future anyway, we don't # bother doing so today. # # Using an ephemeral source has some consequences. But we are happy if # - We do not silently miss recompiles, e.g. we guard when necessary. # We know that this is true, because dynamo guards alone are already # sufficient. # - We are not producing errors for the cases we care about # # The main case we care about is when we guard that two shapes are equal. # In this case, the replacements logic would simplify away the ephemeral # symbol, and there is no error produced. # The unsupported case is when we guard that two shapes are not equal, in # which, we will try and fail to generate a guard. # # Case 1: in-graph construction where the offsets are passed as inputs # def test_in_graph_construction_from_input(self): # The offsets is passed as an input def fn(values, offsets): return torch.nested.nested_tensor_from_jagged(values * 2, offsets) * 2 values = torch.randn(10, 5, requires_grad=True) offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) self._validate_compile(fn, arg_fn=lambda: (values, offsets)) # Do not specialize on the offsets with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): different_offsets = torch.tensor([0, 1, 5, 10], dtype=torch.int64) self._validate_compile(fn, arg_fn=lambda: (values, different_offsets)) def test_in_graph_construction_from_input_2(self): # Construct two NJTs, both are passed as inputs def fn(values, offsets1, offsets2): nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets1) nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2) return nt2, nt1 values = torch.randn(10, 5, requires_grad=True) offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64) # 1. Offsets are different guards_exported, guards_failed = self._validate_compile( fn, arg_fn=lambda: (values, offsets, offsets2) ) self.assertEqual(len(guards_failed), 0) self.assertNotIn("L['offsets1'] is L['offsets2']", guards_exported) # TODO # 2. Offsets are the same new_guards_exported, _ = self._validate_compile( fn, arg_fn=lambda: (values, offsets, offsets) ) self.assertTrue(any("Duplicate tensors found" in g for g in guards_failed)) self.assertIn("L['offsets1'] is L['offsets2']", new_guards_exported) with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): offsets3 = offsets.clone() self._validate_compile(fn, arg_fn=lambda: (values, offsets3, offsets3)) # Do a binary op def fn(values, offsets, offsets2): nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets) nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2) return nt1 * nt2 self._validate_compile(fn, arg_fn=lambda: (values, offsets, offsets)) def test_in_graph_construction_from_input_4(self): # The offsets is taken from an NJT input def fn(nt, other_values): nt2 = torch.nested.nested_tensor_from_jagged(other_values, nt.offsets()) return nt + nt2 values = torch.randn(9, 5, requires_grad=True) other_values = torch.randn(9, 5, requires_grad=True) offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64) def arg_fn(values=values, other_values=other_values, offsets=offsets): nt = torch.nested.nested_tensor_from_jagged(values, offsets) return nt, other_values self._validate_compile(fn, arg_fn=arg_fn) # Do not specialize on the offsets with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): different_offsets = offsets.clone() def arg_fn( values=values, other_values=other_values, offsets=different_offsets ): nt = torch.nested.nested_tensor_from_jagged(values, different_offsets) return nt, other_values self._validate_compile(fn, arg_fn=arg_fn) def test_in_graph_construction_from_input_5(self): # Construct from lengths instead of offsets def fn(values, lengths): nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths) return nt.sin() values = torch.randn(9, 5, requires_grad=True) lengths = torch.tensor([2, 4, 3]) self._validate_compile(fn, arg_fn=lambda: (values, lengths)) # # Case 2: in-graph construction where offsets are graph intermediates # def test_in_graph_construction_from_intermediate(self): # offsets is an intermediate computed from lengths def fn(values, lengths): offsets = torch.cat([lengths.new_zeros(1), lengths.cumsum(0)]) nt = torch.nested.nested_tensor_from_jagged(values, offsets) nt2 = torch.nested.nested_tensor_from_jagged(values, offsets) return (nt * nt2).sin() values = torch.randn(9, 5, requires_grad=True) lengths = torch.tensor([2, 4, 3]) self._validate_compile(fn, arg_fn=lambda: (values, lengths)) # Do not specialize on the lengths with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): different_lengths = lengths.clone() self._validate_compile(fn, arg_fn=lambda: (values, different_lengths)) def test_in_graph_construction_from_intermediate_2(self): def fn(values, offsets): return torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone()) values = torch.randn(10, 5, requires_grad=True) offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) self._validate_compile(fn, arg_fn=lambda: (values, offsets)) def test_in_graph_construction_from_intermediate_3(self): # Note that due to CSE, clone is not necessarily called twice! def fn(values, offsets): nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone()) nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets.clone()) return nt2, nt1 values = torch.randn(10, 5, requires_grad=True) offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) self._validate_compile(fn, arg_fn=lambda: (values, offsets)) def test_in_graph_construction_from_intermediate_4(self): # Shared intermediate (should be same as case #1) def fn(values): offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) nt = torch.nested.nested_tensor_from_jagged(values, offsets) values2 = torch.ones_like(values) nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets) return nt * nt2 values = torch.randn(10, 5).requires_grad_(True) self._validate_compile(fn, arg_fn=lambda: (values,)) # AssertionError: s2 (could be from ['', @unittest.expectedFailure def test_in_graph_construction_from_intermediate_5(self): # non-shared intermediate def fn(values): offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) nt = torch.nested.nested_tensor_from_jagged(values, offsets) values2 = torch.ones_like(values) nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets.clone()) if nt2.shape[1] != nt.shape[1]: return nt * 2 else: return nt * 3 values = torch.randn(10, 5).requires_grad_(True) self._validate_compile(fn, arg_fn=lambda: (values,)) # # Case 3: in-graph construction where offsets are both direct graph inputs # and passed in as part of an NJT's offsets. # def test_in_graph_construction_mixed(self): def fn(nt, values, offsets): nt2 = torch.nested.nested_tensor_from_jagged(values, offsets) return nt * nt2 values = torch.randn(10, 5, requires_grad=True) offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) def arg_fn(values=values, offsets=offsets): nt = torch.nested.nested_tensor_from_jagged(values, offsets) return nt, values, offsets self._validate_compile(fn, arg_fn) # See Note: [Creating symbolic nested int] # AssertionError: s2 (could be from ['', @unittest.expectedFailure def test_in_graph_construction_mixed_2(self): def fn(nt, values, offsets, nt2): # Intermediate offsets has ephemeral source intermediate_nt = torch.nested.nested_tensor_from_jagged( values, offsets.clone() ) # This creates a dynamic shapes neq guard if nt2.shape[1] != intermediate_nt.shape[1]: # We should always go here. nt = nt * 2 return nt values = torch.randn(10, 5, requires_grad=True) offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64) def arg_fn(values=values, offsets=offsets, offsets2=offsets2): # Values is shared, but it shouldn't matter nt = torch.nested.nested_tensor_from_jagged(values, offsets) nt2 = torch.nested.nested_tensor_from_jagged(values, offsets2) return nt, values, offsets, nt2 self._validate_compile(fn, arg_fn) def test_in_graph_construction_mixed_3(self): # More involved mixed case def fn(nt, values, offsets): nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets) nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets) return nt1 + nt2 + nt values = torch.randn(9, 5, requires_grad=True) offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64) def arg_fn(values=values, offsets=offsets): nt = torch.nested.nested_tensor_from_jagged(values, offsets) return nt, values, offsets self._validate_compile(fn, arg_fn) def test_return_shape(self): nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) def fn(nt): return (nt * 2).shape compiled = torch.compile(fn, fullgraph=True, backend="aot_eager") compiled(nt) def test_inference_tensor(self): with torch.inference_mode(): nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) def fn(n): return n * 2 torch.compile(fn, backend="eager")(nt) # TODO: cannot parametrize this test class with device for some reason def _test_autograd(self, backend): a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64) b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64) c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64) nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) # TODO: Switch to public API when it exists nt2, _ = jagged_from_list([a, b, c], nt.offsets()) def fn1(nt1, nt2): return (nt1 + nt2).sin().cos() compiled_f = torch.compile(fn1, fullgraph=True, backend=backend, dynamic=True) out = compiled_f(nt, nt2) out_buffer = out.values() ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c)) out_ref = fn1(nt, nt2) out_buffer_ref = out_ref.values() ga_ref, gb_ref, gc_ref = torch.autograd.grad(out_buffer_ref.sum(), (a, b, c)) self.assertTrue(torch.allclose(ga, ga_ref)) self.assertTrue(torch.allclose(gb, gb_ref)) self.assertTrue(torch.allclose(gc, gc_ref)) def test_basic_autograd(self): self._test_autograd("aot_eager") @requires_cuda def test_basic_autograd_inductor(self): self._test_autograd("inductor") def test_subclass_with_mutation_in_graph(self): # In this graph, we have an in-graph mutation, i.e. a mutation that is allowed # to remain in the graph. Normally this is allowed, but it's not allowed if # the graph handles subclasses at all. # Whether the mutation is allowed or not allowed in the graph alters the number # of outputs from the forward graph. Previously, a bug in this handling meant # that sometimes the expected number and actual number of outputs from the # joint graph did not match, causing assertion failures. def fn(x, y): z = x.sin() y.sin_() return z.cos(), y.cos() fn_c = torch.compile(fn, backend="inductor") values = [torch.rand((i, 8), requires_grad=True) for i in range(1, 6)] values_copy = [x.detach().clone().requires_grad_(True) for x in values] nt, offsets = jagged_from_list(values, None) nt_copy, offsets = jagged_from_list(values_copy, offsets) y = torch.rand((4, 8)) y_copy = y.clone() ret = fn_c(nt, y)[0] ref = fn(nt_copy, y_copy)[0] self.assertEqual(ret.values(), ref.values()) ret.values().sum().backward() ref.values().sum().backward() for ref_v, res_v in zip(values_copy, values): self.assertEqual(ref_v.grad, res_v.grad) @torch._dynamo.config.patch({"capture_scalar_outputs": True}) def test_unbind(self): # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). # This causes a recompile later on when it realizes the batch and last dim # should not always be equal. To avoid that, we use (3, j0, 5) here. nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None) nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None) def fn(x): return x.unbind() compiled_f = torch.compile(fn, fullgraph=True, backend="eager", dynamic=True) out = compiled_f(nt) out_ref = fn(nt) # correctness self.assertEqual(len(out), len(out_ref)) for x, x_ref in zip(out, out_ref): self.assertTrue(torch.allclose(x, x_ref)) # We specialize on the length of offsets, e.g. (1) we recompile if the # length of the offsets is different. (2) we don't recompile if the # length of the offsets is the same, even if the size of the constituent # tensors are different. self._check_recompiles(fn, (nt,), (nt2,), False) self._check_recompiles(fn, (nt,), (nt3,), True) def test_inline_nested_tensor_from_jagged(self): nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) def fn(x): return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets()) torch.compile(fn, fullgraph=True, backend="aot_eager")(nt) # The test here: nn.Parameters that are secretly subclasses # have a metaclass that overrides __isinstance__, # that dynamo needs to respect when it inlines the if statement. def test_param_subclass_isinstance_input(self): x_inner = torch.randn(16, 16, requires_grad=True) x = torch.nn.Parameter(TwoTensor(x_inner, x_inner)) m = torch.nn.Linear(16, 16) m.weight = x def fn(): if isinstance(m.weight, torch.nn.Parameter): return m.weight + 1 else: return m.weight + 2 out_ref = fn() out_test = torch.compile(fn, backend="aot_eager")() self.assertEqual(out_ref, out_test) def _input_view_test(self, nt_view_name): nt_view = VIEW_TEST_CASES[nt_view_name]() def fn(x): return x.sin() out_ref = fn(nt_view) torch._dynamo.reset() compile_fn = torch.compile( fn, fullgraph=True, backend="aot_eager", dynamic=True ) out = compile_fn(nt_view) # Check metadata and values are correct self.assertTrue(out.size() == out_ref.size()) self.assertTrue(out.stride() == out_ref.stride()) if out.is_nested: self.assertTrue(torch.allclose(out.values(), out_ref.values())) else: self.assertTrue(torch.allclose(out, out_ref)) # Check that no upper/lower bound guards are incurred def backend(gm, args): context = torch._guards.TracingContext.get() guards = [str(g.expr) for g in context.fake_mode.shape_env.guards] # varies based on the type of view guard_str = "\n".join(guards) if nt_view_name == "subclass_dense": self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""") elif nt_view_name == "dense_subclass_dense_subclass": self.assertExpectedInline( guard_str, """\ Eq(s5 - 1, s2) Eq(s12 - 1, s7) Eq(s11, s9)""", ) elif nt_view_name.startswith("base_is_nt_True"): self.assertExpectedInline( guard_str, """Eq(s3 - 1, s0)""", ) else: self.assertExpectedInline( guard_str, """\ Eq(s4 - 1, s1) Eq(s13 - 1, s8) Eq(s12, s10)""", ) return gm torch._dynamo.reset() compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True) out = compile_fn(nt_view) @parametrize( "nt_view_name", [k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"], ) def test_inputs_to_compiled_fn_are_views(self, nt_view_name): self._input_view_test(nt_view_name) def test_subclass_gives_static_shapes_when_dynamic_false(self): def check_graph(gm, *args): first_node_example_val = next(iter(gm.graph.nodes)).meta["example_value"] # We compiled with dynamic=False, expect no SymInt sizes on our placeholders self.assertTrue( all(isinstance(x, int) for x in first_node_example_val.shape) ) return gm @torch.compile(backend=check_graph, dynamic=False) def f(x): return x + 1 x_inner = torch.ones(4) x = TwoTensor(x_inner, x_inner) x_view = x.view(2, 2) out = f(x_view) # NJT1 -> Dense -> NJT2 -> Dense view # During view replay, the Dense -> NJT2 part will construct an intermediate, # symbolically-sized NJT that is immediately deconstructed to return the final dense # view. To construct this intermediate properly, we need the associated nested int # to be symbolic. This view is expected to fail compilation until symbolic nested ints # are cached onto fake offsets to solve this problem. @unittest.expectedFailure def test_subclass_dense_subclass_dense_view(self): self._input_view_test("subclass_dense_subclass_dense") instantiate_parametrized_tests(TestNestedTensor) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()