1# Owner(s): ["module: nestedtensor"] 2 3import io 4import itertools 5import math 6import sys 7import unittest 8from functools import partial 9from typing import Optional, Tuple 10 11import numpy as np 12 13import torch 14import torch._dynamo 15import torch._dynamo.testing 16import torch.nn 17import torch.nn.functional as F 18from torch.nested._internal.nested_tensor import ( 19 buffer_from_jagged, 20 jagged_from_list, 21 nested_view_from_values_offsets, 22 NestedTensor, 23 ViewNestedFromBuffer, 24) 25from torch.testing._internal.common_cuda import ( 26 PLATFORM_SUPPORTS_FUSED_ATTENTION, 27 SM70OrLater, 28 SM80OrLater, 29) 30from torch.testing._internal.common_device_type import ( 31 dtypes, 32 dtypesIfCUDA, 33 instantiate_device_type_tests, 34 onlyCPU, 35 onlyCUDA, 36 ops, 37 PYTORCH_CUDA_MEMCHECK, 38 skipCPUIf, 39 skipCUDAIf, 40 skipCUDAIfRocm, 41 skipMeta, 42) 43from torch.testing._internal.common_dtype import floating_types_and_half 44from torch.testing._internal.common_utils import ( 45 decorateIf, 46 freeze_rng_state, 47 gradcheck, 48 instantiate_parametrized_tests, 49 IS_FBCODE, 50 IS_WINDOWS, 51 markDynamoStrictTest, 52 NestedTensorTestCase, 53 parametrize, 54 run_tests, 55 skipIfSlowGradcheckEnv, 56 skipIfTorchDynamo, 57 subtest, 58 TEST_WITH_ROCM, 59 xfailIfTorchDynamo, 60) 61from torch.testing._internal.opinfo.definitions.nested import njt_op_db 62from torch.utils._pytree import tree_flatten 63from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts 64 65 66# Tests are ported from pytorch/nestedtensor. 67# This makes porting as_nested_tensor easier in the future. 68 69 70def _iter_constructors(): 71 # yield as_nested_tensor 72 yield torch.nested.nested_tensor 73 74 75# Returns True if the function recompiles between inputs1 and inputs2 with the 76# specified dynamic setting. 77def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True): 78 compile_count = [0] 79 80 def counter(gm, example_inputs): 81 compile_count[0] += 1 82 return gm 83 84 compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic) 85 compiled_f(*inputs1) 86 compiled_f(*inputs2) 87 return compile_count[0] > 1 88 89 90# Helper function to generate a pair of random nested tensors 91# one is contiguous, the other is not, but they appear to have same entries 92# an output nested tensor consists of 93# * `len(ragged_sizes)` matrices 94# * matrices[i].shape == (20, ragged_sizes[i]) 95 96 97def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16): 98 xs = [] 99 for size in ragged_sizes: 100 xs.append(torch.randn((size, 20), device=device, dtype=dtype)) 101 # contiguous nested tensor 102 ys = [] 103 for x in xs: 104 ys.append(x.transpose(-1, -2)) 105 nt_contiguous = torch.nested.nested_tensor(ys) 106 # noncontiguous nested tensor 107 n = len(ragged_sizes) 108 nt_noncontiguous = torch.nested.nested_tensor(xs).transpose(-1, -2) 109 return nt_contiguous, nt_noncontiguous 110 111 112# Helper functions to pad a noncontiguous nested tensor 113# can be replaced once to_padded_tensor supports noncontiguous memory 114 115 116def noncontiguous_to_padded_tensor(input, shape=None): 117 tensors = input.unbind() 118 ntensors = len(tensors) 119 assert ntensors > 0 120 if shape is None: 121 shape = [] 122 for size in tensors[0].shape: 123 shape.append(size) 124 for i in range(1, ntensors): 125 new_shape = tensors[i].shape 126 for j in range(len(shape)): 127 shape[j] = max(shape[j], new_shape[j]) 128 shape = [ntensors] + shape 129 result = tensors[0].new_zeros(shape) 130 for itensor in range(ntensors): 131 tensor = tensors[itensor] 132 view = result[itensor] 133 for idim in range(tensor.dim()): 134 view = view.narrow(idim, 0, tensor.size(idim)) 135 view.copy_(tensor) 136 return result 137 138 139# Helper function to generate a random nested tensor 140 141 142def random_nt( 143 device, 144 dtype, 145 num_tensors, 146 max_dims, 147 min_dims=None, 148 layout=torch.strided, 149 require_non_empty=True, 150): 151 if min_dims is None: 152 min_dims = tuple([0] * len(max_dims)) 153 154 assert len(max_dims) == len(min_dims) 155 for min_dim, max_dim in zip(min_dims, max_dims): 156 assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim" 157 assert min_dim >= 0, "random_nt: min_dim must be non-negative" 158 if require_non_empty: 159 assert not ( 160 min_dim == 0 and max_dim == 1 161 ), "random_nt: zero cannot be the only possible value if require_non_empty is True" 162 163 if require_non_empty: 164 # Select a random idx that will be required to be non-empty 165 non_zero_idx = torch.randint(low=0, high=num_tensors, size=(1,)).item() 166 167 ts1 = [] 168 for i, _ in enumerate(range(num_tensors)): 169 tensor_dims = [] 170 for min_dim, max_dim in zip(min_dims, max_dims): 171 new_min_dim = min_dim 172 if require_non_empty and i == non_zero_idx and min_dim == 0: 173 new_min_dim = 1 174 tensor_dims.append( 175 torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item() 176 ) 177 t1 = torch.randn(tensor_dims, device=device, dtype=dtype) 178 ts1.append(t1) 179 180 return torch.nested.nested_tensor(ts1, device=device, dtype=dtype, layout=layout) 181 182 183# Alternate approach to generating a random NT. 184# dims should be something like [5, None, 10], with None indicating that a 185# random ragged structure should be used 186def random_nt_from_dims( 187 dims, device=None, dtype=None, layout=torch.strided, requires_grad=False 188): 189 sizes = [ 190 [ 191 d if d is not None else torch.randint(2, 10, size=(1,)).item() 192 for d in dims[1:] 193 ] 194 for d in range(dims[0]) 195 ] 196 return torch.nested.nested_tensor( 197 [torch.randn(*size) for size in sizes], 198 device=device, 199 dtype=dtype, 200 layout=layout, 201 requires_grad=requires_grad, 202 ) 203 204 205# Creates an NT matching another NT's number of components and 206# shape / ragged structure for all dims specified to be -1. 207def random_nt_from_similar(other, dims=None): 208 if dims is None: 209 return torch.randn_like(other) 210 assert len(dims) == other.dim() 211 assert dims[0] == -1 or dims[0] == other.size(0) 212 213 ret_sizes = [] 214 for t in other.unbind(): 215 other_size = t.shape 216 ret_size = [] 217 for i, d in enumerate(dims[1:]): 218 if d == -1: 219 ret_size.append(other_size[i]) 220 else: 221 ret_size.append(d) 222 ret_sizes.append(ret_size) 223 224 return torch.nested.nested_tensor( 225 [torch.randn(*size) for size in ret_sizes], device=other.device 226 ) 227 228 229# makes naming nice for tests that parametrize over layout. 230def layout_name(layout): 231 # e.g. "torch.jagged" -> "jagged" 232 return layout.__repr__().split(".")[-1] 233 234 235def get_op_name(layout): 236 # e.g. "<OpOverload(op='aten.sum', overload='dim_IntList')>" -> "sum" 237 return layout.__name__.split(".")[0].split("_")[-1] 238 239 240# Helper function for test_dummy_mha_with_nt 241@torch.fx.wrap 242def convert_dense_to_nested_tensor_legacy(values): 243 offsets = torch.arange( 244 0, values.shape[0] * values.shape[1] + 1, values.shape[1], device=values.device 245 ) 246 metadata_cache = {"max_seqlen": values.shape[1], "min_seqlen": 1} 247 nt = ViewNestedFromBuffer.apply( 248 values.view(-1, values.shape[-1]), offsets, metadata_cache 249 ) 250 return nt 251 252 253# Helper function for test_dummy_mha_with_nt 254@torch.fx.wrap 255def convert_jagged_to_nested_tensor_legacy( 256 values: torch.Tensor, offsets: torch.Tensor, max_length: int 257) -> torch.Tensor: 258 metadata_cache = {"max_seqlen": max_length, "min_seqlen": 1} 259 nt = ViewNestedFromBuffer.apply(values, offsets, metadata_cache) 260 return nt 261 262 263# Helper function for test_dummy_mha_with_nt 264@torch.fx.wrap 265def convert_nt_to_jagged_legacy(nt): 266 return buffer_from_jagged(nt) 267 268 269# Helper function for test_dummy_mha_with_nt 270@torch.fx.wrap 271def convert_dense_to_nested_tensor(values): 272 nt = torch.nested.as_nested_tensor(values, layout=torch.jagged) 273 return nt 274 275 276# Helper function for test_dummy_mha_with_nt 277@torch.fx.wrap 278def convert_jagged_to_nested_tensor( 279 values: torch.Tensor, offsets: torch.Tensor, max_length: int 280) -> torch.Tensor: 281 nt = torch.nested.nested_tensor_from_jagged( 282 values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length 283 ) 284 return nt 285 286 287# Helper function for test_dummy_mha_with_nt 288def convert_nt_to_jagged(nt): 289 return nt.values() 290 291 292@markDynamoStrictTest 293class TestNestedTensor(NestedTensorTestCase): 294 @parametrize("batch_size", [2, 4]) 295 @parametrize("max_seq_len", [3, 5]) 296 @parametrize("vocab_size", [10, 20]) 297 def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size): 298 data = [] 299 nested_tensor_ref_list = [] 300 for _ in range(batch_size): 301 if max_seq_len == 0: 302 length = 0 303 else: 304 length = np.random.randint(low=1, high=max_seq_len) 305 row = list(np.random.randint(low=0, high=vocab_size, size=(length,))) 306 data.append(row) 307 nested_tensor_ref_list.append(torch.Tensor(row)) 308 nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64) 309 nested_tensor_list = nested_tensor.unbind() 310 for id in range(batch_size): 311 self.assertEqual( 312 nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64) 313 ) 314 315 @parametrize("batch_size", [2, 4]) 316 @parametrize("max_seq_len", [3, 5]) 317 @parametrize("vocab_size", [10, 20]) 318 def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size): 319 data = [] 320 nested_tensor_ref_list = [] 321 for _ in range(batch_size): 322 if max_seq_len == 0: 323 length = 0 324 else: 325 length = np.random.randint(low=1, high=max_seq_len) 326 row = list(np.random.randint(low=0, high=vocab_size, size=(length,))) 327 row = [list(item * np.arange(max_seq_len)) for item in row] 328 data.append(row) 329 nested_tensor_ref_list.append(torch.Tensor(row)) 330 nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64) 331 nested_tensor_list = nested_tensor.unbind() 332 for id in range(batch_size): 333 self.assertEqual( 334 nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64) 335 ) 336 337 @parametrize("batch_size", [2, 4]) 338 @parametrize("max_seq_len", [3, 5]) 339 @parametrize("vocab_size", [10, 20]) 340 def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size): 341 data = [] 342 nested_tensor_ref_list = [] 343 for _ in range(batch_size): 344 if max_seq_len == 0: 345 length = 0 346 else: 347 length = np.random.randint(low=1, high=max_seq_len) 348 row = list( 349 np.random.randint(low=0, high=vocab_size, size=(length,)).astype(float) 350 ) 351 row = [list(item * np.arange(max_seq_len)) for item in row] 352 data.append(row) 353 nested_tensor_ref_list.append(torch.Tensor(row)) 354 nested_tensor = torch.nested.nested_tensor(data, dtype=torch.float) 355 nested_tensor_list = nested_tensor.unbind() 356 for id in range(batch_size): 357 self.assertEqual( 358 nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.float) 359 ) 360 361 @torch.inference_mode() 362 def _test_unbind_case(self, a, b): 363 nt = torch.nested.nested_tensor([a, b]) 364 a1, b1 = nt.unbind() 365 self.assertTrue(a is not a1) 366 self.assertTrue(b is not b1) 367 368 nt = torch.nested.nested_tensor([a, b], dtype=a.dtype) 369 a1, b1 = nt.unbind(0) 370 self.assertEqual(a, a1) 371 self.assertEqual(b, b1) 372 373 a = torch.randn((2, 3)).add_(1) 374 nt = torch.nested.nested_tensor([a]) 375 self.assertEqual(a, nt.unbind(0)[0]) 376 377 @torch.inference_mode() 378 def test_unbind_0(self): 379 self._test_unbind_case(torch.tensor([1, 2]), torch.tensor([7, 8])) 380 381 @torch.inference_mode() 382 def test_unbind_1(self): 383 self._test_unbind_case(torch.tensor([1]), torch.tensor([7])) 384 385 @torch.inference_mode() 386 def test_unbind_3(self): 387 self._test_unbind_case(torch.tensor([1.0]), torch.tensor([])) 388 389 @torch.inference_mode() 390 def test_unbind_4(self): 391 self._test_unbind_case(torch.tensor([]), torch.tensor([])) 392 393 @torch.inference_mode() 394 def test_unbind_dim(self): 395 def _test_fn(unbind_fn): 396 a = torch.rand(3, 2) 397 b = torch.rand(2, 3) 398 nt = torch.nested.nested_tensor([a, b]) 399 self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1)) 400 401 # Both of these tests are necessary, because we're using 402 # torch_function. 403 _test_fn(lambda x, dim: x.unbind(dim)) 404 # TODO: Re-enable this once using torch_dispatch 405 # _test_fn(lambda x, dim: torch.unbind(x, dim)) 406 407 @torch.inference_mode() 408 def test_nested_tensor(self): 409 self.assertRaises( 410 TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0])) 411 ) 412 self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(4.0)) 413 414 @torch.inference_mode() 415 def test_nested_tensor_matching_dim(self): 416 self.assertRaisesRegex( 417 RuntimeError, 418 "Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.", 419 lambda: torch.nested.nested_tensor([torch.tensor(1.0), torch.tensor([])]), 420 ) 421 self.assertRaisesRegex( 422 RuntimeError, 423 "Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.", 424 lambda: torch.nested.nested_tensor( 425 [torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])] 426 ), 427 ) 428 429 @torch.inference_mode() 430 def test_default_nested_tensor(self): 431 self.assertRaises(TypeError, lambda: torch.nested.nested_tensor()) 432 default_nested_tensor = torch.nested.nested_tensor([]) 433 default_tensor = torch.tensor([]) 434 # self.assertEqual(default_nested_tensor.nested_dim(), 1) 435 # self.assertEqual(default_nested_tensor.nested_size(), ()) 436 self.assertEqual(default_nested_tensor.dim(), default_tensor.dim()) 437 self.assertEqual(default_nested_tensor.layout, default_tensor.layout) 438 self.assertEqual(default_nested_tensor.device, default_tensor.device) 439 self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype) 440 self.assertEqual( 441 default_nested_tensor.requires_grad, default_tensor.requires_grad 442 ) 443 self.assertIsNone(default_tensor.grad) 444 # TODO: Re-enable once we have a performance driven 445 # use case and implementation. 446 # self.assertEqual(default_nested_tensor.is_pinned(), 447 # default_tensor.is_pinned()) 448 449 @torch.inference_mode() 450 def test_dim(self): 451 for constructor in _iter_constructors(): 452 a1 = constructor([]) 453 self.assertEqual(a1.dim(), 1) 454 a1 = constructor([torch.tensor(3.0)]) 455 self.assertEqual(a1.dim(), 1) 456 a1 = constructor([torch.tensor([1, 2, 3, 4])]) 457 self.assertEqual(a1.dim(), 2) 458 459 @unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.") 460 @torch.inference_mode() 461 def test_numel(self): 462 for constructor in _iter_constructors(): 463 a1 = constructor([]) 464 self.assertEqual(a1.numel(), 0) 465 a1 = constructor([torch.tensor(3.0), torch.tensor(4.0)]) 466 self.assertEqual(a1.numel(), 2) 467 a1 = constructor([torch.randn(2, 2, 2)]) 468 self.assertEqual(a1.numel(), 8) 469 a1 = constructor([torch.randn([1, 2, 3]), torch.randn(3, 2, 1)]) 470 self.assertEqual(a1.numel(), 12) 471 a1 = constructor([torch.randn([1, 1, 3]), torch.randn(3, 2, 4)]) 472 self.assertEqual(a1.numel(), 27) 473 a1 = constructor([torch.randn([5, 5, 5]), torch.randn(6, 6, 6)]) 474 self.assertEqual(a1.numel(), 341) 475 476 # Interesting edge case 477 a1 = constructor([torch.randn([1, 2, 3]), torch.randn(1, 2, 0)]) 478 self.assertEqual(a1.numel(), 6) 479 480 @torch.inference_mode() 481 def test_size(self): 482 for constructor in _iter_constructors(): 483 a1 = constructor([]) 484 self.assertRaisesRegex( 485 RuntimeError, 486 "NestedTensorImpl doesn't support sizes", 487 lambda: a1.size(), 488 ) 489 490 def test_size_dim(self): 491 a = torch.nested.nested_tensor([]) 492 self.assertEqual(a.size(0), 0) 493 494 a = torch.nested.nested_tensor([torch.tensor(1)]) 495 self.assertEqual(a.size(0), 1) 496 497 a = torch.nested.nested_tensor([torch.tensor(1), torch.tensor(2)]) 498 self.assertEqual(a.size(0), 2) 499 500 a = torch.nested.nested_tensor([torch.rand(1, 2), torch.rand(1, 8)]) 501 self.assertEqual(a.size(0), 2) 502 self.assertEqual(a.size(1), 1) 503 self.assertRaisesRegex( 504 RuntimeError, 505 "Given dimension 2 is irregular and does not have a size", 506 lambda: a.size(2), 507 ) 508 509 a = torch.nested.nested_tensor([torch.rand(3, 4), torch.rand(5, 4)]) 510 self.assertEqual(a.size(0), 2) 511 self.assertRaisesRegex( 512 RuntimeError, 513 "Given dimension 1 is irregular and does not have a size", 514 lambda: a.size(1), 515 ) 516 self.assertEqual(a.size(2), 4) 517 518 @unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.") 519 @torch.inference_mode() 520 def test_stride(self): 521 for constructor in _iter_constructors(): 522 a1 = constructor([]) 523 self.assertRaisesRegex( 524 RuntimeError, 525 "NestedTensorImpl doesn't support strides", 526 lambda: a1.stride(), 527 ) 528 529 @unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.") 530 @torch.inference_mode() 531 def test_is_contiguous(self): 532 # Test empty case 533 nt_empty = torch.nested.nested_tensor([]) 534 assert nt_empty.is_contiguous() 535 self.assertEqual(nt_empty, nt_empty.contiguous()) 536 537 nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) 538 539 # Test contiguous case 540 assert nt_contiguous.is_contiguous() 541 self.assertEqual(nt_contiguous, nt_contiguous.contiguous()) 542 543 # Test non_contiguous case 544 assert not nt_noncontiguous.is_contiguous() 545 self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous()) 546 547 # Test querying by memory_format 548 self.assertTrue( 549 nt_contiguous.is_contiguous(memory_format=torch.contiguous_format) 550 ) 551 self.assertTrue( 552 not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format) 553 ) 554 555 @torch.inference_mode() 556 def test_repr_string(self): 557 a = torch.nested.nested_tensor([]) 558 expected = "nested_tensor([\n\n])" 559 self.assertEqual(str(a), expected) 560 self.assertEqual(repr(a), expected) 561 562 a = torch.nested.nested_tensor([torch.tensor(1.0)]) 563 expected = "nested_tensor([\n tensor(1.)\n])" 564 self.assertEqual(str(a), expected) 565 self.assertEqual(repr(a), expected) 566 567 a = torch.nested.nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])]) 568 expected = "nested_tensor([\n tensor([[1, 2]]),\n tensor([[4, 5]])\n])" 569 self.assertEqual(str(a), expected) 570 self.assertEqual(repr(a), expected) 571 572 def test_to_padded_tensor_on_empty_tensor(self): 573 nt = torch.nested.nested_tensor([]) 574 empty = torch.nested.to_padded_tensor(nt, 4) 575 self.assertEqual(empty, torch.tensor([])) 576 577 def test_nested_namespace(self): 578 nt = torch.nested.nested_tensor([torch.randn(2, 3), torch.randn(4, 5)]) 579 result = nt.to_padded_tensor(4) 580 nested_namespace_result = torch.nested.to_padded_tensor(nt, 4) 581 self.assertEqual(result, nested_namespace_result) 582 583 def test_to(self): 584 ntensors = 4 585 nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) 586 587 def test_copy_behavior(t, non_blocking=False): 588 self.assertIs(t, t.to(t, non_blocking=non_blocking)) 589 self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking)) 590 self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking)) 591 self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True)) 592 self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True)) 593 self.assertIsNot( 594 t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True) 595 ) 596 597 devices = [t.device] 598 if t.device.type == "cuda": 599 if t.device.index == -1: 600 devices.append(f"cuda:{torch.cuda.current_device()}") 601 elif t.device.index == torch.cuda.current_device(): 602 devices.append("cuda") 603 for device in devices: 604 self.assertIs(t, t.to(device, non_blocking=non_blocking)) 605 self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking)) 606 self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True)) 607 self.assertIsNot( 608 t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True) 609 ) 610 611 test_copy_behavior(nt) 612 self.assertEqual(nt.device, nt.to("cpu").device) 613 self.assertEqual(nt.device, nt.to("cpu", dtype=torch.float32).device) 614 self.assertIs(torch.float32, nt.to("cpu", dtype=torch.float32).dtype) 615 self.assertEqual(nt.device, nt.to(torch.float32).device) 616 self.assertIs(torch.float32, nt.to(dtype=torch.float32).dtype) 617 618 def test_data_ptr(getter): 619 self.assertEqual(getter(nt), getter(nt.to("cpu"))) 620 self.assertEqual( 621 getter(nt), getter(nt.to(dtype=nt.dtype, device=nt.device, copy=False)) 622 ) 623 self.assertEqual(getter(nt), getter(nt.to("cpu", copy=False))) 624 self.assertNotEqual(getter(nt), getter(nt.to("cpu", copy=True))) 625 626 test_data_ptr(lambda nt: nt.data_ptr()) 627 628 if torch.cuda.is_available(): 629 for non_blocking in [True, False]: 630 for cuda in [ 631 "cuda", 632 "cuda:0" if torch.cuda.device_count() == 1 else "cuda:1", 633 ]: 634 nt2 = random_nt(cuda, torch.float32, ntensors, (4, 4)) 635 test_copy_behavior(nt2, non_blocking) 636 self.assertEqual( 637 nt2.device, nt2.to(cuda, non_blocking=non_blocking).device 638 ) 639 self.assertEqual( 640 nt.device, nt2.to("cpu", non_blocking=non_blocking).device 641 ) 642 self.assertEqual( 643 nt2.device, nt.to(cuda, non_blocking=non_blocking).device 644 ) 645 self.assertIs( 646 torch.int32, 647 nt2.to( 648 "cpu", dtype=torch.int32, non_blocking=non_blocking 649 ).dtype, 650 ) 651 self.assertEqual( 652 nt.device, 653 nt2.to( 654 "cpu", dtype=torch.int32, non_blocking=non_blocking 655 ).device, 656 ) 657 self.assertIs(torch.int32, nt2.to(dtype=torch.int32).dtype) 658 self.assertEqual(nt2.device, nt2.to(dtype=torch.int32).device) 659 660 def test_copy_(self): 661 ntensors = 4 662 nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) 663 nt_copy = torch.empty_like(nt) 664 nt_copy.copy_(nt) 665 666 for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy): 667 self.assertEqual(nt_ub, nt_copy_ub) 668 669 nt_error = torch.nested.nested_tensor([torch.tensor([0, 0])]) 670 self.assertRaisesRegex( 671 RuntimeError, 672 "copy_ only supports tensors that are the same size for Nested implementations", 673 lambda: nt_error.copy_(nt), 674 ) 675 676 if torch.cuda.is_available(): 677 nt = random_nt(torch.device("cuda"), torch.float32, ntensors, (4, 4)) 678 nt_copy = torch.empty_like(nt, device=torch.device("cpu")) 679 nt_copy.copy_(nt, non_blocking=True) 680 torch.cuda.current_stream(torch.cuda.current_device()).synchronize() 681 for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy): 682 self.assertEqual(nt_ub, nt_copy_ub) 683 684 nt_copy = torch.empty_like(nt, device=torch.device("cpu")) 685 nt_copy.copy_(nt, non_blocking=False) 686 for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy): 687 self.assertEqual(nt_ub, nt_copy_ub) 688 689 def test_fill_(self): 690 ntensors = 4 691 nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) 692 nt.fill_(10.0) 693 for nt_ub in nt.unbind(): 694 t = torch.empty_like(nt_ub) 695 t.fill_(10.0) 696 self.assertEqual(nt_ub, t) 697 698 fill_tensor = torch.tensor([11.0]) 699 self.assertRaisesRegex( 700 RuntimeError, 701 "fill_ only supports 0-dimension value tensor", 702 lambda: nt.fill_(fill_tensor), 703 ) 704 705 nt.fill_(fill_tensor[0]) 706 for nt_ub in nt.unbind(): 707 t = torch.empty_like(nt_ub) 708 t.fill_(11.0) 709 self.assertEqual(nt_ub, t) 710 711 def test_zero_(self): 712 ntensors = 4 713 nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) 714 nt.zero_() 715 for nt_ub in nt.unbind(): 716 t = torch.empty_like(nt_ub) 717 t.fill_(0.0) 718 self.assertEqual(nt_ub, t) 719 720 @parametrize( 721 "func", 722 [torch.ones_like, torch.zeros_like, torch.randn_like], 723 name_fn=lambda f: f.__name__, 724 ) 725 def test_like_functions(self, func): 726 ntensors = 4 727 nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) 728 torch.manual_seed(1) 729 nt_like = func(nt) 730 731 torch.manual_seed(1) 732 for nt_ub in nt_like.unbind(): 733 t_like = func(nt_ub) 734 self.assertEqual(nt_ub, t_like) 735 736 def test_cat(self): 737 # dim=0 success case 738 # No constraints on ragged structures matching. 739 x = random_nt_from_dims([5, None, 10]) 740 y = random_nt_from_dims([3, 4, None]) 741 output = torch.cat([x, y], dim=0) 742 for out_component, xy_component in zip( 743 output.unbind(), itertools.chain(x.unbind(), y.unbind()) 744 ): 745 self.assertEqual(out_component, xy_component) 746 747 # dim=-1 success case 748 # shape (B, *, D) 749 x = random_nt_from_dims([5, None, 10]) 750 # shape (B, *, D'); same structure as x but dim=-1 differs 751 y = random_nt_from_similar(x, dims=[-1, -1, 8]) 752 # should be shape (B, *, D + D') when supported 753 output = torch.cat([x, y], dim=-1) 754 for out_component, x_component, y_component in zip( 755 output.unbind(), x.unbind(), y.unbind() 756 ): 757 self.assertEqual( 758 out_component, torch.cat([x_component, y_component], dim=-1) 759 ) 760 761 # dim between 0 and -1 success case 762 x = random_nt_from_dims([5, None, 2, 3]) 763 # same structure as x but dim=2 differs 764 y = random_nt_from_similar(x, dims=[-1, -1, 4, -1]) 765 output = torch.cat([x, y], dim=2) 766 for out_component, x_component, y_component in zip( 767 output.unbind(), x.unbind(), y.unbind() 768 ): 769 self.assertEqual( 770 out_component, torch.cat([x_component, y_component], dim=1) 771 ) 772 773 # error case: mixed NT / dense inputs 774 x = random_nt_from_dims([5, None, 2]) 775 y = torch.randn(5, 3, 2) 776 with self.assertRaisesRegex( 777 RuntimeError, "expected each tensor in given list to be nested" 778 ): 779 torch.cat([x, y], dim=-1) 780 781 # error case: NTs with different dims 782 x = random_nt_from_dims([5, None, 2]) 783 y = random_nt_from_dims([5, None, 2, 3]) 784 with self.assertRaisesRegex( 785 RuntimeError, 786 "expected all nested tensors to have matching ragged structures outside of the concatenated dim", 787 ): 788 torch.cat([x, y], dim=-1) 789 790 # error case: non-contiguous NT 791 x, y = random_nt_noncontiguous_pair((2, 3, 4), dtype=torch.float32) 792 # transpose to put ragged dim next to batch dim 793 x, y = x.transpose(-2, -1), y.transpose(-2, -1) 794 with self.assertRaisesRegex( 795 RuntimeError, "only contiguous nested tensors are supported" 796 ): 797 torch.cat([x, y], dim=-1) 798 799 # error case: multiple ragged dims in inputs 800 x = random_nt_from_dims([5, None, None, 2]) 801 y = random_nt_from_similar(x) 802 with self.assertRaisesRegex( 803 RuntimeError, 804 "only nested tensors with a single ragged dim next to the batch dim are supported", 805 ): 806 torch.cat([x, y], dim=-1) 807 808 # error case: ragged dim not next to batch dim 809 x = random_nt_from_dims([5, 2, None]) 810 y = random_nt_from_similar(x) 811 with self.assertRaisesRegex( 812 RuntimeError, 813 "only nested tensors with a single ragged dim next to the batch dim are supported", 814 ): 815 torch.cat([x, y], dim=1) 816 817 # error case: NTs with different batch sizes 818 x = random_nt_from_dims([5, None, 2]) 819 y = random_nt_from_dims([3, None, 2]) 820 with self.assertRaisesRegex( 821 RuntimeError, 822 "expected all nested tensors to have matching ragged structures outside of the concatenated dim", 823 ): 824 torch.cat([x, y], dim=-1) 825 826 # error case: NTs with different ragged structures 827 x = torch.nested.nested_tensor( 828 [ 829 torch.randn(2, 6), 830 torch.randn(4, 6), 831 torch.randn(5, 6), 832 ] 833 ) 834 y = torch.nested.nested_tensor( 835 [ 836 torch.randn(5, 6), 837 torch.randn(4, 6), 838 torch.randn(2, 6), 839 ] 840 ) 841 with self.assertRaisesRegex( 842 RuntimeError, 843 "expected all nested tensors to have matching ragged structures outside of the concatenated dim", 844 ): 845 torch.cat([x, y], dim=-1) 846 847 848@markDynamoStrictTest 849class TestNestedTensorDeviceType(NestedTensorTestCase): 850 # Helper function to generate a pair of random nested tensors 851 # the 2 nested tensors have same shapes 852 def random_nt_pair(self, device, dtype, num_tensors, max_dims): 853 ts1 = [] 854 ts2 = [] 855 for _ in range(num_tensors): 856 tensor_dims = tuple( 857 [ 858 torch.randint(low=0, high=max_dim, size=(1,)).item() 859 for max_dim in max_dims 860 ] 861 ) 862 t1 = torch.randn(tensor_dims, device=device, dtype=dtype) 863 t2 = torch.randn(tensor_dims, device=device, dtype=dtype) 864 ts1.append(t1) 865 ts2.append(t2) 866 return ( 867 torch.nested.nested_tensor(ts1, device=device, dtype=dtype), 868 torch.nested.nested_tensor(ts2, device=device, dtype=dtype), 869 ) 870 871 @dtypes(*floating_types_and_half()) 872 def test_detach(self, device, dtype): 873 a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=False) 874 b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=False) 875 x = torch.nested.nested_tensor([a, b], requires_grad=True) 876 877 x_detach = x.detach() 878 879 z = x_detach * 4 880 self.assertFalse(x_detach.requires_grad) 881 self.assertFalse(z.requires_grad) 882 883 a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=True) 884 b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=True) 885 x = torch.nested.as_nested_tensor([a, b]) 886 887 y = x * 2 888 y = y.detach() 889 self.assertFalse(y.requires_grad) 890 self.assertIsNone(y.grad_fn) 891 892 z = x + y 893 torch.nested.to_padded_tensor(z, 0).sum().backward() 894 # This is an incorrect gradient, but we assume that's what the user 895 # wanted. detach() is an advanced option. 896 self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype)) 897 self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype)) 898 899 @dtypes(torch.float, torch.float16, torch.double) 900 def test_unbind_noncontiguous(self, device, dtype): 901 nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( 902 (2, 3, 6, 7), device, dtype 903 ) 904 ub_contiguous = nt_contiguous.unbind() 905 ub_noncontiguous = nt_noncontiguous.unbind() 906 self.assertEqual(len(ub_contiguous), len(ub_noncontiguous)) 907 n = len(ub_contiguous) 908 for i in range(n): 909 self.assertEqual(ub_contiguous[i], ub_noncontiguous[i]) 910 911 @dtypes(torch.float) 912 @skipMeta 913 def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype): 914 t = torch.randn(4, 4, 4, device=device, dtype=dtype) 915 ts = list(torch.unbind(t)) 916 ts[0] = ts[0][:-1] 917 nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 918 padded = torch.nested.to_padded_tensor(nt, 0) 919 920 nt_to = torch._nested_from_padded_and_nested_example(padded, nt) 921 922 for t1, t2 in zip(nt.unbind(), nt_to.unbind()): 923 self.assertEqual(t1, t2) 924 self.assertEqual(nt.device, nt_to.device) 925 926 @dtypes(torch.float) 927 @dtypesIfCUDA(torch.float, torch.half) 928 @skipMeta 929 @torch.inference_mode() 930 def test_layer_norm(self, device, dtype): 931 def _test(size): 932 # Simple shapes test 933 t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False) 934 t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False) 935 ts = [t0, t1, t0, t1] 936 nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 937 layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype) 938 nt_result = layer_norm(nt) 939 for nt_subresult, t in zip(nt_result.unbind(), ts): 940 t_result = layer_norm(t.reshape(1, -1, size).squeeze(0)) 941 self.assertEqual(nt_subresult, t_result) 942 943 # More complex nt test with different lengths for each tensor 944 t0 = torch.randn(4, size, device=device, dtype=dtype, requires_grad=False) 945 t1 = torch.randn(10, size, device=device, dtype=dtype, requires_grad=False) 946 t2 = torch.randn(7, size, device=device, dtype=dtype, requires_grad=False) 947 ts = [t0, t1, t2, t0, t2] 948 nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 949 layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype) 950 nt_result = layer_norm(nt) 951 for nt_subresult, t in zip(nt_result.unbind(), ts): 952 t_result = layer_norm(t.reshape(1, -1, size).squeeze(0)) 953 self.assertEqual(nt_subresult, t_result) 954 955 if size <= 128: 956 # Test with multidimensional tensors after irregular dim 957 # (run only with smaller dimensions to ensure fast execution) 958 t0 = torch.randn( 959 4, size, size, 4, device=device, dtype=dtype, requires_grad=False 960 ) 961 t1 = torch.randn( 962 10, size, size, 4, device=device, dtype=dtype, requires_grad=False 963 ) 964 t2 = torch.randn( 965 7, size, size, 4, device=device, dtype=dtype, requires_grad=False 966 ) 967 ts = [t0, t1, t2, t0, t2] 968 nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 969 layer_norm = torch.nn.LayerNorm( 970 (size, size, 4), device=device, dtype=dtype 971 ) 972 nt_result = layer_norm(nt) 973 for nt_subresult, t in zip(nt_result.unbind(), ts): 974 t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0)) 975 self.assertEqual(nt_subresult, t_result) 976 977 # Test where the normalizing dimensions are not all 978 layer_norm = torch.nn.LayerNorm((size, 4), device=device, dtype=dtype) 979 nt_result = layer_norm(nt) 980 for nt_subresult, t in zip(nt_result.unbind(), ts): 981 t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0)) 982 self.assertEqual(nt_subresult, t_result) 983 984 for size in (1024, 1023, 513, 512, 256, 128, 2, 4, 32): 985 _test(size) 986 987 @dtypes(torch.float) 988 @dtypesIfCUDA(torch.float, torch.half) 989 @skipMeta 990 @torch.inference_mode() 991 def test_layer_norm_breaking(self, device, dtype): 992 size = 128 993 t0 = torch.randn( 994 4, size, size, 4, device=device, dtype=dtype, requires_grad=False 995 ) 996 t1 = torch.randn( 997 10, size, size, 4, device=device, dtype=dtype, requires_grad=False 998 ) 999 t2 = torch.randn( 1000 7, size, size, 4, device=device, dtype=dtype, requires_grad=False 1001 ) 1002 ts = [t0, t1, t2, t0, t2] 1003 nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1004 layer_norm = torch.nn.LayerNorm((4, size, size, 4), device=device, dtype=dtype) 1005 self.assertRaisesRegex( 1006 RuntimeError, 1007 "normalized_shape extends into irregular dimensions for the nested tensor", 1008 lambda: layer_norm(nt), 1009 ) 1010 layer_norm = torch.nn.LayerNorm((size + 1, size, 4), device=device, dtype=dtype) 1011 self.assertRaisesRegex( 1012 RuntimeError, 1013 "The shape at dimension 0", 1014 lambda: layer_norm(nt), 1015 ) 1016 1017 @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) 1018 def test_embedding(self, device, layout): 1019 inputs = [ 1020 torch.randint(100, (L,), device=device, dtype=torch.int64) 1021 for L in torch.randint(5, 50, (8,)) 1022 ] 1023 x = torch.nested.nested_tensor( 1024 inputs, device=device, dtype=torch.int64, layout=layout 1025 ) 1026 emb = torch.nn.Embedding(100, 8, device=device) 1027 y = emb(x) 1028 1029 @torch._dynamo.disable 1030 def check(inputs, y): 1031 ys = y.unbind() 1032 for i, inp in enumerate(inputs): 1033 self.assertEqual(emb(inp), ys[i]) 1034 1035 check(inputs, y) 1036 1037 @skipMeta 1038 @torch.inference_mode() 1039 @dtypes(*floating_types_and_half()) 1040 def test_masked_fill(self, device, dtype): 1041 # nested tensor * nested tensor 1042 (nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4)) 1043 mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()]) 1044 ref = torch.nested.nested_tensor( 1045 [t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())] 1046 ) 1047 out = nt.masked_fill(mask, 0) 1048 self.assertEqual(ref, out) 1049 1050 @dtypes(torch.float, torch.float16) 1051 def test_to_padded_tensor_simple(self, device, dtype): 1052 t = torch.randn(4, 4, 4, device=device, dtype=dtype) 1053 ts = list(torch.unbind(t)) 1054 ts[0] = ts[0][:-1] 1055 nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1056 for padding_value in (0, 1): 1057 padded = torch.nested.to_padded_tensor(nt, padding_value) 1058 1059 correct_output = t.clone() 1060 if padding_value == 0: 1061 correct_output[0][-1] = torch.zeros_like(correct_output[0][-1]) 1062 else: 1063 correct_output[0][-1] = torch.ones_like(correct_output[0][-1]) 1064 1065 self.assertEqual(padded, correct_output) 1066 self.assertEqual(padded.device, torch.device(device)) 1067 self.assertEqual(padded.dtype, dtype) 1068 1069 @dtypes(torch.float, torch.float16) 1070 def test_to_padded_tensor_output_size(self, device, dtype): 1071 t = torch.randn(4, 4, 4, device=device, dtype=dtype) 1072 output_size = (4, 6, 5) 1073 ts = list(torch.unbind(t)) 1074 ts[0] = ts[0][:-1] 1075 nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1076 for padding_value in (0, 1): 1077 padded = torch.nested.to_padded_tensor( 1078 nt, padding_value, output_size=output_size 1079 ) 1080 correct_output = ( 1081 torch.ones(output_size, device=device, dtype=dtype) * padding_value 1082 ) 1083 correct_output[:4:, :4, :4] = t.clone() 1084 if padding_value == 0: 1085 correct_output[0][3] = torch.zeros_like(correct_output[0][3]) 1086 else: 1087 correct_output[0][3] = torch.ones_like(correct_output[0][3]) 1088 1089 self.assertEqual(padded, correct_output) 1090 self.assertEqual(padded.device, torch.device(device)) 1091 self.assertEqual(padded.dtype, dtype) 1092 1093 @dtypes(torch.float, torch.float16, torch.double) 1094 def test_to_padded_tensor_dim2(self, device, dtype): 1095 ts = [ 1096 torch.randn(160, device=device, dtype=dtype), 1097 torch.randn(1240, device=device, dtype=dtype), 1098 torch.randn(2400, device=device, dtype=dtype), 1099 ] 1100 nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1101 pad = 42 1102 correct_output = [] 1103 for t in ts: 1104 next_output = torch.ones_like(ts[2]) * pad 1105 correct_output.append(next_output) 1106 next_output[: t.size(0)].copy_(t) 1107 correct_output = torch.stack(correct_output) 1108 padded = torch.nested.to_padded_tensor(nt, pad) 1109 self.assertEqual(padded, correct_output) 1110 1111 @dtypes(torch.float, torch.float16, torch.double) 1112 def test_to_padded_tensor_dim3(self, device, dtype): 1113 ts = [ 1114 torch.randn(16, 21, device=device, dtype=dtype), 1115 torch.randn(24, 32, device=device, dtype=dtype), 1116 torch.randn(40, 53, device=device, dtype=dtype), 1117 ] 1118 nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1119 pad = 42 1120 correct_output = [] 1121 for t in ts: 1122 next_output = torch.ones_like(ts[2]) * pad 1123 correct_output.append(next_output) 1124 next_output[: t.size(0), : t.size(1)].copy_(t) 1125 correct_output = torch.stack(correct_output) 1126 padded = torch.nested.to_padded_tensor(nt, pad) 1127 self.assertEqual(padded, correct_output) 1128 1129 @dtypes(torch.float, torch.float16, torch.double) 1130 def test_to_padded_tensor_dim4(self, device, dtype): 1131 ts = [ 1132 torch.randn(16, 21, 13, device=device, dtype=dtype), 1133 torch.randn(24, 32, 14, device=device, dtype=dtype), 1134 torch.randn(40, 53, 16, device=device, dtype=dtype), 1135 ] 1136 nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1137 pad = 42 1138 correct_output = [] 1139 for t in ts: 1140 next_output = torch.ones_like(ts[2]) * pad 1141 correct_output.append(next_output) 1142 next_output[: t.size(0), : t.size(1), : t.size(2)].copy_(t) 1143 correct_output = torch.stack(correct_output) 1144 padded = torch.nested.to_padded_tensor(nt, pad) 1145 self.assertEqual(padded, correct_output) 1146 1147 # TODO: test noncontiguous to_padded_tensor 1148 # For now this tests the functionality of noncontiguous_to_padded_tensor 1149 # and the error message of to_padded_tensor 1150 # since to_padded_tensor does not support noncontiguous buffer yet 1151 @dtypes(torch.float, torch.float16, torch.double) 1152 @torch.inference_mode() 1153 def test_to_padded_tensor_noncontiguous(self, device, dtype): 1154 nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( 1155 (2, 3, 6, 7), device, dtype 1156 ) 1157 # test noncontiguous_to_padded_tensor functionality 1158 self.assertEqual( 1159 torch.nested.to_padded_tensor(nt_contiguous, 0.0), 1160 noncontiguous_to_padded_tensor(nt_noncontiguous), 1161 ) 1162 # test to_padded_tensor error message 1163 self.assertRaisesRegex( 1164 RuntimeError, 1165 r"for now to_padded_tensor only supports contiguous nested tensor", 1166 lambda: torch.nested.to_padded_tensor(nt_noncontiguous, 0.0), 1167 ) 1168 1169 @skipMeta 1170 def test_device_checks(self, device): 1171 nt = torch.nested.nested_tensor([], device=device) 1172 is_cuda = "cuda" in str(device) 1173 self.assertEqual(nt.is_cuda, is_cuda) 1174 1175 @dtypes(torch.float, torch.float16, torch.double) 1176 def test_nested_tensor_indexing(self, device, dtype): 1177 # edge case: empty nested tensor 1178 nt0 = torch.nested.nested_tensor([]) 1179 self.assertRaises(IndexError, lambda: nt0[0]) 1180 # normal case 1181 x0 = torch.randn((2, 5), device=device, dtype=dtype) 1182 x1 = torch.randn((3, 4), device=device, dtype=dtype) 1183 nt = torch.nested.nested_tensor([x0, x1]) 1184 # single index: only support integer in the batch dimension 1185 self.assertEqual(nt[0], x0) 1186 self.assertEqual(nt[-1], x1) 1187 self.assertRaises(IndexError, lambda: nt[2]) 1188 self.assertRaises(IndexError, lambda: nt[-3]) 1189 self.assertRaises(NotImplementedError, lambda: nt[:]) 1190 self.assertEqual(nt[...], nt) 1191 # tuple of indices: only support integer in the batch dimension 1192 # + all possible indexing in the original tensor dimensions 1193 self.assertEqual(nt[0, 0, 0], x0[0, 0]) 1194 self.assertEqual(nt[0, 1, :], x0[1, :]) 1195 self.assertEqual(nt[1, ...], x1) 1196 self.assertRaises(IndexError, lambda: nt[1, 4, 2]) 1197 self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1]) 1198 # test select on non-batch dimensions 1199 self.assertEqual(nt.select(1, 0)[0], x0.select(0, 0)) 1200 self.assertEqual(nt.select(1, 0)[1], x1.select(0, 0)) 1201 self.assertRaises(IndexError, lambda: nt.select(1, 3)) 1202 self.assertEqual(nt.select(2, 0)[0], x0.select(1, 0)) 1203 self.assertEqual(nt.select(2, 0)[1], x1.select(1, 0)) 1204 self.assertRaises(IndexError, lambda: nt.select(2, 5)) 1205 # make sure indexing returns a view 1206 nt[0].fill_(100.0) 1207 answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5)) 1208 self.assertEqual(nt[0], answer) 1209 nt[1, 1, :].fill_(200.0) 1210 answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4) 1211 self.assertEqual(nt[1, 1, :], answer) 1212 1213 # Test that indexing works when requires_grad_(True) 1214 # previously this was failing because the backward kernel for select.int uses .sizes() 1215 nt = torch.nested.nested_tensor([x0, x1]).requires_grad_(True) 1216 self.assertEqual(nt[0], x0) 1217 self.assertEqual(nt[-1], x1) 1218 grad_x0 = torch.randn((2, 5), device=device, dtype=dtype) 1219 nt[0].backward(grad_x0) 1220 expected_grad = torch.nested.nested_tensor( 1221 [grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)] 1222 ) 1223 self.assertEqual(nt.grad, expected_grad) 1224 1225 @parametrize( 1226 "func", 1227 [ 1228 subtest(torch.nn.functional.relu, name="relu"), 1229 subtest(torch.nn.functional.relu_, name="relu_"), 1230 subtest(torch.nn.functional.gelu, name="gelu"), 1231 subtest(torch._C._nn.gelu_, name="gelu_"), 1232 subtest(torch.tanh, name="tanh"), 1233 subtest(torch.tanh_, name="tanh_"), 1234 subtest(torch.neg, name="neg"), 1235 subtest(torch.nn.functional.silu, name="silu"), 1236 subtest(partial(torch.nn.functional.silu, inplace=True), name="silu_"), 1237 subtest(torch.abs, name="abs"), 1238 subtest(torch.abs_, name="abs_"), 1239 subtest(torch.sgn, name="sgn"), 1240 subtest(torch.logical_not, name="logical_not"), 1241 subtest(torch.sin, name="sin"), 1242 subtest(torch.cos, name="cos"), 1243 ], 1244 ) 1245 def test_activations(self, device, func): 1246 nt, nt_noncontiguous = random_nt_noncontiguous_pair( 1247 (2, 3, 6, 7), device=device, dtype=torch.float32 1248 ) 1249 nested_result = func(nt) 1250 self.assertTrue(nested_result.is_nested) 1251 for t, t_res in zip(nt.unbind(), nested_result.unbind()): 1252 self.assertEqual(func(t), t_res) 1253 self.assertRaisesRegex( 1254 RuntimeError, 1255 "NestedTensor must be contiguous to get buffer.", 1256 lambda: func(nt_noncontiguous), 1257 ) 1258 1259 @parametrize("func", [subtest(torch.ge, name="ge"), subtest(torch.eq, name="eq")]) 1260 def test_binary_ops_with_scalar(self, device, func): 1261 nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( 1262 (2, 3, 6, 7), device=device, dtype=torch.float32 1263 ) 1264 scalar = 0.0 1265 1266 # should work regardless of contiguity 1267 for nt in (nt_contiguous, nt_noncontiguous): 1268 nested_result = func(nt, scalar) 1269 self.assertTrue(nested_result.is_nested) 1270 for t, t_res in zip(nt.unbind(), nested_result.unbind()): 1271 self.assertEqual(func(t, scalar), t_res) 1272 1273 @dtypes(*floating_types_and_half()) 1274 def test_nested_tensor_chunk(self, device, dtype): 1275 # Transformer use case 1276 a = torch.randn(3, 3 * 4, device=device, dtype=dtype) 1277 b = torch.randn(2, 3 * 4, device=device, dtype=dtype) 1278 c = torch.randn(1, 3 * 4, device=device, dtype=dtype) 1279 a_chunks = a.chunk(3, dim=-1) 1280 b_chunks = b.chunk(3, dim=-1) 1281 c_chunks = c.chunk(3, dim=-1) 1282 1283 a_nt = [a_chunks[0], b_chunks[0], c_chunks[0]] 1284 b_nt = [a_chunks[1], b_chunks[1], c_chunks[1]] 1285 c_nt = [a_chunks[2], b_chunks[2], c_chunks[2]] 1286 1287 nt = torch.nested.nested_tensor([a, b, c]) 1288 chunked = nt.chunk(3, dim=-1) 1289 1290 self.assertEqual(chunked[0], torch.nested.nested_tensor(a_nt)) 1291 self.assertEqual(chunked[1], torch.nested.nested_tensor(b_nt)) 1292 self.assertEqual(chunked[2], torch.nested.nested_tensor(c_nt)) 1293 1294 for chunk in chunked: 1295 self.assertFalse(chunk.is_contiguous()) 1296 1297 # Failure chunking on ragged dimensions 1298 self.assertRaisesRegex( 1299 RuntimeError, 1300 "Chunk for nested tensors is currently only supported for the last dimension.", 1301 lambda: torch.chunk(nt, 5, dim=1), 1302 ) 1303 self.assertRaisesRegex( 1304 RuntimeError, 1305 "Chunk for nested tensors is currently only supported for the last dimension.", 1306 lambda: torch.chunk(nt, 5, dim=0), 1307 ) 1308 1309 # Failure on non-contiguous nt 1310 _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) 1311 self.assertRaisesRegex( 1312 RuntimeError, 1313 "chunk expects `self` to be contiguous.", 1314 lambda: torch.chunk(nt_noncontiguous, 5, dim=-1), 1315 ) 1316 1317 # Failure when calling non divisible n_chunks 1318 self.assertRaisesRegex( 1319 RuntimeError, 1320 "Chunk for nested tensors is only supported for " 1321 "nested tensors with trailing dimension divisible by chunks.", 1322 lambda: torch.chunk(nt, 5, dim=-1), 1323 ) 1324 1325 # Failure when calling backward on a chunk 1326 a = torch.randn(3, 3 * 4, device=device, dtype=dtype, requires_grad=True) 1327 b = torch.randn(2, 3 * 4, device=device, dtype=dtype, requires_grad=True) 1328 nt_grad = torch.nested.as_nested_tensor([a, b]) 1329 chunked = torch.chunk(nt_grad, 2, dim=-1) 1330 self.assertRaisesRegex( 1331 RuntimeError, 1332 "Nested Strided Tensor doesn't support chunk backward.", 1333 lambda: chunked[0].backward(chunked[0].clone()), 1334 ) 1335 1336 @dtypes(*floating_types_and_half()) 1337 def test_nested_tensor_split_with_sizes(self, device, dtype): 1338 a = torch.randn(3, 20, device=device, dtype=dtype) 1339 b = torch.randn(2, 20, device=device, dtype=dtype) 1340 c = torch.randn(1, 20, device=device, dtype=dtype) 1341 1342 split_sizes = [4, 6, 10] 1343 a_splits = a.split_with_sizes(split_sizes, dim=-1) 1344 b_splits = b.split_with_sizes(split_sizes, dim=-1) 1345 c_splits = c.split_with_sizes(split_sizes, dim=-1) 1346 1347 nt = torch.nested.nested_tensor([a, b, c]) 1348 nt_splits = nt.split_with_sizes(split_sizes, dim=-1) 1349 1350 for i, nt_split in enumerate(nt_splits): 1351 self.assertEqual( 1352 nt_split, 1353 torch.nested.nested_tensor([a_splits[i], b_splits[i], c_splits[i]]), 1354 ) 1355 dense_strides = torch.stack( 1356 [ 1357 torch.tensor(a_splits[i].stride()), 1358 torch.tensor(b_splits[i].stride()), 1359 torch.tensor(c_splits[i].stride()), 1360 ] 1361 ) 1362 self.assertEqual(nt_split._nested_tensor_strides(), dense_strides) 1363 self.assertFalse(nt_split.is_contiguous()) 1364 1365 # Failure calling on ragged dimensions 1366 self.assertRaisesRegex( 1367 RuntimeError, 1368 "split_with_sizes for nested tensors is currently only supported for the last dimension.", 1369 lambda: torch.split_with_sizes(nt, split_sizes, dim=1), 1370 ) 1371 1372 # Failure calling on non-last dimension 1373 self.assertRaisesRegex( 1374 RuntimeError, 1375 "split_with_sizes for nested tensors is currently only supported for the last dimension.", 1376 lambda: torch.split_with_sizes(nt, split_sizes, dim=0), 1377 ) 1378 1379 # Failure on non-contiguous nt 1380 _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) 1381 self.assertRaisesRegex( 1382 RuntimeError, 1383 "split_with_sizes expects `self` to be contiguous.", 1384 lambda: torch.split_with_sizes(nt_noncontiguous, split_sizes, dim=-1), 1385 ) 1386 1387 # Failure when calling with split_sizes that don't cover the full dim size 1388 bad_split_sizes = [4, 6, 9] # don't add up to 20 1389 self.assertRaisesRegex( 1390 RuntimeError, 1391 "split_with_sizes expects split_sizes to sum exactly to 20", 1392 lambda: torch.split_with_sizes(nt, bad_split_sizes, dim=-1), 1393 ) 1394 1395 @dtypes(torch.float, torch.float16, torch.double) 1396 @torch.inference_mode() 1397 def test_nested_tensor_indexing_noncontiguous(self, device, dtype): 1398 nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( 1399 (2, 3, 6, 7), device, dtype 1400 ) 1401 self.assertEqual(nt_contiguous.size(0), nt_noncontiguous.size(0)) 1402 n = nt_contiguous.size(0) 1403 for i in range(n): 1404 self.assertEqual(nt_contiguous[i], nt_noncontiguous[i]) 1405 1406 @dtypes(torch.float, torch.float16) 1407 @skipMeta 1408 @torch.inference_mode() 1409 @parametrize("transpose", [True, False]) 1410 def test_nested_tensor_add(self, device, dtype, transpose): 1411 if transpose: 1412 a = torch.randn(2, 2, 2, device=device, dtype=dtype) 1413 b = torch.rand(2, 2, 2, device=device, dtype=dtype) 1414 c = a.transpose(-1, -2).contiguous() 1415 d = b.transpose(-1, -2).contiguous() 1416 nt1 = torch.nested.nested_tensor([a, b, a, b]) 1417 nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2) 1418 else: 1419 (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) 1420 ref = torch.nested.nested_tensor( 1421 [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] 1422 ) 1423 out = nt1 + nt2 1424 self.assertEqual(ref, out) 1425 1426 @dtypes(torch.float, torch.float16) 1427 @skipMeta 1428 @torch.inference_mode() 1429 @parametrize("transpose", [True, False]) 1430 def test_nested_tensor_sub(self, device, dtype, transpose): 1431 if transpose: 1432 a = torch.randn(2, 2, 2, device=device, dtype=dtype) 1433 b = torch.rand(2, 2, 2, device=device, dtype=dtype) 1434 c = a.transpose(-1, -2).contiguous() 1435 d = b.transpose(-1, -2).contiguous() 1436 nt1 = torch.nested.nested_tensor([a, b, a, b]) 1437 nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2) 1438 else: 1439 (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) 1440 ref = torch.nested.nested_tensor( 1441 [t1 - t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] 1442 ) 1443 out = nt1 - nt2 1444 self.assertEqual(ref, out) 1445 1446 @onlyCUDA 1447 @dtypes(torch.float, torch.float16) 1448 @torch.inference_mode() 1449 @parametrize("embedding_dim", [8, 128, 256, 384]) 1450 def test_nested_tensor_dense_elementwise(self, device, dtype, embedding_dim): 1451 def _test_add_mul(nt, t): 1452 ref_add = torch.nested.nested_tensor( 1453 [t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())] 1454 ) 1455 ref_mul = torch.nested.nested_tensor( 1456 [t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())] 1457 ) 1458 self.assertEqual(nt.add(t), ref_add) 1459 self.assertEqual(nt.mul(t), ref_mul) 1460 1461 batch_size = 32 1462 seq_lens = torch.randint(low=0, high=10, size=(batch_size,)) 1463 1464 # [B, *, D], [B, 1, D] case 1465 ts = [torch.randn((seq_len, embedding_dim)) for seq_len in seq_lens] 1466 nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1467 t = torch.randn((batch_size, 1, embedding_dim), device=device, dtype=dtype) 1468 _test_add_mul(nt, t) 1469 1470 # [B, *], [B, 1] case 1471 ts = [torch.randn(seq_len) for seq_len in seq_lens] 1472 nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1473 t = torch.randn((batch_size, 1), device=device, dtype=dtype) 1474 _test_add_mul(nt, t) 1475 1476 @dtypes(torch.float, torch.float16) 1477 @skipMeta 1478 @torch.inference_mode() 1479 def test_nested_tensor_mul(self, device, dtype): 1480 # nested tensor * nested tensor 1481 (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) 1482 ref = torch.nested.nested_tensor( 1483 [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] 1484 ) 1485 out = nt1 * nt2 1486 self.assertEqual(ref, out) 1487 # nested tensor * scalar 1488 number = 10.0 1489 scalar = torch.tensor(number).to(dtype).to(device) 1490 ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()]) 1491 out_number0 = nt1 * number 1492 out_number1 = number * nt1 1493 out_scalar0 = nt1 * scalar 1494 out_scalar1 = scalar * nt1 1495 self.assertEqual(out_number0, ref) 1496 self.assertEqual(out_number1, ref) 1497 self.assertEqual(out_scalar0, ref) 1498 self.assertEqual(out_scalar1, ref) 1499 # error case: numel == 1 but dim > 0 1500 vector = torch.tensor([number]).to(dtype).to(device) 1501 self.assertRaisesRegex( 1502 RuntimeError, 1503 "Expected both self and other to be nested, but got a nested self and non-nested other", 1504 lambda: nt1.mul(vector), 1505 ) 1506 self.assertRaisesRegex( 1507 RuntimeError, 1508 "Expected both self and other to be nested, but got a non-nested self and nested other", 1509 lambda: vector.mul(nt1), 1510 ) 1511 1512 @dtypes(torch.float, torch.float16) 1513 @skipMeta 1514 @torch.inference_mode() 1515 def test_nested_tensor_div(self, device, dtype): 1516 nt, nt2 = self.random_nt_pair(device, dtype, 4, (4, 4)) 1517 scale = 4.0 1518 ref = torch.nested.nested_tensor([t / scale for t in nt.unbind()]) 1519 out = nt / 4.0 1520 self.assertEqual(ref, out) 1521 ref_transposed = ref.transpose(1, 2) 1522 out = nt.transpose(1, 2) / 4.0 1523 self.assertEqual(ref_transposed, out) 1524 1525 ref = torch.nested.nested_tensor( 1526 [t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())] 1527 ) 1528 out = nt / nt2 1529 self.assertEqual(ref, out) 1530 1531 out = nt.transpose(1, 2) / nt2.transpose(1, 2) 1532 self.assertEqual(ref.transpose(1, 2), out) 1533 1534 nt_transpose_copy = torch.nested.nested_tensor( 1535 [t.transpose(0, 1) for t in nt.unbind()] 1536 ) 1537 1538 self.assertRaisesRegex( 1539 RuntimeError, 1540 "div requires strides to match when given NestedTensors", 1541 lambda: nt_transpose_copy.transpose(1, 2) / nt2, 1542 ) 1543 1544 nt = torch.nested.nested_tensor( 1545 [torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype 1546 ) 1547 nt_chunks = nt.chunk(2, -1) 1548 self.assertRaisesRegex( 1549 RuntimeError, 1550 "div requires offsets to match when given NestedTensors", 1551 lambda: nt_chunks[0] / nt_chunks[1], 1552 ) 1553 1554 @dtypes(torch.float, torch.float16) 1555 @skipMeta 1556 @torch.inference_mode() 1557 def test_nested_tensor_add_in_place(self, device, dtype): 1558 (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) 1559 ref = torch.nested.nested_tensor( 1560 [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] 1561 ) 1562 nt1 += nt2 1563 self.assertEqual(ref, nt1) 1564 1565 @dtypes(torch.float, torch.float16) 1566 @skipMeta 1567 @torch.inference_mode() 1568 def test_nested_tensor_mul_in_place(self, device, dtype): 1569 # nested tensor * nested tensor 1570 (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) 1571 ref = torch.nested.nested_tensor( 1572 [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] 1573 ) 1574 nt1 *= nt2 1575 self.assertEqual(ref, nt1) 1576 # nested tensor * scalar 1577 number = 10.0 1578 scalar = torch.tensor(number).to(dtype).to(device) 1579 ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()]) 1580 out_number = nt1.clone() 1581 out_number *= number 1582 out_scalar = nt1.clone() 1583 out_scalar *= scalar 1584 self.assertEqual(out_number, ref) 1585 self.assertEqual(out_scalar, ref) 1586 self.assertRaisesRegex( 1587 RuntimeError, 1588 r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]", 1589 lambda: scalar.mul_(nt1), 1590 ) 1591 # error case: numel == 1 but dim > 0 1592 vector = torch.tensor([number]).to(dtype).to(device) 1593 self.assertRaisesRegex( 1594 RuntimeError, 1595 "Expected both self and other to be nested, but got a nested self and non-nested other", 1596 lambda: nt1.mul_(vector), 1597 ) 1598 self.assertRaisesRegex( 1599 RuntimeError, 1600 "Expected both self and other to be nested, but got a non-nested self and nested other", 1601 lambda: vector.mul_(nt1), 1602 ) 1603 1604 @onlyCPU 1605 @skipMeta 1606 @dtypes(torch.float) 1607 def test_nested_tensor_sum_dim(self, device, dtype): 1608 params = ((2, (1, 1)), ((4), (4, 4)), (10, (3, 5, 7))) 1609 1610 def test_sum(device, dtype, ntensors, max_sizes, dim, keepdim=True): 1611 nt = random_nt(device, dtype, ntensors, max_sizes, require_non_empty=False) 1612 nt2 = nt.clone() 1613 ub2 = nt2.unbind() 1614 nt.requires_grad_(True) 1615 [t.requires_grad_(True) for t in ub2] 1616 nt_sum = nt.sum(dim=dim, keepdim=keepdim) 1617 ub2_sum = [t.sum(-1, keepdim=keepdim) for t in ub2] 1618 self.assertEqual(nt_sum, torch.nested.nested_tensor(ub2_sum)) 1619 1620 # test backward 1621 # generate gradient tensor that has the same size as the output 1622 size = nt_sum._nested_tensor_size() 1623 gt2 = [] 1624 for i in range(ntensors): 1625 gt2.append(torch.randn(size[i].tolist(), device=device, dtype=dtype)) 1626 gt = torch.nested.nested_tensor(gt2).clone() 1627 nt_sum.backward(gt) 1628 for t2, g2 in zip(ub2_sum, gt2): 1629 t2.backward(g2) 1630 self.assertEqual(nt.grad, torch.nested.nested_tensor([t.grad for t in ub2])) 1631 return 1632 1633 for ntensors, max_sizes in params: 1634 test_sum(device, dtype, ntensors, max_sizes, len(max_sizes)) 1635 1636 # Test error inputs 1637 with self.assertRaisesRegex( 1638 RuntimeError, "NestedTensor can only be reduced across the last" 1639 ): 1640 torch.nested.nested_tensor( 1641 [torch.tensor([3, 4, 5]), torch.tensor([1, 2])] 1642 ).sum(0, keepdim=True) 1643 1644 with self.assertRaisesRegex( 1645 RuntimeError, "NestedTensor only allows reduction of a single" 1646 ): 1647 torch.nested.nested_tensor( 1648 [torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])] 1649 ).sum([0, 1], keepdim=True) 1650 1651 with self.assertRaisesRegex( 1652 RuntimeError, "NestedTensor always requires keepdim=True for now." 1653 ): 1654 torch.nested.nested_tensor( 1655 [torch.tensor([3, 4, 5]), torch.tensor([1, 2])] 1656 ).sum(-1) 1657 1658 @dtypes(torch.float, torch.float16) 1659 def test_contiguous(self, device, dtype): 1660 # Since we don't have access to the buffer in python this is harder to show what 1661 # we are testing for. When we call chunk on a consistent dim of a NT 1662 # for chunk_size > 1 the resulting tensors are views of the original NT 1663 # whose numels is now less than the size of the buffer. Clone was 1664 # previously creating a new NT with a buffer that was the same size as the 1665 # original. 1666 nt_contiguous = torch.nested.nested_tensor( 1667 [ 1668 torch.randn(2, 20, device=device, dtype=dtype), 1669 torch.randn(4, 20, device=device, dtype=dtype), 1670 ] 1671 ) 1672 # Split up the last dimension which has a consistent size of 20 into 5 chunks 1673 chunks = nt_contiguous.chunk(5, dim=-1) 1674 1675 # # Check chunks are contiguous after calling contiguous 1676 for chunk in chunks: 1677 self.assertFalse(chunk.is_contiguous()) 1678 self.assertTrue(chunk.contiguous().is_contiguous()) 1679 1680 @dtypes(torch.float, torch.float16) 1681 @skipMeta 1682 def test_clone(self, device, dtype): 1683 nt1 = random_nt(device, dtype, 4, (4, 4), (1, 1)) 1684 nt2 = nt1.clone() 1685 # Verify the values match 1686 self.assertEqual(nt1, nt2) 1687 # Verify modifying nt2 doesn't affect nt1 1688 nt2.mul_(nt1) 1689 ub1 = nt1.unbind() 1690 ub2 = nt2.unbind() 1691 for i in range(len(ub1)): 1692 self.assertNotEqual(ub1[i], ub2[i]) 1693 1694 nt1.clone(memory_format=torch.preserve_format) 1695 msg = "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ChannelsLast" 1696 with self.assertRaisesRegex(RuntimeError, msg): 1697 nt1.clone(memory_format=torch.channels_last) 1698 1699 # cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half' 1700 @decorateIf(xfailIfTorchDynamo, lambda params: params["layout"] == torch.jagged) 1701 @dtypes(torch.float, torch.double) 1702 @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) 1703 def test_dropout(self, device, dtype, layout): 1704 # edge case: empty nested tensor 1705 # TODO: support empty NT in jagged layout 1706 if layout == torch.strided: 1707 nt0 = torch.nested.nested_tensor([], layout=layout) 1708 y = torch.nn.functional.dropout(nt0, 0.5) 1709 self.assertEqual(nt0, y) 1710 # normal nested tensor 1711 ntensors = 4 1712 if layout == torch.jagged: 1713 nt = random_nt(device, dtype, ntensors, (4, 4), (0, 3), layout=layout) 1714 else: 1715 nt = random_nt(device, dtype, ntensors, (4, 4), layout=layout) 1716 # edge case: invalid dropout 1717 self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1)) 1718 self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1)) 1719 self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1)) 1720 self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1)) 1721 # edge case: no dropout 1722 dropouter = torch.nn.Dropout(0.0) 1723 y0 = dropouter(nt) 1724 y1 = torch.nn.functional.dropout(nt, 0.0) 1725 self.assertEqual(nt, y0) 1726 self.assertEqual(nt, y1) 1727 # edge case: all dropout 1728 dropouter = torch.nn.Dropout(1.0) 1729 y0 = dropouter(nt) 1730 y1 = torch.nn.functional.dropout(nt, 1.0) 1731 nt0 = torch.zeros_like(nt) 1732 self.assertEqual(nt0, y0) 1733 self.assertEqual(nt0, y1) 1734 # normal case: normal dropout 1735 p = 0.2 1736 y = torch.nn.functional.dropout(nt, p) 1737 expect = nt.clone() 1738 if layout == torch.jagged: 1739 expect = torch.where(y == 0.0, y, nt) 1740 expect /= 1.0 - p 1741 self.assertEqual(y, expect) 1742 else: 1743 expect = nt.clone() 1744 for i in range(ntensors): 1745 actual_tensor = y[i].view(-1) 1746 expect_tensor = expect[i].view(-1) 1747 for j in range(actual_tensor.shape[0]): 1748 if actual_tensor[j].item() == 0.0: 1749 expect_tensor[j] = 0.0 1750 else: 1751 expect_tensor[j] /= 1.0 - p 1752 self.assertEqual(y, expect) 1753 with freeze_rng_state(): 1754 dropouter = torch.nn.Dropout(p) 1755 y0 = dropouter(nt) 1756 with freeze_rng_state(): 1757 y1 = torch.nn.functional.dropout(nt, p) 1758 self.assertEqual(y0, y1) 1759 1760 @dtypes(torch.float, torch.double) 1761 def test_dropout_noncontiguous(self, device, dtype): 1762 ntensors = 4 1763 nt0 = random_nt(device, dtype, ntensors, (4, 4)) 1764 nt1 = nt0.transpose(-1, -2) 1765 p = 0.3 1766 with freeze_rng_state(): 1767 dropouter = torch.nn.Dropout(p) 1768 y0 = dropouter(nt0) 1769 with freeze_rng_state(): 1770 y1 = torch.nn.functional.dropout(nt1, p).transpose(-1, -2) 1771 self.assertEqual(y0, y1) 1772 1773 # cannot test torch.float16 because: RuntimeError: "softmax_kernel_impl" not implemented for 'Half' 1774 @dtypes(torch.float, torch.double) 1775 def test_softmax(self, device, dtype): 1776 # normal nested tensor 1777 ntensors = 4 1778 nt = random_nt(device, dtype, ntensors, (4, 4)) 1779 # error case: softmax across nested dimension 1780 self.assertRaisesRegex( 1781 RuntimeError, 1782 "Cannot apply softmax across nested dimension 0", 1783 lambda: torch.nn.functional.softmax(nt, 0), 1784 ) 1785 self.assertRaisesRegex( 1786 RuntimeError, 1787 "Cannot apply softmax across nested dimension 0", 1788 lambda: torch.nn.functional.softmax(nt, -3), 1789 ) 1790 # error case: dimension out of range 1791 self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3)) 1792 self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4)) 1793 # normal case: should equal to padding -inf 1794 softmaxer = torch.nn.Softmax(1) 1795 y0 = softmaxer(nt) 1796 y1 = torch.nn.functional.softmax(nt, 1) 1797 self.assertEqual(y0, y1) 1798 pt = torch.nested.to_padded_tensor(nt, float("-inf")) 1799 # if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan 1800 # however, physically speaking that should be 0.0 1801 expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0) 1802 self.assertEqual(torch.nested.to_padded_tensor(y0, 0.0), expect) 1803 # edge case: empty nested tensor 1804 nt0 = torch.nested.nested_tensor([]) 1805 y = torch.nn.functional.softmax(nt0, 1) 1806 self.assertEqual(nt0, y) 1807 # edge case: nesting scalars 1808 nt1 = torch.nested.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)]) 1809 self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0)) 1810 self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1)) 1811 1812 @dtypes(torch.float, torch.double) 1813 @torch.inference_mode() 1814 def test_softmax_noncontiguous(self, device, dtype): 1815 nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( 1816 (2, 3, 6, 7), device, dtype 1817 ) 1818 self.assertEqual( 1819 torch.nn.functional.softmax(nt_contiguous, -1), 1820 torch.nn.functional.softmax(nt_noncontiguous, -1), 1821 ) 1822 1823 def _test_bmm(self, device, dtype): 1824 # error case: not 3D tensors 1825 nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype) 1826 nt1 = torch.nested.nested_tensor( 1827 [torch.randn(2), torch.randn(3)], device=device, dtype=dtype 1828 ) 1829 nt2 = torch.nested.nested_tensor( 1830 [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype 1831 ) 1832 self.assertRaisesRegex( 1833 RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt0) 1834 ) 1835 self.assertRaisesRegex( 1836 RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt1) 1837 ) 1838 self.assertRaisesRegex( 1839 RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt2) 1840 ) 1841 self.assertRaisesRegex( 1842 RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt0) 1843 ) 1844 self.assertRaisesRegex( 1845 RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt1) 1846 ) 1847 self.assertRaisesRegex( 1848 RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt2) 1849 ) 1850 self.assertRaisesRegex( 1851 RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt0) 1852 ) 1853 self.assertRaisesRegex( 1854 RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt1) 1855 ) 1856 # error case: incompatible batch size 1857 nt0 = torch.nested.nested_tensor( 1858 [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype 1859 ) 1860 nt1 = torch.nested.nested_tensor( 1861 [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))], 1862 device=device, 1863 dtype=dtype, 1864 ) 1865 self.assertRaisesRegex( 1866 RuntimeError, 1867 "Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.", 1868 lambda: nt0.bmm(nt1), 1869 ) 1870 self.assertRaisesRegex( 1871 RuntimeError, 1872 "Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.", 1873 lambda: nt1.bmm(nt0), 1874 ) 1875 # error case: underlying matrices cannot be multiplied 1876 nt0 = torch.nested.nested_tensor( 1877 [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype 1878 ) 1879 self.assertRaisesRegex( 1880 RuntimeError, 1881 r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)", 1882 lambda: nt0.bmm(nt0), 1883 ) 1884 # normal nested tensor 1885 nt0 = torch.nested.nested_tensor( 1886 [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype 1887 ) 1888 nt1 = torch.nested.nested_tensor( 1889 [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype 1890 ) 1891 actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) 1892 expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm( 1893 torch.nested.to_padded_tensor(nt1, 0.0) 1894 ) 1895 if dtype == torch.float16: 1896 self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) 1897 else: 1898 self.assertEqual(actual, expect) 1899 1900 # nested tensor bmm normal tensor 1901 nt0 = torch.nested.nested_tensor( 1902 [torch.randn((2, 7)), torch.randn((3, 7))], device=device, dtype=dtype 1903 ) 1904 nt1 = torch.rand(2, 7, 5, dtype=dtype, device=device) 1905 actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) 1906 expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1) 1907 if dtype == torch.float16: 1908 self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) 1909 else: 1910 self.assertEqual(actual, expect) 1911 1912 # nested tensor bmm normal tensor with non-contiguous view 1913 nt1 = torch.rand(2, 5, 7, dtype=dtype, device=device) 1914 nt1 = nt1.transpose(1, 2) 1915 actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) 1916 expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1) 1917 if dtype == torch.float16: 1918 self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) 1919 else: 1920 self.assertEqual(actual, expect) 1921 1922 # normal tensor bmm nested tensor 1923 nt0 = torch.rand(2, 5, 7, dtype=dtype, device=device) 1924 nt1 = torch.nested.nested_tensor( 1925 [torch.randn((7, 6)), torch.randn((7, 5))], device=device, dtype=dtype 1926 ) 1927 actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) 1928 expect = nt0.bmm(torch.nested.to_padded_tensor(nt1, 0.0)) 1929 if dtype == torch.float16: 1930 self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) 1931 else: 1932 self.assertEqual(actual, expect) 1933 1934 # test tensorcore path 1935 nt0 = torch.nested.nested_tensor( 1936 [torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype 1937 ) 1938 nt1 = torch.nested.nested_tensor( 1939 [torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype 1940 ) 1941 actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) 1942 expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm( 1943 torch.nested.to_padded_tensor(nt1, 0.0) 1944 ) 1945 if dtype == torch.float16: 1946 self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) 1947 else: 1948 self.assertEqual(actual, expect) 1949 1950 @onlyCUDA 1951 @dtypes(torch.float, torch.double, torch.float16) 1952 def test_bmm_cuda(self, device, dtype): 1953 self._test_bmm(device, dtype) 1954 1955 @onlyCPU 1956 # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' 1957 @dtypes(torch.float, torch.double) 1958 def test_bmm_cpu(self, device, dtype): 1959 self._test_bmm(device, dtype) 1960 1961 # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' 1962 @dtypes(torch.float, torch.double) 1963 def test_bmm_noncontiguous(self, device, dtype): 1964 nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair( 1965 (2, 3), device, dtype 1966 ) 1967 nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair( 1968 (6, 7), device, dtype 1969 ) 1970 self.assertEqual( 1971 nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous), 1972 nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous), 1973 ) 1974 1975 @dtypes(torch.float, torch.double) 1976 def test_matmul_with_bmm_path(self, device, dtype): 1977 def unbind_rebind_matmul(nt1, nt2): 1978 t1s = nt1.unbind() 1979 t2s = nt2.unbind() 1980 out_ts = [t1.matmul(t2) for t1, t2 in zip(t1s, t2s)] 1981 return torch.nested.nested_tensor(out_ts) 1982 1983 # [N, n_head, *, head_dim], [N, n_head, head_dim, *] 1984 Ns = [1, 2, 5] 1985 n_heads = np.random.randint(2, 5) 1986 head_dim = 3 1987 t1s = [] 1988 t2s = [] 1989 for N in Ns: 1990 for _ in range(N): 1991 seq_len1 = np.random.randint(2, 5) 1992 seq_len2 = np.random.randint(2, 5) 1993 t1s.append(torch.randn(n_heads, seq_len1, head_dim)) 1994 t2s.append(torch.randn(n_heads, head_dim, seq_len2)) 1995 nt1 = torch.nested.nested_tensor(t1s, device=device, dtype=dtype) 1996 nt2 = torch.nested.nested_tensor(t2s, device=device, dtype=dtype) 1997 self.assertEqual(torch.matmul(nt1, nt2), unbind_rebind_matmul(nt1, nt2)) 1998 1999 # test with noncontiguous 2000 t3s = [] 2001 t4s = [] 2002 for _ in range(N): 2003 seq_len = np.random.randint(2, 5) 2004 t3s.append(torch.randn(seq_len, n_heads, head_dim)) 2005 t4s.append(torch.randn(seq_len, n_heads, head_dim)) 2006 nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose( 2007 1, 2 2008 ) 2009 nt4 = ( 2010 torch.nested.nested_tensor(t4s, device=device, dtype=dtype) 2011 .transpose(1, 2) 2012 .transpose(2, 3) 2013 ) 2014 self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4)) 2015 2016 # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' 2017 @dtypes(torch.float, torch.double) 2018 def test_matmul(self, device, dtype): 2019 # error case: one is nested but the other is not 2020 nt = torch.nested.nested_tensor( 2021 [torch.randn(2), torch.randn(3)], device=device, dtype=dtype 2022 ) 2023 t = torch.randn(4, device=device, dtype=dtype) 2024 self.assertRaisesRegex( 2025 RuntimeError, 2026 "Expected both to be nested, but got a nested self and non-nested other", 2027 lambda: torch.matmul(nt, t), 2028 ) 2029 self.assertRaisesRegex( 2030 RuntimeError, 2031 "Expected both to be nested, but got a non-nested self and nested other", 2032 lambda: torch.matmul(t, nt), 2033 ) 2034 # error case: not 3+D tensors 2035 nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype) 2036 nt1 = torch.nested.nested_tensor( 2037 [torch.randn(2), torch.randn(3)], device=device, dtype=dtype 2038 ) 2039 nt2 = torch.nested.nested_tensor( 2040 [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype 2041 ) 2042 self.assertRaisesRegex( 2043 RuntimeError, 2044 r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", 2045 lambda: torch.matmul(nt0, nt0), 2046 ) 2047 self.assertRaisesRegex( 2048 RuntimeError, 2049 r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", 2050 lambda: torch.matmul(nt0, nt1), 2051 ) 2052 self.assertRaisesRegex( 2053 RuntimeError, 2054 r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", 2055 lambda: torch.matmul(nt0, nt2), 2056 ) 2057 self.assertRaisesRegex( 2058 RuntimeError, 2059 r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", 2060 lambda: torch.matmul(nt1, nt0), 2061 ) 2062 self.assertRaisesRegex( 2063 RuntimeError, 2064 r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", 2065 lambda: torch.matmul(nt1, nt1), 2066 ) 2067 self.assertRaisesRegex( 2068 RuntimeError, 2069 r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", 2070 lambda: torch.matmul(nt1, nt2), 2071 ) 2072 self.assertRaisesRegex( 2073 RuntimeError, 2074 r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+", 2075 lambda: torch.matmul(nt2, nt0), 2076 ) 2077 self.assertRaisesRegex( 2078 RuntimeError, 2079 r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+", 2080 lambda: torch.matmul(nt2, nt1), 2081 ) 2082 # error case: incompatible batch size 2083 nt0 = torch.nested.nested_tensor( 2084 [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype 2085 ) 2086 nt1 = torch.nested.nested_tensor( 2087 [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))], 2088 device=device, 2089 dtype=dtype, 2090 ) 2091 self.assertRaisesRegex( 2092 RuntimeError, 2093 r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.", 2094 lambda: torch.matmul(nt0, nt1), 2095 ) 2096 self.assertRaisesRegex( 2097 RuntimeError, 2098 r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.", 2099 lambda: torch.matmul(nt1, nt0), 2100 ) 2101 # error case: incompatible (wrong) batch sizes that shouldn't even broadcast? 2102 nt0 = torch.nested.nested_tensor( 2103 [torch.randn((2, 2, 4)), torch.randn((2, 3, 4))], device=device, dtype=dtype 2104 ) 2105 nt1 = torch.nested.nested_tensor( 2106 [torch.randn((3, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype 2107 ) 2108 self.assertRaisesRegex( 2109 RuntimeError, 2110 "matmul(): For nested tensors, batch dimensions must have the same sizes,", 2111 lambda: torch.matmul(nt0, nt1), 2112 ) 2113 # error case: incompatible batch sizes that should technically broadcast 2114 nt0 = torch.nested.nested_tensor( 2115 [torch.randn((2, 2, 4)), torch.randn((1, 3, 4))], device=device, dtype=dtype 2116 ) 2117 nt1 = torch.nested.nested_tensor( 2118 [torch.randn((1, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype 2119 ) 2120 self.assertRaisesRegex( 2121 RuntimeError, 2122 "matmul(): For nested tensors, batch dimensions must have the same sizes,", 2123 lambda: torch.matmul(nt0, nt1), 2124 ) 2125 # error case: underlying matrices cannot be multiplied 2126 nt0 = torch.nested.nested_tensor( 2127 [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype 2128 ) 2129 self.assertRaisesRegex( 2130 RuntimeError, 2131 "matmul(): Nested tensors cannot be matrix multiplied", 2132 lambda: torch.matmul(nt0, nt0), 2133 ) 2134 # normal nested tensor: 3D 2135 nt0 = torch.nested.nested_tensor( 2136 [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype 2137 ) 2138 nt1 = torch.nested.nested_tensor( 2139 [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype 2140 ) 2141 actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) 2142 expect = torch.matmul( 2143 torch.nested.to_padded_tensor(nt0, 0.0), 2144 torch.nested.to_padded_tensor(nt1, 0.0), 2145 ) 2146 self.assertEqual(actual, expect) 2147 # normal nested tensor: 4D (with testing for batch_size=1) 2148 nt0 = torch.nested.nested_tensor( 2149 [torch.randn((1, 2, 4)), torch.randn((8, 3, 7))], device=device, dtype=dtype 2150 ) 2151 nt1 = torch.nested.nested_tensor( 2152 [torch.randn((1, 4, 6)), torch.randn((8, 7, 5))], device=device, dtype=dtype 2153 ) 2154 actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) 2155 expect = torch.matmul( 2156 torch.nested.to_padded_tensor(nt0, 0.0), 2157 torch.nested.to_padded_tensor(nt1, 0.0), 2158 ) 2159 self.assertEqual(actual, expect) 2160 # normal nested tensor: 5D 2161 nt0 = torch.nested.nested_tensor( 2162 [torch.randn((8, 9, 2, 4)), torch.randn((8, 9, 3, 7))], 2163 device=device, 2164 dtype=dtype, 2165 ) 2166 nt1 = torch.nested.nested_tensor( 2167 [torch.randn((8, 9, 4, 6)), torch.randn((8, 9, 7, 5))], 2168 device=device, 2169 dtype=dtype, 2170 ) 2171 actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) 2172 expect = torch.matmul( 2173 torch.nested.to_padded_tensor(nt0, 0.0), 2174 torch.nested.to_padded_tensor(nt1, 0.0), 2175 ) 2176 self.assertEqual(actual, expect) 2177 2178 # only supported on CUDA for now 2179 @dtypes(torch.float, torch.double) 2180 def test_matmul_nt_with_broadcasted_t(self, device, dtype): 2181 # NT (B, *, C, D) with T (D, E) broadcasting case 2182 nt = random_nt_from_dims([3, None, 4, 5], device=device, dtype=dtype) 2183 t = torch.randn(5, 6, device=device, dtype=dtype) 2184 output = torch.matmul(nt, t) 2185 2186 # should be equivalent to matmul-ing each component with the dense tensor 2187 self.assertEqual(nt.size(0), output.size(0)) 2188 for component, out_component in zip(nt, output): 2189 self.assertEqual(out_component, torch.matmul(component, t)) 2190 2191 # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' 2192 @dtypes(torch.float, torch.double) 2193 def test_matmul_noncontiguous(self, device, dtype): 2194 nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair( 2195 (2, 3), device, dtype 2196 ) 2197 nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair( 2198 (6, 7), device, dtype 2199 ) 2200 self.assertEqual( 2201 torch.matmul(nt0_contiguous.transpose(-1, -2), nt1_contiguous), 2202 torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous), 2203 ) 2204 2205 @dtypes(torch.float, torch.double) 2206 def test_linear(self, device, dtype): 2207 a = torch.randn(1, 2, device=device, dtype=dtype) 2208 b = torch.randn(2, 2, device=device, dtype=dtype) 2209 c = torch.randn(3, 2, device=device, dtype=dtype) 2210 nt = torch.nested.nested_tensor([a, b, c]) 2211 2212 weight = torch.randn(2, 2, device=device, dtype=dtype) 2213 bias = torch.randn(2, device=device, dtype=dtype) 2214 # success case 2215 torch.functional.F.linear(nt, weight, bias) 2216 2217 # invalid nested tensor dimension 2218 msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2" 2219 nt1 = torch.nested.nested_tensor( 2220 [ 2221 torch.randn(1, device=device, dtype=dtype), 2222 torch.randn(2, device=device, dtype=dtype), 2223 ] 2224 ) 2225 with self.assertRaisesRegex(RuntimeError, msg): 2226 torch.functional.F.linear(nt1, weight, bias) 2227 2228 # invalid weight shape 2229 msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3" 2230 weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype) 2231 with self.assertRaisesRegex(RuntimeError, msg): 2232 torch.functional.F.linear(nt, weight1, bias) 2233 2234 # inconsistent last dim of nested tensor 2235 msg = r"Expected all tensors in nested tensor to have the same trailing dimension, instead last dimension equals:" 2236 nt2 = torch.nested.nested_tensor( 2237 [ 2238 torch.randn(1, 2, device=device, dtype=dtype), 2239 torch.randn(2, 3, device=device, dtype=dtype), 2240 ] 2241 ) 2242 with self.assertRaisesRegex(RuntimeError, msg): 2243 torch.functional.F.linear(nt2, weight, bias) 2244 2245 # Mismatch of nested tensor last dim and weight dimension 2246 weight2 = torch.randn(2, 4, device=device, dtype=dtype) 2247 msg = ( 2248 r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'" 2249 r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4" 2250 ) 2251 with self.assertRaisesRegex(RuntimeError, msg): 2252 torch.functional.F.linear(nt, weight2, bias) 2253 2254 # Nested tensor input and nested weight 2255 nt_weight = nt.clone() 2256 msg = r"Linear does not support nested weight when input is a nested tensor." 2257 with self.assertRaisesRegex(RuntimeError, msg): 2258 torch.functional.F.linear(nt, nt_weight, bias) 2259 2260 # TODO: test noncontiguous linear 2261 # For now this tests the error message of linear 2262 # since linear does not support noncontiguous buffer yet 2263 @dtypes(torch.float, torch.double) 2264 def test_linear_noncontiguous(self, device, dtype): 2265 nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( 2266 (2, 3, 6, 7), device, dtype 2267 ) 2268 weight = torch.randn((8, 5), device=device, dtype=dtype) 2269 self.assertRaisesRegex( 2270 RuntimeError, 2271 r"for now linear only supports contiguous nested tensor", 2272 lambda: torch.nn.functional.linear(nt_noncontiguous, weight), 2273 ) 2274 2275 @dtypes(torch.float, torch.float16, torch.double) 2276 def test_to_padded_tensor_zero_numel_errors(self, device, dtype): 2277 ts = [torch.ones(1, 0), torch.ones(0, 0)] 2278 nt = torch.nested.nested_tensor( 2279 ts, device=device, dtype=dtype, layout=torch.strided 2280 ) 2281 self.assertRaisesRegex( 2282 RuntimeError, 2283 r"at least one constituent tensor should have non-zero numel", 2284 lambda: torch.nested.to_padded_tensor(nt, 0.0), 2285 ) 2286 2287 @dtypes(torch.float, torch.float16, torch.double) 2288 def test_transpose(self, device, dtype): 2289 nt = random_nt(device, dtype, 4, (4, 4)) 2290 # error case: transpose nested dimension 2291 self.assertRaisesRegex( 2292 RuntimeError, 2293 "Nested tensor dimension 0 cannot be transposed", 2294 lambda: nt.transpose(0, 1), 2295 ) 2296 self.assertRaisesRegex( 2297 RuntimeError, 2298 "Nested tensor dimension 0 cannot be transposed", 2299 lambda: nt.transpose(1, -3), 2300 ) 2301 # error case: dimension out of range 2302 self.assertRaises(IndexError, lambda: nt.transpose(1, 3)) 2303 self.assertRaises(IndexError, lambda: nt.transpose(-4, -1)) 2304 # normal case 2305 ntT = nt.transpose(-1, -2) 2306 ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) 2307 pt = torch.nested.to_padded_tensor(nt, 0.0) 2308 ptT = pt.transpose(-1, -2) 2309 self.assertEqual(ptT, ptT_from_ntT) 2310 2311 @dtypes(torch.float, torch.float16, torch.double) 2312 def test_squeeze_unsqueeze(self, device, dtype): 2313 a = torch.arange(6).reshape(2, 3) 2314 b = torch.arange(15).reshape(5, 3) 2315 nt = torch.nested.nested_tensor([a, b], device=device, dtype=dtype) 2316 # error case: squeeze no dimension 2317 self.assertRaisesRegex( 2318 RuntimeError, 2319 "For nested tensors, squeeze without the dim argument", 2320 lambda: nt.squeeze(), 2321 ) 2322 # error case: squeeze nested dimension 2323 self.assertRaisesRegex( 2324 RuntimeError, 2325 "For nested tensors, squeezing dimension 0", 2326 lambda: nt.squeeze(0), 2327 ) 2328 # error case: dimension out of range 2329 self.assertRaises(IndexError, lambda: nt.squeeze(3)) 2330 # error case: squeeze nested tensor of singleton tensors 2331 c = torch.ones(1) 2332 nt_singleton = torch.nested.nested_tensor([c, c], device=device, dtype=dtype) 2333 self.assertRaisesRegex( 2334 RuntimeError, 2335 "For nested tensors, squeezing a nested tensor of singleton", 2336 lambda: nt_singleton.squeeze(1), 2337 ) 2338 2339 # squeezing a dim which does not have size 1 should be a no-op 2340 nt2 = nt.squeeze(-1) 2341 self.assertEqual(nt, nt2) 2342 2343 # test cases that should work 2344 nt_sizes = nt._nested_tensor_size() 2345 nt_strides = nt._nested_tensor_strides() 2346 for i in range(-2, 4): 2347 if i == 0: 2348 # cannot unsqueeze batch dim 2349 continue 2350 nt_unsqueezed = nt.unsqueeze(i) 2351 # negative dim will correspond to unsqueeze() applied at dim = dim + nt.dim() + 1 2352 wrapped_i = i + nt.dim() + 1 if i < 0 else i 2353 # col_index into nt size tensor is requires subtraction of 1 to ignore batch dim 2354 size_idx = wrapped_i - 1 2355 self.assertEqual( 2356 nt_unsqueezed._nested_tensor_size()[:, size_idx], 2357 torch.ones(2, dtype=torch.long), 2358 ) 2359 unsqueezed_stride = nt_unsqueezed._nested_tensor_strides()[:, size_idx] 2360 if i == nt.ndim or i == -1: 2361 self.assertEqual(unsqueezed_stride, torch.ones(2, dtype=torch.long)) 2362 else: 2363 stride_col_after = nt_strides[:, size_idx] 2364 size_col_after = nt_sizes[:, size_idx] 2365 self.assertEqual(unsqueezed_stride, stride_col_after * size_col_after) 2366 nt_squeezed = nt_unsqueezed.squeeze(i) 2367 self.assertEqual(nt_squeezed, nt) 2368 self.assertEqual(nt_squeezed._nested_tensor_size(), nt_sizes) 2369 self.assertEqual(nt_squeezed._nested_tensor_strides(), nt_strides) 2370 2371 @dtypes(torch.float, torch.float16, torch.double) 2372 def test_transpose_inference_mode_interaction(self, device, dtype): 2373 nt = random_nt(device, dtype, 4, (4, 4)) 2374 # Construct in default mode and transpose while in inference mode 2375 with torch.inference_mode(): 2376 ntT = nt.transpose(-1, -2) 2377 ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) 2378 pt = torch.nested.to_padded_tensor(nt, 0.0) 2379 ptT = pt.transpose(-1, -2) 2380 self.assertEqual(ptT, ptT_from_ntT) 2381 2382 # Construct and transpose while in inference mode 2383 with torch.inference_mode(): 2384 nt = random_nt(device, dtype, 4, (4, 4)) 2385 ntT = nt.transpose(-1, -2) 2386 ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) 2387 pt = torch.nested.to_padded_tensor(nt, 0.0) 2388 ptT = pt.transpose(-1, -2) 2389 self.assertEqual(ptT, ptT_from_ntT) 2390 2391 @dtypes(torch.float, torch.float16, torch.double) 2392 def test_view(self, device, dtype): 2393 nt = random_nt(device, dtype, 4, (4, 4)) 2394 # error case: empty shape 2395 self.assertRaisesRegex( 2396 RuntimeError, 2397 r"shape '\[\]' is invalid for a nested tensor", 2398 lambda: nt.view(()), 2399 ) 2400 # error case: empty nested tensor 2401 nt_empty = torch.nested.nested_tensor([]) 2402 self.assertRaisesRegex( 2403 RuntimeError, 2404 "empty nested tensor cannot be reshaped", 2405 lambda: nt_empty.view(-1), 2406 ) 2407 # error case: -1 for batch size 2408 self.assertRaisesRegex( 2409 RuntimeError, 2410 r"view: For now nested view cannot change or infer the implicit batch dimension", 2411 lambda: nt.view(-1, 2, 3), 2412 ) 2413 self.assertRaisesRegex( 2414 RuntimeError, 2415 r"shape '\[.*\]' is invalid for input of size [0-9]+", 2416 lambda: nt.view(4, 2, 3), 2417 ) 2418 # normal case 2419 x0 = torch.randn((2, 20), device=device, dtype=dtype) 2420 x1 = torch.randn((3, 20), device=device, dtype=dtype) 2421 nt = torch.nested.nested_tensor([x0, x1]) 2422 pt = torch.nested.to_padded_tensor(nt, 0.0) 2423 # error case, trying to reshape batch dim to a legit shape 2424 self.assertRaisesRegex( 2425 RuntimeError, 2426 r"For now nested view cannot change or infer the implicit batch dimension", 2427 lambda: nt.transpose(-1, -2).view(40, -1), 2428 ) 2429 # inherit only the ragged dimension 2430 # (2, 20) -> (2, 5, 4) 2431 # (3, 20) -> (3, 5, 4) 2432 nt1 = nt.view(2, -1, 5, 4) 2433 # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4) 2434 pt1 = pt.view(2, -1, 5, 4) 2435 self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1) 2436 2437 # more than one -1 (even for "old" dims), should fail 2438 # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2) 2439 # but we ban "inherit old behavior" for >1 dimension 2440 self.assertRaisesRegex( 2441 RuntimeError, 2442 r"only one dimension can be inferred", 2443 lambda: nt1.view(2, -1, -1, 2, 2), 2444 ) 2445 2446 @dtypes(torch.float, torch.float16, torch.double) 2447 def test_view_inference_mode_interaction(self, device, dtype): 2448 # Construct in default mode and view while in inference mode 2449 nt = torch.nested.nested_tensor( 2450 [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype 2451 ) 2452 with torch.inference_mode(): 2453 ntT = nt.view(2, -1, 4, 5) 2454 ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) 2455 pt = torch.nested.to_padded_tensor(nt, 0.0) 2456 ptT = pt.view(2, -1, 4, 5) 2457 self.assertEqual(ptT, ptT_from_ntT) 2458 # Construct and view while in inference mode 2459 with torch.inference_mode(): 2460 nt = torch.nested.nested_tensor( 2461 [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype 2462 ) 2463 ntT = nt.view(2, -1, 4, 5) 2464 ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) 2465 pt = torch.nested.to_padded_tensor(nt, 0.0) 2466 ptT = pt.view(2, -1, 4, 5) 2467 self.assertEqual(ptT, ptT_from_ntT) 2468 2469 @dtypes(torch.float, torch.float16, torch.double) 2470 def test_reshape(self, device, dtype): 2471 nt = random_nt(device, dtype, 4, (4, 4)) 2472 # error case: empty shape 2473 self.assertRaisesRegex( 2474 RuntimeError, 2475 r"shape '\[\]' is invalid for a nested tensor", 2476 lambda: nt.reshape(()), 2477 ) 2478 # error case: empty nested tensor 2479 nt_empty = torch.nested.nested_tensor([]) 2480 self.assertRaisesRegex( 2481 RuntimeError, 2482 "empty nested tensor cannot be reshaped", 2483 lambda: nt_empty.reshape(-1), 2484 ) 2485 # error case: -1 for batch size 2486 self.assertRaisesRegex( 2487 RuntimeError, 2488 r"reshape: For now nested reshape cannot change or infer the implicit batch dimension", 2489 lambda: nt.reshape(-1, 2, 3), 2490 ) 2491 self.assertRaisesRegex( 2492 RuntimeError, 2493 r"shape '\[.*\]' is invalid for input of size [0-9]+", 2494 lambda: nt.reshape(4, 2, 3), 2495 ) 2496 # normal case 2497 x0 = torch.randn((2, 20), device=device, dtype=dtype) 2498 x1 = torch.randn((3, 20), device=device, dtype=dtype) 2499 nt = torch.nested.nested_tensor([x0, x1]) # (2, (2, 3), 20) 2500 pt = torch.nested.to_padded_tensor(nt, 0.0) 2501 # error case, trying to reshape batch dim to a legit shape 2502 self.assertRaisesRegex( 2503 RuntimeError, 2504 r"reshape: For now nested reshape cannot change or infer the implicit batch dimension", 2505 lambda: nt.transpose(-1, -2).reshape(40, -1), 2506 ) 2507 # inherit only the ragged dimension 2508 # (2, 20) -> (2, 5, 4) 2509 # (3, 20) -> (3, 5, 4) 2510 nt1 = nt.reshape(2, -1, 5, 4) 2511 # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4) 2512 pt1 = pt.reshape(2, -1, 5, 4) 2513 self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1) 2514 2515 # more than one -1 (even for "old" dims), should fail 2516 # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2) 2517 # but we ban "inherit old behavior" for >1 dimension 2518 self.assertRaisesRegex( 2519 RuntimeError, 2520 r"only one dimension can be inferred", 2521 lambda: nt1.reshape(2, -1, -1, 2, 2), 2522 ) 2523 2524 def test_nested_masked_select(self, device): 2525 t = torch.randn([3, 3], device=device) 2526 mask = torch.tensor([False], device=device) 2527 2528 njt = torch.nested.masked_select(t, mask) 2529 self.assertEqual(njt.values(), torch.tensor([], device=device)) 2530 self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 0], device=device)) 2531 2532 mask = torch.tensor([[False], [False], [True]], device=device) 2533 njt = torch.nested.masked_select(t, mask) 2534 self.assertEqual(njt.values(), t[-1], atol=0.1, rtol=0.1) 2535 self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 3], device=device)) 2536 2537 mask = torch.tensor( 2538 [[False, False, True], [True, False, True], [False, False, True]], 2539 device=device, 2540 ) 2541 njt = torch.nested.masked_select(t, mask) 2542 self.assertEqual(njt.values(), t.masked_select(mask)) 2543 self.assertEqual(njt.offsets(), torch.tensor([0, 1, 3, 4], device=device)) 2544 2545 t = torch.randn([2, 3, 3, 1], device=device) 2546 mask = torch.tensor( 2547 [ 2548 [ 2549 [[True], [False], [True]], 2550 [[True], [False], [True]], 2551 [[True], [False], [True]], 2552 ], 2553 [ 2554 [[False], [True], [True]], 2555 [[False], [True], [True]], 2556 [[True], [True], [True]], 2557 ], 2558 ], 2559 device=device, 2560 ) 2561 njt = torch.nested.masked_select(t, mask) 2562 self.assertEqual(njt.values(), t.masked_select(mask)) 2563 self.assertEqual( 2564 njt.offsets(), 2565 torch.tensor( 2566 [0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 6, 7, 8, 8, 9, 10, 11, 12, 13], 2567 device=device, 2568 ), 2569 ) 2570 2571 @dtypes(torch.float, torch.float16, torch.double) 2572 def test_narrow(self, device, dtype): 2573 nt = random_nt_from_dims([5, None, None, None], device=device, dtype=dtype) 2574 2575 # narrow on dim=0 from start to end 2576 bounds = [(0, 5), (0, 3), (1, 2), (1, 5), (2, 4)] 2577 for start, end in bounds: 2578 length = end - start 2579 narrowed = nt.narrow(dim=0, start=start, length=length) 2580 # ensure output is a view 2581 self.assertTrue(narrowed._base is nt) 2582 for nc, c in zip(narrowed.unbind(), nt.unbind()[start:end]): 2583 self.assertEqual(nc, c) 2584 2585 # dim != 0 is not supported 2586 for dim in range(1, nt.dim()): 2587 with self.assertRaisesRegex( 2588 RuntimeError, "only dim=0 supported for nested tensors" 2589 ): 2590 nt.narrow(dim=dim, start=0, length=1) 2591 2592 # error case: non-contiguous NT 2593 _, nt_noncont = random_nt_noncontiguous_pair((2, 3, 4)) 2594 with self.assertRaisesRegex( 2595 RuntimeError, "only contiguous nested tensors supported" 2596 ): 2597 nt_noncont.narrow(dim=0, start=0, length=1) 2598 2599 @parametrize("input_dim", [3, 4]) 2600 def test_scaled_dot_product_attention(self, device, input_dim): 2601 def rand_tensor(*shape): 2602 return torch.randn(shape, device=device) 2603 2604 E = 8 2605 if input_dim == 3: 2606 # Shape: (N, L, E); ragged L 2607 query = torch.nested.nested_tensor( 2608 [rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)] 2609 ) 2610 2611 # Shape: (N, S, E); ragged S 2612 key = torch.nested.nested_tensor( 2613 [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)] 2614 ) 2615 value = torch.nested.nested_tensor( 2616 [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)] 2617 ) 2618 elif input_dim == 4: 2619 # In the 4D case the L and S is ragged 2620 # Shape: (N, N', L, E); ragged N' and L 2621 query = torch.nested.nested_tensor( 2622 [rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)] 2623 ) 2624 # Shape: (N, N', S, E); ragged N' and S 2625 key = torch.nested.nested_tensor( 2626 [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)] 2627 ) 2628 value = torch.nested.nested_tensor( 2629 [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)] 2630 ) 2631 else: 2632 self.fail(f"Invalid input_dim {input_dim} encountered in SDP test") 2633 2634 def rand_mask(size): 2635 return torch.randint(0, 2, size=size, dtype=torch.bool, device=device) 2636 2637 # Shape: (N, L, S); ragged L and S matching above 2638 attn_mask = torch.nested.nested_tensor( 2639 [rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))] 2640 ) 2641 2642 dropout_p = 0.0 # no dropout for reproducibility 2643 2644 # Success case: no attn_mask set and is_causal=False. 2645 actual = torch.nn.functional.scaled_dot_product_attention( 2646 query, key, value, attn_mask=None, is_causal=False, dropout_p=dropout_p 2647 ) 2648 2649 expected_outputs = [] 2650 for q, k, v in zip(query.unbind(), key.unbind(), value.unbind()): 2651 output = torch.nn.functional.scaled_dot_product_attention( 2652 q.unsqueeze(0), 2653 k.unsqueeze(0), 2654 v.unsqueeze(0), 2655 attn_mask=None, 2656 dropout_p=dropout_p, 2657 ) 2658 expected_outputs.append(output.squeeze(0)) 2659 expected_output_nested = torch.nested.nested_tensor(expected_outputs) 2660 self.assertEqual(actual, expected_output_nested) 2661 2662 # Error case: explicit attn_mask set. 2663 with self.assertRaisesRegex( 2664 RuntimeError, "not supported when an explicit attn_mask is set" 2665 ): 2666 torch.nn.functional.scaled_dot_product_attention( 2667 query, key, value, attn_mask=attn_mask, dropout_p=dropout_p 2668 ) 2669 2670 # Error case: is_causal=True. 2671 with self.assertRaisesRegex(RuntimeError, "not supported when is_causal=True"): 2672 torch.nn.functional.scaled_dot_product_attention( 2673 query, key, value, dropout_p=dropout_p, is_causal=True 2674 ) 2675 2676 @dtypes(torch.float, torch.float16, torch.double) 2677 def test_empty_like(self, device, dtype): 2678 ntensors = 4 2679 nt = random_nt(device, dtype, ntensors, (4, 4)) 2680 2681 # Create empty on same device as original nested tensor 2682 nt_empty = torch.empty_like(nt) 2683 assert nt.is_same_size(nt_empty) 2684 self.assertEqual(nt.dtype, nt_empty.dtype) 2685 self.assertEqual(nt.device, nt_empty.device) 2686 self.assertEqual(nt.layout, nt_empty.layout) 2687 2688 if torch.cuda.is_available(): 2689 if device == "cpu": 2690 nt_cuda = torch.empty_like(nt, device="cuda") 2691 self.assertEqual(torch.device("cuda").type, nt_cuda.device.type) 2692 else: 2693 nt_cpu = torch.empty_like(nt, device="cpu") 2694 self.assertEqual(torch.device("cpu").type, nt_cpu.device.type) 2695 2696 # Check changing dtype of empty_like nested tensor output 2697 dtype_set = {torch.float, torch.float16, torch.double} 2698 for other_dtype in dtype_set - {dtype}: 2699 nt_empty_other_dtype = torch.empty_like(nt, dtype=other_dtype) 2700 self.assertEqual(nt.dtype, dtype) 2701 self.assertEqual(nt_empty_other_dtype.dtype, other_dtype) 2702 self.assertEqual(nt.device, nt_empty.device) 2703 self.assertEqual(nt.layout, nt_empty.layout) 2704 2705 # Create tensor for autograd 2706 nt_empty_req_grad = torch.empty_like(nt, requires_grad=True) 2707 self.assertEqual(nt_empty_req_grad.requires_grad, True) 2708 2709 # Test noncontiguous tensor does not fail to copy 2710 nt_cont, nt_noncont = random_nt_noncontiguous_pair((2, 3, 6, 7)) 2711 nt_empty = torch.empty_like(nt_cont) 2712 assert nt_cont.is_same_size(nt_empty) 2713 nt_empty_non_contig = torch.empty_like(nt_noncont) 2714 assert nt_noncont.is_same_size(nt_empty_non_contig) 2715 2716 # Test the contiguous memory format option 2717 nt_empty_contig = torch.empty_like( 2718 nt_cont, memory_format=torch.contiguous_format 2719 ) 2720 assert nt_cont.is_same_size(nt_empty_contig) 2721 assert nt_empty_contig.is_contiguous() 2722 2723 nt_empty_non_contig = torch.empty_like( 2724 nt_noncont, memory_format=torch.contiguous_format 2725 ) 2726 assert nt_noncont.is_same_size(nt_empty_non_contig) 2727 assert nt_empty_non_contig.is_contiguous() 2728 2729 # Test other memory formats fail 2730 self.assertRaises( 2731 RuntimeError, 2732 lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last), 2733 ) 2734 self.assertRaises( 2735 RuntimeError, 2736 lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last), 2737 ) 2738 self.assertRaises( 2739 RuntimeError, 2740 lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d), 2741 ) 2742 self.assertRaises( 2743 RuntimeError, 2744 lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d), 2745 ) 2746 2747 2748@markDynamoStrictTest 2749class TestNestedTensorAutograd(NestedTensorTestCase): 2750 # Note [Gradcheck args check_batched_grad=False] the common_utils testing version of gradcheck 2751 # includes the default parameters used for testing ops with gradcheck. However nested tensor 2752 # does not support the stack op therefore we turn it off for these tests 2753 def _create_leaf_nested_tensor_from_list(self, tensor_device, requires_grad=False): 2754 return torch.nested.nested_tensor( 2755 [torch.randn(1, 2), torch.randn(7, 8)], 2756 requires_grad=requires_grad, 2757 device=tensor_device, 2758 ) 2759 2760 def _create_nested_tensor_from_list(self, tensor_device, requires_grad=False): 2761 return torch.nested.as_nested_tensor( 2762 [ 2763 torch.randn(1, 2, requires_grad=requires_grad), 2764 torch.randn(7, 8, requires_grad=requires_grad), 2765 ], 2766 device=tensor_device, 2767 ) 2768 2769 def _create_nested_tensor_from_mask(self, tensor_device, requires_grad=False): 2770 data = torch.randn(2, 3, 4, requires_grad=requires_grad, device=tensor_device) 2771 mask = torch.ones_like(data[:, :, 0]).bool() 2772 return torch._nested_tensor_from_mask(data, mask) 2773 2774 def test_as_nested_tensor_propagates_gradients(self, device): 2775 a = torch.arange(3, dtype=torch.float, device=device) 2776 b = torch.arange(5, dtype=torch.float, device=device) 2777 nt = torch.nested.as_nested_tensor([a, b]) 2778 # tensors with requires_grad=False are leaves 2779 self.assertTrue(nt.is_leaf) 2780 self.assertTrue(not nt.requires_grad) 2781 2782 a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device) 2783 b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device) 2784 nt2 = torch.nested.as_nested_tensor([a, b]) 2785 fake_grad = torch.nested.nested_tensor( 2786 [torch.ones_like(a), torch.zeros_like(b)], device=device 2787 ) 2788 nt2.backward(fake_grad) 2789 self.assertEqual(a.grad, fake_grad[0]) 2790 self.assertEqual(b.grad, fake_grad[1]) 2791 2792 def test_nested_tensor_generates_leaf(self, device): 2793 a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device) 2794 b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device) 2795 2796 nt = torch.nested.nested_tensor([a, b], requires_grad=False) 2797 self.assertTrue(nt.is_leaf) 2798 self.assertTrue(not nt.requires_grad) 2799 2800 nt2 = torch.nested.nested_tensor([a, b], requires_grad=True) 2801 self.assertTrue(nt2.is_leaf) 2802 self.assertTrue(nt2.requires_grad) 2803 2804 fake_grad = torch.nested.nested_tensor( 2805 [torch.ones_like(a), torch.zeros_like(b)], device=device 2806 ) 2807 nt2.backward(fake_grad) 2808 self.assertEqual(nt2.grad, fake_grad) 2809 self.assertEqual(a.grad, None) 2810 self.assertEqual(b.grad, None) 2811 2812 def test_set_requires_grad_from_list(self, device): 2813 nt = self._create_nested_tensor_from_list(device) 2814 nt.requires_grad_() 2815 assert nt.requires_grad 2816 2817 def test_set_requires_grad_from_mask(self, device): 2818 nt = self._create_nested_tensor_from_mask(device) 2819 nt.requires_grad_() 2820 assert nt.requires_grad 2821 2822 def test_backward_for_add_op(self, device): 2823 nt_1 = self._create_nested_tensor_from_mask(device) 2824 nt_2 = self._create_nested_tensor_from_mask(device) 2825 2826 nt_1.requires_grad_() 2827 c = nt_1 + nt_2 2828 2829 assert nt_1.requires_grad 2830 assert c.requires_grad 2831 grad_output = self._create_nested_tensor_from_mask(device) 2832 c.backward(grad_output) 2833 2834 # Grad check doesn't work with nested yet. 2835 # d/dnt_1 (nt + nt_1) = 1*grad_output 2836 self.assertEqual(nt_1.grad, grad_output) 2837 2838 def test_backward_for_sub_op(self, device): 2839 nt_1 = self._create_nested_tensor_from_mask(device) 2840 nt_2 = self._create_nested_tensor_from_mask(device) 2841 2842 nt_1.requires_grad_() 2843 nt_2.requires_grad_() 2844 c = nt_1 - nt_2 2845 2846 assert nt_1.requires_grad 2847 assert nt_2.requires_grad 2848 assert c.requires_grad 2849 grad_output = self._create_nested_tensor_from_mask(device) 2850 c.backward(grad_output) 2851 2852 self.assertEqual(nt_1.grad, grad_output) 2853 self.assertEqual(nt_2.grad, -1 * grad_output) 2854 2855 def test_backward_sub_strided(self, device): 2856 a = torch.nested.nested_tensor( 2857 [torch.randn(9, 2, 4), torch.randn(12, 2, 4)], 2858 requires_grad=True, 2859 device=device, 2860 ) 2861 b = torch.nested.nested_tensor( 2862 [torch.randn(9, 4, 2), torch.randn(12, 4, 2)], 2863 requires_grad=True, 2864 device=device, 2865 ) 2866 c = a - b.transpose(-1, -2) 2867 grad_output = c.clone() 2868 c.backward(grad_output) 2869 self.assertEqual(a.grad, grad_output) 2870 self.assertEqual(b.grad, -1 * grad_output.transpose(-1, -2)) 2871 2872 def test_backward_add_strided(self, device): 2873 a = torch.nested.nested_tensor( 2874 [torch.randn(9, 2, 4), torch.randn(12, 2, 4)], 2875 requires_grad=True, 2876 device=device, 2877 ) 2878 b = torch.nested.nested_tensor( 2879 [torch.randn(9, 4, 2), torch.randn(12, 4, 2)], 2880 requires_grad=True, 2881 device=device, 2882 ) 2883 c = a + b.transpose(-1, -2) 2884 grad_output = c.clone() 2885 c.backward(grad_output) 2886 self.assertEqual(a.grad, grad_output) 2887 self.assertEqual(b.grad, grad_output.transpose(-1, -2)) 2888 2889 # Test Factory Functions 2890 def test_nested_tensor_to_padded_tensor(self, device): 2891 for padding_val in [0, 1]: 2892 nt = self._create_leaf_nested_tensor_from_list( 2893 tensor_device=device, requires_grad=True 2894 ) 2895 2896 out = torch.nested.to_padded_tensor(nt, padding_val) 2897 grad_output = torch.ones(out.shape, device=device) 2898 out.backward(grad_output) 2899 2900 self.assertEqual( 2901 nt.grad, 2902 torch.nested.nested_tensor( 2903 [torch.ones(1, 2), torch.ones(7, 8)], device=device 2904 ), 2905 ) 2906 2907 def test_nested_tensor_from_mask_and_to_padded(self, device): 2908 N, L, D = 2, 4, 4 2909 mask = torch.ones(N, L, device=device) 2910 for i in range(1, N): 2911 end = torch.randint(1, L - 1, (1,), device=device) 2912 mask[i, end:] = 0 2913 2914 mask[0, :] = 1 2915 mask = mask.bool() 2916 2917 data = torch.randn( 2918 N, L, D, requires_grad=True, dtype=torch.float64, device=device 2919 ) 2920 2921 def grad_test_func(inpt): 2922 nt = torch._nested_tensor_from_mask(inpt, mask) 2923 # This implicitly tests to_padded_tensor grads 2924 return torch.nested.to_padded_tensor(nt, 0) 2925 2926 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 2927 2928 def test_nested_tensor_from_padded(self, device): 2929 nested_size = torch.tensor([[1, 2], [2, 2]]) 2930 padded_tensor = torch.randn(2, 2, 2, dtype=torch.float64, device=device) 2931 padded_tensor[0, 1, :] = 0 2932 padded_tensor.requires_grad_() 2933 2934 def grad_test_func(tensor, nested_size): 2935 nt = torch._nested_from_padded( 2936 tensor, nested_size, fuse_transform_0213=False 2937 ) 2938 # This implicitly tests to_padded_tensor grads 2939 return torch.nested.to_padded_tensor(nt, 0) 2940 2941 data = (padded_tensor, nested_size) 2942 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 2943 2944 def test_nested_tensor_from_padded_fused(self, device): 2945 nested_size = torch.tensor([[1, 8], [2, 8]]) 2946 padded_tensor = torch.randn(2, 2, 2, 4, dtype=torch.float64, device=device) 2947 padded_tensor[0, 1, :] = 0 2948 padded_tensor.requires_grad_() 2949 2950 def grad_test_func(tensor, nested_size): 2951 nt = torch._nested_from_padded( 2952 tensor, nested_size, fuse_transform_0213=True 2953 ) 2954 # This implicitly tests to_padded_tensor grads 2955 return torch.nested.to_padded_tensor(nt, 0) 2956 2957 data = (padded_tensor, nested_size) 2958 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 2959 2960 def test_nested_tensor_from_list(self, device): 2961 a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) 2962 b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) 2963 c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64, device=device) 2964 2965 def grad_test_func(a, b, c): 2966 c = torch.nested.as_nested_tensor([a, b, c]) 2967 # This implictily tests to_padded_tensor grads 2968 return torch.nested.to_padded_tensor(c, 0) 2969 2970 data = (a, b, c) 2971 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 2972 2973 @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) 2974 def test_dropout_backward(self, layout): 2975 if layout == torch.jagged: 2976 nt = torch.nested.nested_tensor( 2977 [torch.randn((2, 5)), torch.randn((3, 5))], 2978 requires_grad=True, 2979 layout=layout, 2980 ) 2981 else: 2982 nt = torch.nested.nested_tensor( 2983 [torch.randn((2, 5)), torch.randn((3, 4))], 2984 requires_grad=True, 2985 layout=layout, 2986 ) 2987 p = 0.2 2988 y = torch.nn.functional.dropout(nt, p) 2989 y.backward(nt.clone().detach()) 2990 self.assertEqual(nt.grad, y) 2991 2992 def test_nested_tensor_bmm_gradcheck(self, device): 2993 a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device) 2994 b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device) 2995 c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device) 2996 d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device) 2997 2998 def grad_test_func(a, b, c, d): 2999 nt0 = torch.nested.as_nested_tensor([a, b]) 3000 nt1 = torch.nested.as_nested_tensor([c, d]) 3001 result = nt0.bmm(nt1) 3002 return torch.nested.to_padded_tensor(result, 0.0) 3003 3004 data = (a, b, c, d) 3005 assert torch.autograd.gradcheck(grad_test_func, inputs=data) 3006 3007 def test_nested_tensor_bmm_backward(self, device): 3008 nt0 = torch.nested.nested_tensor( 3009 [torch.randn((2, 6)), torch.randn((3, 6))], 3010 requires_grad=True, 3011 device=device, 3012 ) 3013 nt1 = torch.nested.nested_tensor( 3014 [torch.randn((6, 4)), torch.randn((6, 5))], 3015 requires_grad=True, 3016 device=device, 3017 ) 3018 with torch.no_grad(): 3019 pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True) 3020 pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True) 3021 3022 ynt = nt0.bmm(nt1) 3023 ypt = pt0.bmm(pt1) 3024 ynt.backward(ynt.clone()) 3025 ypt.backward(ypt.clone()) 3026 3027 self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad) 3028 self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad) 3029 3030 def test_nested_tensor_matmul_gradcheck(self, device): 3031 a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device) 3032 b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device) 3033 c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device) 3034 d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device) 3035 3036 def grad_test_func(a, b, c, d): 3037 nt0 = torch.nested.as_nested_tensor([a, b]) 3038 nt1 = torch.nested.as_nested_tensor([c, d]) 3039 result = torch.matmul(nt0, nt1) 3040 return torch.nested.to_padded_tensor(result, 0.0) 3041 3042 data = (a, b, c, d) 3043 assert torch.autograd.gradcheck(grad_test_func, inputs=data) 3044 3045 def test_nested_tensor_matmul_backward(self, device): 3046 nt0 = torch.nested.nested_tensor( 3047 [torch.randn((7, 2, 6)), torch.randn((7, 3, 6))], 3048 requires_grad=True, 3049 device=device, 3050 ) 3051 nt1 = torch.nested.nested_tensor( 3052 [torch.randn((7, 6, 4)), torch.randn((7, 6, 5))], 3053 requires_grad=True, 3054 device=device, 3055 ) 3056 with torch.no_grad(): 3057 pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True) 3058 pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True) 3059 3060 ynt = torch.matmul(nt0, nt1) 3061 ypt = torch.matmul(pt0, pt1) 3062 ynt.backward(ynt.clone()) 3063 ypt.backward(ypt.clone()) 3064 3065 self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad) 3066 self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad) 3067 3068 def test_nested_tensor_transpose_gradcheck(self, device): 3069 a = torch.randn(2, 5, requires_grad=True, device=device) 3070 b = torch.randn(3, 4, requires_grad=True, device=device) 3071 3072 def grad_test_func(a, b): 3073 nt = torch.nested.as_nested_tensor([a, b]) 3074 result = nt.transpose(-2, -1).transpose(-2, -1) 3075 return torch.nested.to_padded_tensor(result, 0.0) 3076 3077 data = (a, b) 3078 assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) 3079 3080 def test_nested_tensor_transpose_backward(self, device): 3081 nt = torch.nested.nested_tensor( 3082 [torch.randn((2, 5)), torch.randn((3, 4))], 3083 requires_grad=True, 3084 device=device, 3085 ) 3086 with torch.no_grad(): 3087 pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) 3088 3089 ynt = nt.transpose(-2, -1) 3090 ypt = pt.transpose(-2, -1) 3091 ynt.backward(ynt.clone()) 3092 ypt.backward(ypt.clone()) 3093 3094 self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) 3095 3096 def test_nested_tensor_reshape_gradcheck(self, device): 3097 a = torch.randn(2, 6, requires_grad=True, device=device) 3098 b = torch.randn(3, 6, requires_grad=True, device=device) 3099 3100 def grad_test_func(a, b): 3101 nt = torch.nested.as_nested_tensor([a, b]) 3102 result = nt.reshape(2, -1, 2, 3) 3103 return torch.nested.to_padded_tensor(result, 0.0) 3104 3105 data = (a, b) 3106 assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) 3107 3108 def test_nested_tensor_reshape_backward(self): 3109 nt = torch.nested.nested_tensor( 3110 [torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True 3111 ) 3112 with torch.no_grad(): 3113 pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) 3114 3115 ynt = nt.reshape(2, -1, 2, 3) 3116 ypt = pt.reshape(2, -1, 2, 3) 3117 ynt.backward(ynt.clone()) 3118 ypt.backward(ypt.clone()) 3119 3120 self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) 3121 3122 def test_nested_tensor_squeeze_backward(self, device): 3123 nt = torch.nested.nested_tensor( 3124 [torch.randn((2, 6, 1)), torch.randn((3, 6, 1))], 3125 requires_grad=True, 3126 device=device, 3127 ) 3128 with torch.no_grad(): 3129 pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) 3130 3131 ynt = nt.squeeze(-1) 3132 ypt = pt.squeeze(-1) 3133 ynt.backward(ynt.clone()) 3134 ypt.backward(ypt.clone()) 3135 3136 self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) 3137 3138 def test_nested_tensor_squeeze_gradcheck(self, device): 3139 a = torch.randn( 3140 (2, 6, 1), dtype=torch.float64, requires_grad=True, device=device 3141 ) 3142 b = torch.randn( 3143 (3, 6, 1), dtype=torch.float64, requires_grad=True, device=device 3144 ) 3145 3146 def grad_test_func(a, b): 3147 nt = torch.nested.as_nested_tensor([a, b]) 3148 result = nt.squeeze(-1) 3149 return torch.nested.to_padded_tensor(result, 0.0) 3150 3151 assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3) 3152 3153 def test_nested_tensor_unsqueeze_backward(self, device): 3154 nt = torch.nested.nested_tensor( 3155 [torch.randn((2, 6)), torch.randn((3, 6))], 3156 requires_grad=True, 3157 device=device, 3158 ) 3159 with torch.no_grad(): 3160 pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) 3161 3162 ynt = nt.unsqueeze(2) 3163 ypt = pt.unsqueeze(2) 3164 ynt.backward(ynt.clone()) 3165 ypt.backward(ypt.clone()) 3166 3167 self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) 3168 3169 def test_nested_tensor_unsqueeze_gradcheck(self, device): 3170 a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True, device=device) 3171 b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True, device=device) 3172 3173 def grad_test_func(a, b): 3174 nt = torch.nested.as_nested_tensor([a, b]) 3175 result = nt.unsqueeze(-1) 3176 return torch.nested.to_padded_tensor(result, 0.0) 3177 3178 assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3) 3179 3180 def test_nested_tensor_linear(self, device): 3181 a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) 3182 b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) 3183 c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) 3184 3185 weight = torch.randn( 3186 2, 2, requires_grad=True, dtype=torch.float64, device=device 3187 ) 3188 bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device) 3189 3190 def grad_test_func(a, b, c, weight, bias=None): 3191 nt = torch.nested.as_nested_tensor([a, b, c]) 3192 # This implicitly tests to_padded_tensor grads 3193 d = torch.functional.F.linear(nt, weight, bias) 3194 return torch.nested.to_padded_tensor(d, 0) 3195 3196 data = (a, b, c, weight, bias) 3197 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3198 3199 # Test linear with no bias added 3200 data = (a, b, c, weight) 3201 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3202 3203 def test_nested_tensor_linear_plus_transpose(self, device): 3204 a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) 3205 b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) 3206 c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) 3207 3208 weight = torch.randn( 3209 2, 2, requires_grad=True, dtype=torch.float64, device=device 3210 ) 3211 bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device) 3212 3213 def grad_test_func(a, b, c, weight, bias=None): 3214 nt = torch.nested.as_nested_tensor([a, b, c]) 3215 # This implicitly tests to_padded_tensor grads 3216 d = torch.functional.F.linear(nt, weight, bias) 3217 d = d.transpose(-1, -2).contiguous() 3218 return torch.nested.to_padded_tensor(d, 0) 3219 3220 data = (a, b, c, weight, bias) 3221 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3222 3223 # Test linear with no bias added 3224 data = (a, b, c, weight) 3225 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3226 3227 def test_nested_tensor_softmax(self, device): 3228 a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) 3229 b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) 3230 c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) 3231 3232 def grad_test_func(a, b, c, dim): 3233 nt = torch.nested.as_nested_tensor([a, b, c]) 3234 # This implicitly tests to_padded_tensor grads 3235 d = torch.functional.F.softmax(nt, dim=dim) 3236 return torch.nested.to_padded_tensor(d, 0) 3237 3238 # softmax over last dim 3239 data = (a, b, c, -1) 3240 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3241 3242 def test_nested_tensor_linear_backward(self, device): 3243 a = torch.randn(1, 2, requires_grad=False, device=device) 3244 b = torch.randn(2, 2, requires_grad=False, device=device) 3245 c = torch.randn(3, 2, requires_grad=False, device=device) 3246 3247 weight = torch.randn(2, 2, requires_grad=True, device=device) 3248 bias = torch.randn(2, requires_grad=True, device=device) 3249 nt = torch.nested.as_nested_tensor([a, b, c], device=device) 3250 3251 out = torch.functional.F.linear(nt, weight, bias) 3252 3253 out.backward(out.clone()) 3254 3255 assert weight.grad is not None 3256 assert bias.grad is not None 3257 3258 assert a.grad is None 3259 assert b.grad is None 3260 assert c.grad is None 3261 3262 def test_values_grad_with_broadcast(self, device): 3263 a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3264 b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3265 c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3266 3267 def grad_test_func(a, b, c): 3268 nt = torch.nested.as_nested_tensor([a, b, c]) 3269 buffer = nt.values() 3270 return buffer.sum() 3271 3272 data = (a, b, c) 3273 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3274 3275 def test_to_buffer_series_ops_grad_with_broadcast(self, device): 3276 a = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) 3277 b = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) 3278 c = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) 3279 3280 def grad_test_func(a, b, c): 3281 nt = torch.nested.as_nested_tensor([a, b, c]) 3282 buffer = nt.values() 3283 buffer = buffer * 2 3284 return buffer.exp() 3285 3286 data = (a, b, c) 3287 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3288 3289 def test_unbind_flow_through(self, device): 3290 a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3291 b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3292 c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3293 3294 def grad_test_func(a, b, c): 3295 nt = torch.nested.as_nested_tensor([a, b, c]) 3296 ntT = nt.transpose(-1, -2) 3297 unbound = ntT.unbind() 3298 d = unbound[0] 3299 d = torch.pow(d, 2) 3300 return d 3301 3302 data = (a, b, c) 3303 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3304 3305 def test_split_with_sizes_flow_through(self, device): 3306 a = torch.randn(2, 5, requires_grad=True, dtype=torch.float64, device=device) 3307 b = torch.randn(3, 5, requires_grad=True, dtype=torch.float64, device=device) 3308 c = torch.randn(4, 5, requires_grad=True, dtype=torch.float64, device=device) 3309 3310 def grad_test_func(a, b, c): 3311 nt = torch.nested.as_nested_tensor([a, b, c]) 3312 splits = nt.split_with_sizes([2, 3], dim=-1) 3313 unbound = splits[1].unbind() 3314 d = unbound[0] 3315 d = torch.pow(d, 2) 3316 return d 3317 3318 data = (a, b, c) 3319 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3320 3321 def test_indexing_backward(self, device): 3322 x0 = torch.randn((2, 5)) 3323 x1 = torch.randn((3, 4)) 3324 nt = torch.nested.nested_tensor([x0, x1], device=device, requires_grad=True) 3325 self.assertEqual(nt[0], x0) 3326 self.assertEqual(nt[-1], x1) 3327 grad_x0 = torch.randn((2, 5), device=device) 3328 nt[0].backward(grad_x0) 3329 expected_grad = torch.nested.nested_tensor( 3330 [grad_x0, torch.zeros((3, 4), device=device)] 3331 ) 3332 self.assertEqual(nt.grad, expected_grad) 3333 3334 def test_masked_fill_backward(self, device): 3335 a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3336 b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3337 c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3338 3339 def grad_test_func(a, b, c): 3340 nt = torch.nested.as_nested_tensor([a, b, c]) 3341 mask = nt.detach().clone().to(bool) 3342 out = nt.masked_fill(mask, 0) 3343 out = torch.nested.to_padded_tensor(out, 0) 3344 return out 3345 3346 data = (a, b, c) 3347 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3348 3349 def test_gelu_backward(self, device): 3350 a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3351 b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3352 c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3353 3354 def grad_test_func(a, b, c): 3355 nt = torch.nested.as_nested_tensor([a, b, c]) 3356 nt_gelu = torch.nn.functional.gelu(nt) 3357 return torch.nested.to_padded_tensor(nt_gelu, 0) 3358 3359 data = (a, b, c) 3360 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3361 3362 def test_relu_backward(self, device): 3363 a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3364 b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3365 c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3366 3367 def grad_test_func(a, b, c): 3368 nt = torch.nested.as_nested_tensor([a, b, c]) 3369 nt_relu = torch.nn.functional.relu(nt) 3370 return torch.nested.to_padded_tensor(nt_relu, 0) 3371 3372 data = (a, b, c) 3373 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3374 3375 def test_selu_backward(self, device): 3376 a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3377 b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3378 c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3379 3380 def grad_test_func(a, b, c): 3381 nt = torch.nested.as_nested_tensor([a, b, c]) 3382 nt_relu = torch.nn.functional.silu(nt) 3383 return torch.nested.to_padded_tensor(nt_relu, 0) 3384 3385 data = (a, b, c) 3386 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3387 3388 def test_abs_backward(self, device): 3389 a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3390 b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3391 c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3392 3393 def grad_test_func(a, b, c): 3394 nt = torch.nested.as_nested_tensor([a, b, c]) 3395 nt_abs = torch.abs(nt) 3396 return torch.nested.to_padded_tensor(nt_abs, 0) 3397 3398 data = (a, b, c) 3399 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3400 3401 # Previously would error when input NT doesn't require grad 3402 # NotImplementedError: Cannot access storage of UndefinedTensorImpl 3403 def test_layer_norm_backward_edge_case(self, device): 3404 size = 4 3405 a = torch.randn( 3406 1, 2, size, requires_grad=False, dtype=torch.float64, device=device 3407 ) 3408 nt = torch.nested.nested_tensor([a]) 3409 nt_layer_norm = torch.nn.LayerNorm( 3410 nt.size(-1), device=device, dtype=torch.float64 3411 ) 3412 out = nt_layer_norm(nt) 3413 out.backward(out.clone()) 3414 3415 def test_accumulate_grad_different_strides(self, device): 3416 a = torch.rand(1, 4, 2, requires_grad=True, dtype=torch.float64, device=device) 3417 b = torch.rand(1, 8, 2, requires_grad=True, dtype=torch.float64, device=device) 3418 3419 def grad_test_func(a, b): 3420 nt_1 = torch.nested.as_nested_tensor([a, b]) 3421 nt_2 = nt_1.clone() 3422 out = torch.nn.functional.scaled_dot_product_attention(nt_1, nt_2, nt_2) 3423 return torch.nested.to_padded_tensor(out, 0) 3424 3425 data = (a, b) 3426 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3427 3428 # https://github.com/pytorch/pytorch/issues/95562 3429 @skipIfSlowGradcheckEnv 3430 @parametrize("size", [1024, 1023, 513, 512, 256, 128, 32, 4, 2]) 3431 def test_layer_norm_backward(self, device, size): 3432 a = torch.randn( 3433 1, 2, size, requires_grad=True, dtype=torch.float64, device=device 3434 ) 3435 b = torch.randn( 3436 2, 2, size, requires_grad=True, dtype=torch.float64, device=device 3437 ) 3438 c = torch.randn( 3439 3, 2, size, requires_grad=True, dtype=torch.float64, device=device 3440 ) 3441 3442 def grad_test_func(a, b, c): 3443 nt = torch.nested.as_nested_tensor([a, b, c]) 3444 layer_norm = torch.nn.LayerNorm( 3445 nt.size(-1), device=device, dtype=torch.float64 3446 ) 3447 nt_layer_norm = layer_norm(nt) 3448 return torch.nested.to_padded_tensor(nt_layer_norm, 0) 3449 3450 data = (a, b, c) 3451 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3452 3453 # https://github.com/pytorch/pytorch/issues/95562 3454 @skipIfSlowGradcheckEnv 3455 # Could either mark slow or reduce size 3456 @parametrize("size", [128, 32, 4, 2]) 3457 def test_layer_norm_backward_5d(self, device, size): 3458 a = torch.randn( 3459 4, size, size, 4, requires_grad=True, dtype=torch.float64, device=device 3460 ) 3461 b = torch.randn( 3462 7, size, size, 4, requires_grad=True, dtype=torch.float64, device=device 3463 ) 3464 c = torch.randn( 3465 10, size, size, 4, requires_grad=True, dtype=torch.float64, device=device 3466 ) 3467 3468 def grad_test_func(a, b, c): 3469 nt = torch.nested.as_nested_tensor([a, b, c]) 3470 layer_norm = torch.nn.LayerNorm( 3471 (size, size, nt.size(-1)), device=device, dtype=torch.float64 3472 ) 3473 nt_layer_norm = layer_norm(nt) 3474 return torch.nested.to_padded_tensor(nt_layer_norm, 0) 3475 3476 data = (a, b, c) 3477 assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3478 3479 3480# Found in torch/testing/_comparison.py 3481default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5} 3482default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6} 3483 3484 3485def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: 3486 deviation = true_value - computed_value 3487 deviation = torch.abs(deviation / true_value) 3488 # Fill in the nans with the default rtol 3489 torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype]) 3490 return deviation.max().item() 3491 3492 3493def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: 3494 deviation = true_value - computed_value 3495 atol = torch.abs(deviation).max().item() 3496 return atol 3497 3498 3499def get_tolerances( 3500 true_value: torch.Tensor, 3501 computed_value: torch.Tensor, 3502 fudge_factor: Optional[float] = None, 3503) -> Tuple[float, float]: 3504 """Returns the absolute and relative tolerances for comparing two tensors.""" 3505 fudge_factor = fudge_factor if fudge_factor is not None else 1.0 3506 atol = get_atol(true_value, computed_value) 3507 rtol = get_rtol(true_value, computed_value) 3508 3509 atol = fudge_factor * max(atol, default_atol[computed_value.dtype]) 3510 rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype]) 3511 # torch.isclose() has weird behavior around see: 3512 # https://github.com/pytorch/pytorch/issues/102400 3513 if rtol > 1e30: 3514 rtol = default_rtol[computed_value.dtype] 3515 return atol, rtol 3516 3517 3518# We can probably parametrizing existing tests instead of having a separate 3519# test class as we begin to support more ops. Also maybe rewrite with OpInfos. 3520@markDynamoStrictTest 3521class TestNestedTensorSubclass(NestedTensorTestCase): 3522 # TODO: consolidate with the below 3523 def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True): 3524 Ds = nested_size[1:] 3525 out = [] 3526 for s in nested_size[0]: 3527 out.append( 3528 torch.randn( 3529 s, 3530 *Ds, 3531 requires_grad=requires_grad, 3532 device=device, 3533 dtype=torch.float64, 3534 ) 3535 ) 3536 return out 3537 3538 def _get_example_tensor_lists( 3539 self, 3540 include_list_of_lists=True, 3541 include_requires_grad=True, 3542 include_inner_dim_size_1=False, 3543 include_2d_tensor=False, 3544 ): 3545 def _make_tensor( 3546 *shape, include_requires_grad=include_requires_grad, requires_grad=True 3547 ): 3548 return torch.randn( 3549 *shape, 3550 requires_grad=(requires_grad if include_requires_grad else False), 3551 ) 3552 3553 # Purposefully introduce mixed requires_grad settings for the components 3554 # when include_requires_grad=True. 3555 example_lists = [ 3556 # (B, *, D) with B=4 3557 [ 3558 _make_tensor(2, 5), 3559 _make_tensor(3, 5, requires_grad=False), 3560 _make_tensor(4, 5, requires_grad=False), 3561 _make_tensor(6, 5), 3562 ], 3563 # (B, *, D_0, D_1) with B=5 3564 [ 3565 _make_tensor(2, 5, 6), 3566 _make_tensor(3, 5, 6), 3567 _make_tensor(4, 5, 6, requires_grad=False), 3568 _make_tensor(5, 5, 6), 3569 _make_tensor(6, 5, 6), 3570 ], 3571 # (B, *, D_0, D_1, D_2) with B=6 3572 [ 3573 _make_tensor(2, 5, 6, 7), 3574 _make_tensor(3, 5, 6, 7), 3575 _make_tensor(4, 5, 6, 7, requires_grad=False), 3576 _make_tensor(5, 5, 6, 7), 3577 _make_tensor(6, 5, 6, 7), 3578 _make_tensor(7, 5, 6, 7), 3579 ], 3580 ] 3581 3582 if include_list_of_lists: 3583 example_lists.append( 3584 # (B, *, D) with B=3 in list form 3585 [ 3586 _make_tensor(2, 5, requires_grad=False).tolist(), 3587 _make_tensor(3, 5).tolist(), 3588 _make_tensor(4, 5).tolist(), 3589 ] 3590 ) 3591 3592 if include_inner_dim_size_1: 3593 example_lists.append( 3594 [ 3595 _make_tensor(2, 1), 3596 _make_tensor(3, 1, requires_grad=False), 3597 _make_tensor(4, 1, requires_grad=False), 3598 _make_tensor(6, 1), 3599 ] # (B, *, 1) 3600 ) 3601 example_lists.append( 3602 [ 3603 _make_tensor(2, 5, 1), 3604 _make_tensor(3, 5, 1, requires_grad=False), 3605 _make_tensor(4, 5, 1, requires_grad=False), 3606 _make_tensor(6, 5, 1), 3607 ] # (B, *, 5, 1) 3608 ) 3609 3610 if include_2d_tensor: 3611 example_lists.append( 3612 [ 3613 _make_tensor(2), 3614 _make_tensor(3, requires_grad=False), 3615 _make_tensor(4, requires_grad=False), 3616 _make_tensor(6), 3617 ] # (B, *) 3618 ) 3619 3620 return example_lists 3621 3622 def test_tensor_attributes(self, device): 3623 a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) 3624 b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) 3625 c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) 3626 nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3627 _offsets = nt.offsets() 3628 3629 for op in ( 3630 torch.ops.aten.is_non_overlapping_and_dense.default, 3631 torch.ops.aten.sym_size.default, 3632 torch.ops.aten.dim.default, 3633 torch.ops.aten.numel.default, 3634 torch.ops.aten.sym_numel.default, 3635 torch.ops.aten.sym_stride.default, 3636 torch.ops.aten.sym_storage_offset.default, 3637 ): 3638 op(nt) 3639 3640 with self.assertRaisesRegex( 3641 RuntimeError, "directly calling torch.ops.aten.size" 3642 ): 3643 torch.ops.aten.size.default(nt) 3644 3645 nested_int = torch.nested._internal.nested_tensor.get_tensor_symint( 3646 _offsets, coeff=1 3647 ) 3648 self.assertEqual(nt.size(), (3, nested_int, 3)) 3649 self.assertEqual(nt.shape, (3, nested_int, 3)) 3650 self.assertEqual(nt.dim(), 3) 3651 self.assertEqual(nt.numel(), 27) 3652 3653 @parametrize("nt_dim", [3, 4, 5]) 3654 def test_linear(self, device, nt_dim): 3655 if nt_dim == 3: 3656 fixed_shape = (3,) 3657 elif nt_dim == 4: 3658 fixed_shape = (4, 3) 3659 elif nt_dim == 5: 3660 fixed_shape = (5, 4, 3) 3661 3662 a = torch.randn( 3663 2, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device 3664 ) 3665 b = torch.randn( 3666 3, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device 3667 ) 3668 c = torch.randn( 3669 4, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device 3670 ) 3671 weight = torch.randn( 3672 4, 3, requires_grad=True, dtype=torch.float64, device=device 3673 ) 3674 3675 def grad_test_func(a, b, c, weight): 3676 nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3677 out = torch.nn.functional.linear(nt, weight) 3678 return out.values() 3679 3680 gradcheck(grad_test_func, inputs=(a, b, c, weight), check_batched_grad=False) 3681 3682 def test_unary_pointwise(self, device): 3683 a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) 3684 b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) 3685 c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) 3686 3687 def grad_test_func(a, b, c): 3688 nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3689 out = torch.nn.functional.silu(nt.sin().cos()) 3690 return out.values() 3691 3692 gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) 3693 3694 def test_unary_pointwise_transposed_inputs(self, device): 3695 a, b, c = ( 3696 torch.randn( 3697 i + 2, 5, requires_grad=True, dtype=torch.float64, device=device 3698 ) 3699 for i in range(3) 3700 ) 3701 3702 nt = torch.nested.nested_tensor( 3703 [a.detach(), b.detach(), c.detach()], layout=torch.jagged 3704 ) 3705 nt_t = nt.transpose(1, 2) 3706 self.assertFalse(nt_t.is_contiguous()) 3707 out = torch.nn.functional.silu(nt_t.sin().cos()) 3708 self.assertEqual( 3709 out.is_contiguous(), 3710 torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous(), 3711 ) 3712 3713 self.assertEqual(nt_t.shape, out.shape) 3714 3715 a, b, c = ( 3716 torch.randn( 3717 i + 2, 5, requires_grad=True, dtype=torch.float64, device=device 3718 ) 3719 for i in range(3) 3720 ) 3721 3722 def grad_test_func(a, b, c): 3723 nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3724 nt_t = nt.transpose(1, 2) 3725 out = torch.nn.functional.silu(nt_t.sin().cos()) 3726 return out.values() 3727 3728 gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) 3729 3730 def test_binary_pointwise(self, device): 3731 a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) 3732 b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) 3733 c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) 3734 3735 # Incorrect usage: shape check will fail if the offsets tensor are not 3736 # the same exact tensor object 3737 nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3738 nt2 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3739 3740 self.assertRaisesRegex( 3741 RuntimeError, 3742 "cannot call binary pointwise function .* with inputs of shapes", 3743 lambda: nt1 * nt2, 3744 ) 3745 3746 # Correct usage: chain the calls using the same offsets tensor object 3747 def grad_test_func(a, b, c): 3748 nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3749 # TODO: Switch to public API that takes in (values, offsets) once it exists 3750 nt2, offsets = jagged_from_list([a, b, c], nt1.offsets()) 3751 out = nt1 * nt2 3752 return out.values() 3753 3754 gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) 3755 3756 def test_binary_pointwise_transposed(self, device): 3757 a, b, c = ( 3758 torch.randn(i + 2, 5, dtype=torch.float64, device=device) for i in range(3) 3759 ) 3760 3761 nt1, offsets = jagged_from_list([a, b, c], None) 3762 nt2, offsets = jagged_from_list([a, b, c], offsets) 3763 3764 nt1_t = nt1.transpose(1, 2) 3765 nt2_t = nt2.transpose(1, 2) 3766 3767 # out = nt1_t * nt2_t 3768 # self.assertFalse(nt1_t.is_contiguous()) 3769 # self.assertEqual(out.is_contiguous(), (b.transpose(-1, -2) * b.transpose(-1, -2)).is_contiguous()) 3770 # self.assertEqual(out.shape, nt1_t.shape) 3771 3772 self.assertRaisesRegex( 3773 RuntimeError, 3774 "cannot call binary pointwise function mul.Tensor with inputs of shapes", 3775 lambda: nt1 * nt2_t, 3776 ) 3777 3778 a, b, c = ( 3779 torch.randn( 3780 i + 2, 5, requires_grad=True, dtype=torch.float64, device=device 3781 ) 3782 for i in range(3) 3783 ) 3784 3785 # Correct usage: chain the calls using the same offsets tensor object 3786 def grad_test_func(a, b, c): 3787 nt1, offsets = jagged_from_list([a, b, c], None) 3788 nt2, offsets = jagged_from_list([a, b, c], offsets) 3789 nt1_t = nt1.transpose(1, 2) 3790 nt2_t = nt2.transpose(1, 2) 3791 out = nt1_t * nt2_t 3792 return out.values() 3793 3794 gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) 3795 3796 def test_split(self, device): 3797 a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) 3798 b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) 3799 c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) 3800 3801 nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3802 out = torch.split(nt, 2, -1) 3803 self.assertEqual(len(out), 2) 3804 self.assertEqualIgnoringNestedInts( 3805 out[0], 3806 torch.nested.as_nested_tensor( 3807 [a[:, 0:2], b[:, 0:2], c[:, 0:2]], layout=torch.jagged 3808 ), 3809 ) 3810 self.assertEqualIgnoringNestedInts( 3811 out[1], 3812 torch.nested.as_nested_tensor( 3813 [a[:, 2:], b[:, 2:], c[:, 2:]], layout=torch.jagged 3814 ), 3815 ) 3816 3817 with self.assertRaisesRegex( 3818 RuntimeError, 3819 r"split\(\): not supported for NestedTensor on dim=1", 3820 ): 3821 torch.split(nt, 2, 1) 3822 3823 def test_split_with_sizes(self, device): 3824 a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) 3825 b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) 3826 c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) 3827 3828 nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3829 out = torch.split(nt, [1, 2], -1) 3830 self.assertEqual(len(out), 2) 3831 self.assertEqualIgnoringNestedInts( 3832 out[0], 3833 torch.nested.as_nested_tensor( 3834 [a[:, 0:1], b[:, 0:1], c[:, 0:1]], layout=torch.jagged 3835 ), 3836 ) 3837 self.assertEqualIgnoringNestedInts( 3838 out[1], 3839 torch.nested.as_nested_tensor( 3840 [a[:, 1:], b[:, 1:], c[:, 1:]], layout=torch.jagged 3841 ), 3842 ) 3843 with self.assertRaisesRegex( 3844 RuntimeError, 3845 r"split_with_sizes\(\): not supported for NestedTensor on dim=1", 3846 ): 3847 torch.split(nt, [1, 2], 1) 3848 3849 def test_softmax(self, device): 3850 nt = random_nt_from_dims( 3851 [3, None, 5], 3852 device=device, 3853 dtype=torch.float32, 3854 layout=torch.jagged, 3855 requires_grad=True, 3856 ) 3857 3858 # operate on dim=2 3859 output = nt.softmax(dim=2) 3860 3861 @torch._dynamo.disable 3862 def _compare_to_ref(nt, output, dim): 3863 for in_component, out_component in zip(nt.unbind(), output.unbind()): 3864 self.assertEqual(in_component.softmax(dim=dim), out_component) 3865 3866 # dim=2 -> dim=1 after unbind 3867 _compare_to_ref(nt, output, dim=1) 3868 3869 # operate on dim=-1 3870 output2 = nt.softmax(dim=-1) 3871 torch._dynamo.disable(self.assertEqual)(output, output2) 3872 _compare_to_ref(nt, output2, dim=-1) 3873 3874 def grad_test_func(a, b): 3875 nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged) 3876 out = nt.softmax(dim=-1) 3877 return out.values() 3878 3879 a = torch.rand(4, 5, requires_grad=True, dtype=torch.float64, device=device) 3880 b = torch.rand(8, 5, requires_grad=True, dtype=torch.float64, device=device) 3881 gradcheck(grad_test_func, inputs=(a, b), check_batched_grad=False) 3882 3883 def test_views_inherit_ragged_dim(self, device): 3884 # view 3885 nt = random_nt_from_dims( 3886 [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged 3887 ) 3888 # inherit ragged dim via -1 3889 view = nt.view(4, -1, 80) 3890 self.assertEqual(nt.shape[1], view.shape[1]) 3891 # inherit batch and ragged dims via -1 3892 view2 = nt.view(-1, -1, 80) 3893 self.assertEqual(nt.shape[:2], view2.shape[:2]) 3894 3895 # expand 3896 nt = random_nt_from_dims( 3897 [3, None, 1], device=device, dtype=torch.float32, layout=torch.jagged 3898 ) 3899 # inherit batch and ragged dims via -1 3900 view = nt.expand(-1, -1, 5) 3901 self.assertEqual(nt.shape[:2], view.shape[:2]) 3902 3903 def test_view_ragged_idx_not_one(self, device): 3904 nt = random_nt_from_dims( 3905 [2, None, 20], device=device, dtype=torch.float32, layout=torch.jagged 3906 ) 3907 3908 view_transposed = nt.transpose(1, 2).view(2, 20, nt.size(1)) 3909 self.assertEqual((2, 20, nt.size(1)), (view_transposed.size())) 3910 self.assertEqual(view_transposed._base, nt._base) 3911 3912 def test_unsafe_view(self, device): 3913 nt = random_nt_from_dims( 3914 [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged 3915 ) 3916 # basic view 3917 view1 = torch.ops.aten._unsafe_view(nt, (4, -1, 80)) 3918 self.assertEqual((4, nt.size(1), 80), tuple(view1.size())) 3919 # _unsafe_view differs from view in that the view information is not tracked 3920 self.assertTrue(view1._base is None) 3921 3922 # test an unsafe_view when ragged_idx != 1, currently only supports identity view 3923 nt_t = nt.transpose(1, 2) 3924 view2 = torch.ops.aten._unsafe_view(nt_t, (4, 8, nt.size(1), 10)) 3925 self.assertEqual((4, 8, nt.size(1), 10), tuple(view2.size())) 3926 self.assertTrue(view2._base is None) 3927 3928 @xfailIfTorchDynamo 3929 @parametrize("requires_grad", [False, True]) 3930 def test_reshape_decomp(self, device, requires_grad): 3931 # contiguous NT should result in view. 3932 nt = ( 3933 random_nt_from_dims( 3934 [3, None, 10], 3935 device=device, 3936 dtype=torch.float32, 3937 layout=torch.jagged, 3938 ) 3939 .detach() 3940 .requires_grad_(requires_grad) 3941 ) 3942 view = nt.reshape(-1, -1, 5, 2) 3943 self.assertEqual(view.shape[:2], nt.shape[:2]) 3944 self.assertTrue(view._is_view() and view._base is nt) 3945 # make sure gradients flow back 3946 if requires_grad: 3947 view.backward(torch.ones_like(view)) 3948 self.assertEqual(nt.grad, torch.ones_like(nt)) 3949 3950 # non-contiguous NT should result in contiguous copy 3951 nt = random_nt_from_dims( 3952 [3, None, 5, 2], 3953 device=device, 3954 dtype=torch.float32, 3955 layout=torch.jagged, 3956 requires_grad=requires_grad, 3957 ) 3958 nt_noncontig = nt.transpose(-1, -2) 3959 self.assertFalse(nt_noncontig.is_contiguous()) 3960 copy = nt_noncontig.reshape(-1, -1, 10) 3961 self.assertTrue(copy.is_contiguous()) 3962 self.assertEqual(copy.shape[:2], nt.shape[:2]) 3963 # make sure gradients flow back 3964 if requires_grad: 3965 copy.backward(torch.ones_like(copy)) 3966 self.assertEqual(nt.grad, torch.ones_like(nt)) 3967 3968 def test_flatten_decomp(self, device): 3969 nt = random_nt_from_dims( 3970 [3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged 3971 ) 3972 flattened = nt.flatten(-2, -1) 3973 self.assertEqual(flattened.shape, nt.view(3, -1, 10).shape) 3974 3975 nt = random_nt_from_dims( 3976 [3, None, 5, 2, 6], device=device, dtype=torch.float32, layout=torch.jagged 3977 ) 3978 flattened = nt.flatten(-3, -2) 3979 self.assertEqual(flattened.shape, nt.view(3, -1, 10, 6).shape) 3980 3981 def test_chunk(self, device): 3982 # none NJT case 3983 t = torch.randn(10, 4, 5, requires_grad=True) 3984 t_list = t.chunk(3, dim=0) 3985 loss = t_list[0].sum() + t_list[2].sum() 3986 loss.backward() 3987 3988 # normal case 3989 D = 30 3990 B = 8 3991 nt = random_nt_from_dims( 3992 [B, None, D], 3993 device=device, 3994 dtype=torch.float32, 3995 layout=torch.jagged, 3996 requires_grad=True, 3997 ) 3998 NUM_CHUNKS = 3 3999 chunks = nt.chunk(NUM_CHUNKS, dim=-1) 4000 self.assertEqual(len(chunks), NUM_CHUNKS) 4001 for i in range(NUM_CHUNKS): 4002 self.assertEqual(chunks[i].shape[-1], D // NUM_CHUNKS) 4003 4004 # test chunk_backward 4005 values = torch.randn( 4006 5, 11, dtype=torch.float64, device=device, requires_grad=True 4007 ) 4008 offsets = torch.tensor([0, 2, 3, 5], device=device) 4009 4010 def grad_test_func(values, offsets): 4011 nt = torch.nested.nested_tensor_from_jagged(values, offsets) 4012 chunks = nt.chunk(3, dim=-1) 4013 return chunks[0].values().sum() 4014 4015 assert gradcheck( 4016 grad_test_func, 4017 inputs=(values, offsets), 4018 check_batched_grad=False, 4019 ) 4020 4021 # chunk on batch dim 4022 chunks = nt.chunk(NUM_CHUNKS, dim=0) 4023 self.assertEqual(len(chunks), NUM_CHUNKS) 4024 chunk_size = math.ceil(B / NUM_CHUNKS) 4025 for i in range(NUM_CHUNKS): 4026 if i < NUM_CHUNKS - 1: 4027 self.assertEqual(chunks[i].shape[0], chunk_size) 4028 else: 4029 self.assertEqual(chunks[i].shape[0], B - chunk_size * (NUM_CHUNKS - 1)) 4030 offsets_expected = ( 4031 nt._offsets[i * chunk_size + 1 : (i + 1) * chunk_size + 1] 4032 - nt._offsets[i * chunk_size] 4033 ) 4034 self.assertEqual(chunks[i]._offsets[1:], offsets_expected) 4035 self.assertEqual(nt._values, torch.cat([x._values for x in chunks], dim=0)) 4036 4037 with self.assertRaisesRegex( 4038 RuntimeError, 4039 "dim != 0 INTERNAL ASSERT FAILED .* Nested Tensor doesn't support chunk backward on dim=0 yet.", 4040 ): 4041 # doesn't support backward for chunk (dim=0) yet 4042 loss = ( 4043 chunks[0].values().sum() 4044 + chunks[1].values().sum() 4045 + chunks[2].values().sum() 4046 ) 4047 loss.backward() 4048 4049 # chunk on ragged dim not supported 4050 with self.assertRaisesRegex( 4051 RuntimeError, "chunk.* not supported for NestedTensor on dim=1" 4052 ): 4053 nt.chunk(2, dim=1) 4054 4055 def test_squeeze(self, device): 4056 B = 4 4057 D = 6 4058 # squeeze middle dim 4059 nt = random_nt_from_dims( 4060 [B, None, 1, D], device=device, dtype=torch.float32, layout=torch.jagged 4061 ) 4062 j0 = nt.shape[1] 4063 4064 for dim_arg in [-2, 2]: 4065 out = nt.squeeze(dim_arg) 4066 self.assertEqual(out.shape, (B, j0, D)) 4067 self.assertEqual(out.unsqueeze(-2), nt) 4068 4069 # squeeze last dim 4070 nt = random_nt_from_dims( 4071 [B, None, 1], device=device, dtype=torch.float32, layout=torch.jagged 4072 ) 4073 j1 = nt.shape[1] 4074 4075 for dim_arg in [-1, 2]: 4076 out = nt.squeeze(dim_arg) 4077 self.assertEqual(out.shape, (B, j1)) 4078 self.assertEqual(out.unsqueeze(-1), nt) 4079 4080 # squeeze on batch dim not supported 4081 with self.assertRaisesRegex( 4082 RuntimeError, "squeeze.* not supported for NestedTensor on dim=0" 4083 ): 4084 nt.squeeze(0) 4085 4086 # squeeze on ragged dim not supported 4087 with self.assertRaisesRegex( 4088 RuntimeError, "squeeze.* not supported for NestedTensor on dim=1" 4089 ): 4090 nt.squeeze(1) 4091 4092 def test_binary_pointwise_broadcasting(self, device): 4093 # (B, j0, 3, 4) 4094 ts = self._get_list_for_jagged_tensor( 4095 ((2, 3, 4), 3, 4), device, requires_grad=True 4096 ) 4097 # (B, j0, ?, ?) + (?) -> (B, j0, ?, ?) 4098 # (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?) 4099 # (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?) 4100 # Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) 4101 t_sizes = ( 4102 (4,), 4103 (1, 4), 4104 (3, 1), 4105 (1, 3, 1), 4106 (1, 1, 1, 4), 4107 # (1, 1, 1, 1, 4), (unsupported today) 4108 ) 4109 4110 def grad_test_func(t, *ts): 4111 nt = torch.nested.as_nested_tensor(list(ts), layout=torch.jagged) 4112 out = nt + t 4113 return out.values() 4114 4115 for t_size in t_sizes: 4116 t = torch.rand( 4117 t_size, requires_grad=True, device=device, dtype=torch.float64 4118 ) 4119 gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False) 4120 4121 def test_threshold_backward(self, device): 4122 ts1 = self._get_list_for_jagged_tensor( 4123 ((2, 3, 4), 16), device=device, requires_grad=False 4124 ) 4125 ts2 = self._get_list_for_jagged_tensor( 4126 ((2, 3, 4), 16), device=device, requires_grad=False 4127 ) 4128 4129 nt1, offsets = jagged_from_list(ts1, None) 4130 nt2, offsets = jagged_from_list(ts2, offsets) 4131 buf1 = nt1.values().detach().clone() 4132 buf2 = nt2.values().detach().clone() 4133 4134 res_nt = torch.ops.aten.threshold_backward(nt1, nt2, 0.0) 4135 res_dense = torch.ops.aten.threshold_backward(buf1, buf2, 0.0) 4136 4137 self.assertEqual(res_dense, res_nt.values()) 4138 4139 @dtypes(torch.float32) 4140 @parametrize( 4141 "func", 4142 [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], 4143 name_fn=get_op_name, 4144 ) 4145 @parametrize("keepdim", [False, True]) 4146 @parametrize("requires_grad", [False, True]) 4147 @parametrize("components_require_grad", [False, True]) 4148 def test_jagged_op_different_output_shape_dim( 4149 self, device, dtype, keepdim, requires_grad, components_require_grad, func 4150 ): 4151 """ 4152 Operator passes when reducing on valid reduction dimensions. 4153 This test is for operators which return an output tensor with a shape different from the input tensor. 4154 """ 4155 if get_op_name(func) == "mean" and not keepdim: 4156 return 4157 4158 op_name = get_op_name(func) 4159 4160 ts = self._get_list_for_jagged_tensor( 4161 ((2, 3, 4), 3, 4), device=device, requires_grad=True 4162 ) # (B, j0, 3, 4) 4163 4164 # verify correctness of shapes (assuming that ragged_idx == 1) 4165 if op_name == "sum": 4166 reduce_dims = ( 4167 ((0, 1), (3, 4), (1, 1, 3, 4), (0,)), # batch, ragged 4168 ((2, 3), (3, None), (3, None, 1, 1), (1, 2)), # non-batch, non-batch 4169 ((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)), # batch, ragged, non-batch 4170 ((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)), # batch, ragged, non-batch 4171 ( 4172 (0, 1, 2, 3), 4173 (), 4174 (1, 1, 1, 1), 4175 (0, 1, 2), 4176 ), # batch, ragged, non-batch, non-batch 4177 ((2,), (3, None, 4), (3, None, 1, 4), (1,)), # non-batch 4178 ) # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None 4179 elif op_name == "mean": 4180 reduce_dims = ( 4181 ((2,), (3, None, 4), (3, None, 1, 4), (1,)), 4182 ((3,), (3, None, 3), (3, None, 3, 1), (2,)), 4183 ) 4184 4185 for rd, ref_shape_no_keepdim, ref_shape_keepdim, _ in reduce_dims: 4186 nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged) 4187 out = func(nt, dim=rd, keepdim=keepdim) 4188 ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim 4189 if not torch.compiler.is_compiling: # if not using torch dynamo 4190 self.assertEqual(len(out.shape), len(ref_shape)) 4191 for o, r in zip(out.shape, ref_shape): 4192 if r is not None: 4193 self.assertEqual(o, r) 4194 else: 4195 self.assertTrue(isinstance(o, torch.SymInt)) 4196 4197 # verify correctness of values 4198 tensor_lists = self._get_example_tensor_lists( 4199 include_list_of_lists=False, 4200 include_requires_grad=components_require_grad, 4201 include_inner_dim_size_1=True, 4202 ) 4203 for tensor_list, reduce_dim_tuple in itertools.product( 4204 tensor_lists, reduce_dims 4205 ): 4206 nt = torch.nested.nested_tensor( 4207 tensor_list, 4208 device=device, 4209 dtype=dtype, 4210 layout=torch.jagged, 4211 requires_grad=requires_grad, 4212 ) 4213 4214 reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple 4215 4216 if nt.dim() > reduce_dim[-1]: 4217 out_actual = func(nt, dim=reduce_dim, keepdim=keepdim) 4218 if nt._ragged_idx in reduce_dim: # raggedness reduced away 4219 out_expected = func( 4220 nt.values(), dim=reduce_dim_expected, keepdim=keepdim 4221 ) 4222 self.assertTrue(torch.allclose(out_actual, out_expected)) 4223 else: # raggedness preserved 4224 out_expected = func(nt.values(), dim=reduce_dim_expected) 4225 self.assertTrue( 4226 torch.allclose( 4227 out_actual.values().view(-1), out_expected.view(-1) 4228 ) 4229 ) 4230 4231 @dtypes(torch.float32) 4232 @parametrize("requires_grad", [False, True]) 4233 @parametrize("components_require_grad", [False, True]) 4234 def test_softmax_dim( 4235 self, 4236 device, 4237 dtype, 4238 requires_grad, 4239 components_require_grad, 4240 ): 4241 """ 4242 Softmax passes when reducing on valid reduction dimensions. 4243 """ 4244 ts = self._get_list_for_jagged_tensor( 4245 ((2, 3, 4), 3, 4), device=device, requires_grad=True 4246 ) # (B, j0, 3, 4) 4247 4248 output_shape = (3, None, 3, 4) 4249 4250 # verify correctness of shapes (assuming that ragged_idx == 1) 4251 reduce_dims = ( 4252 (2, 1), 4253 (3, 2), 4254 ) # (reduction dimension, effective reduction dimension for baseline) 4255 4256 for reduce_dim, _ in reduce_dims: 4257 nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged) 4258 out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim) 4259 torch._dynamo.disable(self.assertEqual)( 4260 len(out_actual.shape), len(output_shape) 4261 ) # disable if running on dynamo 4262 for dim_actual, dim_expected in zip(out_actual.shape, output_shape): 4263 if dim_expected is not None: 4264 self.assertEqual(dim_actual, dim_expected) 4265 else: 4266 self.assertTrue(isinstance(dim_actual, torch.SymInt)) 4267 4268 # verify correctness of values 4269 tensor_lists = self._get_example_tensor_lists( 4270 include_list_of_lists=False, 4271 include_requires_grad=components_require_grad, 4272 include_inner_dim_size_1=True, 4273 ) 4274 for tensor_list, reduce_dim_tuple in itertools.product( 4275 tensor_lists, reduce_dims 4276 ): 4277 nt = torch.nested.nested_tensor( 4278 tensor_list, 4279 device=device, 4280 dtype=dtype, 4281 layout=torch.jagged, 4282 requires_grad=requires_grad, 4283 ) 4284 4285 reduce_dim, reduce_dim_expected = reduce_dim_tuple 4286 4287 if nt.dim() > reduce_dim: 4288 out_actual = torch.nn.functional.softmax( 4289 nt, dim=reduce_dim 4290 ) # nested tensor 4291 out_expected = torch.nn.functional.softmax( 4292 nt.values(), dim=reduce_dim_expected 4293 ) # dense tensor of dimensions 1 less than out_actual 4294 self.assertTrue( 4295 torch.allclose(out_actual.values().view(-1), out_expected.view(-1)) 4296 ) 4297 4298 @dtypes(torch.float32) 4299 @parametrize( 4300 "func", 4301 [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], 4302 name_fn=get_op_name, 4303 ) 4304 @parametrize("keepdim", [False, True]) 4305 @parametrize("requires_grad", [False, True]) 4306 @parametrize("components_require_grad", [False, True]) 4307 def test_op_dim_reduce_ragged_idx_1_different_output_shape( 4308 self, device, dtype, keepdim, requires_grad, components_require_grad, func 4309 ): 4310 """ 4311 Operator on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1. 4312 This test is for operators which return an output tensor with a shape different from the input tensor. 4313 """ 4314 if get_op_name(func) == "mean" and not keepdim: 4315 return 4316 4317 op_name = get_op_name(func) 4318 4319 tensor_lists = self._get_example_tensor_lists( 4320 include_list_of_lists=False, 4321 include_requires_grad=components_require_grad, 4322 include_inner_dim_size_1=True, # (B, *, 1) 4323 ) 4324 reduce_dim = (1,) # ragged 4325 4326 for tensor_list in tensor_lists: 4327 nt = torch.nested.nested_tensor( 4328 tensor_list, 4329 device=device, 4330 dtype=dtype, 4331 layout=torch.jagged, 4332 requires_grad=requires_grad, 4333 ) 4334 4335 out_actual = func(nt, dim=reduce_dim, keepdim=keepdim) 4336 out_expected = torch.cat( 4337 [func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0) for t in nt.unbind()] 4338 ) 4339 4340 self.assertFalse( 4341 out_actual.is_nested, 4342 f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor", 4343 ) # output is a dense tensor 4344 self.assertTrue(torch.allclose(out_actual, out_expected)) 4345 4346 @dtypes(torch.float32) 4347 @parametrize("requires_grad", [False, True]) 4348 @parametrize("components_require_grad", [False, True]) 4349 def test_softmax_dim_reduce_ragged_idx_1( 4350 self, device, dtype, requires_grad, components_require_grad 4351 ): 4352 """ 4353 Softmax on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1. 4354 """ 4355 tensor_lists = self._get_example_tensor_lists( 4356 include_list_of_lists=False, 4357 include_requires_grad=components_require_grad, 4358 include_inner_dim_size_1=True, # (B, *, 1) 4359 include_2d_tensor=True, # (B, *) 4360 ) 4361 reduce_dim = 1 # ragged 4362 4363 for tensor_list in tensor_lists: 4364 nt = torch.nested.nested_tensor( 4365 tensor_list, 4366 device=device, 4367 dtype=dtype, 4368 layout=torch.jagged, 4369 requires_grad=requires_grad, 4370 ) 4371 4372 out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim) 4373 out_expected = torch.cat( 4374 [ 4375 torch.nn.functional.softmax(t, dim=reduce_dim - 1) 4376 for t in nt.unbind() 4377 ] 4378 ) 4379 4380 self.assertTrue( 4381 out_actual.is_nested, 4382 "softmax(): the result of reducing a nested tensor along the ragged dimension is a nested tensor", 4383 ) # output is a nested tensor 4384 self.assertTrue(torch.allclose(out_actual.values(), out_expected)) 4385 4386 @dtypes(torch.float32) 4387 @parametrize("requires_grad", [False, True]) 4388 @parametrize("components_require_grad", [False, True]) 4389 def test_softmax_reduce_batch_dim( 4390 self, device, dtype, requires_grad, components_require_grad 4391 ): 4392 """ 4393 Softmax on NestedTensor fails when trying to reduce across batch dimension. 4394 """ 4395 tensor_lists = self._get_example_tensor_lists( 4396 include_list_of_lists=False, 4397 include_requires_grad=components_require_grad, 4398 include_inner_dim_size_1=True, # (B, *, 1) 4399 ) 4400 reduce_dim = 0 # batch 4401 4402 for tensor_list in tensor_lists: 4403 nt = torch.nested.nested_tensor( 4404 tensor_list, 4405 device=device, 4406 dtype=dtype, 4407 layout=torch.jagged, 4408 requires_grad=requires_grad, 4409 ) 4410 4411 with self.assertRaisesRegex( 4412 RuntimeError, 4413 "not supported when reducing across the batch dimension for NestedTensor", 4414 ): 4415 out = torch.nn.functional.softmax(nt, dim=reduce_dim) 4416 4417 @dtypes(torch.float32) 4418 @parametrize("requires_grad", [False, True]) 4419 @parametrize("components_require_grad", [False, True]) 4420 def test_layer_norm_reduce_ragged_idx_1( 4421 self, device, dtype, requires_grad, components_require_grad 4422 ): 4423 """ 4424 Layer normalization on NestedTensor passes when trying to normalize across ragged dimension, where ragged_idx == 1. 4425 """ 4426 4427 # requires_grad = False does not currently work with dynamo tests and throws this error: 4428 # AssertionError: SymInts must use SymNodeVariable. 4429 # If the underlying value is static, we will create a ConstantVariable and specialize. 4430 if torch._dynamo.is_compiling() and not requires_grad: 4431 return 4432 4433 tensor_lists = self._get_example_tensor_lists( 4434 include_list_of_lists=False, 4435 include_requires_grad=components_require_grad, 4436 include_inner_dim_size_1=True, # (B, *, 1) 4437 ) 4438 4439 for tensor_list in tensor_lists: 4440 nt = torch.nested.nested_tensor( 4441 tensor_list, 4442 device=device, 4443 dtype=dtype, 4444 layout=torch.jagged, 4445 requires_grad=requires_grad, 4446 ) 4447 4448 if ( 4449 nt.dim() >= 3 4450 ): # layer norm only works for tensors with 3 or more dimensions 4451 normalized_shape = nt.shape[nt._ragged_idx :] 4452 4453 out_actual = torch.nn.functional.layer_norm( 4454 nt, normalized_shape=normalized_shape 4455 ) 4456 out_expected = torch.cat( 4457 [ 4458 torch.nn.functional.layer_norm(t, normalized_shape=t.shape) 4459 for t in nt.unbind() 4460 ] 4461 ) # e.g. in 3D tensor (B, *, M), performs layer normalization on B 2D tensors (*, M) 4462 4463 self.assertTrue( 4464 out_actual.is_nested, 4465 "layer_norm(): the result of reducing a nested tensor along the ragged dimension is a nested tensor", 4466 ) # output is a nested tensor 4467 self.assertEqual(out_actual._values.shape, out_expected.shape) 4468 self.assertTrue(torch.allclose(out_actual.values(), out_expected)) 4469 4470 @dtypes(torch.float32) 4471 @parametrize("requires_grad", [False, True]) 4472 @parametrize("components_require_grad", [False, True]) 4473 def test_layer_norm_2d_input( 4474 self, 4475 device, 4476 dtype, 4477 requires_grad, 4478 components_require_grad, 4479 ): 4480 """ 4481 Layer normalization on NestedTensor fails when trying to operate on a 2-dimensional tensor 4482 """ 4483 tensor_lists = self._get_example_tensor_lists( 4484 include_list_of_lists=False, 4485 include_requires_grad=components_require_grad, 4486 include_inner_dim_size_1=True, # (B, *, 1) 4487 include_2d_tensor=True, # (B, *) 4488 ) 4489 4490 for tensor_list in tensor_lists: 4491 nt = torch.nested.nested_tensor( 4492 tensor_list, 4493 device=device, 4494 dtype=dtype, 4495 layout=torch.jagged, 4496 requires_grad=requires_grad, 4497 ) 4498 4499 if nt.dim() <= 2: 4500 with self.assertRaisesRegex( 4501 RuntimeError, 4502 "not supported for NestedTensor objects with 2 or fewer dimensions", 4503 ): 4504 out = torch.nn.functional.layer_norm( 4505 nt, normalized_shape=(nt.shape[nt._ragged_idx],) 4506 ) 4507 4508 @dtypes(torch.float32) 4509 @parametrize("requires_grad", [False, True]) 4510 @parametrize("components_require_grad", [False, True]) 4511 def test_layer_norm_operate_on_batch_dim( 4512 self, 4513 device, 4514 dtype, 4515 requires_grad, 4516 components_require_grad, 4517 ): 4518 """ 4519 Layer normalization on NestedTensor fails when trying to operate on the batch dimension 4520 """ 4521 tensor_lists = self._get_example_tensor_lists( 4522 include_list_of_lists=False, 4523 include_requires_grad=components_require_grad, 4524 include_inner_dim_size_1=True, # (B, *, 1) 4525 include_2d_tensor=True, # (B, *) 4526 ) 4527 4528 for tensor_list in tensor_lists: 4529 nt = torch.nested.nested_tensor( 4530 tensor_list, 4531 device=device, 4532 dtype=dtype, 4533 layout=torch.jagged, 4534 requires_grad=requires_grad, 4535 ) 4536 4537 if nt.dim() > 2: # cannot perform layer normalization on 2D tensors 4538 with self.assertRaisesRegex( 4539 RuntimeError, 4540 "not supported when normalizing over the batch dimension for NestedTensor", 4541 ): 4542 out = torch.nn.functional.layer_norm(nt, normalized_shape=nt.shape) 4543 4544 @dtypes(torch.float32) 4545 @parametrize( 4546 "func", 4547 [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], 4548 name_fn=get_op_name, 4549 ) 4550 @parametrize( 4551 "transpose_offset", [1, 2] 4552 ) # [transpose consecutive dimensions, transpose nonconsecutive dimensions] 4553 @parametrize("keepdim", [False, True]) 4554 @parametrize("requires_grad", [False, True]) 4555 @parametrize("components_require_grad", [False, True]) 4556 def test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape( 4557 self, 4558 device, 4559 dtype, 4560 keepdim, 4561 requires_grad, 4562 components_require_grad, 4563 func, 4564 transpose_offset, 4565 ): 4566 """ 4567 Operator on NestedTensor passes when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1 4568 This test is for operators which return an output tensor with a shape different from the input tensor. 4569 """ 4570 if get_op_name(func) == "mean" and not keepdim: 4571 return 4572 4573 op_name = get_op_name(func) 4574 4575 tensor_lists = self._get_example_tensor_lists( 4576 include_list_of_lists=False, 4577 include_requires_grad=components_require_grad, 4578 include_inner_dim_size_1=True, # (B, *, 1) 4579 include_2d_tensor=True, # (B, *) 4580 ) 4581 4582 for tensor_list in tensor_lists: 4583 nt = torch.nested.nested_tensor( 4584 tensor_list, 4585 device=device, 4586 dtype=dtype, 4587 layout=torch.jagged, 4588 requires_grad=requires_grad, 4589 ) 4590 4591 if nt.dim() > nt._ragged_idx + transpose_offset: 4592 nt_transposed = nt.transpose( 4593 nt._ragged_idx, nt._ragged_idx + transpose_offset 4594 ) 4595 reduce_dim = (nt_transposed._ragged_idx,) # ragged 4596 4597 out_actual = func(nt_transposed, dim=reduce_dim, keepdim=keepdim) 4598 out_expected = torch.cat( 4599 [ 4600 func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0) 4601 for t in nt_transposed.unbind() 4602 ] 4603 ) 4604 4605 self.assertFalse( 4606 out_actual.is_nested, 4607 f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor", 4608 ) # output is a dense tensor 4609 self.assertTrue(torch.allclose(out_actual, out_expected, rtol=1e-4)) 4610 4611 @dtypes(torch.float32) 4612 @parametrize( 4613 "transpose_offset", [1, 2] 4614 ) # [transpose consecutive dimensions, transpose nonconsecutive dimensions] 4615 @parametrize("requires_grad", [False, True]) 4616 @parametrize("components_require_grad", [False, True]) 4617 def test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape( 4618 self, 4619 device, 4620 dtype, 4621 requires_grad, 4622 components_require_grad, 4623 transpose_offset, 4624 ): 4625 """ 4626 Softmax on NestedTensor fails when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1 4627 This test is for operators which return an output tensor with the same shape as the input tensor. 4628 """ 4629 tensor_lists = self._get_example_tensor_lists( 4630 include_list_of_lists=False, 4631 include_requires_grad=components_require_grad, 4632 include_inner_dim_size_1=True, # (B, *, 1) 4633 ) 4634 4635 for tensor_list in tensor_lists: 4636 nt = torch.nested.nested_tensor( 4637 tensor_list, 4638 device=device, 4639 dtype=dtype, 4640 layout=torch.jagged, 4641 requires_grad=requires_grad, 4642 ) 4643 4644 if nt.dim() > nt._ragged_idx + transpose_offset: 4645 nt_transposed = nt.transpose( 4646 nt._ragged_idx, nt._ragged_idx + transpose_offset 4647 ) 4648 reduce_dim = nt_transposed._ragged_idx # ragged 4649 4650 with self.assertRaisesRegex( 4651 RuntimeError, 4652 "not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor", 4653 ): 4654 out = torch.nn.functional.softmax(nt_transposed, dim=reduce_dim) 4655 4656 @dtypes(torch.float32) 4657 @parametrize( 4658 "func", 4659 [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], 4660 name_fn=get_op_name, 4661 ) 4662 @parametrize("keepdim", [False, True]) 4663 @parametrize("requires_grad", [False, True]) 4664 @parametrize("components_require_grad", [False, True]) 4665 def test_op_dim_transpose_non_ragged_dim_different_output_shape( 4666 self, device, dtype, keepdim, requires_grad, components_require_grad, func 4667 ): 4668 """ 4669 Operator passes when reducing transposed nested tensors on valid reduction dimensions. 4670 This test is for operators which return an output tensor with a shape different from the input tensor. 4671 """ 4672 if get_op_name(func) == "mean" and not keepdim: 4673 return 4674 4675 # verify correctness of shapes (assuming that ragged_idx == 1) 4676 if get_op_name(func) == "sum": 4677 reduce_dims = ( 4678 ((0, 1), (3, 4), (1, 1, 3, 4), (0,)), # batch, ragged 4679 ((2, 3), (3, None), (3, None, 1, 1), (1, 2)), # non-batch, non-batch 4680 ((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)), # batch, ragged, non-batch 4681 ((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)), # batch, ragged, non-batch 4682 ( 4683 (0, 1, 2, 3), 4684 (), 4685 (1, 1, 1, 1), 4686 (0, 1, 2), 4687 ), # batch, ragged, non-batch, non-batch 4688 ((2,), (3, None, 4), (3, None, 1, 4), (1,)), # non-batch 4689 ) # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None 4690 elif get_op_name(func) == "mean": 4691 reduce_dims = ( 4692 ((2,), (3, None, 4), (3, None, 1, 4), (1,)), 4693 ((3,), (3, None, 3), (3, None, 3, 1), (2,)), 4694 ) 4695 4696 # verify correctness of values 4697 tensor_lists = self._get_example_tensor_lists( 4698 include_list_of_lists=False, 4699 include_requires_grad=components_require_grad, 4700 ) 4701 for tensor_list, reduce_dim_tuple in itertools.product( 4702 tensor_lists, reduce_dims 4703 ): 4704 nt = torch.nested.nested_tensor( 4705 tensor_list, 4706 device=device, 4707 dtype=dtype, 4708 layout=torch.jagged, 4709 requires_grad=requires_grad, 4710 ).transpose(-1, -2) 4711 4712 reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple 4713 4714 if nt.dim() > max( 4715 reduce_dim[-1], nt._ragged_idx + 2 4716 ): # ensure that transposed dimensions are non-batch, non-ragged dimensions 4717 out_actual = func(nt, dim=reduce_dim, keepdim=keepdim) 4718 if nt._ragged_idx in reduce_dim: # raggedness reduced away 4719 out_expected = func( 4720 nt.values(), dim=reduce_dim_expected, keepdim=keepdim 4721 ) 4722 self.assertTrue(torch.allclose(out_actual, out_expected)) 4723 else: # raggedness preserved 4724 out_expected = func(nt.values(), dim=reduce_dim_expected) 4725 self.assertTrue( 4726 torch.allclose( 4727 out_actual.values().view(-1), out_expected.view(-1) 4728 ) 4729 ) 4730 4731 @dtypes(torch.float32) 4732 @parametrize("requires_grad", [False, True]) 4733 @parametrize("components_require_grad", [False, True]) 4734 def test_softmax_dim_transpose_non_ragged_dim( 4735 self, 4736 device, 4737 dtype, 4738 requires_grad, 4739 components_require_grad, 4740 ): 4741 """ 4742 Softmax passes when reducing transposed nested tensors on valid reduction dimensions. 4743 This test is for operators which return an output tensor with the same shape as the input tensor. 4744 """ 4745 # verify correctness of shapes (assuming that ragged_idx == 1) 4746 reduce_dims = ( 4747 (2, 1), 4748 (3, 2), 4749 ) # (reduction dimension, effective reduction dimension for baseline) 4750 4751 # verify correctness of values 4752 tensor_lists = self._get_example_tensor_lists( 4753 include_list_of_lists=False, 4754 include_requires_grad=components_require_grad, 4755 include_inner_dim_size_1=True, # (B, *, 1) 4756 ) 4757 for tensor_list, reduce_dim_tuple in itertools.product( 4758 tensor_lists, reduce_dims 4759 ): 4760 nt = torch.nested.nested_tensor( 4761 tensor_list, 4762 device=device, 4763 dtype=dtype, 4764 layout=torch.jagged, 4765 requires_grad=requires_grad, 4766 ).transpose(-1, -2) 4767 4768 reduce_dim, reduce_dim_expected = reduce_dim_tuple 4769 4770 if nt.dim() > max(reduce_dim, nt._ragged_idx + 2): 4771 out_actual = torch.nn.functional.softmax( 4772 nt, dim=reduce_dim 4773 ) # nested tensor 4774 out_expected = torch.nn.functional.softmax( 4775 nt.values(), dim=reduce_dim_expected 4776 ) # dense tensor of dimensions 1 less than out_actual 4777 4778 self.assertTrue( 4779 torch.allclose(out_actual.values().view(-1), out_expected.view(-1)) 4780 ) 4781 4782 @dtypes(torch.float32) 4783 @parametrize("keepdim", [False, True]) 4784 @parametrize("requires_grad", [False, True]) 4785 @parametrize("components_require_grad", [False, True]) 4786 def test_sum_dim_reduce_ragged_and_non_batch( 4787 self, 4788 device, 4789 dtype, 4790 keepdim, 4791 requires_grad, 4792 components_require_grad, 4793 ): 4794 """ 4795 Sum on NestedTensor fails when trying to reduce across ragged and non-batch dimensions 4796 """ 4797 tensor_lists = self._get_example_tensor_lists( 4798 include_list_of_lists=False, include_requires_grad=components_require_grad 4799 ) 4800 reduce_dims = ( 4801 (1, 2), # ragged, non-batch 4802 (1, 3), # ragged, non-batch 4803 ) 4804 4805 for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims): 4806 nt = torch.nested.nested_tensor( 4807 tensor_list, 4808 device=device, 4809 dtype=dtype, 4810 layout=torch.jagged, 4811 requires_grad=requires_grad, 4812 ) 4813 4814 if nt.dim() > reduce_dim[-1]: 4815 with self.assertRaisesRegex( 4816 RuntimeError, 4817 "not supported along a ragged and non-batch dimension for NestedTensor", 4818 ): 4819 out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim) 4820 4821 @dtypes(torch.float32) 4822 @parametrize("keepdim", [False, True]) 4823 @parametrize("requires_grad", [False, True]) 4824 @parametrize("components_require_grad", [False, True]) 4825 def test_sum_dim_reduce_batch_and_non_batch( 4826 self, 4827 device, 4828 dtype, 4829 keepdim, 4830 requires_grad, 4831 components_require_grad, 4832 ): 4833 """ 4834 Sum on NestedTensor fails when trying to reduce across batch and non-batch dimensions 4835 """ 4836 tensor_lists = self._get_example_tensor_lists( 4837 include_list_of_lists=False, include_requires_grad=components_require_grad 4838 ) 4839 reduce_dims = ( 4840 (0, 2), # batch, non-batch 4841 (0, 3), # batch, non-batch 4842 ) 4843 4844 for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims): 4845 nt = torch.nested.nested_tensor( 4846 tensor_list, 4847 device=device, 4848 dtype=dtype, 4849 layout=torch.jagged, 4850 requires_grad=requires_grad, 4851 ) 4852 4853 if nt.dim() > reduce_dim[-1]: 4854 with self.assertRaisesRegex( 4855 RuntimeError, 4856 "not supported along the batch dimension but not the ragged dimension for NestedTensor", 4857 ): 4858 out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim) 4859 4860 @dtypes(torch.float32) 4861 @parametrize( 4862 "func", 4863 [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], 4864 name_fn=get_op_name, 4865 ) 4866 @parametrize("keepdim", [False, True]) 4867 @parametrize("requires_grad", [False, True]) 4868 @parametrize("components_require_grad", [False, True]) 4869 def test_op_dim_reduce_batch_only_different_output_shape( 4870 self, device, dtype, keepdim, requires_grad, components_require_grad, func 4871 ): 4872 """ 4873 Operator on NestedTensor fails when trying to reduce across batch dimension 4874 """ 4875 if get_op_name(func) == "mean" and not keepdim: 4876 return 4877 4878 tensor_lists = self._get_example_tensor_lists( 4879 include_list_of_lists=False, include_requires_grad=components_require_grad 4880 ) 4881 reduce_dim = (0,) # batch 4882 4883 for tensor_list in tensor_lists: 4884 nt = torch.nested.nested_tensor( 4885 tensor_list, 4886 device=device, 4887 dtype=dtype, 4888 layout=torch.jagged, 4889 requires_grad=requires_grad, 4890 ) 4891 4892 with self.assertRaisesRegex( 4893 RuntimeError, 4894 "not supported along the batch dimension but not the ragged dimension for NestedTensor", 4895 ): 4896 out = func(nt, dim=reduce_dim, keepdim=keepdim) 4897 4898 @dtypes(torch.float32) 4899 @parametrize( 4900 "func", 4901 [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], 4902 name_fn=get_op_name, 4903 ) 4904 @parametrize("keepdim", [False, True]) 4905 @parametrize("requires_grad", [False, True]) 4906 @parametrize("components_require_grad", [False, True]) 4907 def test_op_dim_with_lengths_different_output_shape( 4908 self, 4909 device, 4910 dtype, 4911 keepdim, 4912 requires_grad, 4913 components_require_grad, 4914 func, 4915 ): 4916 """ 4917 Operator on NestedTensor fails when trying to reduce a nested tensor with lengths, 4918 i.e. a nested tensor with holes, if reducing on the ragged dimension. 4919 This test is for operators which return an output tensor with different shape than the input tensor. 4920 """ 4921 if get_op_name(func) == "mean" and not keepdim: 4922 return 4923 4924 reduce_dims = ((1,), (2,), (2, 3)) 4925 4926 lengths = torch.randint(5, 10, (20,), device=device) 4927 offsets = torch.zeros((21,), device=device, dtype=torch.int) 4928 torch.cumsum(lengths, dim=0, out=offsets[1:]) 4929 4930 values = torch.randn( 4931 (offsets[-1].item(), 20), 4932 device=device, 4933 dtype=dtype, 4934 requires_grad=requires_grad, 4935 ) 4936 4937 nt_with_holes = torch.nested.nested_tensor_from_jagged( 4938 values, 4939 offsets, 4940 lengths=offsets.diff() - 2, # arbitrary subtraction to create holes 4941 ) 4942 4943 for reduce_dim in reduce_dims: 4944 if nt_with_holes.dim() > reduce_dim[-1]: 4945 if nt_with_holes._ragged_idx in reduce_dim: 4946 with self.assertRaisesRegex( 4947 RuntimeError, 4948 "not supported where lengths is not None " 4949 + "if reducing across the ragged dimension for NestedTensor", 4950 ): 4951 out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim) 4952 else: 4953 out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim) 4954 4955 @dtypes(torch.float32) 4956 @parametrize("requires_grad", [False, True]) 4957 @parametrize("components_require_grad", [False, True]) 4958 def test_softmax_dim_with_lengths( 4959 self, 4960 device, 4961 dtype, 4962 requires_grad, 4963 components_require_grad, 4964 ): 4965 """ 4966 Softmax on NestedTensor fails when trying to reduce a nested tensor with lengths, 4967 i.e. a nested tensor with holes, if reducing on the ragged dimension. 4968 """ 4969 reduce_dims = (1, 2, 3) 4970 4971 lengths = torch.randint(5, 10, (20,), device=device) 4972 offsets = torch.zeros((21,), device=device, dtype=torch.int) 4973 torch.cumsum(lengths, dim=0, out=offsets[1:]) 4974 4975 values = torch.randn( 4976 (offsets[-1].item(), 20), 4977 device=device, 4978 dtype=dtype, 4979 requires_grad=requires_grad, 4980 ) 4981 4982 nt_with_holes = torch.nested.nested_tensor_from_jagged( 4983 values, 4984 offsets, 4985 lengths=offsets.diff() - 2, # arbitrary subtraction to create holes 4986 ) 4987 4988 for reduce_dim in reduce_dims: 4989 if nt_with_holes.dim() > reduce_dim: 4990 if nt_with_holes._ragged_idx == reduce_dim: 4991 with self.assertRaisesRegex( 4992 RuntimeError, 4993 "not supported where lengths is not None " 4994 + "if reducing across the ragged dimension for NestedTensor", 4995 ): 4996 out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim) 4997 else: 4998 out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim) 4999 5000 @skipIfTorchDynamo( 5001 "ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx] does not currently work " 5002 + "with dynamo tests and throws this error: `AssertionError: SymInts must use SymNodeVariable. " 5003 + "If the underlying value is static, we will create a ConstantVariable and specialize.`" 5004 ) 5005 @dtypes(torch.float32) 5006 @parametrize("requires_grad", [False, True]) 5007 @parametrize("components_require_grad", [False, True]) 5008 def test_layer_norm_with_lengths( 5009 self, 5010 device, 5011 dtype, 5012 requires_grad, 5013 components_require_grad, 5014 ): 5015 """ 5016 Layer normalization on NestedTensor fails when trying to operate on a nested tensor with lengths, 5017 i.e. a nested tensor with holes, if operating on the ragged dimension. 5018 """ 5019 5020 # create components for nested tensor 5021 lengths = torch.randint(5, 10, (20,), device=device) 5022 offsets = torch.zeros((21,), device=device, dtype=torch.int) 5023 torch.cumsum(lengths, dim=0, out=offsets[1:]) 5024 values = torch.randn( 5025 (offsets[-1].item(), 10, 30), 5026 device=device, 5027 dtype=dtype, 5028 requires_grad=requires_grad, 5029 ) 5030 5031 nt_with_holes = torch.nested.nested_tensor_from_jagged( 5032 values, 5033 offsets, 5034 lengths=offsets.diff() - 2, # arbitrary subtraction to create holes 5035 ) 5036 5037 ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx] 5038 5039 normalized_shapes = ( 5040 (10, 30), # normalization on non-ragged dimension passes 5041 (ragged_size, 10, 30), # normalization on ragged dimension fails 5042 ) 5043 5044 for normalized_shape in normalized_shapes: 5045 if ragged_size in normalized_shape: 5046 with self.assertRaisesRegex( 5047 RuntimeError, 5048 "not supported where lengths is not None if operating on the ragged dimension for NestedTensor", 5049 ): 5050 out = torch.nn.functional.layer_norm( 5051 nt_with_holes, normalized_shape=normalized_shape 5052 ) 5053 else: 5054 out = torch.nn.functional.layer_norm( 5055 nt_with_holes, normalized_shape=normalized_shape 5056 ) 5057 5058 @dtypes(torch.float32) 5059 @parametrize("keepdim", [True]) 5060 @parametrize("requires_grad", [False, True]) 5061 @parametrize("components_require_grad", [False, True]) 5062 def test_mean_dim_reduce_multiple_dims( 5063 self, 5064 device, 5065 dtype, 5066 keepdim, 5067 requires_grad, 5068 components_require_grad, 5069 ): 5070 """ 5071 Mean on NestedTensor fails when trying to reduce across multiple dimensions 5072 """ 5073 tensor_lists = self._get_example_tensor_lists( 5074 include_list_of_lists=False, include_requires_grad=components_require_grad 5075 ) 5076 reduce_dims = ((0, 1), (2, 3), (2, 3, 4)) 5077 5078 for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims): 5079 nt = torch.nested.nested_tensor( 5080 tensor_list, 5081 device=device, 5082 dtype=dtype, 5083 layout=torch.jagged, 5084 requires_grad=requires_grad, 5085 ) 5086 5087 if nt.dim() > reduce_dim[-1]: 5088 with self.assertRaisesRegex( 5089 RuntimeError, 5090 "not supported across multiple dimensions for NestedTensor", 5091 ): 5092 out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim) 5093 5094 @dtypes(torch.float32) 5095 @parametrize("keepdim", [False, True]) 5096 @parametrize("requires_grad", [False, True]) 5097 @parametrize("components_require_grad", [False, True]) 5098 def test_mean_dim_keepdim_False( 5099 self, 5100 device, 5101 dtype, 5102 keepdim, 5103 requires_grad, 5104 components_require_grad, 5105 ): 5106 """ 5107 Mean on NestedTensor fails when keepdim=False 5108 """ 5109 tensor_lists = self._get_example_tensor_lists( 5110 include_list_of_lists=False, include_requires_grad=components_require_grad 5111 ) 5112 reduce_dims = ((1,), (2,), (3,)) 5113 5114 for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims): 5115 nt = torch.nested.nested_tensor( 5116 tensor_list, 5117 device=device, 5118 dtype=dtype, 5119 layout=torch.jagged, 5120 requires_grad=requires_grad, 5121 ) 5122 5123 if nt.dim() > reduce_dim[-1]: 5124 if not keepdim: 5125 with self.assertRaisesRegex( 5126 RuntimeError, 5127 "not supported when keepdim=False for NestedTensor", 5128 ): 5129 out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim) 5130 else: 5131 out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim) 5132 5133 @dtypes(torch.float, torch.double, torch.half) 5134 @parametrize("requires_grad", [False, True]) 5135 @parametrize("weights_only", [False, True]) 5136 def test_serialization(self, device, dtype, requires_grad, weights_only): 5137 def compare_metadata(nt1, nt2): 5138 self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size()) 5139 self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides()) 5140 self.assertEqual( 5141 nt1._nested_tensor_storage_offsets(), 5142 nt2._nested_tensor_storage_offsets(), 5143 ) 5144 5145 nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) 5146 for a in [nt_contiguous, nt_noncontiguous]: 5147 buffer = io.BytesIO() 5148 serialized = torch.save(a, buffer) 5149 buffer.seek(0) 5150 b = torch.load(buffer, weights_only=weights_only) 5151 # should be both conceptually equal and metadata equivalent 5152 self.assertEqual(a, b) 5153 compare_metadata(a, b) 5154 # should be conceptually equal but not necessarily metadata equivalent 5155 self.assertEqual(b, nt_contiguous) 5156 self.assertEqual(b, nt_noncontiguous) 5157 5158 @unittest.skipIf( 5159 PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property" 5160 ) 5161 @onlyCUDA 5162 def test_pin_memory(self, device): 5163 nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) 5164 for nt in [nt_contiguous, nt_noncontiguous]: 5165 self.assertFalse(nt.is_pinned()) 5166 pinned = nt.pin_memory(device) 5167 self.assertTrue(pinned.is_pinned()) 5168 self.assertEqual(nt, pinned) 5169 self.assertNotEqual(nt.data_ptr(), pinned.data_ptr()) 5170 # test that pin_memory on already pinned tensor has no effect 5171 self.assertIs(pinned, pinned.pin_memory()) 5172 self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr()) 5173 5174 @torch.compiler.disable 5175 def _validate_nt( 5176 self, 5177 nt, 5178 device, 5179 dtype, 5180 layout, 5181 requires_grad, 5182 dim, 5183 batch_size, 5184 contiguous, 5185 cached_min_seqlen=None, 5186 cached_max_seqlen=None, 5187 base=None, 5188 ref_nt=None, 5189 ): 5190 # Validate a bunch of properties after NT construction. 5191 device = torch.device(device) 5192 self.assertEqual(nt.dim(), dim) 5193 self.assertEqual(nt.device, device) 5194 self.assertEqual(nt.dtype, dtype) 5195 self.assertEqual(nt.layout, layout) 5196 self.assertEqual(nt.requires_grad, requires_grad) 5197 self.assertEqual(nt.is_contiguous(), contiguous) 5198 5199 if layout == torch.jagged: 5200 self.assertEqual(nt._values.device, device) 5201 self.assertEqual(nt._offsets.device, device) 5202 self.assertEqual(nt.shape[0], batch_size) 5203 self.assertTrue(isinstance(nt.shape[1], torch.SymInt)) 5204 5205 if base is not None: 5206 self.assertTrue(nt._is_view() and nt._base is base) 5207 replay_cache = nt._view_func(torch.randn_like(nt._base))._metadata_cache 5208 self.assertEqual( 5209 "min_seqlen" in replay_cache, cached_min_seqlen is not None 5210 ) 5211 self.assertEqual( 5212 "max_seqlen" in replay_cache, cached_max_seqlen is not None 5213 ) 5214 5215 self.assertEqual( 5216 "min_seqlen" in nt._metadata_cache, cached_min_seqlen is not None 5217 ) 5218 self.assertEqual( 5219 "max_seqlen" in nt._metadata_cache, cached_max_seqlen is not None 5220 ) 5221 5222 if cached_min_seqlen is not None: 5223 self.assertEqual(nt._min_seqlen, cached_min_seqlen) 5224 5225 if cached_max_seqlen is not None: 5226 self.assertEqual(nt._max_seqlen, cached_max_seqlen) 5227 5228 if ref_nt is not None: 5229 self.assertEqual(nt.size(0), ref_nt.size(0)) 5230 for n1, n2 in zip(nt.unbind(), ref_nt.unbind()): 5231 self.assertEqual(n1, n2) 5232 5233 @dtypes(torch.float, torch.double, torch.half) 5234 @parametrize("requires_grad", [False, True]) 5235 @parametrize("components_require_grad", [False, True]) 5236 def test_jagged_layout_construction_nested_tensor( 5237 self, device, dtype, requires_grad, components_require_grad 5238 ): 5239 for tensor_list in self._get_example_tensor_lists( 5240 include_list_of_lists=True, include_requires_grad=components_require_grad 5241 ): 5242 nt = torch.nested.nested_tensor( 5243 tensor_list, 5244 device=device, 5245 dtype=dtype, 5246 layout=torch.jagged, 5247 requires_grad=requires_grad, 5248 ) 5249 5250 expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1 5251 expected_batch_size = len(tensor_list) 5252 expected_contiguous = True 5253 expected_min_seqlen = min( 5254 (torch.tensor(t) if isinstance(t, list) else t).shape[0] 5255 for t in tensor_list 5256 ) 5257 expected_max_seqlen = max( 5258 (torch.tensor(t) if isinstance(t, list) else t).shape[0] 5259 for t in tensor_list 5260 ) 5261 self._validate_nt( 5262 nt, 5263 device, 5264 dtype, 5265 torch.jagged, 5266 requires_grad, 5267 expected_dim, 5268 expected_batch_size, 5269 expected_contiguous, 5270 expected_min_seqlen, 5271 expected_max_seqlen, 5272 ) 5273 5274 # Make sure grads -don't- flow back into original tensors for nested_tensor() 5275 if requires_grad: 5276 (nt * 2).backward(torch.ones_like(nt)) 5277 for t in tensor_list: 5278 t = t if isinstance(t, torch.Tensor) else torch.as_tensor(t) 5279 self.assertTrue(t.grad is None) 5280 5281 @dtypes(torch.float, torch.double, torch.half) 5282 @parametrize("components_require_grad", [False, True]) 5283 def test_jagged_layout_construction_as_nested_tensor( 5284 self, device, dtype, components_require_grad 5285 ): 5286 # NB: as_nested_tensor(tensor_list) doesn't support lists of lists for tensor_list 5287 for tensor_list in self._get_example_tensor_lists( 5288 include_list_of_lists=False, include_requires_grad=components_require_grad 5289 ): 5290 nt = torch.nested.as_nested_tensor( 5291 tensor_list, device=device, dtype=dtype, layout=torch.jagged 5292 ) 5293 5294 # nt.requires_grad=True should be set if at least one component requires grad 5295 expected_dim = tensor_list[0].dim() + 1 5296 expected_batch_size = len(tensor_list) 5297 expected_contiguous = True 5298 expected_min_seqlen = min( 5299 (torch.tensor(t) if isinstance(t, list) else t).shape[0] 5300 for t in tensor_list 5301 ) 5302 expected_max_seqlen = max( 5303 (torch.tensor(t) if isinstance(t, list) else t).shape[0] 5304 for t in tensor_list 5305 ) 5306 self._validate_nt( 5307 nt, 5308 device, 5309 dtype, 5310 torch.jagged, 5311 components_require_grad, 5312 expected_dim, 5313 expected_batch_size, 5314 expected_contiguous, 5315 expected_min_seqlen, 5316 expected_max_seqlen, 5317 ) 5318 5319 # Make sure grads flow back into original tensors for as_nested_tensor() 5320 if components_require_grad: 5321 (nt * 2).backward(torch.ones_like(nt)) 5322 for t in tensor_list: 5323 if t.requires_grad: 5324 self.assertEqual(t.grad, torch.ones_like(t) * 2) 5325 else: 5326 self.assertTrue(t.grad is None) 5327 5328 @xfailIfTorchDynamo 5329 @unittest.skipIf( 5330 PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property" 5331 ) 5332 @onlyCUDA 5333 def test_jagged_layout_construction_with_pinned_memory(self, device): 5334 for tensor_list in self._get_example_tensor_lists(): 5335 nt = torch.nested.nested_tensor( 5336 tensor_list, layout=torch.jagged, device="cpu", pin_memory=True 5337 ) 5338 5339 expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1 5340 expected_batch_size = len(tensor_list) 5341 expected_min_seqlen = min( 5342 (torch.tensor(t) if isinstance(t, list) else t).shape[0] 5343 for t in tensor_list 5344 ) 5345 expected_max_seqlen = max( 5346 (torch.tensor(t) if isinstance(t, list) else t).shape[0] 5347 for t in tensor_list 5348 ) 5349 self._validate_nt( 5350 nt, 5351 device="cpu", 5352 dtype=torch.float32, 5353 layout=torch.jagged, 5354 requires_grad=False, 5355 dim=expected_dim, 5356 batch_size=expected_batch_size, 5357 contiguous=True, 5358 cached_min_seqlen=expected_min_seqlen, 5359 cached_max_seqlen=expected_max_seqlen, 5360 ) 5361 self.assertTrue(nt.is_pinned()) 5362 5363 @dtypes(torch.float, torch.double, torch.half) 5364 @parametrize("requires_grad", [False, True]) 5365 @parametrize("values_is_view", [False, True]) 5366 def test_jagged_view_from_values_offsets( 5367 self, device, dtype, requires_grad, values_is_view 5368 ): 5369 if values_is_view: 5370 # make values a view of base 5371 base = torch.randn( 5372 2, 3, 4, 5, 6, device=device, dtype=dtype, requires_grad=requires_grad 5373 ) 5374 values = base.flatten(0, -2) 5375 else: 5376 values = torch.randn( 5377 10, 5, device=device, dtype=dtype, requires_grad=requires_grad 5378 ) 5379 offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64) 5380 5381 nt = nested_view_from_values_offsets(values, offsets) 5382 5383 expected_dim = values.dim() + 1 5384 expected_batch_size = offsets.shape[0] - 1 5385 expected_base = base if values_is_view else values 5386 lengths = offsets.diff() 5387 self._validate_nt( 5388 nt, 5389 device, 5390 dtype, 5391 torch.jagged, 5392 requires_grad, 5393 expected_dim, 5394 expected_batch_size, 5395 # ensure NT is a proper view 5396 base=expected_base, 5397 contiguous=True, 5398 # if no min / max are passed, expect the metadata cache to be empty 5399 cached_min_seqlen=None, 5400 cached_max_seqlen=None, 5401 ) 5402 5403 if requires_grad: 5404 # Make sure grads flow back 5405 (nt * 2).backward(torch.ones_like(nt)) 5406 5407 @torch.compiler.disable 5408 def _check_grad(t): 5409 self.assertTrue(t.grad is not None) 5410 self.assertEqual(t.grad, torch.ones_like(t) * 2) 5411 5412 _check_grad(base if values_is_view else values) 5413 5414 @dtypes(torch.float) 5415 @parametrize("pass_min_max", [False, True]) 5416 def test_nested_tensor_from_jagged(self, device, dtype, pass_min_max): 5417 # === construct from (values, offsets) === 5418 values = torch.randn(10, 5, device=device, dtype=dtype) 5419 offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64) 5420 5421 # compute min / max seqlen 5422 lengths = offsets.diff() 5423 min_seqlen = lengths.min().item() 5424 max_seqlen = lengths.max().item() 5425 5426 if pass_min_max: 5427 nt = torch.nested.nested_tensor_from_jagged( 5428 values, offsets=offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen 5429 ) 5430 else: 5431 nt = torch.nested.nested_tensor_from_jagged(values, offsets=offsets) 5432 self._validate_nt( 5433 nt, 5434 device, 5435 dtype, 5436 torch.jagged, 5437 requires_grad=False, 5438 dim=3, 5439 batch_size=4, 5440 contiguous=True, 5441 cached_min_seqlen=(min_seqlen if pass_min_max else None), 5442 cached_max_seqlen=(max_seqlen if pass_min_max else None), 5443 base=values, 5444 ) 5445 5446 # === construct from (values, offsets, lengths) === 5447 lengths = torch.tensor([2, 1, 1, 2], device=device) 5448 5449 # compute min / max seqlen 5450 min_seqlen = lengths.min().item() 5451 max_seqlen = lengths.max().item() 5452 5453 if pass_min_max: 5454 nt = torch.nested.nested_tensor_from_jagged( 5455 values, 5456 offsets=offsets, 5457 lengths=lengths, 5458 min_seqlen=min_seqlen, 5459 max_seqlen=max_seqlen, 5460 ) 5461 else: 5462 nt = torch.nested.nested_tensor_from_jagged( 5463 values, offsets=offsets, lengths=lengths 5464 ) 5465 5466 # when both offsets / lengths are specified, expect non-contiguous 5467 self._validate_nt( 5468 nt, 5469 device, 5470 dtype, 5471 torch.jagged, 5472 requires_grad=False, 5473 dim=3, 5474 batch_size=4, 5475 contiguous=False, 5476 cached_min_seqlen=(min_seqlen if pass_min_max else None), 5477 cached_max_seqlen=(max_seqlen if pass_min_max else None), 5478 base=values, 5479 ) 5480 self.assertIs(nt.lengths(), lengths) 5481 5482 # === construct from (values, lengths) === 5483 values = torch.randn(14, 5, device=device, dtype=dtype) 5484 lengths = torch.tensor([2, 3, 4, 5], device=device) 5485 5486 # compute min / max seqlen 5487 min_seqlen = lengths.min().item() 5488 max_seqlen = lengths.max().item() 5489 5490 if pass_min_max: 5491 nt = torch.nested.nested_tensor_from_jagged( 5492 values, lengths=lengths, min_seqlen=min_seqlen, max_seqlen=max_seqlen 5493 ) 5494 else: 5495 nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths) 5496 5497 # for now, if only lengths is specified, convert to offsets to integrate best with the 5498 # existing kernels 5499 expected_offsets = torch.tensor([0, 2, 5, 9, 14], device=device) 5500 expected_nt = torch.nested.nested_tensor_from_jagged( 5501 values, offsets=expected_offsets 5502 ) 5503 self._validate_nt( 5504 nt, 5505 device, 5506 dtype, 5507 torch.jagged, 5508 requires_grad=False, 5509 dim=3, 5510 batch_size=4, 5511 contiguous=True, 5512 cached_min_seqlen=(min_seqlen if pass_min_max else None), 5513 cached_max_seqlen=(max_seqlen if pass_min_max else None), 5514 base=values, 5515 ref_nt=expected_nt, 5516 ) 5517 5518 # error case: no offsets or lengths 5519 with self.assertRaisesRegex( 5520 RuntimeError, "At least one of offsets or lengths is required" 5521 ): 5522 torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None) 5523 5524 @onlyCPU 5525 def test_nested_tensor_from_jagged_fx_trace(self, device): 5526 def fn(x, y): 5527 return torch.nested.nested_tensor_from_jagged(x, y) 5528 5529 def user_unwrapped(x, y): 5530 return fn(x, y) 5531 5532 with self.assertRaisesRegex( 5533 RuntimeError, 5534 "torch.nested.nested_tensor_from_jagged does not support tracing with fx.symbolic_trace", 5535 ): 5536 torch.fx.symbolic_trace(user_unwrapped) 5537 5538 @dtypes(torch.float, torch.double, torch.half) 5539 @parametrize("dim", range(5)) 5540 @parametrize( 5541 "layout", 5542 [torch.strided, torch.jagged], 5543 name_fn=lambda l: f"layout_{str(l).split('.')[1]}", 5544 ) 5545 @parametrize("requires_grad", [False, True]) 5546 @parametrize("contiguous", [False, True]) 5547 def test_as_nested_tensor_from_tensor( 5548 self, device, dtype, dim, layout, requires_grad, contiguous 5549 ): 5550 if dim == 0: 5551 t = torch.tensor(3.0, requires_grad=requires_grad) 5552 else: 5553 t = torch.randn(*(3 for _ in range(dim)), requires_grad=requires_grad) 5554 assert t.dim() == dim 5555 5556 if dim < 2: 5557 # 0-1 dim tensors can't be converted to NTs 5558 with self.assertRaisesRegex( 5559 RuntimeError, "Expected tensor argument to have dim" 5560 ): 5561 nt = torch.nested.as_nested_tensor( 5562 t, device=device, dtype=dtype, layout=layout 5563 ) 5564 return 5565 5566 orig_t = t 5567 if not contiguous: 5568 t = t.transpose(0, 1) 5569 5570 nt = torch.nested.as_nested_tensor(t, device=device, dtype=dtype, layout=layout) 5571 expected_dim = t.dim() 5572 expected_batch_size = t.size(0) 5573 expected_seqlen = t.size(1) if layout == torch.jagged else None 5574 self._validate_nt( 5575 nt, 5576 device, 5577 dtype, 5578 layout, 5579 requires_grad=requires_grad, 5580 dim=dim, 5581 batch_size=expected_batch_size, 5582 contiguous=True, 5583 cached_min_seqlen=expected_seqlen, 5584 cached_max_seqlen=expected_seqlen, 5585 ) 5586 5587 if torch.device(device) == t.device and dtype == t.dtype and contiguous: 5588 # should be the non-copying (view) case 5589 self.assertTrue(nt._is_view() and nt._base is t) 5590 5591 # should have equivalent components to construction from unbound tensor list 5592 nt_from_unbind = torch.nested.as_nested_tensor( 5593 list(t.unbind(0)), device=device, dtype=dtype, layout=layout 5594 ) 5595 self.assertEqualIgnoringNestedInts(nt, nt_from_unbind) 5596 5597 # ensure call on a NT with the same properties returns the NT directly 5598 nt2 = torch.nested.as_nested_tensor( 5599 nt, device=device, dtype=dtype, layout=layout 5600 ) 5601 self.assertTrue(nt is nt2) 5602 5603 # ensure call with device=None uses input tensor device 5604 nt3 = torch.nested.as_nested_tensor( 5605 t.to(device=device, dtype=dtype), 5606 device=None, 5607 dtype=None, 5608 layout=layout, 5609 ) 5610 self._validate_nt( 5611 nt3, 5612 device, 5613 dtype, 5614 layout, 5615 requires_grad=requires_grad, 5616 dim=dim, 5617 batch_size=expected_batch_size, 5618 contiguous=True, 5619 cached_min_seqlen=expected_seqlen, 5620 cached_max_seqlen=expected_seqlen, 5621 ) 5622 5623 # we don't support conversion between layouts this way atm 5624 other_layout = torch.strided if layout == torch.jagged else torch.jagged 5625 with self.assertRaisesRegex( 5626 RuntimeError, "Converting between nested tensor layouts is not supported" 5627 ): 5628 torch.nested.as_nested_tensor( 5629 nt, device=device, dtype=dtype, layout=other_layout 5630 ) 5631 5632 if requires_grad: 5633 # make sure gradients flow back into inputs 5634 (nt * 2).backward(torch.ones_like(nt)) 5635 self.assertEqual(orig_t.grad, torch.ones_like(orig_t) * 2) 5636 5637 @dtypes(torch.double, torch.half) 5638 @onlyCUDA 5639 def test_device_dtype_transfer_updates_offsets(self, device, dtype): 5640 for tensor_list in self._get_example_tensor_lists(): 5641 orig_device = torch.device("cpu") 5642 orig_dtype = torch.float32 5643 nt = torch.nested.nested_tensor( 5644 tensor_list, layout=torch.jagged, device=orig_device, dtype=orig_dtype 5645 ) 5646 5647 self.assertEqual(torch.int64, nt.offsets().dtype) 5648 nt = nt.to(device=device).to(dtype=dtype) 5649 5650 # offsets should still be int64 on the new device 5651 self.assertEqual(nt.values().device, nt.offsets().device) 5652 self.assertEqual(torch.int64, nt.offsets().dtype) 5653 5654 def test_unbind(self, device): 5655 for tensor_list in self._get_example_tensor_lists(): 5656 nt = torch.nested.nested_tensor( 5657 tensor_list, layout=torch.jagged, device=device 5658 ) # ragged_idx = 1 5659 out = nt.unbind() 5660 self.assertEqual(len(out), len(tensor_list)) 5661 for i, t in enumerate(out): 5662 self.assertEqual(t, tensor_list[i]) 5663 5664 @parametrize("ragged_idx", [2, 3]) 5665 def test_unbind_transpose(self, device, ragged_idx): 5666 for tensor_list in self._get_example_tensor_lists(): 5667 nt = torch.nested.nested_tensor( 5668 tensor_list, layout=torch.jagged, device=device 5669 ) 5670 if ragged_idx < nt.dim(): 5671 nt = nt.transpose(1, ragged_idx) # set ragged_idx 5672 out = nt.unbind() 5673 self.assertEqual(len(out), len(tensor_list)) 5674 for i, t in enumerate(out): 5675 self.assertEqual( 5676 t.transpose(0, ragged_idx - 1), tensor_list[i] 5677 ) # transpose back each element of result 5678 5679 def test_unbind_transpose_ragged_idx_last_dim(self, device): 5680 for tensor_list in self._get_example_tensor_lists(): 5681 nt = torch.nested.nested_tensor( 5682 tensor_list, layout=torch.jagged, device=device 5683 ).transpose(1, -1) # set ragged_idx = last dimension 5684 out = nt.unbind() 5685 self.assertEqual(len(out), len(tensor_list)) 5686 for i, t in enumerate(out): 5687 self.assertEqual( 5688 t.transpose(0, -1), tensor_list[i] 5689 ) # transpose back each element of result 5690 5691 def test_unbind_lengths(self, device): 5692 values = torch.randn(16, 128, device=device) 5693 offsets = torch.tensor([0, 8, 12, 13, 16], device=device) 5694 lengths = torch.tensor([6, 2, 1, 2], device=device) 5695 nt = torch.nested.nested_tensor_from_jagged( 5696 values, offsets=offsets, lengths=lengths 5697 ) # 3D nested tensor 5698 5699 tensor_list = [] 5700 for i in range(offsets.shape[0] - 1): 5701 tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i])]) 5702 5703 out = nt.unbind() 5704 self.assertEqual(len(out), len(tensor_list)) 5705 for i, t in enumerate(out): 5706 self.assertEqual(t, tensor_list[i]) 5707 5708 def test_unbind_lengths_ragged_idx_1(self, device): 5709 values = torch.randn(16, 8, 128, device=device) 5710 offsets = torch.tensor([0, 8, 12, 13, 16], device=device) 5711 lengths = torch.tensor([6, 2, 1, 2], device=device) 5712 ragged_idx = 1 5713 nt = torch.nested._internal.nested_tensor.NestedTensor( 5714 values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx 5715 ) # 4D nested tensor 5716 5717 tensor_list = [] 5718 for i in range(offsets.shape[0] - 1): 5719 tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i]), :, :]) 5720 5721 out = nt.unbind() 5722 5723 self.assertEqual(len(out), len(tensor_list)) 5724 for i, t in enumerate(out): 5725 self.assertEqual(t, tensor_list[i]) 5726 5727 def test_unbind_lengths_ragged_idx_equals_2_bad_dim(self, device): 5728 values = torch.randn(16, 8, 128, device=device) 5729 offsets = torch.tensor([0, 8, 12, 13, 16], device=device) 5730 lengths = torch.tensor([6, 2, 1, 2], device=device) 5731 ragged_idx = 2 5732 nt = torch.nested._internal.nested_tensor.NestedTensor( 5733 values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx 5734 ) # 4D nested tensor 5735 5736 self.assertRaisesRegex( 5737 RuntimeError, 5738 r"unbind\(\): nested tensor offsets and lengths.*", 5739 lambda: nt.unbind(), 5740 ) 5741 5742 def test_unbind_lengths_ragged_idx_2(self, device): 5743 values = torch.randn(16, 8, 128, device=device) 5744 offsets = torch.tensor([0, 2, 4, 8], device=device) 5745 lengths = torch.tensor([2, 1, 3], device=device) 5746 ragged_idx = 2 5747 nt = torch.nested._internal.nested_tensor.NestedTensor( 5748 values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx 5749 ) # 4D nested tensor 5750 5751 tensor_list = [] 5752 for i in range(offsets.shape[0] - 1): 5753 tensor_list.append(values[:, offsets[i] : (offsets[i] + lengths[i]), :]) 5754 5755 out = nt.unbind() 5756 5757 self.assertEqual(len(out), len(tensor_list)) 5758 for i, t in enumerate(out): 5759 self.assertEqual(t, tensor_list[i]) 5760 5761 def test_unbind_lengths_ragged_idx_3(self, device): 5762 values = torch.randn(16, 8, 128, device=device) 5763 offsets = torch.tensor([0, 100, 128], device=device) 5764 lengths = torch.tensor([50, 28], device=device) 5765 ragged_idx = 3 5766 nt = torch.nested._internal.nested_tensor.NestedTensor( 5767 values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx 5768 ) # 4D nested tensor 5769 5770 tensor_list = [] 5771 for i in range(offsets.shape[0] - 1): 5772 tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])]) 5773 5774 out = nt.unbind() 5775 5776 self.assertEqual(len(out), len(tensor_list)) 5777 for i, t in enumerate(out): 5778 self.assertEqual(t, tensor_list[i]) 5779 5780 @skipIfTorchDynamo( 5781 "TorchDynamo raises an error for ragged_idx == 0 earlier than Torch" 5782 ) 5783 def test_unbind_lengths_ragged_idx_0(self, device): 5784 values = torch.randn(16, 8, 128, device=device) 5785 offsets = torch.tensor([0, 100, 128], device=device) 5786 lengths = torch.tensor([50, 28], device=device) 5787 ragged_idx = 0 5788 nt = torch.nested._internal.nested_tensor.NestedTensor( 5789 values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx 5790 ) # 4D nested tensor 5791 5792 tensor_list = [] 5793 for i in range(offsets.shape[0] - 1): 5794 tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])]) 5795 5796 self.assertRaisesRegex( 5797 RuntimeError, 5798 r"unbind\(\): nested tensor.*out of bounds", 5799 lambda: nt.unbind(), 5800 ) 5801 5802 def test_narrow(self, device): 5803 starts = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64) 5804 lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64) 5805 buffer = ( 5806 torch.arange(0, 10, device=device, dtype=torch.int64) 5807 .unsqueeze(0) 5808 .expand(5, -1) 5809 .clone() 5810 .detach() 5811 ) 5812 nt = torch.nested.narrow(buffer, 1, starts, lengths, layout=torch.jagged) 5813 5814 self.assertTrue(nt._is_view() and nt._base is buffer) 5815 5816 # TODO: Use this approach when unbind is functional 5817 # unbinded_nt = nt.unbind() 5818 # for i in range(starts.shape[0]): 5819 # self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i]) 5820 for i in range(starts.shape[0]): 5821 self.assertEqual( 5822 torch.arange( 5823 starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64 5824 ), 5825 nt.values()[nt.offsets()[i] : (nt.offsets()[i] + nt.lengths()[i])], 5826 ) 5827 5828 def test_njt_cat(self, device): 5829 offsets = torch.tensor([0, 2, 3], device=device, dtype=torch.int64) 5830 values_1 = torch.randn( 5831 3, 2, dtype=torch.float64, device=device, requires_grad=True 5832 ) 5833 values_2 = torch.randn( 5834 3, 4, dtype=torch.float64, device=device, requires_grad=True 5835 ) 5836 5837 def grad_test_func(values_1, values_2, offsets): 5838 nt_1 = torch.nested.nested_tensor_from_jagged(values_1, offsets) 5839 nt_2 = torch.nested.nested_tensor_from_jagged(values_2, offsets) 5840 nt_3 = torch.cat([nt_1, nt_2], dim=-1) 5841 return nt_3.values() 5842 5843 assert gradcheck( 5844 grad_test_func, 5845 inputs=(values_1, values_2, offsets), 5846 check_batched_grad=False, 5847 ) 5848 5849 def test_is_contiguous(self, device): 5850 a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) 5851 b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) 5852 c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) 5853 nt_contiguous = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 5854 5855 starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64) 5856 lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64) 5857 narrow_base = ( 5858 torch.arange(0, 10, device=device, dtype=torch.int64) 5859 .unsqueeze(0) 5860 .expand(5, -1) 5861 .clone() 5862 ) 5863 nt_noncontiguous = torch.nested.narrow( 5864 narrow_base, 1, starts_nc, lengths_nc, layout=torch.jagged 5865 ) 5866 5867 starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64) 5868 lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64) 5869 nt_contiguous_narrow = torch.nested.narrow( 5870 narrow_base, 1, starts_c, lengths_c, layout=torch.jagged 5871 ) 5872 5873 # Test contiguous case 5874 assert nt_contiguous.is_contiguous() 5875 5876 # Test narrow case 5877 assert not nt_noncontiguous.is_contiguous() 5878 assert nt_contiguous_narrow.is_contiguous() 5879 5880 # Test querying by memory_format 5881 self.assertTrue( 5882 nt_contiguous.is_contiguous(memory_format=torch.contiguous_format) 5883 ) 5884 self.assertTrue( 5885 not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format) 5886 ) 5887 self.assertTrue( 5888 nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format) 5889 ) 5890 5891 def test_layout_under_torch_dispatch_mode(self): 5892 from torch.testing._internal.logging_tensor import ( 5893 capture_logs_with_logging_tensor_mode, 5894 ) 5895 5896 nt = random_nt_from_dims( 5897 [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged 5898 ) 5899 5900 with capture_logs_with_logging_tensor_mode(): 5901 self.assertEqual(nt.layout, torch.jagged) 5902 5903 @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 5904 @parametrize( 5905 "func", [torch.empty_like, torch.randn_like], name_fn=lambda f: f.__name__ 5906 ) 5907 def test_like_shape(self, func): 5908 nt = random_nt_from_dims( 5909 [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged 5910 ) 5911 nt_like = func(nt) 5912 5913 for nt_ub in nt_like.unbind(): 5914 t_like = func(nt_ub) 5915 self.assertEqual(nt_ub.shape, t_like.shape) 5916 5917 @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 5918 @parametrize( 5919 "func", [torch.ones_like, torch.zeros_like], name_fn=lambda f: f.__name__ 5920 ) 5921 def test_like_value(self, func): 5922 nt = random_nt_from_dims( 5923 [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged 5924 ) 5925 nt_like = func(nt) 5926 5927 for nt_ub in nt_like.unbind(): 5928 t_like = func(nt_ub) 5929 self.assertEqual(nt_ub, t_like) 5930 5931 def test_noncontiguous_pointwise(self, device): 5932 a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device) 5933 b = torch.randn(3, 3, 4, requires_grad=True, dtype=torch.float64, device=device) 5934 c = torch.randn(4, 3, 4, requires_grad=True, dtype=torch.float64, device=device) 5935 nt = torch.nested.nested_tensor([a, b, c], layout=torch.jagged) 5936 # transpose ragged dim 5937 transposed = nt.transpose(1, 2) 5938 self.assertFalse(transposed.is_contiguous()) 5939 clone = transposed.clone() 5940 5941 def check_nt_equality(x, y): 5942 self.assertEqual(x.values(), y.values()) 5943 self.assertEqual(x.offsets(), y.offsets()) 5944 self.assertEqual(x._ragged_idx, y._ragged_idx) 5945 self.assertEqual(x.shape, y.shape) 5946 5947 self.assertFalse(clone.is_contiguous()) 5948 check_nt_equality(clone, transposed) 5949 5950 clone_contig = transposed.clone(memory_format=torch.contiguous_format) 5951 self.assertTrue(clone_contig.is_contiguous()) 5952 check_nt_equality(clone_contig, transposed) 5953 5954 detached = transposed.detach() 5955 self.assertFalse(clone.is_contiguous()) 5956 check_nt_equality(detached, transposed) 5957 5958 def test_permute(self, device): 5959 nt = random_nt_from_dims( 5960 [2, None, 3, 5], device, torch.float32, layout=torch.jagged 5961 ) 5962 nt_shape = nt.shape 5963 nt_inner_shape = nt.values().shape 5964 with self.assertRaisesRegex( 5965 ValueError, 5966 r"permute\(\): number of dimensions in the tensor input \(4\) " 5967 + r"does not match the length of the desired ordering of dimensions \(3\).", 5968 ): 5969 nt.permute(0, 2, 1) 5970 with self.assertRaisesRegex( 5971 ValueError, r"permute\(\): duplicate dims are not allowed." 5972 ): 5973 nt.permute(0, 2, -2, 3) 5974 with self.assertRaisesRegex( 5975 ValueError, "Permute is not supported on the batch dimension for jagged NT" 5976 ): 5977 nt.permute(1, 0, 2, 3) 5978 nt_permute = nt.permute(0, 2, 1, -1) 5979 self.assertEqual( 5980 nt_permute.shape, (nt_shape[0], nt_shape[2], nt_shape[1], nt_shape[3]) 5981 ) 5982 self.assertEqual( 5983 nt_permute.values().shape, 5984 (nt_inner_shape[1], nt_inner_shape[0], nt_inner_shape[2]), 5985 ) 5986 self.assertEqual(nt_permute._ragged_idx, 2) 5987 self.assertEqual(nt_permute.permute(0, 2, 1, 3), nt) 5988 5989 def test_to_dtype(self, device): 5990 nt = random_nt_from_dims( 5991 [2, None, 3], device, torch.float32, layout=torch.jagged 5992 ) 5993 nt_after = nt.to(torch.float64) 5994 self.assertEqual(torch.float32, nt.dtype) 5995 self.assertEqual(torch.float64, nt_after.dtype) 5996 self.assertEqual(torch.float64, nt_after.values().dtype) 5997 self.assertEqual(torch.int64, nt_after.offsets().dtype) 5998 5999 noncontiguous_nt = nt.transpose(1, 2) 6000 noncontiguous_nt_after = noncontiguous_nt.to(torch.bfloat16) 6001 self.assertEqual(torch.bfloat16, noncontiguous_nt_after.dtype) 6002 self.assertEqual(torch.bfloat16, noncontiguous_nt_after.values().dtype) 6003 self.assertEqual(torch.int64, noncontiguous_nt_after.offsets().dtype) 6004 6005 def test_to_copy(self, device): 6006 nt = torch.nested.nested_tensor( 6007 [ 6008 torch.randn( 6009 i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device 6010 ) 6011 for i in range(3) 6012 ], 6013 layout=torch.jagged, 6014 ) 6015 6016 nt_copy_dtype = torch.ops.aten._to_copy(nt, dtype=torch.float16) 6017 self.assertEqual(torch.float16, nt_copy_dtype.dtype) 6018 6019 nt_t = nt.transpose(1, 2) 6020 nt_t_copy_dtype = torch.ops.aten._to_copy(nt_t, dtype=torch.float16) 6021 self.assertEqual(torch.float16, nt_t_copy_dtype.dtype) 6022 6023 def test_copy_(self, device): 6024 offsets = torch.tensor([0, 2, 4], device=device) 6025 a = torch.nested.nested_tensor_from_jagged( 6026 torch.zeros(4, 3, device=device), offsets 6027 ) 6028 b = torch.nested.nested_tensor_from_jagged( 6029 torch.ones(4, 3, device=device), offsets 6030 ) 6031 a.copy_(b) 6032 torch._dynamo.disable(self.assertEqual)(a, b) 6033 6034 offsets_2 = torch.tensor([0, 2, 4], device=device) 6035 c = torch.nested.nested_tensor_from_jagged( 6036 torch.ones(4, 3, device=device), offsets_2 6037 ) 6038 # fail when tensors have the same size but not the exact same offset tensor. 6039 with self.assertRaisesRegex( 6040 RuntimeError, 6041 "copy_ only supports Nested Tensors that have same size and the exact same offset tensor.", 6042 ): 6043 a.copy_(c) 6044 6045 # fail when tensors have different sizes 6046 a = a.transpose(1, 2) 6047 with self.assertRaisesRegex( 6048 RuntimeError, 6049 "copy_ only supports Nested Tensors that have same size and the exact same offset tensor.", 6050 ): 6051 a.copy_(b) 6052 6053 @skipIfTorchDynamo("Dynamo doesn't know how to trace prof.events()") 6054 def test_profiler_sequence_nr(self): 6055 with torch.profiler.profile() as prof: 6056 values = torch.randn(4, 6, requires_grad=True) 6057 offsets = torch.tensor([0, 2, 4]) 6058 values = values * 2 6059 l = torch.nn.Linear(6, 8) 6060 nt = torch.nested.nested_tensor_from_jagged(values, offsets) 6061 6062 nt = l(nt) 6063 val = nt.values() 6064 6065 loss = val.sum() 6066 loss.backward() 6067 6068 fwd_seq_nrs = [] 6069 for evt in prof.events(): 6070 if ( 6071 "linear" in evt.name.lower() 6072 and "backward" not in evt.name.lower() 6073 and evt.sequence_nr != -1 6074 ): 6075 fwd_seq_nrs.append(evt.sequence_nr) 6076 6077 bwd_seq_nrs = [] 6078 for evt in prof.events(): 6079 if ( 6080 "linear" in evt.name.lower() 6081 and "backward" in evt.name.lower() 6082 and "evaluate_function" not in evt.name.lower() 6083 and evt.sequence_nr != -1 6084 ): 6085 bwd_seq_nrs.append(evt.sequence_nr) 6086 6087 # There should only be one such event with a sequence number: 6088 # the PythonTLSSnapshot event - but, note that it's not terrible if 6089 # we end up with multiple events with the same sequence number - so we 6090 # could relax this check if it becomes inconvenient to maintain this 6091 # property. 6092 self.assertEqual(len(fwd_seq_nrs), 1) 6093 self.assertEqual(len(bwd_seq_nrs), 1) 6094 self.assertEqual(fwd_seq_nrs[0], bwd_seq_nrs[0]) 6095 6096 def test_is_same_size(self, device): 6097 def get_3_tensors(): 6098 return [ 6099 torch.randn( 6100 i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device 6101 ) 6102 for i in range(3) 6103 ] 6104 6105 nt1, offsets1 = jagged_from_list(get_3_tensors(), None) 6106 nt2, offsets1 = jagged_from_list(get_3_tensors(), offsets1) 6107 6108 nt3, offsets2 = jagged_from_list(get_3_tensors(), None) 6109 nt4, offsets2 = jagged_from_list(get_3_tensors(), offsets2) 6110 6111 def check_size(nt1, nt2, nt3, nt4): 6112 self.assertTrue(torch.ops.aten.is_same_size(nt1, nt2)) 6113 self.assertTrue(torch.ops.aten.is_same_size(nt3, nt4)) 6114 self.assertFalse(torch.ops.aten.is_same_size(nt1, nt3)) 6115 6116 check_size(nt1, nt2, nt3, nt4) 6117 6118 nt1_t, nt2_t, nt3_t, nt4_t = (x.transpose(1, 2) for x in (nt1, nt2, nt3, nt4)) 6119 check_size(nt1_t, nt2_t, nt3_t, nt4_t) 6120 6121 @skipIfTorchDynamo("compiles internally") 6122 @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 6123 @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 6124 def test_specialize_dynamic_shape(self, device): 6125 values = torch.randn((18, 16), device=device) 6126 offsets = torch.tensor([0, 2, 3, 6, 15, 18], device=device) 6127 like_values = torch.randn_like(values) 6128 6129 # this marks values as dynamic 6130 nt = torch.nested.nested_tensor_from_jagged(values, offsets) 6131 6132 def fn(values, same_size): 6133 # here, the dynamic shape is specialized by same_size's shape 6134 # https://github.com/pytorch/pytorch/issues/127097 6135 # make sure this doesn't error out in torch.compile 6136 return values + same_size 6137 6138 self.assertEqual( 6139 fn(values, like_values), 6140 torch.compile(fn)(values, like_values), 6141 ) 6142 6143 @skipIfTorchDynamo("compiles internally") 6144 @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 6145 @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 6146 def test_specialize_dynamic_shape_recompile(self, device): 6147 def generate_inp(total_len): 6148 values = torch.randn((total_len, 16), device=device) 6149 offsets = torch.tensor([0, 2, 3, 6, 15, total_len], device=device) 6150 like_values = torch.randn_like(values) 6151 return values, offsets, like_values 6152 6153 def check_results(ref_fn, res_fn, args): 6154 values, offsets, like_values = args 6155 # this may add dynamic shape markings 6156 # goal of this test is to make sure that whatever markings are there, 6157 # we eventually stop recompiling as shape changes. 6158 nt = torch.nested.nested_tensor_from_jagged(values, offsets) 6159 6160 self.assertEqual(ref_fn(values, like_values), res_fn(values, like_values)) 6161 6162 def fn(values, same_size): 6163 return values + same_size 6164 6165 compile_counter = torch._dynamo.testing.CompileCounter() 6166 6167 compiled_fn = torch._dynamo.optimize(compile_counter, nopython=True)(fn) 6168 check_results(fn, compiled_fn, generate_inp(18)) 6169 self.assertEqual(compile_counter.frame_count, 1) 6170 6171 check_results(fn, compiled_fn, generate_inp(19)) 6172 # we'll probably recompile here with dynamic shapes - it's okay if not though. 6173 frame_count_2 = compile_counter.frame_count 6174 self.assertIn(frame_count_2, [1, 2]) 6175 6176 # make sure that by now we've already compiled with dynamic shapes, so additional 6177 # shapes should not trigger additional recompiles. 6178 check_results(fn, compiled_fn, generate_inp(20)) 6179 self.assertEqual(compile_counter.frame_count, frame_count_2) 6180 6181 # Note 1: Math fallback doesn't work with bfloat16 on CUDA 6182 # Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT 6183 @unittest.skipIf( 6184 TEST_WITH_ROCM, 6185 "ROCm doesn't support flash attention or mem_efficient attention for NT", 6186 ) 6187 @dtypes( 6188 *( 6189 [torch.float16, torch.bfloat16, torch.float32] 6190 if SM80OrLater 6191 else [torch.float16, torch.float32] 6192 ) 6193 ) 6194 def test_sdpa(self, device, dtype): 6195 batch_size = 1 6196 emb_dims = 128 6197 n_heads = 8 6198 head_dims = emb_dims // n_heads 6199 6200 sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) 6201 sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) 6202 6203 query = torch.nn.Linear( 6204 emb_dims, emb_dims, bias=False, device=device, dtype=dtype 6205 ) 6206 key = torch.nn.Linear( 6207 emb_dims, emb_dims, bias=False, device=device, dtype=dtype 6208 ) 6209 value = torch.nn.Linear( 6210 emb_dims, emb_dims, bias=False, device=device, dtype=dtype 6211 ) 6212 6213 # Simplest case: 1 sentence, no batching 6214 x_d1 = sen1.unsqueeze(0) 6215 x_nt = torch.nested.as_nested_tensor([sen1], layout=torch.jagged) 6216 6217 # See note below for why we detach here. 6218 q_d1 = ( 6219 query(x_d1) 6220 .view(batch_size, -1, n_heads, head_dims) 6221 .detach() 6222 .requires_grad_(True) 6223 ) 6224 q_d1_t = q_d1.transpose(1, 2) 6225 k_d1 = ( 6226 key(x_d1) 6227 .view(batch_size, -1, n_heads, head_dims) 6228 .detach() 6229 .requires_grad_(True) 6230 ) 6231 k_d1_t = k_d1.transpose(1, 2) 6232 v_d1 = ( 6233 value(x_d1) 6234 .view(batch_size, -1, n_heads, head_dims) 6235 .detach() 6236 .requires_grad_(True) 6237 ) 6238 v_d1_t = v_d1.transpose(1, 2) 6239 6240 q_nt = ( 6241 query(x_nt) 6242 .view(*x_nt.size()[0:2], n_heads, head_dims) 6243 .detach() 6244 .requires_grad_(True) 6245 ) 6246 q_nt_t = q_nt.transpose(1, 2) 6247 k_nt = ( 6248 key(x_nt) 6249 .view(*x_nt.size()[0:2], n_heads, head_dims) 6250 .detach() 6251 .requires_grad_(True) 6252 ) 6253 k_nt_t = k_nt.transpose(1, 2) 6254 v_nt = ( 6255 value(x_nt) 6256 .view(*x_nt.size()[0:2], n_heads, head_dims) 6257 .detach() 6258 .requires_grad_(True) 6259 ) 6260 v_nt_t = v_nt.transpose(1, 2) 6261 6262 # High Precision Math Reference 6263 q_d1_f32 = q_d1.to(torch.float32) 6264 k_d1_f32 = k_d1.to(torch.float32) 6265 v_d1_f32 = v_d1.to(torch.float32) 6266 q_d1_f32_t = q_d1_f32.transpose(1, 2) 6267 k_d1_f32_t = k_d1_f32.transpose(1, 2) 6268 v_d1_f32_t = v_d1_f32.transpose(1, 2) 6269 out_ref = torch.ops.aten._scaled_dot_product_attention_math( 6270 q_d1_f32_t, k_d1_f32_t, v_d1_f32_t 6271 )[0] 6272 grads_ref = torch.autograd.grad(out_ref.sum(), (q_d1_f32, k_d1_f32, v_d1_f32)) 6273 6274 # Low Precision Math Reference 6275 out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( 6276 q_d1_t, k_d1_t, v_d1_t 6277 )[0] 6278 grads_lp_ref = torch.autograd.grad(out_lp_ref.sum(), (q_d1, k_d1, v_d1)) 6279 6280 # Compute tolerances 6281 output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) 6282 grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(grads_ref[0], grads_lp_ref[0]) 6283 grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(grads_ref[1], grads_lp_ref[1]) 6284 grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(grads_ref[2], grads_lp_ref[2]) 6285 grad_atols = [grad_q_ref_atol, grad_k_ref_atol, grad_v_ref_atol] 6286 grad_rtols = [grad_q_ref_rtol, grad_k_ref_rtol, grad_v_ref_rtol] 6287 6288 attn_d1 = torch.nn.functional.scaled_dot_product_attention( 6289 q_d1_t, k_d1_t, v_d1_t 6290 ).transpose(1, 2) 6291 attn_nt = torch.nn.functional.scaled_dot_product_attention( 6292 q_nt_t, k_nt_t, v_nt_t 6293 ).transpose(1, 2) 6294 6295 self.assertEqual( 6296 attn_d1, 6297 attn_nt.unbind()[0].unsqueeze(0), 6298 atol=output_ref_atol, 6299 rtol=output_ref_rtol, 6300 ) 6301 6302 # Simple case: 2 sentences, no extra params 6303 x_d2 = sen2.unsqueeze(0) 6304 x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged) 6305 6306 # NB: we make sure the leaf tensor we compute gradients for is the view-ed tensor before 6307 # it is transposed. This is because today we cannot backward through view or unbind a 6308 # transposed tensor. 6309 q_d2 = ( 6310 query(x_d2) 6311 .view(batch_size, -1, n_heads, head_dims) 6312 .detach() 6313 .requires_grad_(True) 6314 ) 6315 q_d2_t = q_d2.transpose(1, 2) 6316 k_d2 = ( 6317 key(x_d2) 6318 .view(batch_size, -1, n_heads, head_dims) 6319 .detach() 6320 .requires_grad_(True) 6321 ) 6322 k_d2_t = k_d2.transpose(1, 2) 6323 v_d2 = ( 6324 value(x_d2) 6325 .view(batch_size, -1, n_heads, head_dims) 6326 .detach() 6327 .requires_grad_(True) 6328 ) 6329 v_d2_t = v_d2.transpose(1, 2) 6330 6331 q_nt = ( 6332 query(x_nt) 6333 .view(*x_nt.size()[0:2], n_heads, head_dims) 6334 .detach() 6335 .requires_grad_(True) 6336 ) 6337 q_nt_t = q_nt.transpose(1, 2) 6338 k_nt = ( 6339 key(x_nt) 6340 .view(*x_nt.size()[0:2], n_heads, head_dims) 6341 .detach() 6342 .requires_grad_(True) 6343 ) 6344 k_nt_t = k_nt.transpose(1, 2) 6345 v_nt = ( 6346 value(x_nt) 6347 .view(*x_nt.size()[0:2], n_heads, head_dims) 6348 .detach() 6349 .requires_grad_(True) 6350 ) 6351 v_nt_t = v_nt.transpose(1, 2) 6352 6353 attn_d2 = torch.nn.functional.scaled_dot_product_attention( 6354 q_d2_t, k_d2_t, v_d2_t 6355 ).transpose(1, 2) 6356 d1_grads = torch.autograd.grad(attn_d1.sum(), (q_d1, k_d1, v_d1)) 6357 d2_grads = torch.autograd.grad(attn_d2.sum(), (q_d2, k_d2, v_d2)) 6358 6359 # Simple case 3: batch_size = 1, seq_len = 1 6360 q_3 = torch.randn(1, 8, 16, dtype=dtype, device=device) 6361 q_nt_3 = torch.nested.as_nested_tensor([q_3], layout=torch.jagged) 6362 q_nt_3 = q_nt_3.transpose(1, 2) 6363 attn_out = torch.nn.functional.scaled_dot_product_attention( 6364 q_nt_3, q_nt_3, q_nt_3 6365 ) 6366 self.assertEqual(attn_out.shape, q_nt_3.shape) 6367 6368 def check_forward_backward(): 6369 attn_nt = torch.nn.functional.scaled_dot_product_attention( 6370 q_nt_t, k_nt_t, v_nt_t 6371 ).transpose(1, 2) 6372 6373 attn_nts = attn_nt.unbind() 6374 self.assertEqual( 6375 attn_d1, 6376 attn_nts[0].unsqueeze(0), 6377 atol=output_ref_atol, 6378 rtol=output_ref_rtol, 6379 ) 6380 self.assertEqual( 6381 attn_d2, 6382 attn_nts[1].unsqueeze(0), 6383 atol=output_ref_atol, 6384 rtol=output_ref_rtol, 6385 ) 6386 6387 nt_grads = torch.autograd.grad(attn_nt.values().sum(), (q_nt, k_nt, v_nt)) 6388 for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip( 6389 nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols 6390 ): 6391 unbound_nt_grads = nt_grad.unbind() 6392 self.assertEqual( 6393 d1_grad, 6394 unbound_nt_grads[0].unsqueeze(0), 6395 atol=grad_atol, 6396 rtol=grad_rtol, 6397 ) 6398 self.assertEqual( 6399 d2_grad, 6400 unbound_nt_grads[1].unsqueeze(0), 6401 atol=grad_atol, 6402 rtol=grad_rtol, 6403 ) 6404 6405 # Default 6406 check_forward_backward() 6407 6408 # Test dispatcher works by calling only mem-effn and math (as they are safe for all devices) 6409 with torch.backends.cuda.sdp_kernel( 6410 enable_flash=False, enable_mem_efficient=True, enable_math=True 6411 ): 6412 check_forward_backward() 6413 6414 # Test math fallback 6415 with torch.backends.cuda.sdp_kernel( 6416 enable_flash=False, enable_mem_efficient=False, enable_math=True 6417 ): 6418 # Math fallback doesn't work with bfloat16 on CUDA because 6419 # "group_gemm_dispatch" not implemented for 'BFloat16' 6420 if not (str(device).startswith("cuda") and dtype == torch.bfloat16): 6421 check_forward_backward() 6422 6423 @skipIfTorchDynamo("SDPA test compiles internally") 6424 @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 6425 @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 6426 # Guarding with sqrt() doesn't work on ROCm? 6427 @skipCUDAIfRocm 6428 @onlyCUDA 6429 @dtypes( 6430 *( 6431 [torch.float16, torch.bfloat16, torch.float32] 6432 if SM80OrLater 6433 else [torch.float16, torch.float32] 6434 ) 6435 ) 6436 def test_sdpa_compile(self, device, dtype): 6437 batch_size = 1 6438 emb_dims = 1024 6439 n_heads = 8 6440 head_dims = emb_dims // n_heads 6441 6442 sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) 6443 sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) 6444 6445 query = torch.nn.Linear( 6446 emb_dims, emb_dims, bias=False, device=device, dtype=dtype 6447 ) 6448 key = torch.nn.Linear( 6449 emb_dims, emb_dims, bias=False, device=device, dtype=dtype 6450 ) 6451 value = torch.nn.Linear( 6452 emb_dims, emb_dims, bias=False, device=device, dtype=dtype 6453 ) 6454 6455 # Simplest case: 1 sentence, no batching 6456 x_d1 = sen1.unsqueeze(0) 6457 x_d2 = sen2.unsqueeze(0) 6458 x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged) 6459 6460 q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) 6461 k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) 6462 v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) 6463 q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) 6464 k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) 6465 v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) 6466 6467 q_nt = ( 6468 query(x_nt) 6469 .view(*x_nt.size()[0:2], n_heads, head_dims) 6470 .detach() 6471 .transpose(1, 2) 6472 ) 6473 k_nt = ( 6474 key(x_nt) 6475 .view(*x_nt.size()[0:2], n_heads, head_dims) 6476 .detach() 6477 .transpose(1, 2) 6478 ) 6479 v_nt = ( 6480 value(x_nt) 6481 .view(*x_nt.size()[0:2], n_heads, head_dims) 6482 .detach() 6483 .transpose(1, 2) 6484 ) 6485 6486 # High Precision Math Reference 6487 q_d1_f32 = q_d1.to(torch.float32) 6488 k_d1_f32 = k_d1.to(torch.float32) 6489 v_d1_f32 = v_d1.to(torch.float32) 6490 out_ref = torch.ops.aten._scaled_dot_product_attention_math( 6491 q_d1_f32, k_d1_f32, v_d1_f32 6492 )[0] 6493 # Low Precision Math Reference 6494 out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( 6495 q_d1, k_d1, v_d1 6496 )[0] 6497 output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) 6498 6499 attn_d1 = torch.nn.functional.scaled_dot_product_attention( 6500 q_d1, k_d1, v_d1 6501 ).transpose(1, 2) 6502 attn_d2 = torch.nn.functional.scaled_dot_product_attention( 6503 q_d2, k_d2, v_d2 6504 ).transpose(1, 2) 6505 6506 compiled_sdpa = torch.compile(torch.nn.functional.scaled_dot_product_attention) 6507 attn_nt = compiled_sdpa(q_nt, k_nt, v_nt).transpose(1, 2) 6508 6509 attn_nts = attn_nt.unbind() 6510 self.assertEqual( 6511 attn_d1, 6512 attn_nts[0].unsqueeze(0), 6513 atol=output_ref_atol, 6514 rtol=output_ref_rtol, 6515 ) 6516 self.assertEqual( 6517 attn_d2, 6518 attn_nts[1].unsqueeze(0), 6519 atol=output_ref_atol, 6520 rtol=output_ref_rtol, 6521 ) 6522 6523 @dtypes(torch.float32, torch.double, torch.half) 6524 def test_sdpa_with_constant_sequence_length(self, device, dtype): 6525 # shape (B, P*, S, D) 6526 # B: batch size 6527 # P*: ragged number of prompts 6528 # S: (constant) sequence length 6529 # D: embedding size 6530 query = random_nt_from_dims( 6531 [4, None, 8, 10], 6532 device=device, 6533 dtype=dtype, 6534 layout=torch.jagged, 6535 requires_grad=True, 6536 ) 6537 key = random_nt_from_similar(query) 6538 value = random_nt_from_similar(query) 6539 output = F.scaled_dot_product_attention(query, key, value) 6540 self.assertTrue(isinstance(output, NestedTensor)) 6541 output.values().sum().backward() 6542 6543 query_dense = query.clone().detach().requires_grad_(True) 6544 # should be equivalent to just running the buffers through 6545 output_dense = F.scaled_dot_product_attention( 6546 query_dense.values(), key.values(), value.values() 6547 ) 6548 torch._dynamo.disable(self.assertEqual)(output._values, output_dense) 6549 output_dense.sum().backward() 6550 torch._dynamo.disable(self.assertEqual)(query.grad, query_dense.grad) 6551 6552 @onlyCUDA 6553 @unittest.skipIf( 6554 not PLATFORM_SUPPORTS_FUSED_ATTENTION, 6555 "Platform doesn't support flash or mem-efficient attention", 6556 ) 6557 @dtypes( 6558 *( 6559 [torch.float16, torch.bfloat16, torch.float32] 6560 if SM80OrLater 6561 else [torch.float16, torch.float32] 6562 ) 6563 ) 6564 def test_sdpa_with_packed_in_proj(self, device, dtype): 6565 # shape (B, *, D) 6566 input_packed = random_nt_from_dims( 6567 [5, None, 10], device=device, dtype=dtype, layout=torch.jagged 6568 ) 6569 6570 # Do input projection. 6571 num_heads = 2 6572 # should be multiple of 4 for efficient kernels (e.g. flash / mem-efficient) 6573 head_dim = 8 6574 qkv_linear = torch.nn.Linear(10, num_heads * head_dim * 3).to( 6575 device=device, dtype=dtype 6576 ) 6577 6578 def in_proj(input_packed, qkv_linear=qkv_linear): 6579 qkv_post_proj = qkv_linear(input_packed) 6580 # these are non-contiguous to trigger _is_safe_to_get_storage_as_tensor() 6581 q, k, v = qkv_post_proj.chunk(3, dim=-1) 6582 q = q.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3) 6583 k = k.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3) 6584 v = v.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3) 6585 return q, k, v 6586 6587 q, k, v = in_proj(input_packed) 6588 output = F.scaled_dot_product_attention(q, k, v, attn_mask=None) 6589 6590 # compare to individually running unbound components through 6591 for in_component, out_component in zip( 6592 input_packed.unbind(), output.transpose(-2, -3).unbind() 6593 ): 6594 q, k, v = in_proj(in_component) 6595 out = F.scaled_dot_product_attention(q, k, v).transpose(-2, -3) 6596 6597 # Low Precision Math Reference 6598 out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q, k, v)[ 6599 0 6600 ].transpose(-2, -3) 6601 output_ref_atol, output_ref_rtol = get_tolerances( 6602 out, out_lp_ref, fudge_factor=2 6603 ) 6604 6605 self.assertEqual( 6606 out, out_component, atol=output_ref_atol, rtol=output_ref_rtol 6607 ) 6608 6609 @skipIfTorchDynamo("SDPA test compiles internally") 6610 @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 6611 @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 6612 # mha_varlen_fwd not supported on ROCm 6613 @skipCUDAIfRocm 6614 @onlyCUDA 6615 @dtypes( 6616 *( 6617 [torch.float16, torch.bfloat16, torch.float32] 6618 if SM80OrLater 6619 else [torch.float16, torch.float32] 6620 ) 6621 ) 6622 def test_sdpa_backwards(self, device, dtype): 6623 values = torch.randn(9, 3, 256, requires_grad=True, device=device, dtype=dtype) 6624 offsets = torch.tensor([0, 1, 3, 5, 9], device=device, dtype=torch.int64) 6625 6626 @torch.compile 6627 def f(values, offsets): 6628 nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4) 6629 nt = nt.transpose(-2, -3) 6630 # purposefully graph break to trigger view replay for subclass view input 6631 torch.tensor(1).item() 6632 output = F.scaled_dot_product_attention(nt, nt, nt).transpose(-2, -3) 6633 return convert_nt_to_jagged(output) 6634 6635 output = f(values, offsets) 6636 output.sum().backward() 6637 self.assertEqual(values.grad, torch.ones_like(values)) 6638 6639 @unittest.skipIf( 6640 not PLATFORM_SUPPORTS_FUSED_ATTENTION, 6641 "Platform doesn't support flash or mem-efficient attention", 6642 ) 6643 @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 6644 @skipCUDAIfRocm 6645 @onlyCUDA 6646 @skipIfTorchDynamo() 6647 @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 6648 def test_sdpa_autocast(self, device): 6649 def fn_nt(values32, values16, offsets): 6650 nt32 = convert_jagged_to_nested_tensor(values32, offsets, max_length=16) 6651 nt16 = convert_jagged_to_nested_tensor(values16, offsets, max_length=16) 6652 nt32 = nt32.transpose(1, 2) 6653 nt16 = nt16.transpose(1, 2) 6654 return F.scaled_dot_product_attention(nt32, nt16, nt32) 6655 6656 def fn_dense(x32, x16): 6657 x32 = x32.view(8, 16, 4, 16).transpose(1, 2) 6658 x16 = x16.view(8, 16, 4, 16).transpose(1, 2) 6659 return F.scaled_dot_product_attention(x32, x16, x32) 6660 6661 values32 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float32) 6662 values16 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float16) 6663 offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32) 6664 6665 x32 = values32.clone() 6666 x16 = values16.clone() 6667 6668 with torch.autocast(device_type="cuda", dtype=torch.float16): 6669 out_dense_eager = fn_dense(x32, x16) 6670 out_dense_compiled = torch.compile(fn_dense)(x32, x16) 6671 out_nt_eager = fn_nt(values32, values16, offsets) 6672 out_nt_compiled = torch.compile(fn_nt)(values32, values16, offsets) 6673 6674 self.assertEqual(out_dense_eager, out_dense_compiled) 6675 self.assertEqual( 6676 out_dense_eager.transpose(1, 2), 6677 out_nt_eager.values().transpose(0, 1).view(8, 16, 4, 16), 6678 ) 6679 self.assertEqual( 6680 out_dense_eager.transpose(1, 2), 6681 out_nt_compiled.values().transpose(0, 1).view(8, 16, 4, 16), 6682 ) 6683 6684 def get_values(): 6685 return tuple( 6686 x.clone().detach().requires_grad_(True) for x in (values32, values16) 6687 ) 6688 6689 v32_dense_eager, v16_dense_eager = get_values() 6690 v32_dense_compile, v16_dense_compile = get_values() 6691 v32_nt_eager, v16_nt_eager = get_values() 6692 v32_nt_compile, v16_nt_compile = get_values() 6693 6694 with torch.autocast(device_type="cuda", dtype=torch.float16): 6695 loss_dense_eager = fn_dense(v32_dense_eager, v16_dense_eager).sum() 6696 loss_dense_compile = torch.compile(fn_dense)( 6697 v32_dense_compile, v16_dense_compile 6698 ).sum() 6699 loss_nt_eager = fn_nt(v32_nt_eager, v16_nt_eager, offsets).values().sum() 6700 loss_nt_compile = ( 6701 torch.compile(fn_nt)(v32_nt_compile, v16_nt_compile, offsets) 6702 .values() 6703 .sum() 6704 ) 6705 6706 loss_dense_eager.backward() 6707 loss_dense_compile.backward() 6708 loss_nt_eager.backward() 6709 loss_nt_compile.backward() 6710 6711 self.assertEqual(v32_dense_eager.grad, v32_dense_compile.grad) 6712 self.assertEqual(v32_dense_eager.grad, v32_nt_eager.grad) 6713 self.assertEqual(v32_dense_eager.grad, v32_nt_compile.grad) 6714 6715 self.assertEqual(v16_dense_eager.grad, v16_dense_compile.grad) 6716 self.assertEqual(v16_dense_eager.grad, v16_nt_eager.grad) 6717 self.assertEqual(v16_dense_eager.grad, v16_nt_compile.grad) 6718 6719 @unittest.skipIf( 6720 not PLATFORM_SUPPORTS_FUSED_ATTENTION, 6721 "Platform doesn't support flash or mem-efficient attention", 6722 ) 6723 @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 6724 @skipCUDAIfRocm 6725 @onlyCUDA 6726 @skipIfTorchDynamo() 6727 def test_sdpa_flop_counter(self, device): 6728 from torch.utils.flop_counter import FlopCounterMode 6729 6730 def get_flops(nt): 6731 flop_counter = FlopCounterMode(display=False) 6732 with flop_counter: 6733 ret = torch.nn.functional.scaled_dot_product_attention(nt, nt, nt) 6734 ret.values().sum().backward() 6735 return flop_counter.get_total_flops() 6736 6737 values = torch.randn( 6738 (8 * 16, 4, 16), requires_grad=True, device=device, dtype=torch.float16 6739 ) 6740 offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32) 6741 nt = convert_jagged_to_nested_tensor(values, offsets, max_length=16) 6742 6743 values_meta = torch.randn( 6744 (8 * 16, 4, 16), requires_grad=True, device="meta", dtype=torch.float16 6745 ) 6746 offsets_meta = torch.arange(0, 8 * 16 + 1, 16, device="meta", dtype=torch.int32) 6747 nt_meta = convert_jagged_to_nested_tensor(values, offsets, max_length=16) 6748 6749 self.assertEqual(get_flops(nt), get_flops(nt_meta)) 6750 6751 @skipIfTorchDynamo() 6752 def test_nested_tensor_activation_checkpoint(self, device): 6753 values = torch.randn( 6754 9, 3, 256, requires_grad=True, device=device, dtype=torch.float32 6755 ) 6756 lengths = torch.tensor([1, 2, 3, 3], device=device, dtype=torch.int64) 6757 offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0) 6758 6759 def fn(values, offsets): 6760 nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4) 6761 return convert_nt_to_jagged(nt).sum() 6762 6763 checkpoint(fn, values, offsets, use_reentrant=False).backward() 6764 self.assertIsNotNone(values.grad) 6765 6766 context_fn = partial( 6767 create_selective_checkpoint_contexts, [torch.ops.aten.cumsum.default] 6768 ) 6769 6770 values.grad = None 6771 6772 def fn(values, lengths): 6773 offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0) 6774 nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4) 6775 return convert_nt_to_jagged(nt).sum() 6776 6777 checkpoint( 6778 fn, values, lengths, use_reentrant=False, context_fn=context_fn 6779 ).backward() 6780 self.assertIsNotNone(values.grad) 6781 6782 # Internally-defined NT use cases are lifted to here for maximum test realism. 6783 # TODO: Remove these when ViewNestedFromBuffer, etc. are deprecated. 6784 @skipCUDAIfRocm # not needed 6785 @skipIfTorchDynamo("compiles internally") 6786 @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 6787 @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 6788 @parametrize("use_legacy_api", [True, False]) 6789 @skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644") 6790 def test_dummy_mha_with_nt(self, device, use_legacy_api): 6791 bs = 3 6792 d1 = 2 6793 d2 = 4 6794 d3 = 16 6795 n_heads = 2 6796 d_head = d3 // n_heads 6797 max_length_1 = 10 6798 max_length_2 = 20 6799 torch.manual_seed(0) 6800 6801 class mha(torch.nn.Module): 6802 def __init__(self, use_legacy_api) -> None: 6803 super().__init__() 6804 torch.manual_seed(0) 6805 self.linear = torch.nn.Linear(d2, d3, device=device) 6806 self.use_legacy_api = use_legacy_api 6807 6808 def forward(self, query, value, offsets): 6809 value = self.linear(value) 6810 if self.use_legacy_api: 6811 key = convert_jagged_to_nested_tensor_legacy( 6812 value, offsets, max_length_1 6813 ) 6814 value = convert_jagged_to_nested_tensor_legacy( 6815 value, offsets, max_length_2 6816 ) 6817 query = convert_dense_to_nested_tensor_legacy(query) 6818 else: 6819 key = convert_jagged_to_nested_tensor(value, offsets, max_length_1) 6820 value = convert_jagged_to_nested_tensor( 6821 value, offsets, max_length_2 6822 ) 6823 query = convert_dense_to_nested_tensor(query) 6824 q = query.view(bs, -1, n_heads, d_head).transpose(1, 2) 6825 k = key.view(bs, -1, n_heads, d_head).transpose(1, 2) 6826 v = value.view(bs, -1, n_heads, d_head).transpose(1, 2) 6827 6828 with torch.nn.attention.sdpa_kernel( 6829 [ 6830 torch.nn.attention.SDPBackend.FLASH_ATTENTION, 6831 torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, 6832 ] 6833 ): 6834 attn_output = torch.nn.functional.scaled_dot_product_attention( 6835 q, 6836 k, 6837 v, 6838 attn_mask=None, 6839 dropout_p=0.0, 6840 is_causal=False, 6841 ) 6842 attn_output = attn_output.transpose(1, 2) 6843 if self.use_legacy_api: 6844 attn_output = convert_nt_to_jagged_legacy(attn_output) 6845 else: 6846 attn_output = convert_nt_to_jagged(attn_output) 6847 return attn_output, key._max_seqlen, value._max_seqlen 6848 6849 query = torch.rand(bs, d1, d3, device=device) 6850 value = torch.rand(30, d2, requires_grad=True, device=device) 6851 # total_length must > than max_length otherwise flash_attn backwark will fail 6852 offsets = torch.tensor([0, 2, 3, 30], device=device) 6853 6854 m = mha(use_legacy_api) 6855 symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(m) 6856 m = torch.compile(symbolic_traced) 6857 attn_output, cached_key_max_seqlen, cached_value_max_seqlen = m( 6858 query, value, offsets 6859 ) 6860 loss = attn_output.sum() 6861 # Check that NT can be fx traced and torch.compile, and backward works 6862 loss.backward() 6863 6864 # Check that value.requires_grad is not lost after tracing and compiling 6865 value_grad = value.grad # save for comparison later 6866 self.assertIsNotNone(value_grad) 6867 # check that max_seqlen is cached properly 6868 self.assertEqual(cached_key_max_seqlen, max_length_1) 6869 self.assertEqual(cached_value_max_seqlen, max_length_2) 6870 6871 # check if the output is numerically equivalent with the eager mode 6872 m_eager = mha(use_legacy_api) 6873 6874 value.grad = None 6875 attn_output_eager, _, _ = m_eager(query, value, offsets) 6876 attn_output_eager.sum().backward() 6877 self.assertTrue(torch.allclose(attn_output_eager, attn_output)) 6878 self.assertTrue(torch.allclose(value_grad, value.grad)) 6879 6880 @dtypes(torch.float32) 6881 def test_apply_(self, device, dtype): 6882 nt = random_nt_from_dims( 6883 [5, None, 10], 6884 device=device, 6885 dtype=dtype, 6886 layout=torch.jagged, 6887 requires_grad=True, 6888 ) 6889 6890 def f(x): 6891 return x * 2 6892 6893 if device != "cpu": 6894 with self.assertRaisesRegex( 6895 TypeError, "apply_ is only implemented on CPU tensors" 6896 ): 6897 nt.apply_(f) 6898 return 6899 6900 before = nt._values.clone().detach() 6901 6902 nt.apply_(f) 6903 expected = f(before) 6904 self.assertEqual(expected, nt._values) 6905 # apply_ should swap values in-place without appending to autograd graph 6906 self.assertIsNone(nt.grad) 6907 self.assertIsNone(nt._values.grad_fn) 6908 6909 @dtypes(torch.float64, torch.float32, torch.half) 6910 def test_jagged_padded_dense_conversion_kernels(self, device, dtype): 6911 values = torch.randn(10, 5, device=device, dtype=dtype) 6912 offsets = torch.tensor([0, 1, 3, 8, 10], device=device, dtype=torch.int64) 6913 max_length = offsets.diff().max().item() 6914 padding_value = 1.3 6915 6916 # convert jagged -> padded dense 6917 padded = torch.ops.aten._jagged_to_padded_dense_forward( 6918 values, [offsets], [max_length], padding_value 6919 ) 6920 6921 batch_size = offsets.shape[0] - 1 6922 expected_padded_shape = (batch_size, max_length, values.shape[-1]) 6923 self.assertEqual(padded.shape, expected_padded_shape) 6924 6925 # convert padded dense -> jagged 6926 total_L = values.shape[0] 6927 output_jagged = torch.ops.aten._padded_dense_to_jagged_forward( 6928 padded, [offsets], total_L 6929 ) 6930 6931 # should be equivalent to the original values 6932 self.assertEqual(values, output_jagged) 6933 6934 # success case: truncate to max length as needed 6935 trunc_max_length = max_length - 1 6936 trunc_padded = torch.ops.aten._jagged_to_padded_dense_forward( 6937 values, [offsets], [trunc_max_length], padding_value 6938 ) 6939 self.assertEqual(padded[:, :trunc_max_length, :], trunc_padded) 6940 6941 # specific to CPU impls 6942 if device == "cpu": 6943 # error case: multiple offsets on cpu since CPU kernels don't support more now 6944 with self.assertRaisesRegex( 6945 RuntimeError, "only a single jagged dim is supported" 6946 ): 6947 torch.ops.aten._jagged_to_padded_dense_forward( 6948 values, [offsets, offsets], [max_length, max_length], padding_value 6949 ) 6950 6951 with self.assertRaisesRegex( 6952 RuntimeError, "only a single jagged dim is supported" 6953 ): 6954 torch.ops.aten._padded_dense_to_jagged_forward( 6955 padded, [offsets, offsets], total_L 6956 ) 6957 6958 # error case: > 1D offsets 6959 offsets2d = offsets.unsqueeze(-1) 6960 with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"): 6961 torch.ops.aten._jagged_to_padded_dense_forward( 6962 values, [offsets2d], [max_length], padding_value 6963 ) 6964 6965 with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"): 6966 torch.ops.aten._padded_dense_to_jagged_forward( 6967 padded, [offsets2d], total_L 6968 ) 6969 6970 # error case: final offset != total_L 6971 offsets_wrong = offsets.clone().detach() 6972 offsets_wrong[-1] = total_L + 1 6973 with self.assertRaisesRegex( 6974 RuntimeError, "final offset should match total_L value" 6975 ): 6976 torch.ops.aten._padded_dense_to_jagged_forward( 6977 padded, [offsets_wrong], total_L 6978 ) 6979 6980 # error case: 1D padded input 6981 padded_wrong = padded.flatten().clone().detach() 6982 with self.assertRaisesRegex(RuntimeError, "expected padded dim >= 2"): 6983 torch.ops.aten._padded_dense_to_jagged_forward( 6984 padded_wrong, [offsets], total_L 6985 ) 6986 6987 # error case: batch item has length > max length 6988 # max_length is 5 above; 7 here 6989 offsets_wrong = torch.tensor( 6990 [0, 1, 8, 9, 10], device=device, dtype=torch.int64 6991 ) 6992 with self.assertRaisesRegex(RuntimeError, "found batch item of length"): 6993 torch.ops.aten._padded_dense_to_jagged_forward( 6994 padded, [offsets_wrong], total_L 6995 ) 6996 6997 @dtypes(torch.float32) 6998 @skipIfTorchDynamo("Test compiles internally") 6999 @unittest.skipIf( 7000 sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" 7001 ) 7002 @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 7003 @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 7004 @skipCUDAIfRocm 7005 def test_compile_preserves_metadata_cache(self, device, dtype): 7006 # shape (B, *, D) 7007 nt = random_nt_from_dims( 7008 [4, None, 3, 16], 7009 device=device, 7010 dtype=dtype, 7011 layout=torch.jagged, 7012 requires_grad=True, 7013 ) 7014 7015 # expect min / max seqlen to be stored here 7016 cache = dict(nt._metadata_cache) 7017 7018 @torch.compile 7019 def f(nt): 7020 q = nt.transpose(-3, -2) 7021 output = F.scaled_dot_product_attention(q, q, q).transpose(-3, -2) 7022 return output 7023 7024 output = f(nt) 7025 output.backward(torch.ones_like(output)) 7026 self.assertEqual(output._metadata_cache, cache) 7027 7028 @dtypes(torch.float32) 7029 @skipIfTorchDynamo("Test compiles internally") 7030 @unittest.skipIf( 7031 sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" 7032 ) 7033 @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 7034 @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 7035 @skipCUDAIfRocm 7036 def test_compile_with_dynamic_max_seq_len(self, device, dtype): 7037 # shape (B, *, D) 7038 # max seq len: 18 7039 nt = torch.nested.nested_tensor( 7040 [ 7041 torch.randn(2, 5), 7042 torch.randn(3, 5), 7043 torch.randn(18, 5), 7044 ], 7045 layout=torch.jagged, 7046 ) 7047 7048 # max seq len: 19 7049 nt2 = torch.nested.nested_tensor( 7050 [ 7051 torch.randn(2, 5), 7052 torch.randn(3, 5), 7053 torch.randn(19, 5), 7054 ], 7055 layout=torch.jagged, 7056 ) 7057 7058 def f(nt): 7059 # TODO: Replace with public API when we can use @properties 7060 return torch.ones_like(nt) * nt._get_max_seqlen() 7061 7062 for dynamic in [False, True, None]: 7063 self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) 7064 7065 @dtypes(torch.float32) 7066 @skipIfTorchDynamo("Test compiles internally") 7067 @unittest.skipIf( 7068 sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" 7069 ) 7070 @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 7071 @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 7072 @skipCUDAIfRocm 7073 def test_compile_with_dynamic_min_seq_len(self, device, dtype): 7074 # shape (B, *, D) 7075 # min seq len: 7 7076 nt = torch.nested.nested_tensor( 7077 [ 7078 torch.randn(7, 5), 7079 torch.randn(8, 5), 7080 torch.randn(9, 5), 7081 ], 7082 layout=torch.jagged, 7083 ) 7084 7085 # min seq len: 8 7086 nt2 = torch.nested.nested_tensor( 7087 [ 7088 torch.randn(8, 5), 7089 torch.randn(9, 5), 7090 torch.randn(10, 5), 7091 ], 7092 layout=torch.jagged, 7093 ) 7094 7095 def f(nt): 7096 # TODO: Replace with public API when we can use @properties 7097 return torch.ones_like(nt) * nt._get_min_seqlen() 7098 7099 for dynamic in [False, True, None]: 7100 self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) 7101 7102 @dtypes(torch.float32) 7103 @skipIfTorchDynamo("Test compiles internally") 7104 @unittest.skipIf( 7105 sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" 7106 ) 7107 @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 7108 @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 7109 @skipCUDAIfRocm 7110 def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype): 7111 # shape (B, *, D) 7112 # max seq len: 18 7113 nt = torch.nested.nested_tensor( 7114 [ 7115 torch.randn(2, 5), 7116 torch.randn(3, 5), 7117 torch.randn(18, 5), 7118 ], 7119 layout=torch.jagged, 7120 ) 7121 7122 # max seq len: 19 7123 nt2 = torch.nested.nested_tensor( 7124 [ 7125 torch.randn(2, 5), 7126 torch.randn(3, 5), 7127 torch.randn(19, 5), 7128 ], 7129 layout=torch.jagged, 7130 ) 7131 7132 def f(nt): 7133 nt2 = nt.sin() + 1 7134 # TODO: Replace with public API when we can use @properties 7135 return torch.ones_like(nt2) * nt2._get_max_seqlen() 7136 7137 ref = f(nt) 7138 output = torch.compile(f, fullgraph=True, dynamic=False)(nt) 7139 self.assertEqual(ref, output) 7140 7141 for dynamic in [False, True, None]: 7142 self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) 7143 7144 @dtypes(torch.float32, torch.double, torch.half) 7145 def test_unbind_backward(self, device, dtype): 7146 nt = torch.nested.nested_tensor( 7147 [ 7148 torch.randn(2, 4, device=device), 7149 torch.randn(5, 4, device=device), 7150 torch.randn(3, 4, device=device), 7151 ], 7152 layout=torch.jagged, 7153 requires_grad=True, 7154 ) 7155 7156 a, b, c = nt.unbind() 7157 b.sum().backward() 7158 7159 @torch._dynamo.disable 7160 def check(nt): 7161 expected_grad = torch.zeros_like(nt) 7162 expected_grad.unbind()[1].add_(1.0) 7163 self.assertEqual(nt.grad, expected_grad) 7164 7165 check(nt) 7166 7167 7168FORWARD_FAILURES = { 7169 # === BEGIN NotImplementedError SECTION === 7170 # unary 7171 "nn.functional.celu", 7172 "nn.functional.elu", 7173 "nn.functional.hardshrink", 7174 "nn.functional.hardsigmoid", 7175 "nn.functional.hardtanh", 7176 "nn.functional.logsigmoid", 7177 "nn.functional.mish", 7178 "nn.functional.relu6", 7179 "nn.functional.rrelu", 7180 "nn.functional.selu", 7181 "nn.functional.softplus", 7182 "nn.functional.softshrink", 7183 "nn.functional.threshold", 7184 "rad2deg", 7185 # binary 7186 "__rsub__", 7187 "complex", 7188 "floor_divide", 7189 "polar", 7190 "rsub", 7191 # reduction 7192 "all", 7193 "amax", 7194 "amin", 7195 "any", 7196 "argmax", 7197 "argmin", 7198 "count_nonzero", 7199 "linalg.vector_norm", 7200 "nansum", 7201 "std", 7202 "std.unbiased", 7203 "var", 7204 "var.unbiased", 7205 # === BEGIN UNSUPPORTED SECTION === 7206 # RuntimeError: mean(): not supported for NestedTensor on dim=1 7207 "mean", 7208 # ValueError: expects strided tensor (got torch.jagged tensor) 7209 "masked.amax", 7210 "masked.amin", 7211 "masked.argmax", 7212 "masked.argmin", 7213 "masked.logsumexp", 7214 "masked.mean", 7215 "masked.norm", 7216 "masked.prod", 7217 "masked.std", 7218 "masked.sum", 7219 "masked.var", 7220 # === BEGIN BUG SECTION === 7221 # Returns a tuple of Tensors so it doesn't work with NJT's unary pointwise logic 7222 "frexp", 7223 # Need to adjust sample input func to pass the right thing 7224 "nn.functional.prelu", 7225 # TypeError: fill() received an invalid combination of arguments 7226 # got (NestedTensor), but expected one of: 7227 # * (Tensor input, Tensor value) 7228 # * (Tensor input, Number value) 7229 "fill", 7230 # RuntimeError: unsupported tensor layout: Jagged 7231 "jiterator_binary", 7232 "jiterator_binary_return_by_ref", 7233 "jiterator_unary", 7234 # Bug found: sum() with keepdim=True returns invalid shape 7235 "sum", 7236 # RuntimeError: prod(): keepdim=True must be set for NestedTensor 7237 "prod", 7238 # RuntimeError: "jagged_to_padded_dense" not implemented for 'Bool' 7239 "nanmean", 7240} 7241 7242BACKWARD_FAILURES = { 7243 *FORWARD_FAILURES, 7244 # TODO: categorize these 7245 "__rpow__", 7246 "atanh", 7247 "cdouble", 7248 "cfloat", 7249 "chalf", 7250 "clamp_max", 7251 "clamp_min", 7252 "copysign", 7253 "float_power", 7254 "max.binary", 7255 "maximum", 7256 "min.binary", 7257 "minimum", 7258 "pow", 7259 "sgn", 7260 "sinc", 7261 "special.i1", 7262 "special.i1e", 7263 # clone() on a "non-contiguous with holes" NJT allocates a new offsets -> new nested int 7264 # RuntimeError: Function CloneBackward0 returned an invalid gradient at index 0 - 7265 # got [3, j29, 5] but expected shape compatible with [3, j28, 5] 7266 "clone", 7267 # Calling into torch.ops.aten.size directly 7268 "masked_select", 7269} 7270 7271COMPILE_FORWARD_FAILURES = { 7272 *FORWARD_FAILURES, 7273 # clone() on non-contiguous with holes NJTs currently use unbind(), leading to 7274 # data-dependent error in torch.compile 7275 "clone", 7276} 7277 7278COMPARE_TENSOR_COMPONENT_EQUALITY = { 7279 # masked_select is expected to output a different shape 7280 "masked_select", 7281} 7282 7283 7284def withXFails(failure_list): 7285 return decorateIf( 7286 unittest.expectedFailure, 7287 lambda params: params["op"].full_name in failure_list, 7288 ) 7289 7290 7291# OpInfo-based NJT tests. These tests utilize an NJT-specific op_db generated from the standard 7292# op_db. Note that certain tradeoffs were made wrt coverage vs. time spent running tests: 7293# * All tests run with dtype=torch.float32 only 7294class TestNestedTensorOpInfo(NestedTensorTestCase): 7295 # TODO: move this 7296 def _gen_grad_outputs(self, out_val): 7297 if isinstance(out_val, (list, tuple)): 7298 return tuple(torch.ones_like(c) for c in out_val) 7299 else: 7300 return (torch.ones_like(out_val),) 7301 7302 @withXFails(FORWARD_FAILURES) 7303 @ops([op for op in njt_op_db if op.supports_njt], allowed_dtypes=(torch.float32,)) 7304 def test_forward(self, device, dtype, op): 7305 for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=False): 7306 # compare to reference, but expect different nested int 7307 out = op.op(sample.input, *sample.args, **sample.kwargs) 7308 out_ref = op.ref(op, sample) 7309 self.assertEqualIgnoringNestedInts(out, out_ref) 7310 7311 @withXFails(BACKWARD_FAILURES) 7312 @ops( 7313 [op for op in njt_op_db if op.supports_njt and op.supports_autograd], 7314 allowed_dtypes=(torch.float32,), 7315 ) 7316 def test_backward(self, device, dtype, op): 7317 for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=True): 7318 # compare to reference, but expect different nested int 7319 out = op.op(sample.input, *sample.args, **sample.kwargs) 7320 out_ref = op.ref(op, sample) 7321 self.assertEqualIgnoringNestedInts(out, out_ref) 7322 7323 inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs)) 7324 g_inps = [ 7325 inp 7326 for inp in inps 7327 if isinstance(inp, torch.Tensor) and inp.requires_grad 7328 ] 7329 if len(g_inps) > 0: 7330 grads = torch.autograd.grad( 7331 out, inputs=g_inps, grad_outputs=self._gen_grad_outputs(out) 7332 ) 7333 7334 grads_ref = torch.autograd.grad( 7335 out_ref, 7336 inputs=g_inps, 7337 grad_outputs=self._gen_grad_outputs(out_ref), 7338 ) 7339 7340 self.assertEqual(grads, grads_ref) 7341 7342 @withXFails(COMPILE_FORWARD_FAILURES) 7343 @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) 7344 @ops([op for op in njt_op_db if op.supports_njt], allowed_dtypes=(torch.float32,)) 7345 def test_compile_forward(self, device, dtype, op): 7346 for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=False): 7347 torch.compiler.reset() 7348 7349 op_fn = op.op 7350 7351 def f(*args, **kwargs): 7352 return op_fn(*args, **kwargs) 7353 7354 compiled_f = torch.compile( 7355 f, fullgraph=True, backend="aot_eager_decomp_partition" 7356 ) 7357 7358 out_ref = f(sample.input, *sample.args, **sample.kwargs) 7359 out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs) 7360 7361 if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY: 7362 self.assertEqualIgnoringNestedInts(out_compile, out_ref) 7363 else: 7364 self.assertEqual(out_compile, out_ref) 7365 7366 @withXFails(BACKWARD_FAILURES) 7367 @ops( 7368 [op for op in njt_op_db if op.supports_njt and op.supports_autograd], 7369 allowed_dtypes=(torch.float32,), 7370 ) 7371 @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) 7372 def test_compile_backward(self, device, dtype, op): 7373 for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=True): 7374 torch.compiler.reset() 7375 7376 op_fn = op.op 7377 7378 def f(*args, **kwargs): 7379 return op_fn(*args, **kwargs) 7380 7381 compiled_f = torch.compile( 7382 f, fullgraph=True, backend="aot_eager_decomp_partition" 7383 ) 7384 7385 out_ref = f(sample.input, *sample.args, **sample.kwargs) 7386 out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs) 7387 7388 self.assertEqual(out_compile, out_ref) 7389 7390 inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs)) 7391 g_inps = [ 7392 inp 7393 for inp in inps 7394 if isinstance(inp, torch.Tensor) and inp.requires_grad 7395 ] 7396 if len(g_inps) > 0: 7397 grads_compile = torch.autograd.grad( 7398 out_compile, 7399 inputs=g_inps, 7400 grad_outputs=self._gen_grad_outputs(out_compile), 7401 ) 7402 7403 grads_ref = torch.autograd.grad( 7404 out_ref, inputs=g_inps, grad_outputs=self._gen_grad_outputs(out_ref) 7405 ) 7406 7407 self.assertEqual(grads_compile, grads_ref) 7408 7409 7410instantiate_parametrized_tests(TestNestedTensor) 7411instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) 7412instantiate_device_type_tests(TestNestedTensorAutograd, globals()) 7413instantiate_device_type_tests(TestNestedTensorSubclass, globals()) 7414instantiate_device_type_tests(TestNestedTensorOpInfo, globals()) 7415 7416if __name__ == "__main__": 7417 run_tests() 7418