1# Owner(s): ["module: dynamo"] 2import functools 3import itertools 4import unittest 5from functools import partial 6 7import torch 8import torch._dynamo.test_case 9import torch._dynamo.testing 10import torch._functorch.config 11import torch.utils._pytree as pytree 12import torch.utils.checkpoint 13from torch._dynamo.testing import normalize_gm 14from torch._higher_order_ops.wrap import wrap 15from torch.fx.experimental.symbolic_shapes import ( 16 DimDynamic, 17 ShapeEnv, 18 StatelessSymbolicContext, 19) 20from torch.nested._internal.nested_tensor import ( 21 jagged_from_list, 22 jagged_from_tensor_and_lengths, 23 nested_view_from_values_offsets, 24) 25from torch.testing._internal.common_utils import ( 26 instantiate_parametrized_tests, 27 NestedTensorTestCase, 28 parametrize, 29 subtest, 30) 31from torch.testing._internal.inductor_utils import HAS_CUDA 32from torch.testing._internal.two_tensor import TwoTensor 33from torch.utils._python_dispatch import return_and_correct_aliasing 34 35 36def traceable_subclass(c): 37 return torch._dynamo.config.patch("traceable_tensor_subclasses", {c}) 38 39 40def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): 41 actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2) 42 self.assertEqual(actual_recompiles, expected_recompiles) 43 44 45def get_jagged_tensor(nested_size, offsets, requires_grad=True): 46 # Makes a jagged tensor with N constituent tensors with size 47 # as specified ((S0, S1, S2), D) 48 D = nested_size[1] 49 out = [] 50 for s in nested_size[0]: 51 out.append(torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64)) 52 return jagged_from_list(out, offsets) 53 54 55def get_view_test_cases(): 56 # Test all cases with both an NT base and a dense base 57 # Subclass -> Subclass 58 # Dense -> Subclass 59 60 # NB: Don't close over loop variables, they will not get copied into the 61 # closure 62 # 63 # NB: These return functions so we don't generate tensors during test 64 # collection time 65 66 def mk_basic(base_is_nt): 67 # There are three cases to consider here based on the logic in 68 # meta_utils.py 69 # 70 # (1) basic case: 71 # view is not a leaf and has the same requires grad as its basic case 72 x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True) 73 x = x.clone() if base_is_nt else x 74 assert not x.is_leaf 75 return x.unsqueeze(-1) 76 77 def mk_leaf(base_is_nt, requires_grad_1, requires_grad_2): 78 x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=requires_grad_1) 79 x = x.clone() if base_is_nt else x 80 with torch.no_grad(): 81 x_view = x.unsqueeze(-1) 82 # The issue is this doesn't quite work 83 x_view.requires_grad_(requires_grad_2) 84 85 return x_view 86 87 def mk_obscure(base_is_nt): 88 x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False) 89 x = x.clone() if base_is_nt else x 90 # intermediate leaf view 91 with torch.no_grad(): 92 x_view = x.unsqueeze(-1) 93 x_view.requires_grad_(True) 94 x_view_view = x_view.unsqueeze(-1) 95 return x_view_view 96 97 for base_is_nt in [False, True]: 98 prefix = f"base_is_nt_{base_is_nt}" 99 100 yield partial(mk_basic, base_is_nt), f"{prefix}_basic" 101 102 # (2) leaf view case: 103 # the view has to be a leaf (w/ requires_grad True or requires_grad False) 104 # base w/ requires_grad True or requires_grad False 105 for requires_grad_1, requires_grad_2 in itertools.product( 106 [True, False], repeat=2 107 ): 108 yield partial( 109 mk_leaf, base_is_nt, requires_grad_1, requires_grad_2 110 ), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}" 111 112 # (3) obscure case: 113 # view is not a leaf (implies requires_grad True) 114 # base w/ requires_grad False) 115 yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure" 116 117 # Subclass -> Dense 118 yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[ 119 0 120 ].clone(), "subclass_dense" 121 122 # Dense -> Subclass -> Dense -> Subclass 123 def mk_dense_subclass_dense_subclass(): 124 values = torch.randn(10, 5) 125 offsets = torch.tensor([0, 3, 6, 10]) 126 offsets2 = offsets.clone().detach() 127 return nested_view_from_values_offsets( 128 nested_view_from_values_offsets(values, offsets).values(), offsets 129 ) 130 131 yield mk_dense_subclass_dense_subclass, "dense_subclass_dense_subclass" 132 133 def mk_subclass_dense_subclass_dense(): 134 x = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone() 135 offsets2 = x.offsets().clone().detach() 136 nt_view = nested_view_from_values_offsets(x.values(), offsets2).values() 137 138 yield mk_subclass_dense_subclass_dense, "subclass_dense_subclass_dense" 139 140 141VIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()} 142 143 144requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") 145 146compile_full_eager = torch.compile(backend="eager", fullgraph=True) 147 148 149class BaseTorchFunction(torch.Tensor): 150 @classmethod 151 def __torch_function__(cls, func, types, args=(), kwargs=None): 152 if kwargs is None: 153 kwargs = {} 154 return super().__torch_function__(func, types, args, kwargs) 155 156 157class MockSubclass(torch.Tensor): 158 @classmethod 159 def __torch_function__(cls, func, types, args=(), kwargs=None): 160 if kwargs is None: 161 kwargs = {} 162 return func(*args, **kwargs) 163 164 165class AttrSubclass(torch.Tensor): 166 x: int = 10 167 size: int = 10 168 169 @classmethod 170 def __torch_function__(cls, func, types, args=(), kwargs=None): 171 if kwargs is None: 172 kwargs = {} 173 174 return func(*args, **kwargs) 175 176 177class DummyNDim(torch.Tensor): 178 @classmethod 179 def __torch_function__(cls, func, types, args=(), kwargs=None): 180 if kwargs is None: 181 kwargs = {} 182 183 if func == torch.Tensor.ndim.__get__: 184 return 10 185 186 return super().__torch_function__(func, types, args, kwargs) 187 188 189class WrapperSubclass: 190 def __init__(self, tensor): 191 self.tensor = tensor 192 193 @classmethod 194 def __torch_function__(cls, func, types, args=(), kwargs=None): 195 if kwargs is None: 196 kwargs = {} 197 198 args = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, args) 199 kwargs = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, kwargs) 200 201 return func(*args, **kwargs) 202 203 204class SigmoidToExpSubclass(torch.Tensor): 205 @classmethod 206 def __torch_function__(cls, func, types, args=(), kwargs=None): 207 if kwargs is None: 208 kwargs = {} 209 210 if func == torch.Tensor.sigmoid: 211 return super().__torch_function__(torch.Tensor.exp, types, args, kwargs) 212 213 return super().__torch_function__(func, types, args, kwargs) 214 215 216# Wrapper subclass with two inner tensors: data and scale 217# data has same shape as outer, and scale has single dim size 218class ScaledTensor(torch.Tensor): 219 def __new__( 220 cls, 221 data: torch.Tensor, 222 scale: torch.Tensor, 223 *, 224 constant: int = 0, 225 ): 226 return torch.Tensor._make_wrapper_subclass( 227 cls, 228 data.size(), 229 strides=data.stride(), 230 storage_offset=data.storage_offset(), 231 dtype=data.dtype, 232 layout=data.layout, 233 requires_grad=data.requires_grad, 234 device=data.device, 235 ) 236 237 def __init__(self, data: torch.Tensor, scale: torch.Tensor, constant: int = 0): 238 self._data = data 239 self._scale = scale 240 self._constant = constant 241 242 def __tensor_flatten__(self): 243 ctx = {"_constant": self._constant} 244 return ["_data", "_scale"], ctx 245 246 @staticmethod 247 def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): 248 assert len(inner_tensors) == 2 249 return ScaledTensor( 250 inner_tensors["_data"], 251 inner_tensors["_scale"], 252 constant=metadata["_constant"], 253 ) 254 255 @classmethod 256 def __torch_dispatch__(cls, func, types, args, kwargs=None): 257 scaled_tensor = args[0] 258 out = func(scaled_tensor._data, *args[1:], **kwargs) 259 return ScaledTensor(out, scaled_tensor._scale, constant=scaled_tensor._constant) 260 261 def __repr__(self): 262 return f"{self._data.__repr__()}\n{self._scale.__repr__()}" 263 264 265class OptionalScaledTensor(torch.Tensor): 266 def __new__( 267 cls, 268 data, 269 scale, 270 *, 271 constant: int = 0, 272 ): 273 return torch.Tensor._make_wrapper_subclass( 274 cls, 275 data.size(), 276 strides=data.stride(), 277 storage_offset=data.storage_offset(), 278 dtype=data.dtype, 279 layout=data.layout, 280 requires_grad=data.requires_grad, 281 device=data.device, 282 ) 283 284 def __init__(self, data: torch.Tensor, scale, constant: int = 0): 285 self._data = data 286 self._scale = scale 287 self._constant = constant 288 289 def __tensor_flatten__(self): 290 ctx = {"_constant": self._constant} 291 if self._scale is not None: 292 return ["_data", "_scale"], ctx 293 else: 294 return ["_data"], ctx 295 296 @staticmethod 297 def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): 298 return OptionalScaledTensor( 299 inner_tensors["_data"], 300 inner_tensors["_scale"] if "_scale" in inner_tensors else None, 301 constant=metadata["_constant"], 302 ) 303 304 @classmethod 305 def __torch_dispatch__(cls, func, types, args, kwargs=None): 306 scaled_tensor = args[0] 307 out = func(scaled_tensor._data, *args[1:], **kwargs) 308 if scaled_tensor._scale is not None: 309 out = out * scaled_tensor._scale 310 return OptionalScaledTensor( 311 out, scaled_tensor._scale, constant=scaled_tensor._constant 312 ) 313 314 def __repr__(self): 315 return ( 316 f"OptionalScaledTensor({self._data.__repr__()}\n{self._scale.__repr__()})" 317 ) 318 319 320class CtxSubclassTensor(torch.Tensor): 321 """ 322 Class used to verify guarding on the subclass metadata 323 """ 324 325 @staticmethod 326 def __new__(cls, a, constant): 327 shape = a.shape 328 kwargs = {} 329 kwargs["strides"] = a.stride() 330 kwargs["storage_offset"] = a.storage_offset() 331 kwargs["device"] = a.device 332 kwargs["layout"] = a.layout 333 kwargs["requires_grad"] = a.requires_grad 334 kwargs["dtype"] = a.dtype 335 out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) 336 return out 337 338 def __init__(self, a, constant): 339 self.a = a 340 self.constant = constant 341 342 def __repr__(self): 343 a_repr = repr(self.a) 344 return f"CtxSubclassTensor({a_repr})" 345 346 def __tensor_flatten__(self): 347 return ["a"], (self.constant,) 348 349 @staticmethod 350 def __tensor_unflatten__(inner_tensors, meta, sizes, strides): 351 constant = meta[0] 352 a = inner_tensors["a"] 353 return CtxSubclassTensor(a, constant) 354 355 @classmethod 356 def __torch_dispatch__(cls, func, types, args, kwargs): 357 from torch.utils._python_dispatch import return_and_correct_aliasing 358 359 if kwargs is None: 360 kwargs = {} 361 biggest_constant = max( 362 [ 363 x.constant 364 for x in pytree.tree_flatten(args)[0] 365 if isinstance(x, CtxSubclassTensor) 366 ] 367 ) 368 args_a = pytree.tree_map( 369 lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, args 370 ) 371 kwargs_a = pytree.tree_map( 372 lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, kwargs 373 ) 374 out_a = func(*args_a, **kwargs_a) 375 out = pytree.tree_map( 376 lambda x: CtxSubclassTensor(x, biggest_constant) 377 if isinstance(x, torch.Tensor) 378 else x, 379 out_a, 380 ) 381 382 if func == torch.ops.aten.mul.Tensor: 383 out = out + out.constant 384 385 return return_and_correct_aliasing(func, args, kwargs, out) 386 387 388def func(a): 389 return a.sin() 390 391 392class EagerRecordGraphAndInputs: 393 def __init__(self) -> None: 394 self.graphs = [] 395 self.example_inputs = [] 396 397 def __call__(self, gm: torch.fx.GraphModule, example_inputs): 398 self.graphs.append(gm) 399 self.example_inputs.append(example_inputs) 400 return gm 401 402 403GLOBAL_TEST_SUBCLASSES = { 404 MockSubclass, 405 DummyNDim, 406 SigmoidToExpSubclass, 407 BaseTorchFunction, 408} 409 410 411# Returns True if the function recompiles between inputs1 and inputs2 with the 412# specified dynamic setting. 413def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True): 414 compile_count = [0] 415 416 def counter(gm, example_inputs): 417 compile_count[0] += 1 418 return gm 419 420 compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic) 421 compiled_f(*inputs1) 422 compiled_f(*inputs2) 423 return compile_count[0] > 1 424 425 426class SubclassTests(torch._dynamo.test_case.TestCase): 427 @classmethod 428 def setUpClass(cls): 429 super().setUpClass() 430 cls._exit_stack.enter_context( 431 torch._dynamo.config.patch( 432 "traceable_tensor_subclasses", GLOBAL_TEST_SUBCLASSES 433 ) 434 ) 435 436 @classmethod 437 def tearDownClass(cls): 438 cls._exit_stack.close() 439 440 def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): 441 _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles) 442 443 def test_no_call_to_new(self): 444 class BadNewTorchFunction(torch.Tensor): 445 def __new__(cls, *args, **kwargs): 446 raise RuntimeError("Oops!") 447 448 @classmethod 449 def __torch_function__(cls, func, types, args=(), kwargs=None): 450 if kwargs is None: 451 kwargs = {} 452 return super().__torch_function__(func, types, args, kwargs) 453 454 with torch._dynamo.config.patch( 455 "traceable_tensor_subclasses", {BadNewTorchFunction} 456 ): 457 458 @torch.compile(backend="eager", fullgraph=True) 459 def fn(x): 460 return torch.add(x, 1) 461 462 input = torch.ones(2, 2).as_subclass(BadNewTorchFunction) 463 464 res = fn(input) 465 self.assertIsInstance(res, BadNewTorchFunction) 466 467 def test_no_torch_function_recompiles(self): 468 class NJT: 469 def __repr__(self): 470 return f"NJT(shape={self.shape})" 471 472 def __init__(self, values, offsets): 473 self._values = values 474 self._offsets = offsets 475 476 def sin(self): 477 return torch.sin(self) 478 479 @classmethod 480 def __torch_function__(cls, func, types, args=(), kwargs=None): 481 if kwargs is None: 482 kwargs = {} 483 if func == torch.sin: 484 self = args[0] 485 return NJT(func(self._values), self._offsets) 486 raise AssertionError("should not get here") 487 488 values1 = torch.randn(10, 3, 4, requires_grad=True) 489 values2 = torch.randn(10, 3, 4, requires_grad=True) 490 offsets = torch.tensor([0, 3, 10]) 491 njt1 = NJT(values1, offsets) 492 njt2 = NJT(values2, offsets) 493 494 @torch.compile(backend="eager", fullgraph=True) 495 def f(x): 496 return torch.sin(x) 497 498 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 499 f(njt1) 500 f(njt2) 501 502 def test_base_torch_function_tracing(self): 503 def fn(x): 504 return torch.add(x, 1) 505 506 input = torch.ones(2, 2).as_subclass(BaseTorchFunction) 507 out = fn(input) 508 out_opt = compile_full_eager(fn)(input) 509 self.assertIsInstance(out, BaseTorchFunction) 510 self.assertEqual(out, out_opt) 511 512 def test_torch_function_state_graph_break(self): 513 @torch.compile(backend="eager") 514 def fn(x): 515 with torch._C.DisableTorchFunctionSubclass(): 516 torch._dynamo.graph_break() 517 return torch._C._is_torch_function_enabled(), torch.add(x, 1.0) 518 519 input = torch.ones(2, 2) 520 res, _ = fn(input) 521 self.assertFalse(res) 522 523 def test_torch_function_state_nested(self): 524 @torch.compile(backend="eager") 525 def fn(x): 526 with torch._C.DisableTorchFunctionSubclass(): 527 with torch._C.DisableTorchFunctionSubclass(): 528 x = x + 1 529 # Should reset to the outer state (disabled) after exiting ctx manager 530 return torch._C._is_torch_function_enabled(), torch.add(x, 1.0) 531 532 input = torch.ones(2, 2) 533 res, _ = fn(input) 534 self.assertFalse(res) 535 536 def test_torch_function_state_tracing(self): 537 @torch.compile(backend="eager", fullgraph=True) 538 def fn(x): 539 with torch._C.DisableTorchFunctionSubclass(): 540 torch.add(x, 1.0) 541 542 input = torch.ones(2, 2) 543 544 res = fn(input) 545 546 def test_torch_function_state_guards(self): 547 cnt = torch._dynamo.testing.CompileCounter() 548 549 @torch.compile(backend=cnt, fullgraph=True) 550 def fn(x): 551 torch.add(x, 1.0) 552 553 input = torch.ones(2, 2) 554 555 with torch._C.DisableTorchFunctionSubclass(): 556 res = fn(input) 557 558 res = fn(input) 559 560 self.assertEqual(cnt.frame_count, 2) 561 562 def test_return_subclass(self): 563 @torch.compile(backend="eager", fullgraph=True) 564 def fn(x): 565 return MockSubclass(torch.add(x, 1.0)) 566 567 input = torch.ones(2, 2) 568 569 res = fn(input) 570 self.assertIsInstance(res, MockSubclass) 571 572 def test_return_as_subclass(self): 573 @torch.compile(backend="eager", fullgraph=True) 574 def fn(x): 575 return torch.add(x, 1.0).as_subclass(MockSubclass) 576 577 input = torch.ones(2, 2) 578 579 res = fn(input) 580 self.assertIsInstance(res, MockSubclass) 581 582 def test_return_local_subclass(self): 583 class LocalSubclass(torch.Tensor): 584 @classmethod 585 def __torch_function__(cls, func, types, args=(), kwargs=None): 586 if kwargs is None: 587 kwargs = {} 588 return func(*args, **kwargs) 589 590 with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): 591 592 @torch.compile(backend="eager", fullgraph=True) 593 def fn(x): 594 return LocalSubclass(torch.add(x, 1.0)) 595 596 input = torch.ones(2, 2) 597 598 res = fn(input) 599 self.assertIsInstance(res, LocalSubclass) 600 601 def test_torch_function_list_args(self): 602 HANDLED_FUNCTIONS = {} 603 604 class MyClass: 605 def __init__(self, foo): 606 self.foo = foo 607 608 @classmethod 609 def __torch_function__( 610 cls, 611 func, 612 types, 613 args=(), 614 kwargs=None, 615 ): 616 if kwargs is None: 617 kwargs = {} 618 if func not in HANDLED_FUNCTIONS or not all( # noqa: C419 619 [ # noqa: C419 620 issubclass(t, (torch.Tensor, MyClass)) for t in types 621 ] 622 ): 623 return NotImplemented 624 return HANDLED_FUNCTIONS[func](*args, **kwargs) 625 626 def _stack(input, dim=0, *, out=None): 627 return MyClass(sum([x.foo for x in input])) 628 629 HANDLED_FUNCTIONS[torch.stack] = _stack 630 631 @torch.compile(backend="eager", fullgraph=True) 632 def fn(v0, v1): 633 return torch.stack([v0, v1]) 634 635 ret = fn(MyClass(1), MyClass(1)) 636 self.assertEqual(ret.foo, 2) 637 638 @parametrize( 639 "comparison", 640 [ 641 subtest(isinstance, "isinstance"), 642 subtest(lambda instance, type_: type(instance) == type_, "equality"), 643 subtest(lambda instance, type_: type(instance) is type_, "identity"), 644 ], 645 ) 646 @parametrize( 647 "input_type", 648 [ 649 subtest(torch.Tensor, "tensor"), 650 subtest(DummyNDim, "subclass"), 651 ], 652 ) 653 def test_type_check(self, comparison, input_type): 654 with torch._dynamo.config.patch("traceable_tensor_subclasses", {DummyNDim}): 655 656 def fn(x): 657 if comparison(x, DummyNDim): 658 return torch.ones(1, 1) 659 else: 660 return torch.zeros(2, 2) 661 662 input = torch.ones(2, 2).as_subclass(input_type) 663 exp_res = fn(input) 664 act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input) 665 self.assertEqual(exp_res, act_res) 666 667 def test_torch_function_call_on_method(self): 668 x = torch.ones(2, 2) 669 y = torch.ones(2, 2) 670 z = torch.ones(2, 2) 671 wrapped = x.as_subclass(SigmoidToExpSubclass) 672 wrapped2 = y.as_subclass(SigmoidToExpSubclass) 673 674 def fn(w): 675 return w.sigmoid() 676 677 fn_opt = compile_full_eager(fn) 678 679 res_exp = fn(wrapped) 680 res_act = fn_opt(wrapped2) 681 res_exp2 = z.exp() 682 683 self.assertEqual(res_exp, res_act) 684 self.assertEqual(res_exp, res_exp2) 685 686 def test_user_overidden_method_unsupported(self): 687 class LocalSubclass(torch.Tensor): 688 @classmethod 689 def __torch_function__(cls, func, types, args=(), kwargs=None): 690 if kwargs is None: 691 kwargs = {} 692 return super().__torch_function__(func, types, args, kwargs) 693 694 def sigmoid(self): 695 return None 696 697 @torch.compile(backend="eager", fullgraph=True) 698 def fn(x): 699 x.sigmoid() 700 701 msg = ( 702 "Accessing overridden method/attribute sigmoid on a tensor" 703 " subclass with a __torch_function__ override is not supported" 704 ) 705 with torch._dynamo.config.patch( 706 "traceable_tensor_subclasses", {LocalSubclass} 707 ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): 708 x = torch.ones(2, 2).as_subclass(LocalSubclass) 709 fn(x) 710 711 def test_user_overidden_attr_unsupported(self): 712 class LocalSubclass(torch.Tensor): 713 @classmethod 714 def __torch_function__(cls, func, types, args=(), kwargs=None): 715 if kwargs is None: 716 kwargs = {} 717 return super().__torch_function__(func, types, args, kwargs) 718 719 ndim = 10 720 721 @torch.compile(backend="eager", fullgraph=True) 722 def fn(x): 723 return x.ndim 724 725 msg = ( 726 "Accessing overridden method/attribute ndim on a tensor" 727 " subclass with a __torch_function__ override is not supported" 728 ) 729 with torch._dynamo.config.patch( 730 "traceable_tensor_subclasses", {LocalSubclass} 731 ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): 732 x = torch.ones(2, 2).as_subclass(LocalSubclass) 733 fn(x) 734 735 def test_user_overidden_property_unsupported(self): 736 class LocalSubclass(torch.Tensor): 737 def __init__(self) -> None: 738 self._ndim = 10 739 740 @classmethod 741 def __torch_function__(cls, func, types, args=(), kwargs=None): 742 if kwargs is None: 743 kwargs = {} 744 return super().__torch_function__(func, types, args, kwargs) 745 746 @property 747 def ndim(self): 748 return self._ndim 749 750 @ndim.setter 751 def ndim(self, value): 752 self._ndim = value 753 754 @torch.compile(backend="eager", fullgraph=True) 755 def fn(x): 756 return x.ndim 757 758 msg = ( 759 "Accessing overridden method/attribute ndim on a tensor" 760 " subclass with a __torch_function__ override is not supported" 761 ) 762 with torch._dynamo.config.patch( 763 "traceable_tensor_subclasses", {LocalSubclass} 764 ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): 765 x = torch.ones(2, 2).as_subclass(LocalSubclass) 766 fn(x) 767 768 def test_overridden_method_guarding(self): 769 class LocalSubclass(torch.Tensor): 770 @classmethod 771 def __torch_function__(cls, func, types, args=(), kwargs=None): 772 if kwargs is None: 773 kwargs = {} 774 return super().__torch_function__(func, types, args, kwargs) 775 776 @torch.compile(backend="eager") 777 def fn(x): 778 return x.sigmoid() 779 780 with torch._dynamo.config.patch( 781 error_on_recompile=True, traceable_tensor_subclasses={LocalSubclass} 782 ): 783 x = torch.ones(2, 2).as_subclass(LocalSubclass) 784 fn(x) 785 fn(x) 786 x = torch.ones(2, 2).as_subclass(LocalSubclass) 787 fn(x) 788 789 with torch._dynamo.config.patch( 790 traceable_tensor_subclasses={LocalSubclass} 791 ), self.assertRaisesRegex( 792 TypeError, 793 "'bool' object is not callable", 794 ): 795 LocalSubclass.sigmoid = False 796 fn(x) 797 798 def test_torch_function_call_on_attr(self): 799 x = torch.ones(2, 2) 800 wrapped = x.as_subclass(DummyNDim) 801 802 def fn(w): 803 return w.ndim + torch.ones(2) 804 805 fn_opt = compile_full_eager(fn) 806 807 res_exp = fn(wrapped) 808 res_act = fn_opt(wrapped) 809 810 self.assertEqual(res_exp, res_act) 811 self.assertEqual(res_exp, torch.ones(2) + 10) 812 813 def test_torch_function_wrapper_class(self): 814 x = torch.ones(2, 2) 815 wrapped = WrapperSubclass(x) 816 817 def fn(w): 818 return torch.add(w, 1.0) 819 820 fn_opt = compile_full_eager(fn) 821 822 res_exp = fn(wrapped) 823 res_act = fn_opt(wrapped) 824 self.assertEqual(res_exp, res_act) 825 826 def test_torch_function_wrapper_class_with_kwargs(self): 827 x = torch.ones(2, 2) 828 wrapped = WrapperSubclass(x) 829 830 def fn(w): 831 return torch.add(w, 1.0, alpha=2.0) 832 833 fn_opt = compile_full_eager(fn) 834 835 res_exp = fn(wrapped) 836 res_act = fn_opt(wrapped) 837 self.assertEqual(res_exp, res_act) 838 839 def test_tensor_subclass_custom_attr(self): 840 class AttrSubclass(torch.Tensor): 841 x: int = 10 842 843 @classmethod 844 def __torch_function__(cls, func, types, args=(), kwargs=None): 845 if kwargs is None: 846 kwargs = {} 847 848 return super().__torch_function__(func, types, args, kwargs) 849 850 @torch.compile(backend="eager", fullgraph=True) 851 def fn(x): 852 return x.x + torch.ones(2, 2) 853 854 with traceable_subclass(AttrSubclass): 855 input = torch.ones(2, 2).as_subclass(AttrSubclass) 856 fn_opt = compile_full_eager(fn) 857 858 res_exp = fn(input) 859 res_act = fn_opt(input) 860 self.assertEqual(res_exp, res_act) 861 862 def test_compile_with_fake_tensor_dynamic_dim(self): 863 x = torch.randn([3, 4]) 864 865 def f(x): 866 return torch.sin(x) 867 868 def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count): 869 torch._dynamo.reset() 870 cnt = torch._dynamo.testing.CompileCounter() 871 872 opt_f = torch.compile(f, backend=cnt, fullgraph=True) 873 874 x1 = torch.rand_like(x) 875 f(x) 876 f(torch.randn([4, 3])) 877 shape_env = ShapeEnv() 878 with torch._subclasses.fake_tensor.FakeTensorMode( 879 shape_env=shape_env 880 ) as fake_mode: 881 x_fake = fake_mode.from_tensor( 882 x, 883 symbolic_context=StatelessSymbolicContext( 884 dynamic_sizes=[dim_dynamic for i in range(x.dim())] 885 ), 886 ) 887 x1_fake = fake_mode.from_tensor( 888 x1, 889 symbolic_context=StatelessSymbolicContext( 890 dynamic_sizes=[dim_dynamic for i in range(x.dim())] 891 ), 892 ) 893 opt_f(x_fake) 894 opt_f(x1_fake) 895 896 self.assertEqual(cnt.frame_count, exp_frame_count) 897 self.assertEqual(cnt.op_count, exp_op_count) 898 899 test_dynamic_dim(f, x, DimDynamic.DYNAMIC, 1, 1) 900 test_dynamic_dim(f, x, DimDynamic.DUCK, 1, 1) 901 test_dynamic_dim(f, x, DimDynamic.STATIC, 1, 1) 902 903 def test_compile_with_fake_tensor_automatic_dynamic(self): 904 def f(x): 905 return torch.sin(x) 906 907 def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count): 908 torch._dynamo.reset() 909 cnt = torch._dynamo.testing.CompileCounter() 910 opt_f = torch.compile(f, backend=cnt, fullgraph=True) 911 912 shape_env = ShapeEnv() 913 with torch._subclasses.fake_tensor.FakeTensorMode( 914 shape_env=shape_env 915 ) as fake_mode: 916 for inp in inps: 917 fake_inp = fake_mode.from_tensor( 918 inp, 919 symbolic_context=StatelessSymbolicContext( 920 [dim_dynamic for i in range(x.dim())] 921 ), 922 ) 923 opt_f(fake_inp) 924 self.assertEqual(cnt.frame_count, exp_frame_count) 925 self.assertEqual(cnt.op_count, exp_op_count) 926 927 x = torch.randn([3, 4]) 928 y = torch.randn([4, 5]) 929 z = torch.randn([5, 6]) 930 a = torch.randn([3, 5]) 931 b = torch.randn([4, 4]) 932 # When inputs' DimDynamic is DYNAMIC or DUCK, the inputs 933 # to opt_f will be tensors with SymInt sizes. Dynamo will treat input 934 # as dynamic automatically and will only compile once 935 for dim_dynamic in [DimDynamic.DYNAMIC, DimDynamic.DUCK]: 936 test_automatic_dynamic(f, [x, y, z], dim_dynamic, 1, 1) 937 test_automatic_dynamic(f, [x, a, z], dim_dynamic, 1, 1) 938 test_automatic_dynamic(f, [x, b, z], dim_dynamic, 1, 1) 939 940 for dim_dynamic in [DimDynamic.STATIC]: 941 # Recompile once, first with dim 0 and 1 become Dynamic 942 test_automatic_dynamic(f, [x, y, z], dim_dynamic, 2, 2) 943 # Recompile 2 times, first with dim 1 become Dynamic, second with dim 0 becomes Dynamic. 944 test_automatic_dynamic(f, [x, a, z], dim_dynamic, 3, 3) 945 # Recompile 2 times, first with dim 0 become Dynamic, second with dim 1 becomes Dynamic. 946 test_automatic_dynamic(f, [x, b, z], dim_dynamic, 3, 3) 947 948 def test_compile_with_functionalization(self): 949 x = torch.randn([3, 4]) 950 x_clone = x.clone() 951 x_clone2 = x.clone() 952 backend = EagerRecordGraphAndInputs() 953 cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) 954 955 @torch.compile(backend=cnt, fullgraph=True) 956 def f(x): 957 return x.add_(1.0) + torch.nn.functional.relu_(x) 958 959 f_out = f(x) 960 self.assertEqual(cnt.frame_count, 1) 961 self.assertEqual(cnt.op_count, 3) 962 self.assertEqual(len(backend.graphs), 1) 963 self.assertEqual(len(backend.example_inputs), 1) 964 965 actual = normalize_gm(backend.graphs[0].print_readable(print_output=False)) 966 self.assertExpectedInline( 967 actual, 968 """\ 969class GraphModule(torch.nn.Module): 970 def forward(self, L_x_: "f32[3, 4]"): 971 l_x_ = L_x_ 972 973 add_: "f32[3, 4]" = l_x_.add_(1.0) 974 relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None 975 add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None 976 return (add,) 977""", 978 ) 979 980 ff = torch.func.functionalize(f) 981 ff_out = ff(x_clone) 982 983 self.assertEqual(cnt.frame_count, 2) 984 self.assertEqual(cnt.op_count, 6) 985 self.assertEqual(len(backend.graphs), 2) 986 self.assertEqual(len(backend.example_inputs), 2) 987 actual = normalize_gm(backend.graphs[1].print_readable(print_output=False)) 988 self.assertExpectedInline( 989 actual, 990 """\ 991class GraphModule(torch.nn.Module): 992 def forward(self, L_x_: "f32[3, 4]"): 993 l_x_ = L_x_ 994 995 add_: "f32[3, 4]" = l_x_.add_(1.0) 996 relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None 997 add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None 998 return (add,) 999""", 1000 ) 1001 self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) 1002 1003 # Cannot re-use the version from AOTAutograd, since that uses python functional tensors. 1004 def to_fun(x): 1005 x_functional = torch._to_functional_tensor(x) 1006 torch._mirror_autograd_meta_to(x, x_functional) 1007 return x_functional 1008 1009 def aot_f_wrapper(func): 1010 @functools.wraps(func) 1011 def wrapper(*args, **kwargs): 1012 torch._enable_functionalization(reapply_views=False) 1013 try: 1014 func_args = pytree.tree_map(to_fun, args) 1015 func_kwargs = pytree.tree_map(to_fun, kwargs) 1016 return func(*func_args, **func_kwargs) 1017 finally: 1018 torch._disable_functionalization() 1019 1020 return wrapper 1021 1022 aot_ff = aot_f_wrapper(f) 1023 aot_ff_out = aot_ff(x_clone2) 1024 1025 self.assertEqual(cnt.frame_count, 3) 1026 self.assertEqual(cnt.op_count, 9) 1027 self.assertEqual(len(backend.graphs), 3) 1028 self.assertEqual(len(backend.example_inputs), 3) 1029 actual = normalize_gm(backend.graphs[2].print_readable(print_output=False)) 1030 self.assertExpectedInline( 1031 actual, 1032 """\ 1033class GraphModule(torch.nn.Module): 1034 def forward(self, L_x_: "f32[3, 4]"): 1035 l_x_ = L_x_ 1036 1037 add_: "f32[3, 4]" = l_x_.add_(1.0) 1038 relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None 1039 add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None 1040 return (add,) 1041""", 1042 ) 1043 self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) 1044 1045 self.assertEqual(f_out, ff_out) 1046 self.assertEqual(f_out, aot_ff_out) 1047 1048 try: 1049 torch._enable_functionalization(reapply_views=False) 1050 xf = pytree.tree_map(to_fun, x) 1051 x_view = xf.t() 1052 with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"): 1053 f(x_view) 1054 finally: 1055 torch._disable_functionalization() 1056 1057 def test_compile_higher_order_with_functionalization(self): 1058 backend = EagerRecordGraphAndInputs() 1059 cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) 1060 1061 @torch.compile(backend=cnt, fullgraph=True) 1062 def f(x): 1063 return wrap(lambda x: x.add_(1.0), x) 1064 1065 def check_count_and_graph( 1066 exp_frame_count, exp_op_count, exp_n_graph, exp_graph 1067 ): 1068 self.assertEqual(cnt.frame_count, exp_frame_count) 1069 self.assertEqual(cnt.op_count, exp_op_count) 1070 self.assertEqual(len(backend.graphs), exp_n_graph) 1071 actual = normalize_gm( 1072 backend.graphs[exp_n_graph - 1].print_readable(print_output=False) 1073 ) 1074 self.assertExpectedInline(actual, exp_graph, skip=1) 1075 1076 t = torch.randn([3, 4]) 1077 t_clone = t.clone() 1078 t_clone2 = t.clone() 1079 f(t) 1080 1081 check_count_and_graph( 1082 1, 1083 2, 1084 1, 1085 """\ 1086class GraphModule(torch.nn.Module): 1087 def forward(self, L_x_: "f32[3, 4]"): 1088 l_x_ = L_x_ 1089 1090 wrap_body_0 = self.wrap_body_0 1091 wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None 1092 getitem: "f32[3, 4]" = wrap[0]; wrap = None 1093 return (getitem,) 1094 1095 class wrap_body_0(torch.nn.Module): 1096 def forward(self, l_x_: "f32[3, 4]"): 1097 add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None 1098 return (add_,) 1099""", 1100 ) 1101 1102 ff = torch.func.functionalize(f) 1103 ff_out = ff(t_clone) 1104 # frame count and op count are incremented due to re-compilation 1105 check_count_and_graph( 1106 2, 1107 4, 1108 2, 1109 """\ 1110class GraphModule(torch.nn.Module): 1111 def forward(self, L_x_: "f32[3, 4]"): 1112 l_x_ = L_x_ 1113 1114 wrap_body_0 = self.wrap_body_0 1115 wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None 1116 getitem: "f32[3, 4]" = wrap[0]; wrap = None 1117 return (getitem,) 1118 1119 class wrap_body_0(torch.nn.Module): 1120 def forward(self, l_x_: "f32[3, 4]"): 1121 add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None 1122 return (add_,) 1123""", 1124 ) 1125 1126 try: 1127 x = torch._to_functional_tensor(t_clone2) 1128 torch._mirror_autograd_meta_to(t_clone2, x) 1129 torch._enable_functionalization(reapply_views=False) 1130 aot_f_out = f(x) 1131 finally: 1132 torch._disable_functionalization() 1133 1134 # frame count and op count are incremented due to re-compilation 1135 check_count_and_graph( 1136 3, 1137 6, 1138 3, 1139 """\ 1140class GraphModule(torch.nn.Module): 1141 def forward(self, L_x_: "f32[3, 4]"): 1142 l_x_ = L_x_ 1143 1144 wrap_body_0 = self.wrap_body_0 1145 wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None 1146 getitem: "f32[3, 4]" = wrap[0]; wrap = None 1147 return (getitem,) 1148 1149 class wrap_body_0(torch.nn.Module): 1150 def forward(self, l_x_: "f32[3, 4]"): 1151 add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None 1152 return (add_,) 1153""", 1154 ) 1155 1156 def test_has_torch_function(self): 1157 class MyTensor: 1158 @classmethod 1159 def __torch_function__(cls, func, types, args=(), kwargs=None): 1160 if kwargs is None: 1161 kwargs = {} 1162 1163 if func is torch.max: 1164 return torch.tensor(123) 1165 return func(*args, **kwargs) 1166 1167 class LocalSubclass(torch.Tensor): 1168 @classmethod 1169 def __torch_function__(cls, func, types, args=(), kwargs=None): 1170 if kwargs is None: 1171 kwargs = {} 1172 return func(*args, **kwargs) 1173 1174 def fn(x): 1175 return torch.overrides.has_torch_function_unary( 1176 x 1177 ), torch.overrides.has_torch_function_variadic(x) 1178 1179 for test_class in [MyTensor, LocalSubclass]: 1180 x = test_class() 1181 ref0 = fn(x) 1182 ref1 = fn(4) 1183 opt_fn = torch._dynamo.optimize("eager")(fn) 1184 res0 = opt_fn(x) 1185 res1 = opt_fn(4) 1186 self.assertEqual(ref0, res0) 1187 self.assertEqual(ref1, res1) 1188 1189 def test_wrapper_subclass_guards_on_inner_tensor(self): 1190 # Holds an inner tensor, that has a distinct shape from the outer wrapper tensor. 1191 # Also adds additional guards on the inner tensor's sizes. 1192 # When the first input to an op has x.shape[0] > 5, we insert an extra add node. 1193 class DoubleSizeMaybeAddGeThreeTensor(torch.Tensor): 1194 @staticmethod 1195 def __new__(cls, inner): 1196 # Double the outer-most dimension 1197 outer_shape = (inner.shape[0] * 2,) + inner.shape[1:] 1198 return torch.Tensor._make_wrapper_subclass( 1199 # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. 1200 # Calling the overload that has kwargs causes us to go down the first overload path, 1201 # which will **always** specialize sizes. 1202 # We should probably eventually fix this so that the first overload can just handle dynamic shapes. 1203 cls, 1204 outer_shape, 1205 inner.stride(), 1206 None, 1207 None, 1208 inner.dtype, 1209 inner.layout, 1210 inner.device, 1211 False, 1212 inner.requires_grad, 1213 ) 1214 1215 def __init__(self, inner): 1216 self.inner_elem = inner 1217 1218 def __tensor_flatten__(self): 1219 return ["inner_elem"], None 1220 1221 @staticmethod 1222 def __tensor_unflatten__(inner_tensors, _, outer_size, outer_stride): 1223 return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"]) 1224 1225 def __repr__(self): 1226 return f"DoubleSizeMayberAddGeThreeTensor({repr(self.inner_elem)})" 1227 1228 @classmethod 1229 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1230 if kwargs is None: 1231 kwargs = {} 1232 1233 args_inner = torch.utils._pytree.tree_map_only( 1234 DoubleSizeMaybeAddGeThreeTensor, lambda x: x.inner_elem, args 1235 ) 1236 out_inner = func(*args_inner, **kwargs) 1237 1238 # Add guards on the inner tensor's sizes 1239 if args_inner[0].shape[0] > 3: 1240 out_inner += 2 1241 1242 return DoubleSizeMaybeAddGeThreeTensor(out_inner) 1243 1244 curr_var_to_val = None 1245 curr_var_to_sources = None 1246 guards = None 1247 1248 def backend(gm, args): 1249 context = torch._guards.TracingContext.get() 1250 1251 # Grab info on sources and guards from the shapeenv 1252 nonlocal curr_var_to_val 1253 nonlocal curr_var_to_sources 1254 nonlocal guards 1255 1256 guards = [str(g.expr) for g in context.fake_mode.shape_env.guards] 1257 curr_var_to_val = { 1258 str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items() 1259 } 1260 curr_var_to_sources = { 1261 str(k): v[0].name() 1262 for k, v in context.fake_mode.shape_env.var_to_sources.items() 1263 } 1264 return gm 1265 1266 @torch.compile(backend=backend) 1267 def fn(x): 1268 if x.shape[0] < 13: 1269 return torch.mul(x, x) 1270 else: 1271 return torch.div(x, x) 1272 1273 inp = torch.ones(4, 4) 1274 1275 x = DoubleSizeMaybeAddGeThreeTensor(inp) 1276 torch._dynamo.mark_dynamic(x, 0) 1277 res = fn(x) 1278 # During fakeifying, we end up allocating a separate symint 1279 # for the outer and inner tensor (in this test, s0 is unused). 1280 expected_var_to_val = { 1281 "s0": 8, 1282 "s1": 4, 1283 } 1284 expected_var_to_sources = { 1285 "s0": "L['x'].size()[0]", 1286 "s1": "L['x'].inner_elem.size()[0]", 1287 } 1288 self.assertEqual(curr_var_to_val, expected_var_to_val) 1289 self.assertEqual(curr_var_to_sources, expected_var_to_sources) 1290 self.assertExpectedInline( 1291 "\n".join(guards), 1292 """\ 1293Eq(2*s1, s0) 12942*s1 < 13 1295s1 > 3""", 1296 ) 1297 1298 def test_wrapper_subclass_with_same_sized_inner_tensor(self): 1299 # shouldn't recompile for different sizes when dynamic=True 1300 sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) 1301 sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7)) 1302 self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=True)) 1303 1304 # should recompile for different data size when dynamic=False 1305 sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) 1306 sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6)) 1307 self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) 1308 1309 # avoid recompile using manual mark_dynamic() for different data size 1310 sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) 1311 # NB: mark_dynamic() on outer tensor should translate to inner tensors of the same size 1312 torch._dynamo.mark_dynamic(sub1, 0) 1313 torch._dynamo.mark_dynamic(sub1, 1) 1314 sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6)) 1315 self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) 1316 1317 def test_wrapper_subclass_with_differently_sized_inner_tensor(self): 1318 # should recompile for different scale size when dynamic=False 1319 sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3)) 1320 sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5)) 1321 self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) 1322 1323 # still recompiles using manual mark_dynamic() on outer for different scale size 1324 sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3)) 1325 # NB: mark_dynamic() on outer tensor doesn't translate to inner tensors of different size 1326 torch._dynamo.mark_dynamic(sub1, 0) 1327 torch._dynamo.mark_dynamic(sub1, 1) 1328 sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5)) 1329 self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) 1330 1331 def test_recompiles_with_optional_inner_tensor(self): 1332 def f(x): 1333 return x + 1 1334 1335 # sub1 does not have the optional tensor specified while sub2 does 1336 sub1 = OptionalScaledTensor(torch.randn(2, 4), None) 1337 sub2 = OptionalScaledTensor(torch.randn(2, 4), torch.randn(2, 4)) 1338 1339 # sanity check; don't recompile for same input 1340 self.assertFalse(_recompiles_for_inputs(f, (sub1,), (sub1,), dynamic=True)) 1341 self.assertFalse(_recompiles_for_inputs(f, (sub2,), (sub2,), dynamic=True)) 1342 1343 # these should recompile; optional tensor changes between specified and unspecified 1344 self.assertTrue(_recompiles_for_inputs(f, (sub1,), (sub2,), dynamic=True)) 1345 self.assertTrue(_recompiles_for_inputs(f, (sub2,), (sub1,), dynamic=True)) 1346 1347 f_compiled = torch.compile(f, backend="aot_eager") 1348 self.assertEqual(f(sub1)._data, f_compiled(sub1)._data) 1349 self.assertEqual(f(sub2)._data, f_compiled(sub2)._data) 1350 1351 def test_torch_dispatch_subclass_guard_recompile(self): 1352 x = torch.ones(2, 2) 1353 x_two = TwoTensor(x.clone(), x.clone()) 1354 1355 def fn(w): 1356 return torch.add(w, 1.0) 1357 1358 fn_opt = torch.compile(backend="eager")(fn) 1359 1360 ref = fn(x_two) 1361 res = fn_opt(x_two) 1362 self.assertEqual(ref, res) 1363 1364 # ensure no recompilation on same input type 1365 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 1366 fn_opt(TwoTensor(x + 1, x + 2)) 1367 1368 # recompile! 1369 ref = fn(x) 1370 res = fn_opt(x) 1371 self.assertEqual(ref, res) 1372 1373 def test_tensor_subclass_ctx_guards(self): 1374 x = CtxSubclassTensor(torch.ones(2), 3) 1375 x2 = CtxSubclassTensor(torch.ones(2), 3) 1376 x3 = CtxSubclassTensor(torch.ones(2), 4) 1377 _check_recompiles(self, lambda x: x * x, (x,), (x2,), False) 1378 _check_recompiles(self, lambda x: x * x, (x,), (x3,), True) 1379 1380 def test_tensor_subclass_ctx_recursive_guards(self): 1381 x0 = torch.ones(2, 2) 1382 x1 = CtxSubclassTensor(x0.clone(), 2) 1383 x2 = CtxSubclassTensor(x0.clone(), 3) 1384 tt0 = TwoTensor(x0.clone(), x1) 1385 tt1 = TwoTensor(x0.clone(), x2) 1386 1387 _check_recompiles(self, lambda x: x * x, (tt0,), (tt1,), True) 1388 1389 def test_tensor_subclass_ctx_custom_guards_override(self): 1390 class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor): 1391 @classmethod 1392 def __metadata_guard__(cls, orig_data, other): 1393 return orig_data[0] <= other[0] 1394 1395 x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 2) 1396 x2 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3) 1397 x3 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 1) 1398 _check_recompiles(self, lambda x: x * x, (x,), (x2,), False) 1399 _check_recompiles(self, lambda x: x * x, (x,), (x3,), True) 1400 1401 def test_tensor_subclass_ctx_custom_guards_error_arg_num(self): 1402 import torch._dynamo.exc 1403 1404 class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor): 1405 @classmethod 1406 def __metadata_guard__(cls, y): 1407 # Shouldn't reach here 1408 return False 1409 1410 x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3) 1411 self.assertRaisesRegex( 1412 torch._dynamo.exc.InternalTorchDynamoError, 1413 "Tensor subclass method __metadata_guard__ must take exactly two subclass metadata arguments", 1414 lambda: torch.compile(lambda x: x * x)(x), 1415 ) 1416 1417 def test_tensor_subclass_ctx_custom_guards_error_not_classmethod(self): 1418 import torch._dynamo.exc 1419 1420 class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor): 1421 def __metadata_guard__(self, x, y): 1422 return False 1423 1424 x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3) 1425 self.assertRaisesRegex( 1426 torch._dynamo.exc.InternalTorchDynamoError, 1427 "Tensor subclass method __metadata_guard__ must be a classmethod", 1428 lambda: torch.compile(lambda x: x * x)(x), 1429 ) 1430 1431 def test_subclass_constructor_proxying(self): 1432 import dataclasses 1433 from collections import namedtuple 1434 from typing import Any 1435 1436 @dataclasses.dataclass(frozen=True) 1437 class SubclassTensorArgs: 1438 original_shape: torch.Size 1439 device: torch.device 1440 inner_meta: Any 1441 1442 SubclassTensorArgs2 = namedtuple( 1443 "SubclassTensorArgs2", 1444 [ 1445 "original_shape", 1446 "device", 1447 "inner_meta", 1448 ], 1449 ) 1450 1451 class SubclassTensor(torch.Tensor): 1452 @staticmethod 1453 def __new__(cls, a, meta): 1454 shape = a.shape 1455 kwargs = {} 1456 kwargs["strides"] = a.stride() 1457 kwargs["storage_offset"] = a.storage_offset() 1458 kwargs["device"] = a.device 1459 kwargs["layout"] = a.layout 1460 kwargs["requires_grad"] = a.requires_grad 1461 kwargs["dtype"] = a.dtype 1462 out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) 1463 return out 1464 1465 def __init__(self, a, meta): 1466 self.a = a 1467 self.meta = meta 1468 1469 def __repr__(self): 1470 a_repr = repr(self.a) 1471 return f"SubclassTensor({a_repr})" 1472 1473 def __tensor_flatten__(self): 1474 return ["a"], self.meta 1475 1476 @staticmethod 1477 def __tensor_unflatten__(inner_tensors, meta, _, __): 1478 a = inner_tensors["a"] 1479 return SubclassTensor(a, meta) 1480 1481 @classmethod 1482 def __torch_dispatch__(cls, func, types, args, kwargs): 1483 if kwargs is None: 1484 kwargs = {} 1485 args_a = pytree.tree_map( 1486 lambda x: x.a if isinstance(x, SubclassTensor) else x, args 1487 ) 1488 kwargs_a = pytree.tree_map( 1489 lambda x: x.a if isinstance(x, SubclassTensor) else x, kwargs 1490 ) 1491 out_a = func(*args_a, **kwargs_a) 1492 out = pytree.tree_map( 1493 lambda x: SubclassTensor( 1494 x, SubclassTensorArgs2(x.shape, x.device, None) 1495 ) 1496 if isinstance(x, torch.Tensor) 1497 else x, 1498 out_a, 1499 ) 1500 return return_and_correct_aliasing(func, args, kwargs, out) 1501 1502 @torch.compile(fullgraph=True) 1503 def f1(x): 1504 meta = SubclassTensorArgs( 1505 x.shape, x.device, SubclassTensorArgs(x.shape, x.device, None) 1506 ) 1507 out = SubclassTensor(x, meta) 1508 return out * out 1509 1510 x = torch.randn(3, 3) 1511 f1(x) 1512 1513 @torch.compile(fullgraph=True) 1514 def f1(x): 1515 meta = SubclassTensorArgs2( 1516 x.shape, x.device, SubclassTensorArgs2(x.shape, x.device, None) 1517 ) 1518 out = SubclassTensor(x, meta) 1519 return out * out 1520 1521 x = torch.randn(3, 3) 1522 f1(x) 1523 1524 def test_torch_function_subclass_survives_into_aot_autograd(self): 1525 # If you have a tensor subclass that relies on dispatch into the same op 1526 # without unwrapping and calling torch._C.DisableTorchFunctionSubclass(), 1527 # the torch function-ness will survive into AOTAutograd. Today, NestedTensor 1528 # actually relies on this behavior! Because that torch function logic 1529 # runs during AOTAutograd, this test tests that there is no logic below 1530 # that relies torch function that gets unexpectedly disabled after we 1531 # redispatch from the subclass's torch function. 1532 class SubTensor(torch.Tensor): 1533 @staticmethod 1534 def __new__(cls, t): 1535 return torch.Tensor._make_wrapper_subclass( 1536 cls, 1537 t.shape, 1538 t.stride(), 1539 t.storage_offset(), 1540 torch.contiguous_format, 1541 t.dtype, 1542 torch.strided, 1543 t.device, 1544 False, 1545 t.requires_grad, 1546 "sizes", 1547 False, 1548 False, 1549 None, 1550 ) 1551 1552 def __init__(self, t): 1553 super().__init__() 1554 self._t = t 1555 1556 def __tensor_flatten__(self): 1557 return ["_t"], {} 1558 1559 @staticmethod 1560 def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): 1561 t = inner_tensors["_t"] 1562 return SubTensor(t) 1563 1564 def __repr__(self): 1565 return f"SubTensor({self._t})" 1566 1567 @classmethod 1568 def __torch_function__(cls, func, types, args=(), kwargs=None): 1569 if kwargs is None: 1570 kwargs = {} 1571 1572 with torch._C.DisableTorchFunctionSubclass(): 1573 return func(*args, **kwargs) 1574 1575 @classmethod 1576 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1577 kwargs = {} if kwargs is None else kwargs 1578 new_args = pytree.tree_map_only(SubTensor, lambda s: s._t, args) 1579 output = func(*new_args, **kwargs) 1580 output = pytree.tree_map_only( 1581 torch.Tensor, lambda t: SubTensor(t), output 1582 ) 1583 return output 1584 1585 @torch.compile(dynamic=True) 1586 def f(x): 1587 return x.unflatten(-1, [2, 5]) 1588 1589 s = SubTensor(torch.randn(3, 10)) 1590 f(s) 1591 1592 # Guard validation upsets the guard 1593 # https://github.com/pytorch/pytorch/issues/129936 1594 @unittest.expectedFailure 1595 def test_recompile_with_symbool_inputs(self): 1596 def f(pred: bool): 1597 if pred: 1598 return torch.ones([3, 4]) 1599 else: 1600 return torch.ones([4, 3]) 1601 1602 def test_recompilation( 1603 f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards 1604 ): 1605 torch._dynamo.reset() 1606 shape_env = ShapeEnv() 1607 backend = torch._dynamo.testing.EagerAndRecordGraphs() 1608 cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) 1609 f_cond = torch.compile(f, backend=cnt, fullgraph=True) 1610 with torch._subclasses.fake_tensor.FakeTensorMode( 1611 shape_env=shape_env 1612 ) as fake_mode: 1613 fake_inp = fake_mode.from_tensor( 1614 x, 1615 symbolic_context=StatelessSymbolicContext( 1616 dynamic_sizes=[DimDynamic.DYNAMIC for i in range(x.dim())] 1617 ), 1618 ) 1619 for i, size in enumerate(sizes): 1620 pred = fake_inp.size(0) == size 1621 f_cond(pred) 1622 actual = normalize_gm( 1623 backend.graphs[exp_frame_count[i] - 1].print_readable( 1624 print_output=False 1625 ) 1626 ) 1627 actual_guard_str = [str(guard.expr) for guard in shape_env.guards] 1628 self.assertExpectedInline(actual, exp_graphs[i]) 1629 self.assertEqual(cnt.frame_count, exp_frame_count[i]) 1630 self.assertEqual(actual_guard_str, exp_shape_env_guards[i]) 1631 1632 true_graph = """\ 1633class GraphModule(torch.nn.Module): 1634 def forward(self): 1635 ones: "f32[3, 4]" = torch.ones([3, 4]) 1636 return (ones,) 1637""" 1638 false_graph = """\ 1639class GraphModule(torch.nn.Module): 1640 def forward(self): 1641 ones: "f32[4, 3]" = torch.ones([4, 3]) 1642 return (ones,) 1643""" 1644 test_recompilation( 1645 f, 1646 torch.randn([3, 4]), 1647 [3, 3, 4, 5], 1648 exp_graphs=[true_graph, true_graph, false_graph, false_graph], 1649 exp_frame_count=[1, 1, 2, 2], 1650 exp_shape_env_guards=[ 1651 [], 1652 # s0 is specialized and guarded in outter shape_env when dynamo checks the guards 1653 ["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"], 1654 [ 1655 "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", 1656 "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)", 1657 ], 1658 [ 1659 "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", 1660 "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)", 1661 "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", 1662 ], 1663 ], 1664 ) 1665 1666 test_recompilation( 1667 f, 1668 torch.randn([3, 4]), 1669 [4, 5, 3, 3], 1670 exp_graphs=[false_graph, false_graph, true_graph, true_graph], 1671 exp_frame_count=[1, 1, 2, 2], 1672 exp_shape_env_guards=[ 1673 [], 1674 # s0 is specialized and guarded in outter shape_env when dynamo checks the guards 1675 ["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"], 1676 [ 1677 "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", 1678 "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", 1679 ], 1680 [ 1681 "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", 1682 "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", 1683 "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", 1684 ], 1685 ], 1686 ) 1687 1688 def test_wrapper_subclass_dynamo_attribute_access_on_intermediate(self): 1689 def f(x_subclass): 1690 tmp_subclass = torch.add(x, 1) 1691 return torch.mul(tmp_subclass._scale, tmp_subclass._constant) 1692 1693 x = ScaledTensor(torch.randn(2, 4), torch.randn(3), constant=2) 1694 out_ref = f(x) 1695 out_test = torch.compile(f, backend="aot_eager", fullgraph=True)(x) 1696 self.assertEqual(out_ref, out_test) 1697 1698 def test_support_bases(self): 1699 import abc 1700 1701 import torch.fx._symbolic_trace 1702 1703 class Meta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta): 1704 def __new__(cls, name, bases, dct): 1705 x = super().__new__(cls, name, bases, dct) 1706 x.attr = 100 1707 return x 1708 1709 class Multistreamable(abc.ABC): # noqa: B024 1710 pass 1711 1712 class Foo(Multistreamable, metaclass=Meta): 1713 pass 1714 1715 @torch.compile(backend="eager", fullgraph=True) 1716 def f(x): 1717 typ = type(Foo()) 1718 typ.__bases__ 1719 return typ.__bases__ 1720 1721 self.assertEqual(f(torch.randn(1)), (Multistreamable,)) 1722 1723 @torch.compile(backend="eager", fullgraph=True) 1724 def g(x): 1725 typ = type(Foo()) 1726 typ.__base__ 1727 return typ.__base__ 1728 1729 self.assertEqual(g(torch.randn(1)), Multistreamable) 1730 1731 @parametrize("dynamic", [False, True]) 1732 def test_subclass_views(self, dynamic): 1733 def _get_views(t): # returns (view: Tensor, expects_raises_false) 1734 # Note that any closed-over SymInts will be symbolicized during fake-ification. 1735 yield t.narrow(dim=-1, start=3, length=8), False 1736 yield t.split(5, -1)[2], False 1737 yield t.split_with_sizes([9, 6], -1)[1], False 1738 yield t.unsqueeze(-1).expand(4, 15, 10), False 1739 yield t.select(-1, 6), False 1740 # https://github.com/pytorch/pytorch/issues/128649 1741 yield t[2:3, 5:9], dynamic 1742 yield t.view(-1, 15), False 1743 1744 def f(x): 1745 return x * 2 1746 1747 compiled_f = torch.compile( 1748 f, backend="aot_eager", fullgraph=True, dynamic=dynamic 1749 ) 1750 1751 # Take a view of a subclass to pass as input. 1752 t = TwoTensor(torch.randn(4, 15), torch.randn(4, 15)) 1753 for view, expects_raises in _get_views(t): 1754 torch._dynamo.reset() 1755 out_ref = f(view) 1756 if expects_raises: 1757 with self.assertRaises(AssertionError): 1758 out_test = compiled_f(view) 1759 else: 1760 out_test = compiled_f(view) 1761 self.assertEqual(out_ref, out_test) 1762 1763 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 1764 def test_mark_static_with_subclass_desugaring(self): 1765 from typing import Any, Callable, Dict, List, Optional 1766 1767 from torch._dynamo.decorators import mark_static_address 1768 from torch._inductor.compile_fx import compile_fx 1769 from torch._inductor.cudagraph_utils import BoxedDeviceIndex 1770 from torch._inductor.utils import BoxedBool 1771 1772 x_inner = torch.ones(4) 1773 x = TwoTensor(x_inner, x_inner) 1774 mark_static_address(x, guard=False) 1775 1776 def inner_compile( 1777 gm: torch.fx.GraphModule, 1778 example_inputs: List[torch.Tensor], 1779 cudagraphs: Optional[BoxedBool] = None, 1780 static_input_idxs: Optional[List[int]] = None, 1781 is_backward: bool = False, 1782 graph_id: Optional[int] = None, 1783 cpp_wrapper: bool = False, 1784 aot_mode: bool = False, 1785 is_inference: bool = False, 1786 boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, 1787 user_visible_outputs: Optional[Dict[str, None]] = None, 1788 layout_opt: Optional[bool] = None, 1789 extern_node_serializer: Optional[Callable[[List[Any]], Any]] = None, 1790 ): 1791 self.assertEqual(static_input_idxs, [1, 2]) 1792 return gm 1793 1794 compiler = functools.partial(compile_fx, inner_compile=inner_compile) 1795 1796 @torch.compile(backend=compiler) 1797 def fn(t0, t1, t2): 1798 return t0 + t1 + t2 + 2 1799 1800 fn(torch.ones(4), x, torch.ones(4)) 1801 1802 1803instantiate_parametrized_tests(SubclassTests) 1804 1805 1806class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase): 1807 def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True): 1808 return get_jagged_tensor(nested_size, offsets, requires_grad) 1809 1810 def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True): 1811 # Makes a jagged tensor with N constituent tensors with size 1812 # as specified ((S0, S1, S2), D) 1813 max_dim = (starts + lengths).max() 1814 values_tensor = torch.randn( 1815 starts.shape[0], 1816 max_dim.item(), 1817 inner_dim, 1818 requires_grad=requires_grad, 1819 dtype=torch.float64, 1820 ) 1821 return jagged_from_tensor_and_lengths(values_tensor, starts, lengths) 1822 1823 def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): 1824 _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles) 1825 1826 def test_unary_does_not_recompile(self): 1827 nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None) 1828 nt2, _ = self._get_jagged_tensor(((3, 4, 5, 6), 4), None) 1829 self._check_recompiles(lambda nt1: nt1.sin(), (nt1,), (nt2,), False) 1830 1831 def test_binary_does_not_recompile(self): 1832 def binary(nt1, nt2): 1833 if nt1.shape == nt2.shape: 1834 return nt1 + nt2 1835 else: 1836 return nt1.sin() 1837 1838 # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). 1839 # This causes a recompile later on when it realizes the batch and last dim 1840 # should not always be equal. To avoid that, we use (3, j0, 5) here. 1841 nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) 1842 nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets) 1843 nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None) 1844 nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets) 1845 self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False) 1846 1847 def test_binary_recompiles(self): 1848 def binary(nt1, nt2): 1849 if nt1.shape == nt2.shape: 1850 return nt1 + nt2 1851 else: 1852 return nt1.sin() 1853 1854 # Binary recompiles because singleton ints no longer match 1855 nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) 1856 nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets) 1857 nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) 1858 self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True) 1859 1860 def _validate_compile(self, fn, arg_fn): 1861 def _gen_grad_outputs(out_val): 1862 if isinstance(out_val, (list, tuple)): 1863 return tuple(torch.ones_like(c) for c in out_val) 1864 else: 1865 return (torch.ones_like(out_val),) 1866 1867 with self.branch_nested_state(): 1868 from torch.nested._internal.nested_tensor import _tensor_symint_registry 1869 1870 # Validate that compilation does not modify eager state 1871 registry_before = list(_tensor_symint_registry.items()) 1872 count_before = torch.nested._internal.nested_tensor._tensor_id_counter 1873 1874 guards_exported = [] 1875 guards_failed = [] 1876 1877 def append_guard_export(guards): 1878 for g in guards: 1879 if g.code_list is not None: 1880 guards_exported.append(g.code_list[0]) 1881 1882 def append_guard_fail(guards): 1883 guards_failed.extend(guards) 1884 1885 compiled = torch._dynamo.optimize( 1886 nopython=True, 1887 backend="aot_eager", 1888 guard_export_fn=append_guard_export, 1889 guard_fail_fn=append_guard_fail, 1890 )(fn) 1891 registry_after = list(_tensor_symint_registry.items()) 1892 count_after = torch.nested._internal.nested_tensor._tensor_id_counter 1893 self.assertEqual(registry_before, registry_after) 1894 self.assertEqual(count_before, count_after) 1895 1896 args = arg_fn() 1897 compile_out = compiled(*args) 1898 compile_grads = [] 1899 g_args = [arg for arg in args if arg.requires_grad] 1900 if len(g_args) > 0: 1901 compile_grad_outputs = _gen_grad_outputs(compile_out) 1902 compile_grads = torch.autograd.grad( 1903 compile_out, inputs=g_args, grad_outputs=compile_grad_outputs 1904 ) 1905 1906 with self.branch_nested_state(): 1907 args = arg_fn() 1908 ref_out = fn(*args) 1909 ref_grads = [] 1910 g_args = [arg for arg in args if arg.requires_grad] 1911 if len(g_args) > 0: 1912 ref_grad_outputs = _gen_grad_outputs(ref_out) 1913 ref_grads = torch.autograd.grad( 1914 ref_out, inputs=g_args, grad_outputs=ref_grad_outputs 1915 ) 1916 1917 # Validate correctness forward 1918 if isinstance(compile_out, (list, tuple)): 1919 # TODO: Fix assertEqual() to support NJTs so this isn't necessary 1920 self.assertEqual(len(compile_out), len(ref_out)) 1921 for c, r in zip(compile_out, ref_out): 1922 self.assertEqualIgnoringNestedInts(c, r) 1923 else: 1924 self.assertEqualIgnoringNestedInts(compile_out, ref_out) 1925 1926 # Validate correctness backward 1927 for compile_grad, ref_grad in zip(compile_grads, ref_grads): 1928 self.assertEqualIgnoringNestedInts(compile_grad, ref_grad) 1929 1930 return guards_exported, guards_failed 1931 1932 # Note: [What kind of guards are involved in nested tensor compilation] 1933 # 1934 # Until we implement UnionFind, dynamic shapes guards are not involved. 1935 # we rely only on dynamo's tensor aliasing guards. 1936 # 1937 # This is possible because dynamo able to generate tensor aliasing guards 1938 # not only for the outer tensor, but also for the inner tensor. 1939 # 1940 # The case where dynamic shapes guards would eventually come into play is 1941 # when my inputs are (1) two non-aliased tensors, but (2) declared as 1942 # equal using a "trust me assert equal" API. 1943 1944 # Note: [Compiling nested tensor global state] 1945 # 1946 # Today there are two pieces of global eager state that NJTs deals with: 1947 # - tensor_id_counter: a global counter that assigns unique ids to tensors 1948 # - tensor_symint_registry: maps tensor to nested int 1949 # - this is used in eager only (we should get rid of this because it is 1950 # not necessary to cache nested int in eager) 1951 # - during tracing, we DO need to cache nested int, but we do so on 1952 # the FakeTensor. 1953 # 1954 # Ideally we would like to satisfy the following: 1955 # - (1) The eager state is not mutated during tracing 1956 # - (2) Running the compiled function should mutate the eager state in the 1957 # same way that running the eager function would 1958 # (a) The global counter should be incremented 1959 # (b) The registry is updated in the same way 1960 # 1961 # Today we can satisfy (1) and (2a) but cannot satisfy (2b) 1962 # 1963 # Today, (1) is satisfied because we maintain a separate counter during 1964 # tracing, and cache nested int on FakeTensor instead of relying on 1965 # tensor_symint_registry. 1966 # 1967 # (2) is cannot be completely satisfied because we trace away the 1968 # side-effectful operations (which we can fix this by wrapping the 1969 # side-effectful operations in a custom op, and threading through effect 1970 # tokens.) The current plan is to do that in the UnionFind impl. 1971 # 1972 # Interestingly, despite this, the state is mutated in a way that is somewhat 1973 # close to what we want, e.g. if I construct a nested tensor using an 1974 # offsets in the compiled region and return it, AOTAutograd runtime wrapper 1975 # must rewrap the inner->inner graph outputs back into subclass. This 1976 # triggers the eager logic to run, updating the counter and registry. 1977 # 1978 # Notably however, compile differs in two ways from eager: 1979 # (1) The order in which the offsets are assigned ids is differnet 1980 # the registry would be set in the order the offsets are returned 1981 # which is not necessarily the same order as they were constructed. 1982 # (2) If a NestedTensor is not returned, then the AOTAutograd wrapping 1983 # logic will not be triggered. 1984 # 1985 # I claim that correctness is not affected by these differences today. 1986 # e.g. there is never the case where two distinct offsets silently share 1987 # the same id. 1988 # 1989 # (1) is clearly not a problem, and (2) should only be a problem if 1990 # the nested int is returned on its own, without the corresponding NJT 1991 # being returned. This is not a problem in the current implementation 1992 # because returning only a shape is not supported! 1993 1994 # Note: [Creating symbolic nested int] 1995 # 1996 # We must create a symbolic nested int when we construct a nested tensor 1997 # from a tensor. There are two main cases: 1998 # 1999 # 1. The offsets has NOT been used to construct a NJT 2000 # - Create a new plain nested int with current val of fake nt id counter 2001 # - Increment the fake nt id counter 2002 # - Create a new symint with plain nested int as hint 2003 # 2. The offsets HAS been used to construct a NJT 2004 # - Create a new symint with plain nested int as hint 2005 # 2006 # More details on case 2: 2007 # - During fakification of the offsets, we check the eager registry, and 2008 # if the tensor HAS been used to construct a NJT, 2009 # we create a symint, with the existing nested int as hint, and cache 2010 # it on to the FakeTensor. 2011 # 2012 # [ Always use ephemeral source ] 2013 # 2014 # We create the new symint ALWAYS with ephemeral source whether that is 2015 # in case (1) or (2) even though we could've had a proper source for case (2). 2016 # Using a proper source would enable a few more (edge) cases, but since 2017 # we plan to handle things more holistically in the future anyway, we don't 2018 # bother doing so today. 2019 # 2020 # Using an ephemeral source has some consequences. But we are happy if 2021 # - We do not silently miss recompiles, e.g. we guard when necessary. 2022 # We know that this is true, because dynamo guards alone are already 2023 # sufficient. 2024 # - We are not producing errors for the cases we care about 2025 # 2026 # The main case we care about is when we guard that two shapes are equal. 2027 # In this case, the replacements logic would simplify away the ephemeral 2028 # symbol, and there is no error produced. 2029 # The unsupported case is when we guard that two shapes are not equal, in 2030 # which, we will try and fail to generate a guard. 2031 2032 # 2033 # Case 1: in-graph construction where the offsets are passed as inputs 2034 # 2035 def test_in_graph_construction_from_input(self): 2036 # The offsets is passed as an input 2037 def fn(values, offsets): 2038 return torch.nested.nested_tensor_from_jagged(values * 2, offsets) * 2 2039 2040 values = torch.randn(10, 5, requires_grad=True) 2041 offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2042 self._validate_compile(fn, arg_fn=lambda: (values, offsets)) 2043 2044 # Do not specialize on the offsets 2045 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 2046 different_offsets = torch.tensor([0, 1, 5, 10], dtype=torch.int64) 2047 self._validate_compile(fn, arg_fn=lambda: (values, different_offsets)) 2048 2049 def test_in_graph_construction_from_input_2(self): 2050 # Construct two NJTs, both are passed as inputs 2051 def fn(values, offsets1, offsets2): 2052 nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets1) 2053 nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2) 2054 return nt2, nt1 2055 2056 values = torch.randn(10, 5, requires_grad=True) 2057 offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2058 offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64) 2059 # 1. Offsets are different 2060 guards_exported, guards_failed = self._validate_compile( 2061 fn, arg_fn=lambda: (values, offsets, offsets2) 2062 ) 2063 self.assertEqual(len(guards_failed), 0) 2064 self.assertNotIn("L['offsets1'] is L['offsets2']", guards_exported) 2065 2066 # TODO 2067 # 2. Offsets are the same 2068 new_guards_exported, _ = self._validate_compile( 2069 fn, arg_fn=lambda: (values, offsets, offsets) 2070 ) 2071 self.assertTrue(any("Duplicate tensors found" in g for g in guards_failed)) 2072 self.assertIn("L['offsets1'] is L['offsets2']", new_guards_exported) 2073 2074 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 2075 offsets3 = offsets.clone() 2076 self._validate_compile(fn, arg_fn=lambda: (values, offsets3, offsets3)) 2077 2078 # Do a binary op 2079 def fn(values, offsets, offsets2): 2080 nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets) 2081 nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2) 2082 return nt1 * nt2 2083 2084 self._validate_compile(fn, arg_fn=lambda: (values, offsets, offsets)) 2085 2086 def test_in_graph_construction_from_input_4(self): 2087 # The offsets is taken from an NJT input 2088 def fn(nt, other_values): 2089 nt2 = torch.nested.nested_tensor_from_jagged(other_values, nt.offsets()) 2090 return nt + nt2 2091 2092 values = torch.randn(9, 5, requires_grad=True) 2093 other_values = torch.randn(9, 5, requires_grad=True) 2094 offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64) 2095 2096 def arg_fn(values=values, other_values=other_values, offsets=offsets): 2097 nt = torch.nested.nested_tensor_from_jagged(values, offsets) 2098 return nt, other_values 2099 2100 self._validate_compile(fn, arg_fn=arg_fn) 2101 2102 # Do not specialize on the offsets 2103 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 2104 different_offsets = offsets.clone() 2105 2106 def arg_fn( 2107 values=values, other_values=other_values, offsets=different_offsets 2108 ): 2109 nt = torch.nested.nested_tensor_from_jagged(values, different_offsets) 2110 return nt, other_values 2111 2112 self._validate_compile(fn, arg_fn=arg_fn) 2113 2114 def test_in_graph_construction_from_input_5(self): 2115 # Construct from lengths instead of offsets 2116 def fn(values, lengths): 2117 nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths) 2118 return nt.sin() 2119 2120 values = torch.randn(9, 5, requires_grad=True) 2121 lengths = torch.tensor([2, 4, 3]) 2122 self._validate_compile(fn, arg_fn=lambda: (values, lengths)) 2123 2124 # 2125 # Case 2: in-graph construction where offsets are graph intermediates 2126 # 2127 def test_in_graph_construction_from_intermediate(self): 2128 # offsets is an intermediate computed from lengths 2129 def fn(values, lengths): 2130 offsets = torch.cat([lengths.new_zeros(1), lengths.cumsum(0)]) 2131 nt = torch.nested.nested_tensor_from_jagged(values, offsets) 2132 nt2 = torch.nested.nested_tensor_from_jagged(values, offsets) 2133 return (nt * nt2).sin() 2134 2135 values = torch.randn(9, 5, requires_grad=True) 2136 lengths = torch.tensor([2, 4, 3]) 2137 self._validate_compile(fn, arg_fn=lambda: (values, lengths)) 2138 2139 # Do not specialize on the lengths 2140 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 2141 different_lengths = lengths.clone() 2142 self._validate_compile(fn, arg_fn=lambda: (values, different_lengths)) 2143 2144 def test_in_graph_construction_from_intermediate_2(self): 2145 def fn(values, offsets): 2146 return torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone()) 2147 2148 values = torch.randn(10, 5, requires_grad=True) 2149 offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2150 self._validate_compile(fn, arg_fn=lambda: (values, offsets)) 2151 2152 def test_in_graph_construction_from_intermediate_3(self): 2153 # Note that due to CSE, clone is not necessarily called twice! 2154 def fn(values, offsets): 2155 nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone()) 2156 nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets.clone()) 2157 return nt2, nt1 2158 2159 values = torch.randn(10, 5, requires_grad=True) 2160 offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2161 self._validate_compile(fn, arg_fn=lambda: (values, offsets)) 2162 2163 def test_in_graph_construction_from_intermediate_4(self): 2164 # Shared intermediate (should be same as case #1) 2165 def fn(values): 2166 offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2167 nt = torch.nested.nested_tensor_from_jagged(values, offsets) 2168 values2 = torch.ones_like(values) 2169 nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets) 2170 return nt * nt2 2171 2172 values = torch.randn(10, 5).requires_grad_(True) 2173 self._validate_compile(fn, arg_fn=lambda: (values,)) 2174 2175 # AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>', 2176 @unittest.expectedFailure 2177 def test_in_graph_construction_from_intermediate_5(self): 2178 # non-shared intermediate 2179 def fn(values): 2180 offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2181 nt = torch.nested.nested_tensor_from_jagged(values, offsets) 2182 values2 = torch.ones_like(values) 2183 nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets.clone()) 2184 if nt2.shape[1] != nt.shape[1]: 2185 return nt * 2 2186 else: 2187 return nt * 3 2188 2189 values = torch.randn(10, 5).requires_grad_(True) 2190 self._validate_compile(fn, arg_fn=lambda: (values,)) 2191 2192 # 2193 # Case 3: in-graph construction where offsets are both direct graph inputs 2194 # and passed in as part of an NJT's offsets. 2195 # 2196 def test_in_graph_construction_mixed(self): 2197 def fn(nt, values, offsets): 2198 nt2 = torch.nested.nested_tensor_from_jagged(values, offsets) 2199 return nt * nt2 2200 2201 values = torch.randn(10, 5, requires_grad=True) 2202 offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2203 2204 def arg_fn(values=values, offsets=offsets): 2205 nt = torch.nested.nested_tensor_from_jagged(values, offsets) 2206 return nt, values, offsets 2207 2208 self._validate_compile(fn, arg_fn) 2209 2210 # See Note: [Creating symbolic nested int] 2211 # AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>', 2212 @unittest.expectedFailure 2213 def test_in_graph_construction_mixed_2(self): 2214 def fn(nt, values, offsets, nt2): 2215 # Intermediate offsets has ephemeral source 2216 intermediate_nt = torch.nested.nested_tensor_from_jagged( 2217 values, offsets.clone() 2218 ) 2219 # This creates a dynamic shapes neq guard 2220 if nt2.shape[1] != intermediate_nt.shape[1]: 2221 # We should always go here. 2222 nt = nt * 2 2223 return nt 2224 2225 values = torch.randn(10, 5, requires_grad=True) 2226 offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2227 offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64) 2228 2229 def arg_fn(values=values, offsets=offsets, offsets2=offsets2): 2230 # Values is shared, but it shouldn't matter 2231 nt = torch.nested.nested_tensor_from_jagged(values, offsets) 2232 nt2 = torch.nested.nested_tensor_from_jagged(values, offsets2) 2233 return nt, values, offsets, nt2 2234 2235 self._validate_compile(fn, arg_fn) 2236 2237 def test_in_graph_construction_mixed_3(self): 2238 # More involved mixed case 2239 def fn(nt, values, offsets): 2240 nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets) 2241 nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets) 2242 return nt1 + nt2 + nt 2243 2244 values = torch.randn(9, 5, requires_grad=True) 2245 offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64) 2246 2247 def arg_fn(values=values, offsets=offsets): 2248 nt = torch.nested.nested_tensor_from_jagged(values, offsets) 2249 return nt, values, offsets 2250 2251 self._validate_compile(fn, arg_fn) 2252 2253 def test_return_shape(self): 2254 nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) 2255 2256 def fn(nt): 2257 return (nt * 2).shape 2258 2259 compiled = torch.compile(fn, fullgraph=True, backend="aot_eager") 2260 compiled(nt) 2261 2262 def test_inference_tensor(self): 2263 with torch.inference_mode(): 2264 nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) 2265 2266 def fn(n): 2267 return n * 2 2268 2269 torch.compile(fn, backend="eager")(nt) 2270 2271 # TODO: cannot parametrize this test class with device for some reason 2272 def _test_autograd(self, backend): 2273 a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64) 2274 b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64) 2275 c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64) 2276 nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 2277 # TODO: Switch to public API when it exists 2278 nt2, _ = jagged_from_list([a, b, c], nt.offsets()) 2279 2280 def fn1(nt1, nt2): 2281 return (nt1 + nt2).sin().cos() 2282 2283 compiled_f = torch.compile(fn1, fullgraph=True, backend=backend, dynamic=True) 2284 out = compiled_f(nt, nt2) 2285 out_buffer = out.values() 2286 ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c)) 2287 2288 out_ref = fn1(nt, nt2) 2289 out_buffer_ref = out_ref.values() 2290 ga_ref, gb_ref, gc_ref = torch.autograd.grad(out_buffer_ref.sum(), (a, b, c)) 2291 2292 self.assertTrue(torch.allclose(ga, ga_ref)) 2293 self.assertTrue(torch.allclose(gb, gb_ref)) 2294 self.assertTrue(torch.allclose(gc, gc_ref)) 2295 2296 def test_basic_autograd(self): 2297 self._test_autograd("aot_eager") 2298 2299 @requires_cuda 2300 def test_basic_autograd_inductor(self): 2301 self._test_autograd("inductor") 2302 2303 def test_subclass_with_mutation_in_graph(self): 2304 # In this graph, we have an in-graph mutation, i.e. a mutation that is allowed 2305 # to remain in the graph. Normally this is allowed, but it's not allowed if 2306 # the graph handles subclasses at all. 2307 # Whether the mutation is allowed or not allowed in the graph alters the number 2308 # of outputs from the forward graph. Previously, a bug in this handling meant 2309 # that sometimes the expected number and actual number of outputs from the 2310 # joint graph did not match, causing assertion failures. 2311 def fn(x, y): 2312 z = x.sin() 2313 y.sin_() 2314 return z.cos(), y.cos() 2315 2316 fn_c = torch.compile(fn, backend="inductor") 2317 2318 values = [torch.rand((i, 8), requires_grad=True) for i in range(1, 6)] 2319 values_copy = [x.detach().clone().requires_grad_(True) for x in values] 2320 2321 nt, offsets = jagged_from_list(values, None) 2322 nt_copy, offsets = jagged_from_list(values_copy, offsets) 2323 y = torch.rand((4, 8)) 2324 y_copy = y.clone() 2325 2326 ret = fn_c(nt, y)[0] 2327 ref = fn(nt_copy, y_copy)[0] 2328 2329 self.assertEqual(ret.values(), ref.values()) 2330 2331 ret.values().sum().backward() 2332 ref.values().sum().backward() 2333 for ref_v, res_v in zip(values_copy, values): 2334 self.assertEqual(ref_v.grad, res_v.grad) 2335 2336 @torch._dynamo.config.patch({"capture_scalar_outputs": True}) 2337 def test_unbind(self): 2338 # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). 2339 # This causes a recompile later on when it realizes the batch and last dim 2340 # should not always be equal. To avoid that, we use (3, j0, 5) here. 2341 nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) 2342 nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None) 2343 nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None) 2344 2345 def fn(x): 2346 return x.unbind() 2347 2348 compiled_f = torch.compile(fn, fullgraph=True, backend="eager", dynamic=True) 2349 out = compiled_f(nt) 2350 2351 out_ref = fn(nt) 2352 2353 # correctness 2354 self.assertEqual(len(out), len(out_ref)) 2355 for x, x_ref in zip(out, out_ref): 2356 self.assertTrue(torch.allclose(x, x_ref)) 2357 2358 # We specialize on the length of offsets, e.g. (1) we recompile if the 2359 # length of the offsets is different. (2) we don't recompile if the 2360 # length of the offsets is the same, even if the size of the constituent 2361 # tensors are different. 2362 self._check_recompiles(fn, (nt,), (nt2,), False) 2363 self._check_recompiles(fn, (nt,), (nt3,), True) 2364 2365 def test_inline_nested_tensor_from_jagged(self): 2366 nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) 2367 2368 def fn(x): 2369 return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets()) 2370 2371 torch.compile(fn, fullgraph=True, backend="aot_eager")(nt) 2372 2373 # The test here: nn.Parameters that are secretly subclasses 2374 # have a metaclass that overrides __isinstance__, 2375 # that dynamo needs to respect when it inlines the if statement. 2376 def test_param_subclass_isinstance_input(self): 2377 x_inner = torch.randn(16, 16, requires_grad=True) 2378 x = torch.nn.Parameter(TwoTensor(x_inner, x_inner)) 2379 m = torch.nn.Linear(16, 16) 2380 m.weight = x 2381 2382 def fn(): 2383 if isinstance(m.weight, torch.nn.Parameter): 2384 return m.weight + 1 2385 else: 2386 return m.weight + 2 2387 2388 out_ref = fn() 2389 out_test = torch.compile(fn, backend="aot_eager")() 2390 self.assertEqual(out_ref, out_test) 2391 2392 def _input_view_test(self, nt_view_name): 2393 nt_view = VIEW_TEST_CASES[nt_view_name]() 2394 2395 def fn(x): 2396 return x.sin() 2397 2398 out_ref = fn(nt_view) 2399 torch._dynamo.reset() 2400 compile_fn = torch.compile( 2401 fn, fullgraph=True, backend="aot_eager", dynamic=True 2402 ) 2403 out = compile_fn(nt_view) 2404 2405 # Check metadata and values are correct 2406 self.assertTrue(out.size() == out_ref.size()) 2407 self.assertTrue(out.stride() == out_ref.stride()) 2408 if out.is_nested: 2409 self.assertTrue(torch.allclose(out.values(), out_ref.values())) 2410 else: 2411 self.assertTrue(torch.allclose(out, out_ref)) 2412 2413 # Check that no upper/lower bound guards are incurred 2414 def backend(gm, args): 2415 context = torch._guards.TracingContext.get() 2416 guards = [str(g.expr) for g in context.fake_mode.shape_env.guards] 2417 2418 # varies based on the type of view 2419 guard_str = "\n".join(guards) 2420 if nt_view_name == "subclass_dense": 2421 self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""") 2422 elif nt_view_name == "dense_subclass_dense_subclass": 2423 self.assertExpectedInline( 2424 guard_str, 2425 """\ 2426Eq(s5 - 1, s2) 2427Eq(s12 - 1, s7) 2428Eq(s11, s9)""", 2429 ) 2430 elif nt_view_name.startswith("base_is_nt_True"): 2431 self.assertExpectedInline( 2432 guard_str, 2433 """Eq(s3 - 1, s0)""", 2434 ) 2435 else: 2436 self.assertExpectedInline( 2437 guard_str, 2438 """\ 2439Eq(s4 - 1, s1) 2440Eq(s13 - 1, s8) 2441Eq(s12, s10)""", 2442 ) 2443 return gm 2444 2445 torch._dynamo.reset() 2446 compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True) 2447 out = compile_fn(nt_view) 2448 2449 @parametrize( 2450 "nt_view_name", 2451 [k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"], 2452 ) 2453 def test_inputs_to_compiled_fn_are_views(self, nt_view_name): 2454 self._input_view_test(nt_view_name) 2455 2456 def test_subclass_gives_static_shapes_when_dynamic_false(self): 2457 def check_graph(gm, *args): 2458 first_node_example_val = next(iter(gm.graph.nodes)).meta["example_value"] 2459 # We compiled with dynamic=False, expect no SymInt sizes on our placeholders 2460 self.assertTrue( 2461 all(isinstance(x, int) for x in first_node_example_val.shape) 2462 ) 2463 return gm 2464 2465 @torch.compile(backend=check_graph, dynamic=False) 2466 def f(x): 2467 return x + 1 2468 2469 x_inner = torch.ones(4) 2470 x = TwoTensor(x_inner, x_inner) 2471 x_view = x.view(2, 2) 2472 out = f(x_view) 2473 2474 # NJT1 -> Dense -> NJT2 -> Dense view 2475 # During view replay, the Dense -> NJT2 part will construct an intermediate, 2476 # symbolically-sized NJT that is immediately deconstructed to return the final dense 2477 # view. To construct this intermediate properly, we need the associated nested int 2478 # to be symbolic. This view is expected to fail compilation until symbolic nested ints 2479 # are cached onto fake offsets to solve this problem. 2480 @unittest.expectedFailure 2481 def test_subclass_dense_subclass_dense_view(self): 2482 self._input_view_test("subclass_dense_subclass_dense") 2483 2484 2485instantiate_parametrized_tests(TestNestedTensor) 2486 2487 2488if __name__ == "__main__": 2489 from torch._dynamo.test_case import run_tests 2490 2491 run_tests() 2492