1# Owner(s): ["module: tests"] 2 3import operator 4import random 5import unittest 6import warnings 7from functools import reduce 8 9import numpy as np 10 11import torch 12from torch import tensor 13from torch.testing import make_tensor 14from torch.testing._internal.common_device_type import ( 15 dtypes, 16 dtypesIfCPU, 17 dtypesIfCUDA, 18 instantiate_device_type_tests, 19 onlyCUDA, 20 onlyNativeDeviceTypes, 21 skipXLA, 22) 23from torch.testing._internal.common_utils import ( 24 DeterministicGuard, 25 run_tests, 26 serialTest, 27 skipIfTorchDynamo, 28 TEST_CUDA, 29 TestCase, 30 xfailIfTorchDynamo, 31) 32 33 34class TestIndexing(TestCase): 35 def test_index(self, device): 36 def consec(size, start=1): 37 sequence = torch.ones(torch.tensor(size).prod(0)).cumsum(0) 38 sequence.add_(start - 1) 39 return sequence.view(*size) 40 41 reference = consec((3, 3, 3)).to(device) 42 43 # empty tensor indexing 44 self.assertEqual( 45 reference[torch.LongTensor().to(device)], reference.new(0, 3, 3) 46 ) 47 48 self.assertEqual(reference[0], consec((3, 3)), atol=0, rtol=0) 49 self.assertEqual(reference[1], consec((3, 3), 10), atol=0, rtol=0) 50 self.assertEqual(reference[2], consec((3, 3), 19), atol=0, rtol=0) 51 self.assertEqual(reference[0, 1], consec((3,), 4), atol=0, rtol=0) 52 self.assertEqual(reference[0:2], consec((2, 3, 3)), atol=0, rtol=0) 53 self.assertEqual(reference[2, 2, 2], 27, atol=0, rtol=0) 54 self.assertEqual(reference[:], consec((3, 3, 3)), atol=0, rtol=0) 55 56 # indexing with Ellipsis 57 self.assertEqual( 58 reference[..., 2], 59 torch.tensor([[3.0, 6.0, 9.0], [12.0, 15.0, 18.0], [21.0, 24.0, 27.0]]), 60 atol=0, 61 rtol=0, 62 ) 63 self.assertEqual( 64 reference[0, ..., 2], torch.tensor([3.0, 6.0, 9.0]), atol=0, rtol=0 65 ) 66 self.assertEqual(reference[..., 2], reference[:, :, 2], atol=0, rtol=0) 67 self.assertEqual(reference[0, ..., 2], reference[0, :, 2], atol=0, rtol=0) 68 self.assertEqual(reference[0, 2, ...], reference[0, 2], atol=0, rtol=0) 69 self.assertEqual(reference[..., 2, 2, 2], 27, atol=0, rtol=0) 70 self.assertEqual(reference[2, ..., 2, 2], 27, atol=0, rtol=0) 71 self.assertEqual(reference[2, 2, ..., 2], 27, atol=0, rtol=0) 72 self.assertEqual(reference[2, 2, 2, ...], 27, atol=0, rtol=0) 73 self.assertEqual(reference[...], reference, atol=0, rtol=0) 74 75 reference_5d = consec((3, 3, 3, 3, 3)).to(device) 76 self.assertEqual( 77 reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], atol=0, rtol=0 78 ) 79 self.assertEqual( 80 reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], atol=0, rtol=0 81 ) 82 self.assertEqual( 83 reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], atol=0, rtol=0 84 ) 85 self.assertEqual(reference_5d[...], reference_5d, atol=0, rtol=0) 86 87 # LongTensor indexing 88 reference = consec((5, 5, 5)).to(device) 89 idx = torch.LongTensor([2, 4]).to(device) 90 self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]])) 91 # TODO: enable one indexing is implemented like in numpy 92 # self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]])) 93 # self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1]) 94 95 # None indexing 96 self.assertEqual(reference[2, None], reference[2].unsqueeze(0)) 97 self.assertEqual( 98 reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0) 99 ) 100 self.assertEqual(reference[2:4, None], reference[2:4].unsqueeze(1)) 101 self.assertEqual( 102 reference[None, 2, None, None], 103 reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0), 104 ) 105 self.assertEqual( 106 reference[None, 2:5, None, None], 107 reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2), 108 ) 109 110 # indexing 0-length slice 111 self.assertEqual(torch.empty(0, 5, 5), reference[slice(0)]) 112 self.assertEqual(torch.empty(0, 5), reference[slice(0), 2]) 113 self.assertEqual(torch.empty(0, 5), reference[2, slice(0)]) 114 self.assertEqual(torch.tensor([]), reference[2, 1:1, 2]) 115 116 # indexing with step 117 reference = consec((10, 10, 10)).to(device) 118 self.assertEqual(reference[1:5:2], torch.stack([reference[1], reference[3]], 0)) 119 self.assertEqual( 120 reference[1:6:2], torch.stack([reference[1], reference[3], reference[5]], 0) 121 ) 122 self.assertEqual(reference[1:9:4], torch.stack([reference[1], reference[5]], 0)) 123 self.assertEqual( 124 reference[2:4, 1:5:2], 125 torch.stack([reference[2:4, 1], reference[2:4, 3]], 1), 126 ) 127 self.assertEqual( 128 reference[3, 1:6:2], 129 torch.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0), 130 ) 131 self.assertEqual( 132 reference[None, 2, 1:9:4], 133 torch.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0), 134 ) 135 self.assertEqual( 136 reference[:, 2, 1:6:2], 137 torch.stack( 138 [reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1 139 ), 140 ) 141 142 lst = [list(range(i, i + 10)) for i in range(0, 100, 10)] 143 tensor = torch.DoubleTensor(lst).to(device) 144 for _i in range(100): 145 idx1_start = random.randrange(10) 146 idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1) 147 idx1_step = random.randrange(1, 8) 148 idx1 = slice(idx1_start, idx1_end, idx1_step) 149 if random.randrange(2) == 0: 150 idx2_start = random.randrange(10) 151 idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1) 152 idx2_step = random.randrange(1, 8) 153 idx2 = slice(idx2_start, idx2_end, idx2_step) 154 lst_indexed = [l[idx2] for l in lst[idx1]] 155 tensor_indexed = tensor[idx1, idx2] 156 else: 157 lst_indexed = lst[idx1] 158 tensor_indexed = tensor[idx1] 159 self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed) 160 161 self.assertRaises(ValueError, lambda: reference[1:9:0]) 162 self.assertRaises(ValueError, lambda: reference[1:9:-1]) 163 164 self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1]) 165 self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1]) 166 self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3]) 167 168 self.assertRaises(IndexError, lambda: reference[0.0]) 169 self.assertRaises(TypeError, lambda: reference[0.0:2.0]) 170 self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0]) 171 self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0]) 172 self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0]) 173 self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0]) 174 175 def delitem(): 176 del reference[0] 177 178 self.assertRaises(TypeError, delitem) 179 180 @onlyNativeDeviceTypes 181 @dtypes(torch.half, torch.double) 182 def test_advancedindex(self, device, dtype): 183 # Tests for Integer Array Indexing, Part I - Purely integer array 184 # indexing 185 186 def consec(size, start=1): 187 # Creates the sequence in float since CPU half doesn't support the 188 # needed operations. Converts to dtype before returning. 189 numel = reduce(operator.mul, size, 1) 190 sequence = torch.ones(numel, dtype=torch.float, device=device).cumsum(0) 191 sequence.add_(start - 1) 192 return sequence.view(*size).to(dtype=dtype) 193 194 # pick a random valid indexer type 195 def ri(indices): 196 choice = random.randint(0, 2) 197 if choice == 0: 198 return torch.LongTensor(indices).to(device) 199 elif choice == 1: 200 return list(indices) 201 else: 202 return tuple(indices) 203 204 def validate_indexing(x): 205 self.assertEqual(x[[0]], consec((1,))) 206 self.assertEqual(x[ri([0]),], consec((1,))) 207 self.assertEqual(x[ri([3]),], consec((1,), 4)) 208 self.assertEqual(x[[2, 3, 4]], consec((3,), 3)) 209 self.assertEqual(x[ri([2, 3, 4]),], consec((3,), 3)) 210 self.assertEqual( 211 x[ri([0, 2, 4]),], torch.tensor([1, 3, 5], dtype=dtype, device=device) 212 ) 213 214 def validate_setting(x): 215 x[[0]] = -2 216 self.assertEqual(x[[0]], torch.tensor([-2], dtype=dtype, device=device)) 217 x[[0]] = -1 218 self.assertEqual( 219 x[ri([0]),], torch.tensor([-1], dtype=dtype, device=device) 220 ) 221 x[[2, 3, 4]] = 4 222 self.assertEqual( 223 x[[2, 3, 4]], torch.tensor([4, 4, 4], dtype=dtype, device=device) 224 ) 225 x[ri([2, 3, 4]),] = 3 226 self.assertEqual( 227 x[ri([2, 3, 4]),], torch.tensor([3, 3, 3], dtype=dtype, device=device) 228 ) 229 x[ri([0, 2, 4]),] = torch.tensor([5, 4, 3], dtype=dtype, device=device) 230 self.assertEqual( 231 x[ri([0, 2, 4]),], torch.tensor([5, 4, 3], dtype=dtype, device=device) 232 ) 233 234 # Only validates indexing and setting for halfs 235 if dtype == torch.half: 236 reference = consec((10,)) 237 validate_indexing(reference) 238 validate_setting(reference) 239 return 240 241 # Case 1: Purely Integer Array Indexing 242 reference = consec((10,)) 243 validate_indexing(reference) 244 245 # setting values 246 validate_setting(reference) 247 248 # Tensor with stride != 1 249 # strided is [1, 3, 5, 7] 250 reference = consec((10,)) 251 strided = torch.tensor((), dtype=dtype, device=device) 252 strided.set_( 253 reference.storage(), storage_offset=0, size=torch.Size([4]), stride=[2] 254 ) 255 256 self.assertEqual(strided[[0]], torch.tensor([1], dtype=dtype, device=device)) 257 self.assertEqual( 258 strided[ri([0]),], torch.tensor([1], dtype=dtype, device=device) 259 ) 260 self.assertEqual( 261 strided[ri([3]),], torch.tensor([7], dtype=dtype, device=device) 262 ) 263 self.assertEqual( 264 strided[[1, 2]], torch.tensor([3, 5], dtype=dtype, device=device) 265 ) 266 self.assertEqual( 267 strided[ri([1, 2]),], torch.tensor([3, 5], dtype=dtype, device=device) 268 ) 269 self.assertEqual( 270 strided[ri([[2, 1], [0, 3]]),], 271 torch.tensor([[5, 3], [1, 7]], dtype=dtype, device=device), 272 ) 273 274 # stride is [4, 8] 275 strided = torch.tensor((), dtype=dtype, device=device) 276 strided.set_( 277 reference.storage(), storage_offset=4, size=torch.Size([2]), stride=[4] 278 ) 279 self.assertEqual(strided[[0]], torch.tensor([5], dtype=dtype, device=device)) 280 self.assertEqual( 281 strided[ri([0]),], torch.tensor([5], dtype=dtype, device=device) 282 ) 283 self.assertEqual( 284 strided[ri([1]),], torch.tensor([9], dtype=dtype, device=device) 285 ) 286 self.assertEqual( 287 strided[[0, 1]], torch.tensor([5, 9], dtype=dtype, device=device) 288 ) 289 self.assertEqual( 290 strided[ri([0, 1]),], torch.tensor([5, 9], dtype=dtype, device=device) 291 ) 292 self.assertEqual( 293 strided[ri([[0, 1], [1, 0]]),], 294 torch.tensor([[5, 9], [9, 5]], dtype=dtype, device=device), 295 ) 296 297 # reference is 1 2 298 # 3 4 299 # 5 6 300 reference = consec((3, 2)) 301 self.assertEqual( 302 reference[ri([0, 1, 2]), ri([0])], 303 torch.tensor([1, 3, 5], dtype=dtype, device=device), 304 ) 305 self.assertEqual( 306 reference[ri([0, 1, 2]), ri([1])], 307 torch.tensor([2, 4, 6], dtype=dtype, device=device), 308 ) 309 self.assertEqual(reference[ri([0]), ri([0])], consec((1,))) 310 self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6)) 311 self.assertEqual( 312 reference[[ri([0, 0]), ri([0, 1])]], 313 torch.tensor([1, 2], dtype=dtype, device=device), 314 ) 315 self.assertEqual( 316 reference[[ri([0, 1, 1, 0, 2]), ri([1])]], 317 torch.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device), 318 ) 319 self.assertEqual( 320 reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], 321 torch.tensor([1, 2, 3, 3], dtype=dtype, device=device), 322 ) 323 324 rows = ri([[0, 0], [1, 2]]) 325 columns = ([0],) 326 self.assertEqual( 327 reference[rows, columns], 328 torch.tensor([[1, 1], [3, 5]], dtype=dtype, device=device), 329 ) 330 331 rows = ri([[0, 0], [1, 2]]) 332 columns = ri([1, 0]) 333 self.assertEqual( 334 reference[rows, columns], 335 torch.tensor([[2, 1], [4, 5]], dtype=dtype, device=device), 336 ) 337 rows = ri([[0, 0], [1, 2]]) 338 columns = ri([[0, 1], [1, 0]]) 339 self.assertEqual( 340 reference[rows, columns], 341 torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device), 342 ) 343 344 # setting values 345 reference[ri([0]), ri([1])] = -1 346 self.assertEqual( 347 reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device) 348 ) 349 reference[ri([0, 1, 2]), ri([0])] = torch.tensor( 350 [-1, 2, -4], dtype=dtype, device=device 351 ) 352 self.assertEqual( 353 reference[ri([0, 1, 2]), ri([0])], 354 torch.tensor([-1, 2, -4], dtype=dtype, device=device), 355 ) 356 reference[rows, columns] = torch.tensor( 357 [[4, 6], [2, 3]], dtype=dtype, device=device 358 ) 359 self.assertEqual( 360 reference[rows, columns], 361 torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device), 362 ) 363 364 # Verify still works with Transposed (i.e. non-contiguous) Tensors 365 366 reference = torch.tensor( 367 [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype, device=device 368 ).t_() 369 370 # Transposed: [[0, 4, 8], 371 # [1, 5, 9], 372 # [2, 6, 10], 373 # [3, 7, 11]] 374 375 self.assertEqual( 376 reference[ri([0, 1, 2]), ri([0])], 377 torch.tensor([0, 1, 2], dtype=dtype, device=device), 378 ) 379 self.assertEqual( 380 reference[ri([0, 1, 2]), ri([1])], 381 torch.tensor([4, 5, 6], dtype=dtype, device=device), 382 ) 383 self.assertEqual( 384 reference[ri([0]), ri([0])], torch.tensor([0], dtype=dtype, device=device) 385 ) 386 self.assertEqual( 387 reference[ri([2]), ri([1])], torch.tensor([6], dtype=dtype, device=device) 388 ) 389 self.assertEqual( 390 reference[[ri([0, 0]), ri([0, 1])]], 391 torch.tensor([0, 4], dtype=dtype, device=device), 392 ) 393 self.assertEqual( 394 reference[[ri([0, 1, 1, 0, 3]), ri([1])]], 395 torch.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device), 396 ) 397 self.assertEqual( 398 reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], 399 torch.tensor([0, 4, 1, 1], dtype=dtype, device=device), 400 ) 401 402 rows = ri([[0, 0], [1, 2]]) 403 columns = ([0],) 404 self.assertEqual( 405 reference[rows, columns], 406 torch.tensor([[0, 0], [1, 2]], dtype=dtype, device=device), 407 ) 408 409 rows = ri([[0, 0], [1, 2]]) 410 columns = ri([1, 0]) 411 self.assertEqual( 412 reference[rows, columns], 413 torch.tensor([[4, 0], [5, 2]], dtype=dtype, device=device), 414 ) 415 rows = ri([[0, 0], [1, 3]]) 416 columns = ri([[0, 1], [1, 2]]) 417 self.assertEqual( 418 reference[rows, columns], 419 torch.tensor([[0, 4], [5, 11]], dtype=dtype, device=device), 420 ) 421 422 # setting values 423 reference[ri([0]), ri([1])] = -1 424 self.assertEqual( 425 reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device) 426 ) 427 reference[ri([0, 1, 2]), ri([0])] = torch.tensor( 428 [-1, 2, -4], dtype=dtype, device=device 429 ) 430 self.assertEqual( 431 reference[ri([0, 1, 2]), ri([0])], 432 torch.tensor([-1, 2, -4], dtype=dtype, device=device), 433 ) 434 reference[rows, columns] = torch.tensor( 435 [[4, 6], [2, 3]], dtype=dtype, device=device 436 ) 437 self.assertEqual( 438 reference[rows, columns], 439 torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device), 440 ) 441 442 # stride != 1 443 444 # strided is [[1 3 5 7], 445 # [9 11 13 15]] 446 447 reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) 448 strided = torch.tensor((), dtype=dtype, device=device) 449 strided.set_(reference.storage(), 1, size=torch.Size([2, 4]), stride=[8, 2]) 450 451 self.assertEqual( 452 strided[ri([0, 1]), ri([0])], 453 torch.tensor([1, 9], dtype=dtype, device=device), 454 ) 455 self.assertEqual( 456 strided[ri([0, 1]), ri([1])], 457 torch.tensor([3, 11], dtype=dtype, device=device), 458 ) 459 self.assertEqual( 460 strided[ri([0]), ri([0])], torch.tensor([1], dtype=dtype, device=device) 461 ) 462 self.assertEqual( 463 strided[ri([1]), ri([3])], torch.tensor([15], dtype=dtype, device=device) 464 ) 465 self.assertEqual( 466 strided[[ri([0, 0]), ri([0, 3])]], 467 torch.tensor([1, 7], dtype=dtype, device=device), 468 ) 469 self.assertEqual( 470 strided[[ri([1]), ri([0, 1, 1, 0, 3])]], 471 torch.tensor([9, 11, 11, 9, 15], dtype=dtype, device=device), 472 ) 473 self.assertEqual( 474 strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], 475 torch.tensor([1, 3, 9, 9], dtype=dtype, device=device), 476 ) 477 478 rows = ri([[0, 0], [1, 1]]) 479 columns = ([0],) 480 self.assertEqual( 481 strided[rows, columns], 482 torch.tensor([[1, 1], [9, 9]], dtype=dtype, device=device), 483 ) 484 485 rows = ri([[0, 1], [1, 0]]) 486 columns = ri([1, 2]) 487 self.assertEqual( 488 strided[rows, columns], 489 torch.tensor([[3, 13], [11, 5]], dtype=dtype, device=device), 490 ) 491 rows = ri([[0, 0], [1, 1]]) 492 columns = ri([[0, 1], [1, 2]]) 493 self.assertEqual( 494 strided[rows, columns], 495 torch.tensor([[1, 3], [11, 13]], dtype=dtype, device=device), 496 ) 497 498 # setting values 499 500 # strided is [[10, 11], 501 # [17, 18]] 502 503 reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) 504 strided = torch.tensor((), dtype=dtype, device=device) 505 strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) 506 self.assertEqual( 507 strided[ri([0]), ri([1])], torch.tensor([11], dtype=dtype, device=device) 508 ) 509 strided[ri([0]), ri([1])] = -1 510 self.assertEqual( 511 strided[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device) 512 ) 513 514 reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) 515 strided = torch.tensor((), dtype=dtype, device=device) 516 strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) 517 self.assertEqual( 518 strided[ri([0, 1]), ri([1, 0])], 519 torch.tensor([11, 17], dtype=dtype, device=device), 520 ) 521 strided[ri([0, 1]), ri([1, 0])] = torch.tensor( 522 [-1, 2], dtype=dtype, device=device 523 ) 524 self.assertEqual( 525 strided[ri([0, 1]), ri([1, 0])], 526 torch.tensor([-1, 2], dtype=dtype, device=device), 527 ) 528 529 reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) 530 strided = torch.tensor((), dtype=dtype, device=device) 531 strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) 532 533 rows = ri([[0], [1]]) 534 columns = ri([[0, 1], [0, 1]]) 535 self.assertEqual( 536 strided[rows, columns], 537 torch.tensor([[10, 11], [17, 18]], dtype=dtype, device=device), 538 ) 539 strided[rows, columns] = torch.tensor( 540 [[4, 6], [2, 3]], dtype=dtype, device=device 541 ) 542 self.assertEqual( 543 strided[rows, columns], 544 torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device), 545 ) 546 547 # Tests using less than the number of dims, and ellipsis 548 549 # reference is 1 2 550 # 3 4 551 # 5 6 552 reference = consec((3, 2)) 553 self.assertEqual( 554 reference[ri([0, 2]),], 555 torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device), 556 ) 557 self.assertEqual( 558 reference[ri([1]), ...], torch.tensor([[3, 4]], dtype=dtype, device=device) 559 ) 560 self.assertEqual( 561 reference[..., ri([1])], 562 torch.tensor([[2], [4], [6]], dtype=dtype, device=device), 563 ) 564 565 # verify too many indices fails 566 with self.assertRaises(IndexError): 567 reference[ri([1]), ri([0, 2]), ri([3])] 568 569 # test invalid index fails 570 reference = torch.empty(10, dtype=dtype, device=device) 571 # can't test cuda because it is a device assert 572 if not reference.is_cuda: 573 for err_idx in (10, -11): 574 with self.assertRaisesRegex(IndexError, r"out of"): 575 reference[err_idx] 576 with self.assertRaisesRegex(IndexError, r"out of"): 577 reference[torch.LongTensor([err_idx]).to(device)] 578 with self.assertRaisesRegex(IndexError, r"out of"): 579 reference[[err_idx]] 580 581 def tensor_indices_to_np(tensor, indices): 582 # convert the Torch Tensor to a numpy array 583 tensor = tensor.to(device="cpu") 584 npt = tensor.numpy() 585 586 # convert indices 587 idxs = tuple( 588 i.tolist() if isinstance(i, torch.LongTensor) else i for i in indices 589 ) 590 591 return npt, idxs 592 593 def get_numpy(tensor, indices): 594 npt, idxs = tensor_indices_to_np(tensor, indices) 595 596 # index and return as a Torch Tensor 597 return torch.tensor(npt[idxs], dtype=dtype, device=device) 598 599 def set_numpy(tensor, indices, value): 600 if not isinstance(value, int): 601 if self.device_type != "cpu": 602 value = value.cpu() 603 value = value.numpy() 604 605 npt, idxs = tensor_indices_to_np(tensor, indices) 606 npt[idxs] = value 607 return npt 608 609 def assert_get_eq(tensor, indexer): 610 self.assertEqual(tensor[indexer], get_numpy(tensor, indexer)) 611 612 def assert_set_eq(tensor, indexer, val): 613 pyt = tensor.clone() 614 numt = tensor.clone() 615 pyt[indexer] = val 616 numt = torch.tensor( 617 set_numpy(numt, indexer, val), dtype=dtype, device=device 618 ) 619 self.assertEqual(pyt, numt) 620 621 def assert_backward_eq(tensor, indexer): 622 cpu = tensor.float().clone().detach().requires_grad_(True) 623 outcpu = cpu[indexer] 624 gOcpu = torch.rand_like(outcpu) 625 outcpu.backward(gOcpu) 626 dev = cpu.to(device).detach().requires_grad_(True) 627 outdev = dev[indexer] 628 outdev.backward(gOcpu.to(device)) 629 self.assertEqual(cpu.grad, dev.grad) 630 631 def get_set_tensor(indexed, indexer): 632 set_size = indexed[indexer].size() 633 set_count = indexed[indexer].numel() 634 set_tensor = torch.randperm(set_count).view(set_size).double().to(device) 635 return set_tensor 636 637 # Tensor is 0 1 2 3 4 638 # 5 6 7 8 9 639 # 10 11 12 13 14 640 # 15 16 17 18 19 641 reference = torch.arange(0.0, 20, dtype=dtype, device=device).view(4, 5) 642 643 indices_to_test = [ 644 # grab the second, fourth columns 645 [slice(None), [1, 3]], 646 # first, third rows, 647 [[0, 2], slice(None)], 648 # weird shape 649 [slice(None), [[0, 1], [2, 3]]], 650 # negatives 651 [[-1], [0]], 652 [[0, 2], [-1]], 653 [slice(None), [-1]], 654 ] 655 656 # only test dupes on gets 657 get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]] 658 659 for indexer in get_indices_to_test: 660 assert_get_eq(reference, indexer) 661 if self.device_type != "cpu": 662 assert_backward_eq(reference, indexer) 663 664 for indexer in indices_to_test: 665 assert_set_eq(reference, indexer, 44) 666 assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) 667 668 reference = torch.arange(0.0, 160, dtype=dtype, device=device).view(4, 8, 5) 669 670 indices_to_test = [ 671 [slice(None), slice(None), [0, 3, 4]], 672 [slice(None), [2, 4, 5, 7], slice(None)], 673 [[2, 3], slice(None), slice(None)], 674 [slice(None), [0, 2, 3], [1, 3, 4]], 675 [slice(None), [0], [1, 2, 4]], 676 [slice(None), [0, 1, 3], [4]], 677 [slice(None), [[0, 1], [1, 0]], [[2, 3]]], 678 [slice(None), [[0, 1], [2, 3]], [[0]]], 679 [slice(None), [[5, 6]], [[0, 3], [4, 4]]], 680 [[0, 2, 3], [1, 3, 4], slice(None)], 681 [[0], [1, 2, 4], slice(None)], 682 [[0, 1, 3], [4], slice(None)], 683 [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], 684 [[[0, 1], [1, 0]], [[2, 3]], slice(None)], 685 [[[0, 1], [2, 3]], [[0]], slice(None)], 686 [[[2, 1]], [[0, 3], [4, 4]], slice(None)], 687 [[[2]], [[0, 3], [4, 1]], slice(None)], 688 # non-contiguous indexing subspace 689 [[0, 2, 3], slice(None), [1, 3, 4]], 690 # [...] 691 # less dim, ellipsis 692 [[0, 2]], 693 [[0, 2], slice(None)], 694 [[0, 2], Ellipsis], 695 [[0, 2], slice(None), Ellipsis], 696 [[0, 2], Ellipsis, slice(None)], 697 [[0, 2], [1, 3]], 698 [[0, 2], [1, 3], Ellipsis], 699 [Ellipsis, [1, 3], [2, 3]], 700 [Ellipsis, [2, 3, 4]], 701 [Ellipsis, slice(None), [2, 3, 4]], 702 [slice(None), Ellipsis, [2, 3, 4]], 703 # ellipsis counts for nothing 704 [Ellipsis, slice(None), slice(None), [0, 3, 4]], 705 [slice(None), Ellipsis, slice(None), [0, 3, 4]], 706 [slice(None), slice(None), Ellipsis, [0, 3, 4]], 707 [slice(None), slice(None), [0, 3, 4], Ellipsis], 708 [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], 709 [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)], 710 [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis], 711 ] 712 713 for indexer in indices_to_test: 714 assert_get_eq(reference, indexer) 715 assert_set_eq(reference, indexer, 212) 716 assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) 717 if torch.cuda.is_available(): 718 assert_backward_eq(reference, indexer) 719 720 reference = torch.arange(0.0, 1296, dtype=dtype, device=device).view(3, 9, 8, 6) 721 722 indices_to_test = [ 723 [slice(None), slice(None), slice(None), [0, 3, 4]], 724 [slice(None), slice(None), [2, 4, 5, 7], slice(None)], 725 [slice(None), [2, 3], slice(None), slice(None)], 726 [[1, 2], slice(None), slice(None), slice(None)], 727 [slice(None), slice(None), [0, 2, 3], [1, 3, 4]], 728 [slice(None), slice(None), [0], [1, 2, 4]], 729 [slice(None), slice(None), [0, 1, 3], [4]], 730 [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]], 731 [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]], 732 [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]], 733 [slice(None), [0, 2, 3], [1, 3, 4], slice(None)], 734 [slice(None), [0], [1, 2, 4], slice(None)], 735 [slice(None), [0, 1, 3], [4], slice(None)], 736 [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)], 737 [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)], 738 [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)], 739 [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)], 740 [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)], 741 [[0, 1, 2], [1, 3, 4], slice(None), slice(None)], 742 [[0], [1, 2, 4], slice(None), slice(None)], 743 [[0, 1, 2], [4], slice(None), slice(None)], 744 [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)], 745 [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)], 746 [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)], 747 [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)], 748 [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]], 749 [slice(None), [2, 3, 4], [1, 3, 4], [4]], 750 [slice(None), [0, 1, 3], [4], [1, 3, 4]], 751 [slice(None), [6], [0, 2, 3], [1, 3, 4]], 752 [slice(None), [2, 3, 5], [3], [4]], 753 [slice(None), [0], [4], [1, 3, 4]], 754 [slice(None), [6], [0, 2, 3], [1]], 755 [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]], 756 [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)], 757 [[2, 0, 1], [1, 2, 3], [4], slice(None)], 758 [[0, 1, 2], [4], [1, 3, 4], slice(None)], 759 [[0], [0, 2, 3], [1, 3, 4], slice(None)], 760 [[0, 2, 1], [3], [4], slice(None)], 761 [[0], [4], [1, 3, 4], slice(None)], 762 [[1], [0, 2, 3], [1], slice(None)], 763 [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)], 764 # less dim, ellipsis 765 [Ellipsis, [0, 3, 4]], 766 [Ellipsis, slice(None), [0, 3, 4]], 767 [Ellipsis, slice(None), slice(None), [0, 3, 4]], 768 [slice(None), Ellipsis, [0, 3, 4]], 769 [slice(None), slice(None), Ellipsis, [0, 3, 4]], 770 [slice(None), [0, 2, 3], [1, 3, 4]], 771 [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis], 772 [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)], 773 [[0], [1, 2, 4]], 774 [[0], [1, 2, 4], slice(None)], 775 [[0], [1, 2, 4], Ellipsis], 776 [[0], [1, 2, 4], Ellipsis, slice(None)], 777 [[1]], 778 [[0, 2, 1], [3], [4]], 779 [[0, 2, 1], [3], [4], slice(None)], 780 [[0, 2, 1], [3], [4], Ellipsis], 781 [Ellipsis, [0, 2, 1], [3], [4]], 782 ] 783 784 for indexer in indices_to_test: 785 assert_get_eq(reference, indexer) 786 assert_set_eq(reference, indexer, 1333) 787 assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) 788 indices_to_test += [ 789 [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]], 790 [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]], 791 ] 792 for indexer in indices_to_test: 793 assert_get_eq(reference, indexer) 794 assert_set_eq(reference, indexer, 1333) 795 if self.device_type != "cpu": 796 assert_backward_eq(reference, indexer) 797 798 def test_advancedindex_big(self, device): 799 reference = torch.arange(0, 123344, dtype=torch.int, device=device) 800 801 self.assertEqual( 802 reference[[0, 123, 44488, 68807, 123343],], 803 torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int), 804 ) 805 806 def test_set_item_to_scalar_tensor(self, device): 807 m = random.randint(1, 10) 808 n = random.randint(1, 10) 809 z = torch.randn([m, n], device=device) 810 a = 1.0 811 w = torch.tensor(a, requires_grad=True, device=device) 812 z[:, 0] = w 813 z.sum().backward() 814 self.assertEqual(w.grad, m * a) 815 816 def test_single_int(self, device): 817 v = torch.randn(5, 7, 3, device=device) 818 self.assertEqual(v[4].shape, (7, 3)) 819 820 def test_multiple_int(self, device): 821 v = torch.randn(5, 7, 3, device=device) 822 self.assertEqual(v[4].shape, (7, 3)) 823 self.assertEqual(v[4, :, 1].shape, (7,)) 824 825 def test_none(self, device): 826 v = torch.randn(5, 7, 3, device=device) 827 self.assertEqual(v[None].shape, (1, 5, 7, 3)) 828 self.assertEqual(v[:, None].shape, (5, 1, 7, 3)) 829 self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3)) 830 self.assertEqual(v[..., None].shape, (5, 7, 3, 1)) 831 832 def test_step(self, device): 833 v = torch.arange(10, device=device) 834 self.assertEqual(v[::1], v) 835 self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8]) 836 self.assertEqual(v[::3].tolist(), [0, 3, 6, 9]) 837 self.assertEqual(v[::11].tolist(), [0]) 838 self.assertEqual(v[1:6:2].tolist(), [1, 3, 5]) 839 840 def test_step_assignment(self, device): 841 v = torch.zeros(4, 4, device=device) 842 v[0, 1::2] = torch.tensor([3.0, 4.0], device=device) 843 self.assertEqual(v[0].tolist(), [0, 3, 0, 4]) 844 self.assertEqual(v[1:].sum(), 0) 845 846 def test_bool_indices(self, device): 847 v = torch.randn(5, 7, 3, device=device) 848 boolIndices = torch.tensor( 849 [True, False, True, True, False], dtype=torch.bool, device=device 850 ) 851 self.assertEqual(v[boolIndices].shape, (3, 7, 3)) 852 self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]])) 853 854 v = torch.tensor([True, False, True], dtype=torch.bool, device=device) 855 boolIndices = torch.tensor( 856 [True, False, False], dtype=torch.bool, device=device 857 ) 858 uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device) 859 with warnings.catch_warnings(record=True) as w: 860 v1 = v[boolIndices] 861 v2 = v[uint8Indices] 862 self.assertEqual(v1.shape, v2.shape) 863 self.assertEqual(v1, v2) 864 self.assertEqual( 865 v[boolIndices], tensor([True], dtype=torch.bool, device=device) 866 ) 867 self.assertEqual(len(w), 1) 868 869 def test_bool_indices_accumulate(self, device): 870 mask = torch.zeros(size=(10,), dtype=torch.bool, device=device) 871 y = torch.ones(size=(10, 10), device=device) 872 y.index_put_((mask,), y[mask], accumulate=True) 873 self.assertEqual(y, torch.ones(size=(10, 10), device=device)) 874 875 def test_multiple_bool_indices(self, device): 876 v = torch.randn(5, 7, 3, device=device) 877 # note: these broadcast together and are transposed to the first dim 878 mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device) 879 mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device) 880 self.assertEqual(v[mask1, :, mask2].shape, (3, 7)) 881 882 def test_byte_mask(self, device): 883 v = torch.randn(5, 7, 3, device=device) 884 mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) 885 with warnings.catch_warnings(record=True) as w: 886 res = v[mask] 887 self.assertEqual(res.shape, (3, 7, 3)) 888 self.assertEqual(res, torch.stack([v[0], v[2], v[3]])) 889 self.assertEqual(len(w), 1) 890 891 v = torch.tensor([1.0], device=device) 892 self.assertEqual(v[v == 0], torch.tensor([], device=device)) 893 894 def test_byte_mask_accumulate(self, device): 895 mask = torch.zeros(size=(10,), dtype=torch.uint8, device=device) 896 y = torch.ones(size=(10, 10), device=device) 897 with warnings.catch_warnings(record=True) as w: 898 warnings.simplefilter("always") 899 y.index_put_((mask,), y[mask], accumulate=True) 900 self.assertEqual(y, torch.ones(size=(10, 10), device=device)) 901 self.assertEqual(len(w), 2) 902 903 @skipIfTorchDynamo( 904 "This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472" 905 ) 906 @serialTest(TEST_CUDA) 907 def test_index_put_accumulate_large_tensor(self, device): 908 # This test is for tensors with number of elements >= INT_MAX (2^31 - 1). 909 N = (1 << 31) + 5 910 dt = torch.int8 911 a = torch.ones(N, dtype=dt, device=device) 912 indices = torch.tensor( 913 [-2, 0, -2, -1, 0, -1, 1], device=device, dtype=torch.long 914 ) 915 values = torch.tensor([6, 5, 6, 6, 5, 7, 11], dtype=dt, device=device) 916 917 a.index_put_((indices,), values, accumulate=True) 918 919 self.assertEqual(a[0], 11) 920 self.assertEqual(a[1], 12) 921 self.assertEqual(a[2], 1) 922 self.assertEqual(a[-3], 1) 923 self.assertEqual(a[-2], 13) 924 self.assertEqual(a[-1], 14) 925 926 a = torch.ones((2, N), dtype=dt, device=device) 927 indices0 = torch.tensor([0, -1, 0, 1], device=device, dtype=torch.long) 928 indices1 = torch.tensor([-2, -1, 0, 1], device=device, dtype=torch.long) 929 values = torch.tensor([12, 13, 10, 11], dtype=dt, device=device) 930 931 a.index_put_((indices0, indices1), values, accumulate=True) 932 933 self.assertEqual(a[0, 0], 11) 934 self.assertEqual(a[0, 1], 1) 935 self.assertEqual(a[1, 0], 1) 936 self.assertEqual(a[1, 1], 12) 937 self.assertEqual(a[:, 2], torch.ones(2, dtype=torch.int8)) 938 self.assertEqual(a[:, -3], torch.ones(2, dtype=torch.int8)) 939 self.assertEqual(a[0, -2], 13) 940 self.assertEqual(a[1, -2], 1) 941 self.assertEqual(a[-1, -1], 14) 942 self.assertEqual(a[0, -1], 1) 943 944 @onlyNativeDeviceTypes 945 def test_index_put_accumulate_expanded_values(self, device): 946 # checks the issue with cuda: https://github.com/pytorch/pytorch/issues/39227 947 # and verifies consistency with CPU result 948 t = torch.zeros((5, 2)) 949 t_dev = t.to(device) 950 indices = [torch.tensor([0, 1, 2, 3]), torch.tensor([1])] 951 indices_dev = [i.to(device) for i in indices] 952 values0d = torch.tensor(1.0) 953 values1d = torch.tensor([1.0]) 954 955 out_cuda = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True) 956 out_cpu = t.index_put_(indices, values0d, accumulate=True) 957 self.assertEqual(out_cuda.cpu(), out_cpu) 958 959 out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True) 960 out_cpu = t.index_put_(indices, values1d, accumulate=True) 961 self.assertEqual(out_cuda.cpu(), out_cpu) 962 963 t = torch.zeros(4, 3, 2) 964 t_dev = t.to(device) 965 966 indices = [ 967 torch.tensor([0]), 968 torch.arange(3)[:, None], 969 torch.arange(2)[None, :], 970 ] 971 indices_dev = [i.to(device) for i in indices] 972 values1d = torch.tensor([-1.0, -2.0]) 973 values2d = torch.tensor([[-1.0, -2.0]]) 974 975 out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True) 976 out_cpu = t.index_put_(indices, values1d, accumulate=True) 977 self.assertEqual(out_cuda.cpu(), out_cpu) 978 979 out_cuda = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True) 980 out_cpu = t.index_put_(indices, values2d, accumulate=True) 981 self.assertEqual(out_cuda.cpu(), out_cpu) 982 983 @onlyCUDA 984 def test_index_put_accumulate_non_contiguous(self, device): 985 t = torch.zeros((5, 2, 2)) 986 t_dev = t.to(device) 987 t1 = t_dev[:, 0, :] 988 t2 = t[:, 0, :] 989 self.assertTrue(not t1.is_contiguous()) 990 self.assertTrue(not t2.is_contiguous()) 991 992 indices = [torch.tensor([0, 1])] 993 indices_dev = [i.to(device) for i in indices] 994 value = torch.randn(2, 2) 995 out_cuda = t1.index_put_(indices_dev, value.to(device), accumulate=True) 996 out_cpu = t2.index_put_(indices, value, accumulate=True) 997 self.assertTrue(not t1.is_contiguous()) 998 self.assertTrue(not t2.is_contiguous()) 999 1000 self.assertEqual(out_cuda.cpu(), out_cpu) 1001 1002 @onlyCUDA 1003 @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 1004 def test_index_put_accumulate_with_optional_tensors(self, device): 1005 # TODO: replace with a better solution. 1006 # Currently, here using torchscript to put None into indices. 1007 # on C++ it gives indices as a list of 2 optional tensors: first is null and 1008 # the second is a valid tensor. 1009 @torch.jit.script 1010 def func(x, i, v): 1011 idx = [None, i] 1012 x.index_put_(idx, v, accumulate=True) 1013 return x 1014 1015 n = 4 1016 t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2) 1017 t_dev = t.to(device) 1018 indices = torch.tensor([1, 0]) 1019 indices_dev = indices.to(device) 1020 value0d = torch.tensor(10.0) 1021 value1d = torch.tensor([1.0, 2.0]) 1022 1023 out_cuda = func(t_dev, indices_dev, value0d.cuda()) 1024 out_cpu = func(t, indices, value0d) 1025 self.assertEqual(out_cuda.cpu(), out_cpu) 1026 1027 out_cuda = func(t_dev, indices_dev, value1d.cuda()) 1028 out_cpu = func(t, indices, value1d) 1029 self.assertEqual(out_cuda.cpu(), out_cpu) 1030 1031 @onlyNativeDeviceTypes 1032 def test_index_put_accumulate_duplicate_indices(self, device): 1033 for i in range(1, 512): 1034 # generate indices by random walk, this will create indices with 1035 # lots of duplicates interleaved with each other 1036 delta = torch.empty(i, dtype=torch.double, device=device).uniform_(-1, 1) 1037 indices = delta.cumsum(0).long() 1038 1039 input = torch.randn(indices.abs().max() + 1, device=device) 1040 values = torch.randn(indices.size(0), device=device) 1041 output = input.index_put((indices,), values, accumulate=True) 1042 1043 input_list = input.tolist() 1044 indices_list = indices.tolist() 1045 values_list = values.tolist() 1046 for i, v in zip(indices_list, values_list): 1047 input_list[i] += v 1048 1049 self.assertEqual(output, input_list) 1050 1051 @onlyNativeDeviceTypes 1052 def test_index_ind_dtype(self, device): 1053 x = torch.randn(4, 4, device=device) 1054 ind_long = torch.randint(4, (4,), dtype=torch.long, device=device) 1055 ind_int = ind_long.int() 1056 src = torch.randn(4, device=device) 1057 ref = x[ind_long, ind_long] 1058 res = x[ind_int, ind_int] 1059 self.assertEqual(ref, res) 1060 ref = x[ind_long, :] 1061 res = x[ind_int, :] 1062 self.assertEqual(ref, res) 1063 ref = x[:, ind_long] 1064 res = x[:, ind_int] 1065 self.assertEqual(ref, res) 1066 # no repeating indices for index_put 1067 ind_long = torch.arange(4, dtype=torch.long, device=device) 1068 ind_int = ind_long.int() 1069 for accum in (True, False): 1070 inp_ref = x.clone() 1071 inp_res = x.clone() 1072 torch.index_put_(inp_ref, (ind_long, ind_long), src, accum) 1073 torch.index_put_(inp_res, (ind_int, ind_int), src, accum) 1074 self.assertEqual(inp_ref, inp_res) 1075 1076 @skipXLA 1077 def test_index_put_accumulate_empty(self, device): 1078 # Regression test for https://github.com/pytorch/pytorch/issues/94667 1079 input = torch.rand([], dtype=torch.float32, device=device) 1080 with self.assertRaises(RuntimeError): 1081 input.index_put([], torch.tensor([1.0], device=device), True) 1082 1083 def test_multiple_byte_mask(self, device): 1084 v = torch.randn(5, 7, 3, device=device) 1085 # note: these broadcast together and are transposed to the first dim 1086 mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) 1087 mask2 = torch.ByteTensor([1, 1, 1]).to(device) 1088 with warnings.catch_warnings(record=True) as w: 1089 warnings.simplefilter("always") 1090 self.assertEqual(v[mask1, :, mask2].shape, (3, 7)) 1091 self.assertEqual(len(w), 2) 1092 1093 def test_byte_mask2d(self, device): 1094 v = torch.randn(5, 7, 3, device=device) 1095 c = torch.randn(5, 7, device=device) 1096 num_ones = (c > 0).sum() 1097 r = v[c > 0] 1098 self.assertEqual(r.shape, (num_ones, 3)) 1099 1100 @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 1101 def test_jit_indexing(self, device): 1102 def fn1(x): 1103 x[x < 50] = 1.0 1104 return x 1105 1106 def fn2(x): 1107 x[0:50] = 1.0 1108 return x 1109 1110 scripted_fn1 = torch.jit.script(fn1) 1111 scripted_fn2 = torch.jit.script(fn2) 1112 data = torch.arange(100, device=device, dtype=torch.float) 1113 out = scripted_fn1(data.detach().clone()) 1114 ref = torch.tensor( 1115 np.concatenate((np.ones(50), np.arange(50, 100))), 1116 device=device, 1117 dtype=torch.float, 1118 ) 1119 self.assertEqual(out, ref) 1120 out = scripted_fn2(data.detach().clone()) 1121 self.assertEqual(out, ref) 1122 1123 def test_int_indices(self, device): 1124 v = torch.randn(5, 7, 3, device=device) 1125 self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3)) 1126 self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3)) 1127 self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3)) 1128 1129 @dtypes( 1130 torch.cfloat, torch.cdouble, torch.float, torch.bfloat16, torch.long, torch.bool 1131 ) 1132 @dtypesIfCPU( 1133 torch.cfloat, torch.cdouble, torch.float, torch.long, torch.bool, torch.bfloat16 1134 ) 1135 @dtypesIfCUDA( 1136 torch.cfloat, 1137 torch.cdouble, 1138 torch.half, 1139 torch.long, 1140 torch.bool, 1141 torch.bfloat16, 1142 torch.float8_e5m2, 1143 torch.float8_e4m3fn, 1144 ) 1145 def test_index_put_src_datatype(self, device, dtype): 1146 src = torch.ones(3, 2, 4, device=device, dtype=dtype) 1147 vals = torch.ones(3, 2, 4, device=device, dtype=dtype) 1148 indices = (torch.tensor([0, 2, 1]),) 1149 res = src.index_put_(indices, vals, accumulate=True) 1150 self.assertEqual(res.shape, src.shape) 1151 1152 @dtypes(torch.float, torch.bfloat16, torch.long, torch.bool) 1153 @dtypesIfCPU(torch.float, torch.long, torch.bfloat16, torch.bool) 1154 @dtypesIfCUDA(torch.half, torch.long, torch.bfloat16, torch.bool) 1155 def test_index_src_datatype(self, device, dtype): 1156 src = torch.ones(3, 2, 4, device=device, dtype=dtype) 1157 # test index 1158 res = src[[0, 2, 1], :, :] 1159 self.assertEqual(res.shape, src.shape) 1160 # test index_put, no accum 1161 src[[0, 2, 1], :, :] = res 1162 self.assertEqual(res.shape, src.shape) 1163 1164 def test_int_indices2d(self, device): 1165 # From the NumPy indexing example 1166 x = torch.arange(0, 12, device=device).view(4, 3) 1167 rows = torch.tensor([[0, 0], [3, 3]], device=device) 1168 columns = torch.tensor([[0, 2], [0, 2]], device=device) 1169 self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]]) 1170 1171 def test_int_indices_broadcast(self, device): 1172 # From the NumPy indexing example 1173 x = torch.arange(0, 12, device=device).view(4, 3) 1174 rows = torch.tensor([0, 3], device=device) 1175 columns = torch.tensor([0, 2], device=device) 1176 result = x[rows[:, None], columns] 1177 self.assertEqual(result.tolist(), [[0, 2], [9, 11]]) 1178 1179 def test_empty_index(self, device): 1180 x = torch.arange(0, 12, device=device).view(4, 3) 1181 idx = torch.tensor([], dtype=torch.long, device=device) 1182 self.assertEqual(x[idx].numel(), 0) 1183 1184 # empty assignment should have no effect but not throw an exception 1185 y = x.clone() 1186 y[idx] = -1 1187 self.assertEqual(x, y) 1188 1189 mask = torch.zeros(4, 3, device=device).bool() 1190 y[mask] = -1 1191 self.assertEqual(x, y) 1192 1193 def test_empty_ndim_index(self, device): 1194 x = torch.randn(5, device=device) 1195 self.assertEqual( 1196 torch.empty(0, 2, device=device), 1197 x[torch.empty(0, 2, dtype=torch.int64, device=device)], 1198 ) 1199 1200 x = torch.randn(2, 3, 4, 5, device=device) 1201 self.assertEqual( 1202 torch.empty(2, 0, 6, 4, 5, device=device), 1203 x[:, torch.empty(0, 6, dtype=torch.int64, device=device)], 1204 ) 1205 1206 x = torch.empty(10, 0, device=device) 1207 self.assertEqual(x[[1, 2]].shape, (2, 0)) 1208 self.assertEqual(x[[], []].shape, (0,)) 1209 with self.assertRaisesRegex(IndexError, "for dimension with size 0"): 1210 x[:, [0, 1]] 1211 1212 def test_empty_ndim_index_bool(self, device): 1213 x = torch.randn(5, device=device) 1214 self.assertRaises( 1215 IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)] 1216 ) 1217 1218 def test_empty_slice(self, device): 1219 x = torch.randn(2, 3, 4, 5, device=device) 1220 y = x[:, :, :, 1] 1221 z = y[:, 1:1, :] 1222 self.assertEqual((2, 0, 4), z.shape) 1223 # this isn't technically necessary, but matches NumPy stride calculations. 1224 self.assertEqual((60, 20, 5), z.stride()) 1225 self.assertTrue(z.is_contiguous()) 1226 1227 def test_index_getitem_copy_bools_slices(self, device): 1228 true = torch.tensor(1, dtype=torch.uint8, device=device) 1229 false = torch.tensor(0, dtype=torch.uint8, device=device) 1230 1231 tensors = [torch.randn(2, 3, device=device), torch.tensor(3.0, device=device)] 1232 1233 for a in tensors: 1234 self.assertNotEqual(a.data_ptr(), a[True].data_ptr()) 1235 self.assertEqual(torch.empty(0, *a.shape), a[False]) 1236 self.assertNotEqual(a.data_ptr(), a[true].data_ptr()) 1237 self.assertEqual(torch.empty(0, *a.shape), a[false]) 1238 self.assertEqual(a.data_ptr(), a[None].data_ptr()) 1239 self.assertEqual(a.data_ptr(), a[...].data_ptr()) 1240 1241 def test_index_setitem_bools_slices(self, device): 1242 true = torch.tensor(1, dtype=torch.uint8, device=device) 1243 false = torch.tensor(0, dtype=torch.uint8, device=device) 1244 1245 tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)] 1246 1247 for a in tensors: 1248 # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s 1249 # (some of these ops already prefix a 1 to the size) 1250 neg_ones = torch.ones_like(a) * -1 1251 neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0) 1252 a[True] = neg_ones_expanded 1253 self.assertEqual(a, neg_ones) 1254 a[False] = 5 1255 self.assertEqual(a, neg_ones) 1256 a[true] = neg_ones_expanded * 2 1257 self.assertEqual(a, neg_ones * 2) 1258 a[false] = 5 1259 self.assertEqual(a, neg_ones * 2) 1260 a[None] = neg_ones_expanded * 3 1261 self.assertEqual(a, neg_ones * 3) 1262 a[...] = neg_ones_expanded * 4 1263 self.assertEqual(a, neg_ones * 4) 1264 if a.dim() == 0: 1265 with self.assertRaises(IndexError): 1266 a[:] = neg_ones_expanded * 5 1267 1268 def test_index_scalar_with_bool_mask(self, device): 1269 a = torch.tensor(1, device=device) 1270 uintMask = torch.tensor(True, dtype=torch.uint8, device=device) 1271 boolMask = torch.tensor(True, dtype=torch.bool, device=device) 1272 self.assertEqual(a[uintMask], a[boolMask]) 1273 self.assertEqual(a[uintMask].dtype, a[boolMask].dtype) 1274 1275 a = torch.tensor(True, dtype=torch.bool, device=device) 1276 self.assertEqual(a[uintMask], a[boolMask]) 1277 self.assertEqual(a[uintMask].dtype, a[boolMask].dtype) 1278 1279 def test_setitem_expansion_error(self, device): 1280 true = torch.tensor(True, device=device) 1281 a = torch.randn(2, 3, device=device) 1282 # check prefix with non-1s doesn't work 1283 a_expanded = a.expand(torch.Size([5, 1]) + a.size()) 1284 # NumPy: ValueError 1285 with self.assertRaises(RuntimeError): 1286 a[True] = a_expanded 1287 with self.assertRaises(RuntimeError): 1288 a[true] = a_expanded 1289 1290 def test_getitem_scalars(self, device): 1291 zero = torch.tensor(0, dtype=torch.int64, device=device) 1292 one = torch.tensor(1, dtype=torch.int64, device=device) 1293 1294 # non-scalar indexed with scalars 1295 a = torch.randn(2, 3, device=device) 1296 self.assertEqual(a[0], a[zero]) 1297 self.assertEqual(a[0][1], a[zero][one]) 1298 self.assertEqual(a[0, 1], a[zero, one]) 1299 self.assertEqual(a[0, one], a[zero, 1]) 1300 1301 # indexing by a scalar should slice (not copy) 1302 self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr()) 1303 self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr()) 1304 self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr()) 1305 1306 # scalar indexed with scalar 1307 r = torch.randn((), device=device) 1308 with self.assertRaises(IndexError): 1309 r[:] 1310 with self.assertRaises(IndexError): 1311 r[zero] 1312 self.assertEqual(r, r[...]) 1313 1314 def test_setitem_scalars(self, device): 1315 zero = torch.tensor(0, dtype=torch.int64) 1316 1317 # non-scalar indexed with scalars 1318 a = torch.randn(2, 3, device=device) 1319 a_set_with_number = a.clone() 1320 a_set_with_scalar = a.clone() 1321 b = torch.randn(3, device=device) 1322 1323 a_set_with_number[0] = b 1324 a_set_with_scalar[zero] = b 1325 self.assertEqual(a_set_with_number, a_set_with_scalar) 1326 a[1, zero] = 7.7 1327 self.assertEqual(7.7, a[1, 0]) 1328 1329 # scalar indexed with scalars 1330 r = torch.randn((), device=device) 1331 with self.assertRaises(IndexError): 1332 r[:] = 8.8 1333 with self.assertRaises(IndexError): 1334 r[zero] = 8.8 1335 r[...] = 9.9 1336 self.assertEqual(9.9, r) 1337 1338 def test_basic_advanced_combined(self, device): 1339 # From the NumPy indexing example 1340 x = torch.arange(0, 12, device=device).view(4, 3) 1341 self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]]) 1342 self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]]) 1343 1344 # Check that it is a copy 1345 unmodified = x.clone() 1346 x[1:2, [1, 2]].zero_() 1347 self.assertEqual(x, unmodified) 1348 1349 # But assignment should modify the original 1350 unmodified = x.clone() 1351 x[1:2, [1, 2]] = 0 1352 self.assertNotEqual(x, unmodified) 1353 1354 def test_int_assignment(self, device): 1355 x = torch.arange(0, 4, device=device).view(2, 2) 1356 x[1] = 5 1357 self.assertEqual(x.tolist(), [[0, 1], [5, 5]]) 1358 1359 x = torch.arange(0, 4, device=device).view(2, 2) 1360 x[1] = torch.arange(5, 7, device=device) 1361 self.assertEqual(x.tolist(), [[0, 1], [5, 6]]) 1362 1363 def test_byte_tensor_assignment(self, device): 1364 x = torch.arange(0.0, 16, device=device).view(4, 4) 1365 b = torch.ByteTensor([True, False, True, False]).to(device) 1366 value = torch.tensor([3.0, 4.0, 5.0, 6.0], device=device) 1367 1368 with warnings.catch_warnings(record=True) as w: 1369 x[b] = value 1370 self.assertEqual(len(w), 1) 1371 1372 self.assertEqual(x[0], value) 1373 self.assertEqual(x[1], torch.arange(4.0, 8, device=device)) 1374 self.assertEqual(x[2], value) 1375 self.assertEqual(x[3], torch.arange(12.0, 16, device=device)) 1376 1377 def test_variable_slicing(self, device): 1378 x = torch.arange(0, 16, device=device).view(4, 4) 1379 indices = torch.IntTensor([0, 1]).to(device) 1380 i, j = indices 1381 self.assertEqual(x[i:j], x[0:1]) 1382 1383 def test_ellipsis_tensor(self, device): 1384 x = torch.arange(0, 9, device=device).view(3, 3) 1385 idx = torch.tensor([0, 2], device=device) 1386 self.assertEqual(x[..., idx].tolist(), [[0, 2], [3, 5], [6, 8]]) 1387 self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2], [6, 7, 8]]) 1388 1389 def test_unravel_index_errors(self, device): 1390 with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"): 1391 torch.unravel_index(torch.tensor(0.5, device=device), (2, 2)) 1392 1393 with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"): 1394 torch.unravel_index(torch.tensor([], device=device), (10, 3, 5)) 1395 1396 with self.assertRaisesRegex( 1397 TypeError, r"expected 'shape' to be int or sequence" 1398 ): 1399 torch.unravel_index( 1400 torch.tensor([1], device=device, dtype=torch.int64), 1401 torch.tensor([1, 2, 3]), 1402 ) 1403 1404 with self.assertRaisesRegex( 1405 TypeError, r"expected 'shape' sequence to only contain ints" 1406 ): 1407 torch.unravel_index( 1408 torch.tensor([1], device=device, dtype=torch.int64), (1, 2, 2.0) 1409 ) 1410 1411 with self.assertRaisesRegex( 1412 ValueError, r"'shape' cannot have negative values, but got \(2, -3\)" 1413 ): 1414 torch.unravel_index(torch.tensor(0, device=device), (2, -3)) 1415 1416 def test_invalid_index(self, device): 1417 x = torch.arange(0, 16, device=device).view(4, 4) 1418 self.assertRaisesRegex(TypeError, "slice indices", lambda: x["0":"1"]) 1419 1420 def test_out_of_bound_index(self, device): 1421 x = torch.arange(0, 100, device=device).view(2, 5, 10) 1422 self.assertRaisesRegex( 1423 IndexError, 1424 "index 5 is out of bounds for dimension 1 with size 5", 1425 lambda: x[0, 5], 1426 ) 1427 self.assertRaisesRegex( 1428 IndexError, 1429 "index 4 is out of bounds for dimension 0 with size 2", 1430 lambda: x[4, 5], 1431 ) 1432 self.assertRaisesRegex( 1433 IndexError, 1434 "index 15 is out of bounds for dimension 2 with size 10", 1435 lambda: x[0, 1, 15], 1436 ) 1437 self.assertRaisesRegex( 1438 IndexError, 1439 "index 12 is out of bounds for dimension 2 with size 10", 1440 lambda: x[:, :, 12], 1441 ) 1442 1443 def test_zero_dim_index(self, device): 1444 x = torch.tensor(10, device=device) 1445 self.assertEqual(x, x.item()) 1446 1447 def runner(): 1448 print(x[0]) 1449 return x[0] 1450 1451 self.assertRaisesRegex(IndexError, "invalid index", runner) 1452 1453 @onlyCUDA 1454 def test_invalid_device(self, device): 1455 idx = torch.tensor([0, 1]) 1456 b = torch.zeros(5, device=device) 1457 c = torch.tensor([1.0, 2.0], device="cpu") 1458 1459 for accumulate in [True, False]: 1460 self.assertRaises( 1461 RuntimeError, 1462 lambda: torch.index_put_(b, (idx,), c, accumulate=accumulate), 1463 ) 1464 1465 @onlyCUDA 1466 def test_cpu_indices(self, device): 1467 idx = torch.tensor([0, 1]) 1468 b = torch.zeros(2, device=device) 1469 x = torch.ones(10, device=device) 1470 x[idx] = b # index_put_ 1471 ref = torch.ones(10, device=device) 1472 ref[:2] = 0 1473 self.assertEqual(x, ref, atol=0, rtol=0) 1474 out = x[idx] # index 1475 self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0) 1476 1477 @dtypes(torch.long, torch.float32) 1478 def test_take_along_dim(self, device, dtype): 1479 def _test_against_numpy(t, indices, dim): 1480 actual = torch.take_along_dim(t, indices, dim=dim) 1481 t_np = t.cpu().numpy() 1482 indices_np = indices.cpu().numpy() 1483 expected = np.take_along_axis(t_np, indices_np, axis=dim) 1484 self.assertEqual(actual, expected, atol=0, rtol=0) 1485 1486 for shape in [(3, 2), (2, 3, 5), (2, 4, 0), (2, 3, 1, 4)]: 1487 for noncontiguous in [True, False]: 1488 t = make_tensor( 1489 shape, device=device, dtype=dtype, noncontiguous=noncontiguous 1490 ) 1491 for dim in list(range(t.ndim)) + [None]: 1492 if dim is None: 1493 indices = torch.argsort(t.view(-1)) 1494 else: 1495 indices = torch.argsort(t, dim=dim) 1496 1497 _test_against_numpy(t, indices, dim) 1498 1499 # test broadcasting 1500 t = torch.ones((3, 4, 1), device=device) 1501 indices = torch.ones((1, 2, 5), dtype=torch.long, device=device) 1502 1503 _test_against_numpy(t, indices, 1) 1504 1505 # test empty indices 1506 t = torch.ones((3, 4, 5), device=device) 1507 indices = torch.ones((3, 0, 5), dtype=torch.long, device=device) 1508 1509 _test_against_numpy(t, indices, 1) 1510 1511 @dtypes(torch.long, torch.float) 1512 def test_take_along_dim_invalid(self, device, dtype): 1513 shape = (2, 3, 1, 4) 1514 dim = 0 1515 t = make_tensor(shape, device=device, dtype=dtype) 1516 indices = torch.argsort(t, dim=dim) 1517 1518 # dim of `t` and `indices` does not match 1519 with self.assertRaisesRegex( 1520 RuntimeError, "input and indices should have the same number of dimensions" 1521 ): 1522 torch.take_along_dim(t, indices[0], dim=0) 1523 1524 # invalid `indices` dtype 1525 with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"): 1526 torch.take_along_dim(t, indices.to(torch.bool), dim=0) 1527 1528 with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"): 1529 torch.take_along_dim(t, indices.to(torch.float), dim=0) 1530 1531 with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"): 1532 torch.take_along_dim(t, indices.to(torch.int32), dim=0) 1533 1534 # invalid axis 1535 with self.assertRaisesRegex(IndexError, "Dimension out of range"): 1536 torch.take_along_dim(t, indices, dim=-7) 1537 1538 with self.assertRaisesRegex(IndexError, "Dimension out of range"): 1539 torch.take_along_dim(t, indices, dim=7) 1540 1541 @onlyCUDA 1542 @dtypes(torch.float) 1543 def test_gather_take_along_dim_cross_device(self, device, dtype): 1544 shape = (2, 3, 1, 4) 1545 dim = 0 1546 t = make_tensor(shape, device=device, dtype=dtype) 1547 indices = torch.argsort(t, dim=dim) 1548 1549 with self.assertRaisesRegex( 1550 RuntimeError, "Expected all tensors to be on the same device" 1551 ): 1552 torch.gather(t, 0, indices.cpu()) 1553 1554 with self.assertRaisesRegex( 1555 RuntimeError, 1556 r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()", 1557 ): 1558 torch.take_along_dim(t, indices.cpu(), dim=0) 1559 1560 with self.assertRaisesRegex( 1561 RuntimeError, "Expected all tensors to be on the same device" 1562 ): 1563 torch.gather(t.cpu(), 0, indices) 1564 1565 with self.assertRaisesRegex( 1566 RuntimeError, 1567 r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()", 1568 ): 1569 torch.take_along_dim(t.cpu(), indices, dim=0) 1570 1571 @onlyCUDA 1572 def test_cuda_broadcast_index_use_deterministic_algorithms(self, device): 1573 with DeterministicGuard(True): 1574 idx1 = torch.tensor([0]) 1575 idx2 = torch.tensor([2, 6]) 1576 idx3 = torch.tensor([1, 5, 7]) 1577 1578 tensor_a = torch.rand(13, 11, 12, 13, 12).cpu() 1579 tensor_b = tensor_a.to(device=device) 1580 tensor_a[idx1] = 1.0 1581 tensor_a[idx1, :, idx2, idx2, :] = 2.0 1582 tensor_a[:, idx1, idx3, :, idx3] = 3.0 1583 tensor_b[idx1] = 1.0 1584 tensor_b[idx1, :, idx2, idx2, :] = 2.0 1585 tensor_b[:, idx1, idx3, :, idx3] = 3.0 1586 self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0) 1587 1588 tensor_a = torch.rand(10, 11).cpu() 1589 tensor_b = tensor_a.to(device=device) 1590 tensor_a[idx3] = 1.0 1591 tensor_a[idx2, :] = 2.0 1592 tensor_a[:, idx2] = 3.0 1593 tensor_a[:, idx1] = 4.0 1594 tensor_b[idx3] = 1.0 1595 tensor_b[idx2, :] = 2.0 1596 tensor_b[:, idx2] = 3.0 1597 tensor_b[:, idx1] = 4.0 1598 self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0) 1599 1600 tensor_a = torch.rand(10, 10).cpu() 1601 tensor_b = tensor_a.to(device=device) 1602 tensor_a[[8]] = 1.0 1603 tensor_b[[8]] = 1.0 1604 self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0) 1605 1606 tensor_a = torch.rand(10).cpu() 1607 tensor_b = tensor_a.to(device=device) 1608 tensor_a[6] = 1.0 1609 tensor_b[6] = 1.0 1610 self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0) 1611 1612 def test_index_limits(self, device): 1613 # Regression test for https://github.com/pytorch/pytorch/issues/115415 1614 t = torch.tensor([], device=device) 1615 idx_min = torch.iinfo(torch.int64).min 1616 idx_max = torch.iinfo(torch.int64).max 1617 self.assertRaises(IndexError, lambda: t[idx_min]) 1618 self.assertRaises(IndexError, lambda: t[idx_max]) 1619 1620 1621# The tests below are from NumPy test_indexing.py with some modifications to 1622# make them compatible with PyTorch. It's licensed under the BDS license below: 1623# 1624# Copyright (c) 2005-2017, NumPy Developers. 1625# All rights reserved. 1626# 1627# Redistribution and use in source and binary forms, with or without 1628# modification, are permitted provided that the following conditions are 1629# met: 1630# 1631# * Redistributions of source code must retain the above copyright 1632# notice, this list of conditions and the following disclaimer. 1633# 1634# * Redistributions in binary form must reproduce the above 1635# copyright notice, this list of conditions and the following 1636# disclaimer in the documentation and/or other materials provided 1637# with the distribution. 1638# 1639# * Neither the name of the NumPy Developers nor the names of any 1640# contributors may be used to endorse or promote products derived 1641# from this software without specific prior written permission. 1642# 1643# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 1644# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 1645# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 1646# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 1647# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 1648# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 1649# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 1650# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 1651# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 1652# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 1653# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 1654 1655 1656class NumpyTests(TestCase): 1657 def test_index_no_floats(self, device): 1658 a = torch.tensor([[[5.0]]], device=device) 1659 1660 self.assertRaises(IndexError, lambda: a[0.0]) 1661 self.assertRaises(IndexError, lambda: a[0, 0.0]) 1662 self.assertRaises(IndexError, lambda: a[0.0, 0]) 1663 self.assertRaises(IndexError, lambda: a[0.0, :]) 1664 self.assertRaises(IndexError, lambda: a[:, 0.0]) 1665 self.assertRaises(IndexError, lambda: a[:, 0.0, :]) 1666 self.assertRaises(IndexError, lambda: a[0.0, :, :]) 1667 self.assertRaises(IndexError, lambda: a[0, 0, 0.0]) 1668 self.assertRaises(IndexError, lambda: a[0.0, 0, 0]) 1669 self.assertRaises(IndexError, lambda: a[0, 0.0, 0]) 1670 self.assertRaises(IndexError, lambda: a[-1.4]) 1671 self.assertRaises(IndexError, lambda: a[0, -1.4]) 1672 self.assertRaises(IndexError, lambda: a[-1.4, 0]) 1673 self.assertRaises(IndexError, lambda: a[-1.4, :]) 1674 self.assertRaises(IndexError, lambda: a[:, -1.4]) 1675 self.assertRaises(IndexError, lambda: a[:, -1.4, :]) 1676 self.assertRaises(IndexError, lambda: a[-1.4, :, :]) 1677 self.assertRaises(IndexError, lambda: a[0, 0, -1.4]) 1678 self.assertRaises(IndexError, lambda: a[-1.4, 0, 0]) 1679 self.assertRaises(IndexError, lambda: a[0, -1.4, 0]) 1680 # self.assertRaises(IndexError, lambda: a[0.0:, 0.0]) 1681 # self.assertRaises(IndexError, lambda: a[0.0:, 0.0,:]) 1682 1683 def test_none_index(self, device): 1684 # `None` index adds newaxis 1685 a = tensor([1, 2, 3], device=device) 1686 self.assertEqual(a[None].dim(), a.dim() + 1) 1687 1688 def test_empty_tuple_index(self, device): 1689 # Empty tuple index creates a view 1690 a = tensor([1, 2, 3], device=device) 1691 self.assertEqual(a[()], a) 1692 self.assertEqual(a[()].data_ptr(), a.data_ptr()) 1693 1694 def test_empty_fancy_index(self, device): 1695 # Empty list index creates an empty array 1696 a = tensor([1, 2, 3], device=device) 1697 self.assertEqual(a[[]], torch.tensor([], dtype=torch.long, device=device)) 1698 1699 b = tensor([], device=device).long() 1700 self.assertEqual(a[[]], torch.tensor([], dtype=torch.long, device=device)) 1701 1702 b = tensor([], device=device).float() 1703 self.assertRaises(IndexError, lambda: a[b]) 1704 1705 def test_ellipsis_index(self, device): 1706 a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device) 1707 self.assertIsNot(a[...], a) 1708 self.assertEqual(a[...], a) 1709 # `a[...]` was `a` in numpy <1.9. 1710 self.assertEqual(a[...].data_ptr(), a.data_ptr()) 1711 1712 # Slicing with ellipsis can skip an 1713 # arbitrary number of dimensions 1714 self.assertEqual(a[0, ...], a[0]) 1715 self.assertEqual(a[0, ...], a[0, :]) 1716 self.assertEqual(a[..., 0], a[:, 0]) 1717 1718 # In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch 1719 # we don't have separate 0-dim arrays and scalars. 1720 self.assertEqual(a[0, ..., 1], torch.tensor(2, device=device)) 1721 1722 # Assignment with `(Ellipsis,)` on 0-d arrays 1723 b = torch.tensor(1) 1724 b[(Ellipsis,)] = 2 1725 self.assertEqual(b, 2) 1726 1727 def test_single_int_index(self, device): 1728 # Single integer index selects one row 1729 a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device) 1730 1731 self.assertEqual(a[0], [1, 2, 3]) 1732 self.assertEqual(a[-1], [7, 8, 9]) 1733 1734 # Index out of bounds produces IndexError 1735 self.assertRaises(IndexError, a.__getitem__, 1 << 30) 1736 # Index overflow produces Exception NB: different exception type 1737 self.assertRaises(Exception, a.__getitem__, 1 << 64) 1738 1739 def test_single_bool_index(self, device): 1740 # Single boolean index 1741 a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device) 1742 1743 self.assertEqual(a[True], a[None]) 1744 self.assertEqual(a[False], a[None][0:0]) 1745 1746 def test_boolean_shape_mismatch(self, device): 1747 arr = torch.ones((5, 4, 3), device=device) 1748 1749 index = tensor([True], device=device) 1750 self.assertRaisesRegex(IndexError, "mask", lambda: arr[index]) 1751 1752 index = tensor([False] * 6, device=device) 1753 self.assertRaisesRegex(IndexError, "mask", lambda: arr[index]) 1754 1755 index = torch.ByteTensor(4, 4).to(device).zero_() 1756 self.assertRaisesRegex(IndexError, "mask", lambda: arr[index]) 1757 self.assertRaisesRegex(IndexError, "mask", lambda: arr[(slice(None), index)]) 1758 1759 def test_boolean_indexing_onedim(self, device): 1760 # Indexing a 2-dimensional array with 1761 # boolean array of length one 1762 a = tensor([[0.0, 0.0, 0.0]], device=device) 1763 b = tensor([True], device=device) 1764 self.assertEqual(a[b], a) 1765 # boolean assignment 1766 a[b] = 1.0 1767 self.assertEqual(a, tensor([[1.0, 1.0, 1.0]], device=device)) 1768 1769 # https://github.com/pytorch/pytorch/issues/127003 1770 @xfailIfTorchDynamo 1771 def test_boolean_assignment_value_mismatch(self, device): 1772 # A boolean assignment should fail when the shape of the values 1773 # cannot be broadcast to the subscription. (see also gh-3458) 1774 a = torch.arange(0, 4, device=device) 1775 1776 def f(a, v): 1777 a[a > -1] = tensor(v).to(device) 1778 1779 self.assertRaisesRegex(Exception, "shape mismatch", f, a, []) 1780 self.assertRaisesRegex(Exception, "shape mismatch", f, a, [1, 2, 3]) 1781 self.assertRaisesRegex(Exception, "shape mismatch", f, a[:1], [1, 2, 3]) 1782 1783 def test_boolean_indexing_twodim(self, device): 1784 # Indexing a 2-dimensional array with 1785 # 2-dimensional boolean array 1786 a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device) 1787 b = tensor( 1788 [[True, False, True], [False, True, False], [True, False, True]], 1789 device=device, 1790 ) 1791 self.assertEqual(a[b], tensor([1, 3, 5, 7, 9], device=device)) 1792 self.assertEqual(a[b[1]], tensor([[4, 5, 6]], device=device)) 1793 self.assertEqual(a[b[0]], a[b[2]]) 1794 1795 # boolean assignment 1796 a[b] = 0 1797 self.assertEqual(a, tensor([[0, 2, 0], [4, 0, 6], [0, 8, 0]], device=device)) 1798 1799 def test_boolean_indexing_weirdness(self, device): 1800 # Weird boolean indexing things 1801 a = torch.ones((2, 3, 4), device=device) 1802 self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape) 1803 self.assertEqual( 1804 torch.ones(1, 2, device=device), a[True, [0, 1], True, True, [1], [[2]]] 1805 ) 1806 self.assertRaises(IndexError, lambda: a[False, [0, 1], ...]) 1807 1808 def test_boolean_indexing_weirdness_tensors(self, device): 1809 # Weird boolean indexing things 1810 false = torch.tensor(False, device=device) 1811 true = torch.tensor(True, device=device) 1812 a = torch.ones((2, 3, 4), device=device) 1813 self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape) 1814 self.assertEqual( 1815 torch.ones(1, 2, device=device), a[true, [0, 1], true, true, [1], [[2]]] 1816 ) 1817 self.assertRaises(IndexError, lambda: a[false, [0, 1], ...]) 1818 1819 def test_boolean_indexing_alldims(self, device): 1820 true = torch.tensor(True, device=device) 1821 a = torch.ones((2, 3), device=device) 1822 self.assertEqual((1, 2, 3), a[True, True].shape) 1823 self.assertEqual((1, 2, 3), a[true, true].shape) 1824 1825 def test_boolean_list_indexing(self, device): 1826 # Indexing a 2-dimensional array with 1827 # boolean lists 1828 a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device) 1829 b = [True, False, False] 1830 c = [True, True, False] 1831 self.assertEqual(a[b], tensor([[1, 2, 3]], device=device)) 1832 self.assertEqual(a[b, b], tensor([1], device=device)) 1833 self.assertEqual(a[c], tensor([[1, 2, 3], [4, 5, 6]], device=device)) 1834 self.assertEqual(a[c, c], tensor([1, 5], device=device)) 1835 1836 def test_everything_returns_views(self, device): 1837 # Before `...` would return a itself. 1838 a = tensor([5], device=device) 1839 1840 self.assertIsNot(a, a[()]) 1841 self.assertIsNot(a, a[...]) 1842 self.assertIsNot(a, a[:]) 1843 1844 def test_broaderrors_indexing(self, device): 1845 a = torch.zeros(5, 5, device=device) 1846 self.assertRaisesRegex( 1847 IndexError, "shape mismatch", a.__getitem__, ([0, 1], [0, 1, 2]) 1848 ) 1849 self.assertRaisesRegex( 1850 IndexError, "shape mismatch", a.__setitem__, ([0, 1], [0, 1, 2]), 0 1851 ) 1852 1853 def test_trivial_fancy_out_of_bounds(self, device): 1854 a = torch.zeros(5, device=device) 1855 ind = torch.ones(20, dtype=torch.int64, device=device) 1856 if a.is_cuda: 1857 raise unittest.SkipTest("CUDA asserts instead of raising an exception") 1858 ind[-1] = 10 1859 self.assertRaises(IndexError, a.__getitem__, ind) 1860 self.assertRaises(IndexError, a.__setitem__, ind, 0) 1861 ind = torch.ones(20, dtype=torch.int64, device=device) 1862 ind[0] = 11 1863 self.assertRaises(IndexError, a.__getitem__, ind) 1864 self.assertRaises(IndexError, a.__setitem__, ind, 0) 1865 1866 def test_index_is_larger(self, device): 1867 # Simple case of fancy index broadcasting of the index. 1868 a = torch.zeros((5, 5), device=device) 1869 a[[[0], [1], [2]], [0, 1, 2]] = tensor([2.0, 3.0, 4.0], device=device) 1870 1871 self.assertTrue((a[:3, :3] == tensor([2.0, 3.0, 4.0], device=device)).all()) 1872 1873 def test_broadcast_subspace(self, device): 1874 a = torch.zeros((100, 100), device=device) 1875 v = torch.arange(0.0, 100, device=device)[:, None] 1876 b = torch.arange(99, -1, -1, device=device).long() 1877 a[b] = v 1878 expected = b.float().unsqueeze(1).expand(100, 100) 1879 self.assertEqual(a, expected) 1880 1881 def test_truncate_leading_1s(self, device): 1882 col_max = torch.randn(1, 4) 1883 kernel = col_max.T * col_max # [4, 4] tensor 1884 kernel2 = kernel.clone() 1885 # Set the diagonal 1886 kernel[range(len(kernel)), range(len(kernel))] = torch.square(col_max) 1887 torch.diagonal(kernel2).copy_(torch.square(col_max.view(4))) 1888 self.assertEqual(kernel, kernel2) 1889 1890 1891instantiate_device_type_tests(TestIndexing, globals(), except_for="meta") 1892instantiate_device_type_tests(NumpyTests, globals(), except_for="meta") 1893 1894if __name__ == "__main__": 1895 run_tests() 1896