1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport operator 4*da0073e9SAndroid Build Coastguard Workerimport random 5*da0073e9SAndroid Build Coastguard Workerimport unittest 6*da0073e9SAndroid Build Coastguard Workerimport warnings 7*da0073e9SAndroid Build Coastguard Workerfrom functools import reduce 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport numpy as np 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerimport torch 12*da0073e9SAndroid Build Coastguard Workerfrom torch import tensor 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 15*da0073e9SAndroid Build Coastguard Worker dtypes, 16*da0073e9SAndroid Build Coastguard Worker dtypesIfCPU, 17*da0073e9SAndroid Build Coastguard Worker dtypesIfCUDA, 18*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 19*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 20*da0073e9SAndroid Build Coastguard Worker onlyNativeDeviceTypes, 21*da0073e9SAndroid Build Coastguard Worker skipXLA, 22*da0073e9SAndroid Build Coastguard Worker) 23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 24*da0073e9SAndroid Build Coastguard Worker DeterministicGuard, 25*da0073e9SAndroid Build Coastguard Worker run_tests, 26*da0073e9SAndroid Build Coastguard Worker serialTest, 27*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 28*da0073e9SAndroid Build Coastguard Worker TEST_CUDA, 29*da0073e9SAndroid Build Coastguard Worker TestCase, 30*da0073e9SAndroid Build Coastguard Worker xfailIfTorchDynamo, 31*da0073e9SAndroid Build Coastguard Worker) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Workerclass TestIndexing(TestCase): 35*da0073e9SAndroid Build Coastguard Worker def test_index(self, device): 36*da0073e9SAndroid Build Coastguard Worker def consec(size, start=1): 37*da0073e9SAndroid Build Coastguard Worker sequence = torch.ones(torch.tensor(size).prod(0)).cumsum(0) 38*da0073e9SAndroid Build Coastguard Worker sequence.add_(start - 1) 39*da0073e9SAndroid Build Coastguard Worker return sequence.view(*size) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker reference = consec((3, 3, 3)).to(device) 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker # empty tensor indexing 44*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 45*da0073e9SAndroid Build Coastguard Worker reference[torch.LongTensor().to(device)], reference.new(0, 3, 3) 46*da0073e9SAndroid Build Coastguard Worker ) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[0], consec((3, 3)), atol=0, rtol=0) 49*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[1], consec((3, 3), 10), atol=0, rtol=0) 50*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[2], consec((3, 3), 19), atol=0, rtol=0) 51*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[0, 1], consec((3,), 4), atol=0, rtol=0) 52*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[0:2], consec((2, 3, 3)), atol=0, rtol=0) 53*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[2, 2, 2], 27, atol=0, rtol=0) 54*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[:], consec((3, 3, 3)), atol=0, rtol=0) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker # indexing with Ellipsis 57*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 58*da0073e9SAndroid Build Coastguard Worker reference[..., 2], 59*da0073e9SAndroid Build Coastguard Worker torch.tensor([[3.0, 6.0, 9.0], [12.0, 15.0, 18.0], [21.0, 24.0, 27.0]]), 60*da0073e9SAndroid Build Coastguard Worker atol=0, 61*da0073e9SAndroid Build Coastguard Worker rtol=0, 62*da0073e9SAndroid Build Coastguard Worker ) 63*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 64*da0073e9SAndroid Build Coastguard Worker reference[0, ..., 2], torch.tensor([3.0, 6.0, 9.0]), atol=0, rtol=0 65*da0073e9SAndroid Build Coastguard Worker ) 66*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[..., 2], reference[:, :, 2], atol=0, rtol=0) 67*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[0, ..., 2], reference[0, :, 2], atol=0, rtol=0) 68*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[0, 2, ...], reference[0, 2], atol=0, rtol=0) 69*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[..., 2, 2, 2], 27, atol=0, rtol=0) 70*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[2, ..., 2, 2], 27, atol=0, rtol=0) 71*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[2, 2, ..., 2], 27, atol=0, rtol=0) 72*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[2, 2, 2, ...], 27, atol=0, rtol=0) 73*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[...], reference, atol=0, rtol=0) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker reference_5d = consec((3, 3, 3, 3, 3)).to(device) 76*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 77*da0073e9SAndroid Build Coastguard Worker reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], atol=0, rtol=0 78*da0073e9SAndroid Build Coastguard Worker ) 79*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 80*da0073e9SAndroid Build Coastguard Worker reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], atol=0, rtol=0 81*da0073e9SAndroid Build Coastguard Worker ) 82*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 83*da0073e9SAndroid Build Coastguard Worker reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], atol=0, rtol=0 84*da0073e9SAndroid Build Coastguard Worker ) 85*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference_5d[...], reference_5d, atol=0, rtol=0) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker # LongTensor indexing 88*da0073e9SAndroid Build Coastguard Worker reference = consec((5, 5, 5)).to(device) 89*da0073e9SAndroid Build Coastguard Worker idx = torch.LongTensor([2, 4]).to(device) 90*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]])) 91*da0073e9SAndroid Build Coastguard Worker # TODO: enable one indexing is implemented like in numpy 92*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]])) 93*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1]) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker # None indexing 96*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[2, None], reference[2].unsqueeze(0)) 97*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 98*da0073e9SAndroid Build Coastguard Worker reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0) 99*da0073e9SAndroid Build Coastguard Worker ) 100*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[2:4, None], reference[2:4].unsqueeze(1)) 101*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 102*da0073e9SAndroid Build Coastguard Worker reference[None, 2, None, None], 103*da0073e9SAndroid Build Coastguard Worker reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0), 104*da0073e9SAndroid Build Coastguard Worker ) 105*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 106*da0073e9SAndroid Build Coastguard Worker reference[None, 2:5, None, None], 107*da0073e9SAndroid Build Coastguard Worker reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2), 108*da0073e9SAndroid Build Coastguard Worker ) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker # indexing 0-length slice 111*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, 5, 5), reference[slice(0)]) 112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, 5), reference[slice(0), 2]) 113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, 5), reference[2, slice(0)]) 114*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([]), reference[2, 1:1, 2]) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker # indexing with step 117*da0073e9SAndroid Build Coastguard Worker reference = consec((10, 10, 10)).to(device) 118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[1:5:2], torch.stack([reference[1], reference[3]], 0)) 119*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 120*da0073e9SAndroid Build Coastguard Worker reference[1:6:2], torch.stack([reference[1], reference[3], reference[5]], 0) 121*da0073e9SAndroid Build Coastguard Worker ) 122*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[1:9:4], torch.stack([reference[1], reference[5]], 0)) 123*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 124*da0073e9SAndroid Build Coastguard Worker reference[2:4, 1:5:2], 125*da0073e9SAndroid Build Coastguard Worker torch.stack([reference[2:4, 1], reference[2:4, 3]], 1), 126*da0073e9SAndroid Build Coastguard Worker ) 127*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 128*da0073e9SAndroid Build Coastguard Worker reference[3, 1:6:2], 129*da0073e9SAndroid Build Coastguard Worker torch.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0), 130*da0073e9SAndroid Build Coastguard Worker ) 131*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 132*da0073e9SAndroid Build Coastguard Worker reference[None, 2, 1:9:4], 133*da0073e9SAndroid Build Coastguard Worker torch.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0), 134*da0073e9SAndroid Build Coastguard Worker ) 135*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 136*da0073e9SAndroid Build Coastguard Worker reference[:, 2, 1:6:2], 137*da0073e9SAndroid Build Coastguard Worker torch.stack( 138*da0073e9SAndroid Build Coastguard Worker [reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1 139*da0073e9SAndroid Build Coastguard Worker ), 140*da0073e9SAndroid Build Coastguard Worker ) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker lst = [list(range(i, i + 10)) for i in range(0, 100, 10)] 143*da0073e9SAndroid Build Coastguard Worker tensor = torch.DoubleTensor(lst).to(device) 144*da0073e9SAndroid Build Coastguard Worker for _i in range(100): 145*da0073e9SAndroid Build Coastguard Worker idx1_start = random.randrange(10) 146*da0073e9SAndroid Build Coastguard Worker idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1) 147*da0073e9SAndroid Build Coastguard Worker idx1_step = random.randrange(1, 8) 148*da0073e9SAndroid Build Coastguard Worker idx1 = slice(idx1_start, idx1_end, idx1_step) 149*da0073e9SAndroid Build Coastguard Worker if random.randrange(2) == 0: 150*da0073e9SAndroid Build Coastguard Worker idx2_start = random.randrange(10) 151*da0073e9SAndroid Build Coastguard Worker idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1) 152*da0073e9SAndroid Build Coastguard Worker idx2_step = random.randrange(1, 8) 153*da0073e9SAndroid Build Coastguard Worker idx2 = slice(idx2_start, idx2_end, idx2_step) 154*da0073e9SAndroid Build Coastguard Worker lst_indexed = [l[idx2] for l in lst[idx1]] 155*da0073e9SAndroid Build Coastguard Worker tensor_indexed = tensor[idx1, idx2] 156*da0073e9SAndroid Build Coastguard Worker else: 157*da0073e9SAndroid Build Coastguard Worker lst_indexed = lst[idx1] 158*da0073e9SAndroid Build Coastguard Worker tensor_indexed = tensor[idx1] 159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: reference[1:9:0]) 162*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: reference[1:9:-1]) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1]) 165*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1]) 166*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3]) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: reference[0.0]) 169*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: reference[0.0:2.0]) 170*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0]) 171*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0]) 172*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0]) 173*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0]) 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker def delitem(): 176*da0073e9SAndroid Build Coastguard Worker del reference[0] 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, delitem) 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 181*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.double) 182*da0073e9SAndroid Build Coastguard Worker def test_advancedindex(self, device, dtype): 183*da0073e9SAndroid Build Coastguard Worker # Tests for Integer Array Indexing, Part I - Purely integer array 184*da0073e9SAndroid Build Coastguard Worker # indexing 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker def consec(size, start=1): 187*da0073e9SAndroid Build Coastguard Worker # Creates the sequence in float since CPU half doesn't support the 188*da0073e9SAndroid Build Coastguard Worker # needed operations. Converts to dtype before returning. 189*da0073e9SAndroid Build Coastguard Worker numel = reduce(operator.mul, size, 1) 190*da0073e9SAndroid Build Coastguard Worker sequence = torch.ones(numel, dtype=torch.float, device=device).cumsum(0) 191*da0073e9SAndroid Build Coastguard Worker sequence.add_(start - 1) 192*da0073e9SAndroid Build Coastguard Worker return sequence.view(*size).to(dtype=dtype) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker # pick a random valid indexer type 195*da0073e9SAndroid Build Coastguard Worker def ri(indices): 196*da0073e9SAndroid Build Coastguard Worker choice = random.randint(0, 2) 197*da0073e9SAndroid Build Coastguard Worker if choice == 0: 198*da0073e9SAndroid Build Coastguard Worker return torch.LongTensor(indices).to(device) 199*da0073e9SAndroid Build Coastguard Worker elif choice == 1: 200*da0073e9SAndroid Build Coastguard Worker return list(indices) 201*da0073e9SAndroid Build Coastguard Worker else: 202*da0073e9SAndroid Build Coastguard Worker return tuple(indices) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker def validate_indexing(x): 205*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[[0]], consec((1,))) 206*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[ri([0]),], consec((1,))) 207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[ri([3]),], consec((1,), 4)) 208*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[[2, 3, 4]], consec((3,), 3)) 209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[ri([2, 3, 4]),], consec((3,), 3)) 210*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 211*da0073e9SAndroid Build Coastguard Worker x[ri([0, 2, 4]),], torch.tensor([1, 3, 5], dtype=dtype, device=device) 212*da0073e9SAndroid Build Coastguard Worker ) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker def validate_setting(x): 215*da0073e9SAndroid Build Coastguard Worker x[[0]] = -2 216*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[[0]], torch.tensor([-2], dtype=dtype, device=device)) 217*da0073e9SAndroid Build Coastguard Worker x[[0]] = -1 218*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 219*da0073e9SAndroid Build Coastguard Worker x[ri([0]),], torch.tensor([-1], dtype=dtype, device=device) 220*da0073e9SAndroid Build Coastguard Worker ) 221*da0073e9SAndroid Build Coastguard Worker x[[2, 3, 4]] = 4 222*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 223*da0073e9SAndroid Build Coastguard Worker x[[2, 3, 4]], torch.tensor([4, 4, 4], dtype=dtype, device=device) 224*da0073e9SAndroid Build Coastguard Worker ) 225*da0073e9SAndroid Build Coastguard Worker x[ri([2, 3, 4]),] = 3 226*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 227*da0073e9SAndroid Build Coastguard Worker x[ri([2, 3, 4]),], torch.tensor([3, 3, 3], dtype=dtype, device=device) 228*da0073e9SAndroid Build Coastguard Worker ) 229*da0073e9SAndroid Build Coastguard Worker x[ri([0, 2, 4]),] = torch.tensor([5, 4, 3], dtype=dtype, device=device) 230*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 231*da0073e9SAndroid Build Coastguard Worker x[ri([0, 2, 4]),], torch.tensor([5, 4, 3], dtype=dtype, device=device) 232*da0073e9SAndroid Build Coastguard Worker ) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker # Only validates indexing and setting for halfs 235*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half: 236*da0073e9SAndroid Build Coastguard Worker reference = consec((10,)) 237*da0073e9SAndroid Build Coastguard Worker validate_indexing(reference) 238*da0073e9SAndroid Build Coastguard Worker validate_setting(reference) 239*da0073e9SAndroid Build Coastguard Worker return 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker # Case 1: Purely Integer Array Indexing 242*da0073e9SAndroid Build Coastguard Worker reference = consec((10,)) 243*da0073e9SAndroid Build Coastguard Worker validate_indexing(reference) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker # setting values 246*da0073e9SAndroid Build Coastguard Worker validate_setting(reference) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker # Tensor with stride != 1 249*da0073e9SAndroid Build Coastguard Worker # strided is [1, 3, 5, 7] 250*da0073e9SAndroid Build Coastguard Worker reference = consec((10,)) 251*da0073e9SAndroid Build Coastguard Worker strided = torch.tensor((), dtype=dtype, device=device) 252*da0073e9SAndroid Build Coastguard Worker strided.set_( 253*da0073e9SAndroid Build Coastguard Worker reference.storage(), storage_offset=0, size=torch.Size([4]), stride=[2] 254*da0073e9SAndroid Build Coastguard Worker ) 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strided[[0]], torch.tensor([1], dtype=dtype, device=device)) 257*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 258*da0073e9SAndroid Build Coastguard Worker strided[ri([0]),], torch.tensor([1], dtype=dtype, device=device) 259*da0073e9SAndroid Build Coastguard Worker ) 260*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 261*da0073e9SAndroid Build Coastguard Worker strided[ri([3]),], torch.tensor([7], dtype=dtype, device=device) 262*da0073e9SAndroid Build Coastguard Worker ) 263*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 264*da0073e9SAndroid Build Coastguard Worker strided[[1, 2]], torch.tensor([3, 5], dtype=dtype, device=device) 265*da0073e9SAndroid Build Coastguard Worker ) 266*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 267*da0073e9SAndroid Build Coastguard Worker strided[ri([1, 2]),], torch.tensor([3, 5], dtype=dtype, device=device) 268*da0073e9SAndroid Build Coastguard Worker ) 269*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 270*da0073e9SAndroid Build Coastguard Worker strided[ri([[2, 1], [0, 3]]),], 271*da0073e9SAndroid Build Coastguard Worker torch.tensor([[5, 3], [1, 7]], dtype=dtype, device=device), 272*da0073e9SAndroid Build Coastguard Worker ) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker # stride is [4, 8] 275*da0073e9SAndroid Build Coastguard Worker strided = torch.tensor((), dtype=dtype, device=device) 276*da0073e9SAndroid Build Coastguard Worker strided.set_( 277*da0073e9SAndroid Build Coastguard Worker reference.storage(), storage_offset=4, size=torch.Size([2]), stride=[4] 278*da0073e9SAndroid Build Coastguard Worker ) 279*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strided[[0]], torch.tensor([5], dtype=dtype, device=device)) 280*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 281*da0073e9SAndroid Build Coastguard Worker strided[ri([0]),], torch.tensor([5], dtype=dtype, device=device) 282*da0073e9SAndroid Build Coastguard Worker ) 283*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 284*da0073e9SAndroid Build Coastguard Worker strided[ri([1]),], torch.tensor([9], dtype=dtype, device=device) 285*da0073e9SAndroid Build Coastguard Worker ) 286*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 287*da0073e9SAndroid Build Coastguard Worker strided[[0, 1]], torch.tensor([5, 9], dtype=dtype, device=device) 288*da0073e9SAndroid Build Coastguard Worker ) 289*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 290*da0073e9SAndroid Build Coastguard Worker strided[ri([0, 1]),], torch.tensor([5, 9], dtype=dtype, device=device) 291*da0073e9SAndroid Build Coastguard Worker ) 292*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 293*da0073e9SAndroid Build Coastguard Worker strided[ri([[0, 1], [1, 0]]),], 294*da0073e9SAndroid Build Coastguard Worker torch.tensor([[5, 9], [9, 5]], dtype=dtype, device=device), 295*da0073e9SAndroid Build Coastguard Worker ) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker # reference is 1 2 298*da0073e9SAndroid Build Coastguard Worker # 3 4 299*da0073e9SAndroid Build Coastguard Worker # 5 6 300*da0073e9SAndroid Build Coastguard Worker reference = consec((3, 2)) 301*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 302*da0073e9SAndroid Build Coastguard Worker reference[ri([0, 1, 2]), ri([0])], 303*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 3, 5], dtype=dtype, device=device), 304*da0073e9SAndroid Build Coastguard Worker ) 305*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 306*da0073e9SAndroid Build Coastguard Worker reference[ri([0, 1, 2]), ri([1])], 307*da0073e9SAndroid Build Coastguard Worker torch.tensor([2, 4, 6], dtype=dtype, device=device), 308*da0073e9SAndroid Build Coastguard Worker ) 309*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[ri([0]), ri([0])], consec((1,))) 310*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6)) 311*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 312*da0073e9SAndroid Build Coastguard Worker reference[[ri([0, 0]), ri([0, 1])]], 313*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 2], dtype=dtype, device=device), 314*da0073e9SAndroid Build Coastguard Worker ) 315*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 316*da0073e9SAndroid Build Coastguard Worker reference[[ri([0, 1, 1, 0, 2]), ri([1])]], 317*da0073e9SAndroid Build Coastguard Worker torch.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device), 318*da0073e9SAndroid Build Coastguard Worker ) 319*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 320*da0073e9SAndroid Build Coastguard Worker reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], 321*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 2, 3, 3], dtype=dtype, device=device), 322*da0073e9SAndroid Build Coastguard Worker ) 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker rows = ri([[0, 0], [1, 2]]) 325*da0073e9SAndroid Build Coastguard Worker columns = ([0],) 326*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 327*da0073e9SAndroid Build Coastguard Worker reference[rows, columns], 328*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 1], [3, 5]], dtype=dtype, device=device), 329*da0073e9SAndroid Build Coastguard Worker ) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker rows = ri([[0, 0], [1, 2]]) 332*da0073e9SAndroid Build Coastguard Worker columns = ri([1, 0]) 333*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 334*da0073e9SAndroid Build Coastguard Worker reference[rows, columns], 335*da0073e9SAndroid Build Coastguard Worker torch.tensor([[2, 1], [4, 5]], dtype=dtype, device=device), 336*da0073e9SAndroid Build Coastguard Worker ) 337*da0073e9SAndroid Build Coastguard Worker rows = ri([[0, 0], [1, 2]]) 338*da0073e9SAndroid Build Coastguard Worker columns = ri([[0, 1], [1, 0]]) 339*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 340*da0073e9SAndroid Build Coastguard Worker reference[rows, columns], 341*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device), 342*da0073e9SAndroid Build Coastguard Worker ) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker # setting values 345*da0073e9SAndroid Build Coastguard Worker reference[ri([0]), ri([1])] = -1 346*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 347*da0073e9SAndroid Build Coastguard Worker reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device) 348*da0073e9SAndroid Build Coastguard Worker ) 349*da0073e9SAndroid Build Coastguard Worker reference[ri([0, 1, 2]), ri([0])] = torch.tensor( 350*da0073e9SAndroid Build Coastguard Worker [-1, 2, -4], dtype=dtype, device=device 351*da0073e9SAndroid Build Coastguard Worker ) 352*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 353*da0073e9SAndroid Build Coastguard Worker reference[ri([0, 1, 2]), ri([0])], 354*da0073e9SAndroid Build Coastguard Worker torch.tensor([-1, 2, -4], dtype=dtype, device=device), 355*da0073e9SAndroid Build Coastguard Worker ) 356*da0073e9SAndroid Build Coastguard Worker reference[rows, columns] = torch.tensor( 357*da0073e9SAndroid Build Coastguard Worker [[4, 6], [2, 3]], dtype=dtype, device=device 358*da0073e9SAndroid Build Coastguard Worker ) 359*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 360*da0073e9SAndroid Build Coastguard Worker reference[rows, columns], 361*da0073e9SAndroid Build Coastguard Worker torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device), 362*da0073e9SAndroid Build Coastguard Worker ) 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker # Verify still works with Transposed (i.e. non-contiguous) Tensors 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker reference = torch.tensor( 367*da0073e9SAndroid Build Coastguard Worker [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype, device=device 368*da0073e9SAndroid Build Coastguard Worker ).t_() 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker # Transposed: [[0, 4, 8], 371*da0073e9SAndroid Build Coastguard Worker # [1, 5, 9], 372*da0073e9SAndroid Build Coastguard Worker # [2, 6, 10], 373*da0073e9SAndroid Build Coastguard Worker # [3, 7, 11]] 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 376*da0073e9SAndroid Build Coastguard Worker reference[ri([0, 1, 2]), ri([0])], 377*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 1, 2], dtype=dtype, device=device), 378*da0073e9SAndroid Build Coastguard Worker ) 379*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 380*da0073e9SAndroid Build Coastguard Worker reference[ri([0, 1, 2]), ri([1])], 381*da0073e9SAndroid Build Coastguard Worker torch.tensor([4, 5, 6], dtype=dtype, device=device), 382*da0073e9SAndroid Build Coastguard Worker ) 383*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 384*da0073e9SAndroid Build Coastguard Worker reference[ri([0]), ri([0])], torch.tensor([0], dtype=dtype, device=device) 385*da0073e9SAndroid Build Coastguard Worker ) 386*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 387*da0073e9SAndroid Build Coastguard Worker reference[ri([2]), ri([1])], torch.tensor([6], dtype=dtype, device=device) 388*da0073e9SAndroid Build Coastguard Worker ) 389*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 390*da0073e9SAndroid Build Coastguard Worker reference[[ri([0, 0]), ri([0, 1])]], 391*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 4], dtype=dtype, device=device), 392*da0073e9SAndroid Build Coastguard Worker ) 393*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 394*da0073e9SAndroid Build Coastguard Worker reference[[ri([0, 1, 1, 0, 3]), ri([1])]], 395*da0073e9SAndroid Build Coastguard Worker torch.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device), 396*da0073e9SAndroid Build Coastguard Worker ) 397*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 398*da0073e9SAndroid Build Coastguard Worker reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], 399*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 4, 1, 1], dtype=dtype, device=device), 400*da0073e9SAndroid Build Coastguard Worker ) 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker rows = ri([[0, 0], [1, 2]]) 403*da0073e9SAndroid Build Coastguard Worker columns = ([0],) 404*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 405*da0073e9SAndroid Build Coastguard Worker reference[rows, columns], 406*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0, 0], [1, 2]], dtype=dtype, device=device), 407*da0073e9SAndroid Build Coastguard Worker ) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker rows = ri([[0, 0], [1, 2]]) 410*da0073e9SAndroid Build Coastguard Worker columns = ri([1, 0]) 411*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 412*da0073e9SAndroid Build Coastguard Worker reference[rows, columns], 413*da0073e9SAndroid Build Coastguard Worker torch.tensor([[4, 0], [5, 2]], dtype=dtype, device=device), 414*da0073e9SAndroid Build Coastguard Worker ) 415*da0073e9SAndroid Build Coastguard Worker rows = ri([[0, 0], [1, 3]]) 416*da0073e9SAndroid Build Coastguard Worker columns = ri([[0, 1], [1, 2]]) 417*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 418*da0073e9SAndroid Build Coastguard Worker reference[rows, columns], 419*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0, 4], [5, 11]], dtype=dtype, device=device), 420*da0073e9SAndroid Build Coastguard Worker ) 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker # setting values 423*da0073e9SAndroid Build Coastguard Worker reference[ri([0]), ri([1])] = -1 424*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 425*da0073e9SAndroid Build Coastguard Worker reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device) 426*da0073e9SAndroid Build Coastguard Worker ) 427*da0073e9SAndroid Build Coastguard Worker reference[ri([0, 1, 2]), ri([0])] = torch.tensor( 428*da0073e9SAndroid Build Coastguard Worker [-1, 2, -4], dtype=dtype, device=device 429*da0073e9SAndroid Build Coastguard Worker ) 430*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 431*da0073e9SAndroid Build Coastguard Worker reference[ri([0, 1, 2]), ri([0])], 432*da0073e9SAndroid Build Coastguard Worker torch.tensor([-1, 2, -4], dtype=dtype, device=device), 433*da0073e9SAndroid Build Coastguard Worker ) 434*da0073e9SAndroid Build Coastguard Worker reference[rows, columns] = torch.tensor( 435*da0073e9SAndroid Build Coastguard Worker [[4, 6], [2, 3]], dtype=dtype, device=device 436*da0073e9SAndroid Build Coastguard Worker ) 437*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 438*da0073e9SAndroid Build Coastguard Worker reference[rows, columns], 439*da0073e9SAndroid Build Coastguard Worker torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device), 440*da0073e9SAndroid Build Coastguard Worker ) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker # stride != 1 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker # strided is [[1 3 5 7], 445*da0073e9SAndroid Build Coastguard Worker # [9 11 13 15]] 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) 448*da0073e9SAndroid Build Coastguard Worker strided = torch.tensor((), dtype=dtype, device=device) 449*da0073e9SAndroid Build Coastguard Worker strided.set_(reference.storage(), 1, size=torch.Size([2, 4]), stride=[8, 2]) 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 452*da0073e9SAndroid Build Coastguard Worker strided[ri([0, 1]), ri([0])], 453*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 9], dtype=dtype, device=device), 454*da0073e9SAndroid Build Coastguard Worker ) 455*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 456*da0073e9SAndroid Build Coastguard Worker strided[ri([0, 1]), ri([1])], 457*da0073e9SAndroid Build Coastguard Worker torch.tensor([3, 11], dtype=dtype, device=device), 458*da0073e9SAndroid Build Coastguard Worker ) 459*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 460*da0073e9SAndroid Build Coastguard Worker strided[ri([0]), ri([0])], torch.tensor([1], dtype=dtype, device=device) 461*da0073e9SAndroid Build Coastguard Worker ) 462*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 463*da0073e9SAndroid Build Coastguard Worker strided[ri([1]), ri([3])], torch.tensor([15], dtype=dtype, device=device) 464*da0073e9SAndroid Build Coastguard Worker ) 465*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 466*da0073e9SAndroid Build Coastguard Worker strided[[ri([0, 0]), ri([0, 3])]], 467*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 7], dtype=dtype, device=device), 468*da0073e9SAndroid Build Coastguard Worker ) 469*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 470*da0073e9SAndroid Build Coastguard Worker strided[[ri([1]), ri([0, 1, 1, 0, 3])]], 471*da0073e9SAndroid Build Coastguard Worker torch.tensor([9, 11, 11, 9, 15], dtype=dtype, device=device), 472*da0073e9SAndroid Build Coastguard Worker ) 473*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 474*da0073e9SAndroid Build Coastguard Worker strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], 475*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 3, 9, 9], dtype=dtype, device=device), 476*da0073e9SAndroid Build Coastguard Worker ) 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker rows = ri([[0, 0], [1, 1]]) 479*da0073e9SAndroid Build Coastguard Worker columns = ([0],) 480*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 481*da0073e9SAndroid Build Coastguard Worker strided[rows, columns], 482*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 1], [9, 9]], dtype=dtype, device=device), 483*da0073e9SAndroid Build Coastguard Worker ) 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker rows = ri([[0, 1], [1, 0]]) 486*da0073e9SAndroid Build Coastguard Worker columns = ri([1, 2]) 487*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 488*da0073e9SAndroid Build Coastguard Worker strided[rows, columns], 489*da0073e9SAndroid Build Coastguard Worker torch.tensor([[3, 13], [11, 5]], dtype=dtype, device=device), 490*da0073e9SAndroid Build Coastguard Worker ) 491*da0073e9SAndroid Build Coastguard Worker rows = ri([[0, 0], [1, 1]]) 492*da0073e9SAndroid Build Coastguard Worker columns = ri([[0, 1], [1, 2]]) 493*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 494*da0073e9SAndroid Build Coastguard Worker strided[rows, columns], 495*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 3], [11, 13]], dtype=dtype, device=device), 496*da0073e9SAndroid Build Coastguard Worker ) 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker # setting values 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker # strided is [[10, 11], 501*da0073e9SAndroid Build Coastguard Worker # [17, 18]] 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) 504*da0073e9SAndroid Build Coastguard Worker strided = torch.tensor((), dtype=dtype, device=device) 505*da0073e9SAndroid Build Coastguard Worker strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) 506*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 507*da0073e9SAndroid Build Coastguard Worker strided[ri([0]), ri([1])], torch.tensor([11], dtype=dtype, device=device) 508*da0073e9SAndroid Build Coastguard Worker ) 509*da0073e9SAndroid Build Coastguard Worker strided[ri([0]), ri([1])] = -1 510*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 511*da0073e9SAndroid Build Coastguard Worker strided[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device) 512*da0073e9SAndroid Build Coastguard Worker ) 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) 515*da0073e9SAndroid Build Coastguard Worker strided = torch.tensor((), dtype=dtype, device=device) 516*da0073e9SAndroid Build Coastguard Worker strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) 517*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 518*da0073e9SAndroid Build Coastguard Worker strided[ri([0, 1]), ri([1, 0])], 519*da0073e9SAndroid Build Coastguard Worker torch.tensor([11, 17], dtype=dtype, device=device), 520*da0073e9SAndroid Build Coastguard Worker ) 521*da0073e9SAndroid Build Coastguard Worker strided[ri([0, 1]), ri([1, 0])] = torch.tensor( 522*da0073e9SAndroid Build Coastguard Worker [-1, 2], dtype=dtype, device=device 523*da0073e9SAndroid Build Coastguard Worker ) 524*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 525*da0073e9SAndroid Build Coastguard Worker strided[ri([0, 1]), ri([1, 0])], 526*da0073e9SAndroid Build Coastguard Worker torch.tensor([-1, 2], dtype=dtype, device=device), 527*da0073e9SAndroid Build Coastguard Worker ) 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) 530*da0073e9SAndroid Build Coastguard Worker strided = torch.tensor((), dtype=dtype, device=device) 531*da0073e9SAndroid Build Coastguard Worker strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker rows = ri([[0], [1]]) 534*da0073e9SAndroid Build Coastguard Worker columns = ri([[0, 1], [0, 1]]) 535*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 536*da0073e9SAndroid Build Coastguard Worker strided[rows, columns], 537*da0073e9SAndroid Build Coastguard Worker torch.tensor([[10, 11], [17, 18]], dtype=dtype, device=device), 538*da0073e9SAndroid Build Coastguard Worker ) 539*da0073e9SAndroid Build Coastguard Worker strided[rows, columns] = torch.tensor( 540*da0073e9SAndroid Build Coastguard Worker [[4, 6], [2, 3]], dtype=dtype, device=device 541*da0073e9SAndroid Build Coastguard Worker ) 542*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 543*da0073e9SAndroid Build Coastguard Worker strided[rows, columns], 544*da0073e9SAndroid Build Coastguard Worker torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device), 545*da0073e9SAndroid Build Coastguard Worker ) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker # Tests using less than the number of dims, and ellipsis 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Worker # reference is 1 2 550*da0073e9SAndroid Build Coastguard Worker # 3 4 551*da0073e9SAndroid Build Coastguard Worker # 5 6 552*da0073e9SAndroid Build Coastguard Worker reference = consec((3, 2)) 553*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 554*da0073e9SAndroid Build Coastguard Worker reference[ri([0, 2]),], 555*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device), 556*da0073e9SAndroid Build Coastguard Worker ) 557*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 558*da0073e9SAndroid Build Coastguard Worker reference[ri([1]), ...], torch.tensor([[3, 4]], dtype=dtype, device=device) 559*da0073e9SAndroid Build Coastguard Worker ) 560*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 561*da0073e9SAndroid Build Coastguard Worker reference[..., ri([1])], 562*da0073e9SAndroid Build Coastguard Worker torch.tensor([[2], [4], [6]], dtype=dtype, device=device), 563*da0073e9SAndroid Build Coastguard Worker ) 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker # verify too many indices fails 566*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 567*da0073e9SAndroid Build Coastguard Worker reference[ri([1]), ri([0, 2]), ri([3])] 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker # test invalid index fails 570*da0073e9SAndroid Build Coastguard Worker reference = torch.empty(10, dtype=dtype, device=device) 571*da0073e9SAndroid Build Coastguard Worker # can't test cuda because it is a device assert 572*da0073e9SAndroid Build Coastguard Worker if not reference.is_cuda: 573*da0073e9SAndroid Build Coastguard Worker for err_idx in (10, -11): 574*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, r"out of"): 575*da0073e9SAndroid Build Coastguard Worker reference[err_idx] 576*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, r"out of"): 577*da0073e9SAndroid Build Coastguard Worker reference[torch.LongTensor([err_idx]).to(device)] 578*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, r"out of"): 579*da0073e9SAndroid Build Coastguard Worker reference[[err_idx]] 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker def tensor_indices_to_np(tensor, indices): 582*da0073e9SAndroid Build Coastguard Worker # convert the Torch Tensor to a numpy array 583*da0073e9SAndroid Build Coastguard Worker tensor = tensor.to(device="cpu") 584*da0073e9SAndroid Build Coastguard Worker npt = tensor.numpy() 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker # convert indices 587*da0073e9SAndroid Build Coastguard Worker idxs = tuple( 588*da0073e9SAndroid Build Coastguard Worker i.tolist() if isinstance(i, torch.LongTensor) else i for i in indices 589*da0073e9SAndroid Build Coastguard Worker ) 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker return npt, idxs 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard Worker def get_numpy(tensor, indices): 594*da0073e9SAndroid Build Coastguard Worker npt, idxs = tensor_indices_to_np(tensor, indices) 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker # index and return as a Torch Tensor 597*da0073e9SAndroid Build Coastguard Worker return torch.tensor(npt[idxs], dtype=dtype, device=device) 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker def set_numpy(tensor, indices, value): 600*da0073e9SAndroid Build Coastguard Worker if not isinstance(value, int): 601*da0073e9SAndroid Build Coastguard Worker if self.device_type != "cpu": 602*da0073e9SAndroid Build Coastguard Worker value = value.cpu() 603*da0073e9SAndroid Build Coastguard Worker value = value.numpy() 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker npt, idxs = tensor_indices_to_np(tensor, indices) 606*da0073e9SAndroid Build Coastguard Worker npt[idxs] = value 607*da0073e9SAndroid Build Coastguard Worker return npt 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker def assert_get_eq(tensor, indexer): 610*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor[indexer], get_numpy(tensor, indexer)) 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker def assert_set_eq(tensor, indexer, val): 613*da0073e9SAndroid Build Coastguard Worker pyt = tensor.clone() 614*da0073e9SAndroid Build Coastguard Worker numt = tensor.clone() 615*da0073e9SAndroid Build Coastguard Worker pyt[indexer] = val 616*da0073e9SAndroid Build Coastguard Worker numt = torch.tensor( 617*da0073e9SAndroid Build Coastguard Worker set_numpy(numt, indexer, val), dtype=dtype, device=device 618*da0073e9SAndroid Build Coastguard Worker ) 619*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pyt, numt) 620*da0073e9SAndroid Build Coastguard Worker 621*da0073e9SAndroid Build Coastguard Worker def assert_backward_eq(tensor, indexer): 622*da0073e9SAndroid Build Coastguard Worker cpu = tensor.float().clone().detach().requires_grad_(True) 623*da0073e9SAndroid Build Coastguard Worker outcpu = cpu[indexer] 624*da0073e9SAndroid Build Coastguard Worker gOcpu = torch.rand_like(outcpu) 625*da0073e9SAndroid Build Coastguard Worker outcpu.backward(gOcpu) 626*da0073e9SAndroid Build Coastguard Worker dev = cpu.to(device).detach().requires_grad_(True) 627*da0073e9SAndroid Build Coastguard Worker outdev = dev[indexer] 628*da0073e9SAndroid Build Coastguard Worker outdev.backward(gOcpu.to(device)) 629*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu.grad, dev.grad) 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker def get_set_tensor(indexed, indexer): 632*da0073e9SAndroid Build Coastguard Worker set_size = indexed[indexer].size() 633*da0073e9SAndroid Build Coastguard Worker set_count = indexed[indexer].numel() 634*da0073e9SAndroid Build Coastguard Worker set_tensor = torch.randperm(set_count).view(set_size).double().to(device) 635*da0073e9SAndroid Build Coastguard Worker return set_tensor 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker # Tensor is 0 1 2 3 4 638*da0073e9SAndroid Build Coastguard Worker # 5 6 7 8 9 639*da0073e9SAndroid Build Coastguard Worker # 10 11 12 13 14 640*da0073e9SAndroid Build Coastguard Worker # 15 16 17 18 19 641*da0073e9SAndroid Build Coastguard Worker reference = torch.arange(0.0, 20, dtype=dtype, device=device).view(4, 5) 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker indices_to_test = [ 644*da0073e9SAndroid Build Coastguard Worker # grab the second, fourth columns 645*da0073e9SAndroid Build Coastguard Worker [slice(None), [1, 3]], 646*da0073e9SAndroid Build Coastguard Worker # first, third rows, 647*da0073e9SAndroid Build Coastguard Worker [[0, 2], slice(None)], 648*da0073e9SAndroid Build Coastguard Worker # weird shape 649*da0073e9SAndroid Build Coastguard Worker [slice(None), [[0, 1], [2, 3]]], 650*da0073e9SAndroid Build Coastguard Worker # negatives 651*da0073e9SAndroid Build Coastguard Worker [[-1], [0]], 652*da0073e9SAndroid Build Coastguard Worker [[0, 2], [-1]], 653*da0073e9SAndroid Build Coastguard Worker [slice(None), [-1]], 654*da0073e9SAndroid Build Coastguard Worker ] 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker # only test dupes on gets 657*da0073e9SAndroid Build Coastguard Worker get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]] 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker for indexer in get_indices_to_test: 660*da0073e9SAndroid Build Coastguard Worker assert_get_eq(reference, indexer) 661*da0073e9SAndroid Build Coastguard Worker if self.device_type != "cpu": 662*da0073e9SAndroid Build Coastguard Worker assert_backward_eq(reference, indexer) 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker for indexer in indices_to_test: 665*da0073e9SAndroid Build Coastguard Worker assert_set_eq(reference, indexer, 44) 666*da0073e9SAndroid Build Coastguard Worker assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Worker reference = torch.arange(0.0, 160, dtype=dtype, device=device).view(4, 8, 5) 669*da0073e9SAndroid Build Coastguard Worker 670*da0073e9SAndroid Build Coastguard Worker indices_to_test = [ 671*da0073e9SAndroid Build Coastguard Worker [slice(None), slice(None), [0, 3, 4]], 672*da0073e9SAndroid Build Coastguard Worker [slice(None), [2, 4, 5, 7], slice(None)], 673*da0073e9SAndroid Build Coastguard Worker [[2, 3], slice(None), slice(None)], 674*da0073e9SAndroid Build Coastguard Worker [slice(None), [0, 2, 3], [1, 3, 4]], 675*da0073e9SAndroid Build Coastguard Worker [slice(None), [0], [1, 2, 4]], 676*da0073e9SAndroid Build Coastguard Worker [slice(None), [0, 1, 3], [4]], 677*da0073e9SAndroid Build Coastguard Worker [slice(None), [[0, 1], [1, 0]], [[2, 3]]], 678*da0073e9SAndroid Build Coastguard Worker [slice(None), [[0, 1], [2, 3]], [[0]]], 679*da0073e9SAndroid Build Coastguard Worker [slice(None), [[5, 6]], [[0, 3], [4, 4]]], 680*da0073e9SAndroid Build Coastguard Worker [[0, 2, 3], [1, 3, 4], slice(None)], 681*da0073e9SAndroid Build Coastguard Worker [[0], [1, 2, 4], slice(None)], 682*da0073e9SAndroid Build Coastguard Worker [[0, 1, 3], [4], slice(None)], 683*da0073e9SAndroid Build Coastguard Worker [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], 684*da0073e9SAndroid Build Coastguard Worker [[[0, 1], [1, 0]], [[2, 3]], slice(None)], 685*da0073e9SAndroid Build Coastguard Worker [[[0, 1], [2, 3]], [[0]], slice(None)], 686*da0073e9SAndroid Build Coastguard Worker [[[2, 1]], [[0, 3], [4, 4]], slice(None)], 687*da0073e9SAndroid Build Coastguard Worker [[[2]], [[0, 3], [4, 1]], slice(None)], 688*da0073e9SAndroid Build Coastguard Worker # non-contiguous indexing subspace 689*da0073e9SAndroid Build Coastguard Worker [[0, 2, 3], slice(None), [1, 3, 4]], 690*da0073e9SAndroid Build Coastguard Worker # [...] 691*da0073e9SAndroid Build Coastguard Worker # less dim, ellipsis 692*da0073e9SAndroid Build Coastguard Worker [[0, 2]], 693*da0073e9SAndroid Build Coastguard Worker [[0, 2], slice(None)], 694*da0073e9SAndroid Build Coastguard Worker [[0, 2], Ellipsis], 695*da0073e9SAndroid Build Coastguard Worker [[0, 2], slice(None), Ellipsis], 696*da0073e9SAndroid Build Coastguard Worker [[0, 2], Ellipsis, slice(None)], 697*da0073e9SAndroid Build Coastguard Worker [[0, 2], [1, 3]], 698*da0073e9SAndroid Build Coastguard Worker [[0, 2], [1, 3], Ellipsis], 699*da0073e9SAndroid Build Coastguard Worker [Ellipsis, [1, 3], [2, 3]], 700*da0073e9SAndroid Build Coastguard Worker [Ellipsis, [2, 3, 4]], 701*da0073e9SAndroid Build Coastguard Worker [Ellipsis, slice(None), [2, 3, 4]], 702*da0073e9SAndroid Build Coastguard Worker [slice(None), Ellipsis, [2, 3, 4]], 703*da0073e9SAndroid Build Coastguard Worker # ellipsis counts for nothing 704*da0073e9SAndroid Build Coastguard Worker [Ellipsis, slice(None), slice(None), [0, 3, 4]], 705*da0073e9SAndroid Build Coastguard Worker [slice(None), Ellipsis, slice(None), [0, 3, 4]], 706*da0073e9SAndroid Build Coastguard Worker [slice(None), slice(None), Ellipsis, [0, 3, 4]], 707*da0073e9SAndroid Build Coastguard Worker [slice(None), slice(None), [0, 3, 4], Ellipsis], 708*da0073e9SAndroid Build Coastguard Worker [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], 709*da0073e9SAndroid Build Coastguard Worker [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)], 710*da0073e9SAndroid Build Coastguard Worker [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis], 711*da0073e9SAndroid Build Coastguard Worker ] 712*da0073e9SAndroid Build Coastguard Worker 713*da0073e9SAndroid Build Coastguard Worker for indexer in indices_to_test: 714*da0073e9SAndroid Build Coastguard Worker assert_get_eq(reference, indexer) 715*da0073e9SAndroid Build Coastguard Worker assert_set_eq(reference, indexer, 212) 716*da0073e9SAndroid Build Coastguard Worker assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) 717*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 718*da0073e9SAndroid Build Coastguard Worker assert_backward_eq(reference, indexer) 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker reference = torch.arange(0.0, 1296, dtype=dtype, device=device).view(3, 9, 8, 6) 721*da0073e9SAndroid Build Coastguard Worker 722*da0073e9SAndroid Build Coastguard Worker indices_to_test = [ 723*da0073e9SAndroid Build Coastguard Worker [slice(None), slice(None), slice(None), [0, 3, 4]], 724*da0073e9SAndroid Build Coastguard Worker [slice(None), slice(None), [2, 4, 5, 7], slice(None)], 725*da0073e9SAndroid Build Coastguard Worker [slice(None), [2, 3], slice(None), slice(None)], 726*da0073e9SAndroid Build Coastguard Worker [[1, 2], slice(None), slice(None), slice(None)], 727*da0073e9SAndroid Build Coastguard Worker [slice(None), slice(None), [0, 2, 3], [1, 3, 4]], 728*da0073e9SAndroid Build Coastguard Worker [slice(None), slice(None), [0], [1, 2, 4]], 729*da0073e9SAndroid Build Coastguard Worker [slice(None), slice(None), [0, 1, 3], [4]], 730*da0073e9SAndroid Build Coastguard Worker [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]], 731*da0073e9SAndroid Build Coastguard Worker [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]], 732*da0073e9SAndroid Build Coastguard Worker [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]], 733*da0073e9SAndroid Build Coastguard Worker [slice(None), [0, 2, 3], [1, 3, 4], slice(None)], 734*da0073e9SAndroid Build Coastguard Worker [slice(None), [0], [1, 2, 4], slice(None)], 735*da0073e9SAndroid Build Coastguard Worker [slice(None), [0, 1, 3], [4], slice(None)], 736*da0073e9SAndroid Build Coastguard Worker [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)], 737*da0073e9SAndroid Build Coastguard Worker [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)], 738*da0073e9SAndroid Build Coastguard Worker [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)], 739*da0073e9SAndroid Build Coastguard Worker [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)], 740*da0073e9SAndroid Build Coastguard Worker [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)], 741*da0073e9SAndroid Build Coastguard Worker [[0, 1, 2], [1, 3, 4], slice(None), slice(None)], 742*da0073e9SAndroid Build Coastguard Worker [[0], [1, 2, 4], slice(None), slice(None)], 743*da0073e9SAndroid Build Coastguard Worker [[0, 1, 2], [4], slice(None), slice(None)], 744*da0073e9SAndroid Build Coastguard Worker [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)], 745*da0073e9SAndroid Build Coastguard Worker [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)], 746*da0073e9SAndroid Build Coastguard Worker [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)], 747*da0073e9SAndroid Build Coastguard Worker [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)], 748*da0073e9SAndroid Build Coastguard Worker [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]], 749*da0073e9SAndroid Build Coastguard Worker [slice(None), [2, 3, 4], [1, 3, 4], [4]], 750*da0073e9SAndroid Build Coastguard Worker [slice(None), [0, 1, 3], [4], [1, 3, 4]], 751*da0073e9SAndroid Build Coastguard Worker [slice(None), [6], [0, 2, 3], [1, 3, 4]], 752*da0073e9SAndroid Build Coastguard Worker [slice(None), [2, 3, 5], [3], [4]], 753*da0073e9SAndroid Build Coastguard Worker [slice(None), [0], [4], [1, 3, 4]], 754*da0073e9SAndroid Build Coastguard Worker [slice(None), [6], [0, 2, 3], [1]], 755*da0073e9SAndroid Build Coastguard Worker [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]], 756*da0073e9SAndroid Build Coastguard Worker [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)], 757*da0073e9SAndroid Build Coastguard Worker [[2, 0, 1], [1, 2, 3], [4], slice(None)], 758*da0073e9SAndroid Build Coastguard Worker [[0, 1, 2], [4], [1, 3, 4], slice(None)], 759*da0073e9SAndroid Build Coastguard Worker [[0], [0, 2, 3], [1, 3, 4], slice(None)], 760*da0073e9SAndroid Build Coastguard Worker [[0, 2, 1], [3], [4], slice(None)], 761*da0073e9SAndroid Build Coastguard Worker [[0], [4], [1, 3, 4], slice(None)], 762*da0073e9SAndroid Build Coastguard Worker [[1], [0, 2, 3], [1], slice(None)], 763*da0073e9SAndroid Build Coastguard Worker [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)], 764*da0073e9SAndroid Build Coastguard Worker # less dim, ellipsis 765*da0073e9SAndroid Build Coastguard Worker [Ellipsis, [0, 3, 4]], 766*da0073e9SAndroid Build Coastguard Worker [Ellipsis, slice(None), [0, 3, 4]], 767*da0073e9SAndroid Build Coastguard Worker [Ellipsis, slice(None), slice(None), [0, 3, 4]], 768*da0073e9SAndroid Build Coastguard Worker [slice(None), Ellipsis, [0, 3, 4]], 769*da0073e9SAndroid Build Coastguard Worker [slice(None), slice(None), Ellipsis, [0, 3, 4]], 770*da0073e9SAndroid Build Coastguard Worker [slice(None), [0, 2, 3], [1, 3, 4]], 771*da0073e9SAndroid Build Coastguard Worker [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis], 772*da0073e9SAndroid Build Coastguard Worker [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)], 773*da0073e9SAndroid Build Coastguard Worker [[0], [1, 2, 4]], 774*da0073e9SAndroid Build Coastguard Worker [[0], [1, 2, 4], slice(None)], 775*da0073e9SAndroid Build Coastguard Worker [[0], [1, 2, 4], Ellipsis], 776*da0073e9SAndroid Build Coastguard Worker [[0], [1, 2, 4], Ellipsis, slice(None)], 777*da0073e9SAndroid Build Coastguard Worker [[1]], 778*da0073e9SAndroid Build Coastguard Worker [[0, 2, 1], [3], [4]], 779*da0073e9SAndroid Build Coastguard Worker [[0, 2, 1], [3], [4], slice(None)], 780*da0073e9SAndroid Build Coastguard Worker [[0, 2, 1], [3], [4], Ellipsis], 781*da0073e9SAndroid Build Coastguard Worker [Ellipsis, [0, 2, 1], [3], [4]], 782*da0073e9SAndroid Build Coastguard Worker ] 783*da0073e9SAndroid Build Coastguard Worker 784*da0073e9SAndroid Build Coastguard Worker for indexer in indices_to_test: 785*da0073e9SAndroid Build Coastguard Worker assert_get_eq(reference, indexer) 786*da0073e9SAndroid Build Coastguard Worker assert_set_eq(reference, indexer, 1333) 787*da0073e9SAndroid Build Coastguard Worker assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) 788*da0073e9SAndroid Build Coastguard Worker indices_to_test += [ 789*da0073e9SAndroid Build Coastguard Worker [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]], 790*da0073e9SAndroid Build Coastguard Worker [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]], 791*da0073e9SAndroid Build Coastguard Worker ] 792*da0073e9SAndroid Build Coastguard Worker for indexer in indices_to_test: 793*da0073e9SAndroid Build Coastguard Worker assert_get_eq(reference, indexer) 794*da0073e9SAndroid Build Coastguard Worker assert_set_eq(reference, indexer, 1333) 795*da0073e9SAndroid Build Coastguard Worker if self.device_type != "cpu": 796*da0073e9SAndroid Build Coastguard Worker assert_backward_eq(reference, indexer) 797*da0073e9SAndroid Build Coastguard Worker 798*da0073e9SAndroid Build Coastguard Worker def test_advancedindex_big(self, device): 799*da0073e9SAndroid Build Coastguard Worker reference = torch.arange(0, 123344, dtype=torch.int, device=device) 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 802*da0073e9SAndroid Build Coastguard Worker reference[[0, 123, 44488, 68807, 123343],], 803*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int), 804*da0073e9SAndroid Build Coastguard Worker ) 805*da0073e9SAndroid Build Coastguard Worker 806*da0073e9SAndroid Build Coastguard Worker def test_set_item_to_scalar_tensor(self, device): 807*da0073e9SAndroid Build Coastguard Worker m = random.randint(1, 10) 808*da0073e9SAndroid Build Coastguard Worker n = random.randint(1, 10) 809*da0073e9SAndroid Build Coastguard Worker z = torch.randn([m, n], device=device) 810*da0073e9SAndroid Build Coastguard Worker a = 1.0 811*da0073e9SAndroid Build Coastguard Worker w = torch.tensor(a, requires_grad=True, device=device) 812*da0073e9SAndroid Build Coastguard Worker z[:, 0] = w 813*da0073e9SAndroid Build Coastguard Worker z.sum().backward() 814*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w.grad, m * a) 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker def test_single_int(self, device): 817*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 818*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[4].shape, (7, 3)) 819*da0073e9SAndroid Build Coastguard Worker 820*da0073e9SAndroid Build Coastguard Worker def test_multiple_int(self, device): 821*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 822*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[4].shape, (7, 3)) 823*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[4, :, 1].shape, (7,)) 824*da0073e9SAndroid Build Coastguard Worker 825*da0073e9SAndroid Build Coastguard Worker def test_none(self, device): 826*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 827*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[None].shape, (1, 5, 7, 3)) 828*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[:, None].shape, (5, 1, 7, 3)) 829*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3)) 830*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[..., None].shape, (5, 7, 3, 1)) 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker def test_step(self, device): 833*da0073e9SAndroid Build Coastguard Worker v = torch.arange(10, device=device) 834*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[::1], v) 835*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8]) 836*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[::3].tolist(), [0, 3, 6, 9]) 837*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[::11].tolist(), [0]) 838*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[1:6:2].tolist(), [1, 3, 5]) 839*da0073e9SAndroid Build Coastguard Worker 840*da0073e9SAndroid Build Coastguard Worker def test_step_assignment(self, device): 841*da0073e9SAndroid Build Coastguard Worker v = torch.zeros(4, 4, device=device) 842*da0073e9SAndroid Build Coastguard Worker v[0, 1::2] = torch.tensor([3.0, 4.0], device=device) 843*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[0].tolist(), [0, 3, 0, 4]) 844*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[1:].sum(), 0) 845*da0073e9SAndroid Build Coastguard Worker 846*da0073e9SAndroid Build Coastguard Worker def test_bool_indices(self, device): 847*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 848*da0073e9SAndroid Build Coastguard Worker boolIndices = torch.tensor( 849*da0073e9SAndroid Build Coastguard Worker [True, False, True, True, False], dtype=torch.bool, device=device 850*da0073e9SAndroid Build Coastguard Worker ) 851*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[boolIndices].shape, (3, 7, 3)) 852*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]])) 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([True, False, True], dtype=torch.bool, device=device) 855*da0073e9SAndroid Build Coastguard Worker boolIndices = torch.tensor( 856*da0073e9SAndroid Build Coastguard Worker [True, False, False], dtype=torch.bool, device=device 857*da0073e9SAndroid Build Coastguard Worker ) 858*da0073e9SAndroid Build Coastguard Worker uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device) 859*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 860*da0073e9SAndroid Build Coastguard Worker v1 = v[boolIndices] 861*da0073e9SAndroid Build Coastguard Worker v2 = v[uint8Indices] 862*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v1.shape, v2.shape) 863*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v1, v2) 864*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 865*da0073e9SAndroid Build Coastguard Worker v[boolIndices], tensor([True], dtype=torch.bool, device=device) 866*da0073e9SAndroid Build Coastguard Worker ) 867*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Worker def test_bool_indices_accumulate(self, device): 870*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros(size=(10,), dtype=torch.bool, device=device) 871*da0073e9SAndroid Build Coastguard Worker y = torch.ones(size=(10, 10), device=device) 872*da0073e9SAndroid Build Coastguard Worker y.index_put_((mask,), y[mask], accumulate=True) 873*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, torch.ones(size=(10, 10), device=device)) 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker def test_multiple_bool_indices(self, device): 876*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 877*da0073e9SAndroid Build Coastguard Worker # note: these broadcast together and are transposed to the first dim 878*da0073e9SAndroid Build Coastguard Worker mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device) 879*da0073e9SAndroid Build Coastguard Worker mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device) 880*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[mask1, :, mask2].shape, (3, 7)) 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker def test_byte_mask(self, device): 883*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 884*da0073e9SAndroid Build Coastguard Worker mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) 885*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 886*da0073e9SAndroid Build Coastguard Worker res = v[mask] 887*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, (3, 7, 3)) 888*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, torch.stack([v[0], v[2], v[3]])) 889*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 890*da0073e9SAndroid Build Coastguard Worker 891*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([1.0], device=device) 892*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[v == 0], torch.tensor([], device=device)) 893*da0073e9SAndroid Build Coastguard Worker 894*da0073e9SAndroid Build Coastguard Worker def test_byte_mask_accumulate(self, device): 895*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros(size=(10,), dtype=torch.uint8, device=device) 896*da0073e9SAndroid Build Coastguard Worker y = torch.ones(size=(10, 10), device=device) 897*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 898*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 899*da0073e9SAndroid Build Coastguard Worker y.index_put_((mask,), y[mask], accumulate=True) 900*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, torch.ones(size=(10, 10), device=device)) 901*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 2) 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo( 904*da0073e9SAndroid Build Coastguard Worker "This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472" 905*da0073e9SAndroid Build Coastguard Worker ) 906*da0073e9SAndroid Build Coastguard Worker @serialTest(TEST_CUDA) 907*da0073e9SAndroid Build Coastguard Worker def test_index_put_accumulate_large_tensor(self, device): 908*da0073e9SAndroid Build Coastguard Worker # This test is for tensors with number of elements >= INT_MAX (2^31 - 1). 909*da0073e9SAndroid Build Coastguard Worker N = (1 << 31) + 5 910*da0073e9SAndroid Build Coastguard Worker dt = torch.int8 911*da0073e9SAndroid Build Coastguard Worker a = torch.ones(N, dtype=dt, device=device) 912*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor( 913*da0073e9SAndroid Build Coastguard Worker [-2, 0, -2, -1, 0, -1, 1], device=device, dtype=torch.long 914*da0073e9SAndroid Build Coastguard Worker ) 915*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([6, 5, 6, 6, 5, 7, 11], dtype=dt, device=device) 916*da0073e9SAndroid Build Coastguard Worker 917*da0073e9SAndroid Build Coastguard Worker a.index_put_((indices,), values, accumulate=True) 918*da0073e9SAndroid Build Coastguard Worker 919*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0], 11) 920*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[1], 12) 921*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[2], 1) 922*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[-3], 1) 923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[-2], 13) 924*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[-1], 14) 925*da0073e9SAndroid Build Coastguard Worker 926*da0073e9SAndroid Build Coastguard Worker a = torch.ones((2, N), dtype=dt, device=device) 927*da0073e9SAndroid Build Coastguard Worker indices0 = torch.tensor([0, -1, 0, 1], device=device, dtype=torch.long) 928*da0073e9SAndroid Build Coastguard Worker indices1 = torch.tensor([-2, -1, 0, 1], device=device, dtype=torch.long) 929*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([12, 13, 10, 11], dtype=dt, device=device) 930*da0073e9SAndroid Build Coastguard Worker 931*da0073e9SAndroid Build Coastguard Worker a.index_put_((indices0, indices1), values, accumulate=True) 932*da0073e9SAndroid Build Coastguard Worker 933*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0, 0], 11) 934*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0, 1], 1) 935*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[1, 0], 1) 936*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[1, 1], 12) 937*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[:, 2], torch.ones(2, dtype=torch.int8)) 938*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[:, -3], torch.ones(2, dtype=torch.int8)) 939*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0, -2], 13) 940*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[1, -2], 1) 941*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[-1, -1], 14) 942*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0, -1], 1) 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 945*da0073e9SAndroid Build Coastguard Worker def test_index_put_accumulate_expanded_values(self, device): 946*da0073e9SAndroid Build Coastguard Worker # checks the issue with cuda: https://github.com/pytorch/pytorch/issues/39227 947*da0073e9SAndroid Build Coastguard Worker # and verifies consistency with CPU result 948*da0073e9SAndroid Build Coastguard Worker t = torch.zeros((5, 2)) 949*da0073e9SAndroid Build Coastguard Worker t_dev = t.to(device) 950*da0073e9SAndroid Build Coastguard Worker indices = [torch.tensor([0, 1, 2, 3]), torch.tensor([1])] 951*da0073e9SAndroid Build Coastguard Worker indices_dev = [i.to(device) for i in indices] 952*da0073e9SAndroid Build Coastguard Worker values0d = torch.tensor(1.0) 953*da0073e9SAndroid Build Coastguard Worker values1d = torch.tensor([1.0]) 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker out_cuda = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True) 956*da0073e9SAndroid Build Coastguard Worker out_cpu = t.index_put_(indices, values0d, accumulate=True) 957*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cuda.cpu(), out_cpu) 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Worker out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True) 960*da0073e9SAndroid Build Coastguard Worker out_cpu = t.index_put_(indices, values1d, accumulate=True) 961*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cuda.cpu(), out_cpu) 962*da0073e9SAndroid Build Coastguard Worker 963*da0073e9SAndroid Build Coastguard Worker t = torch.zeros(4, 3, 2) 964*da0073e9SAndroid Build Coastguard Worker t_dev = t.to(device) 965*da0073e9SAndroid Build Coastguard Worker 966*da0073e9SAndroid Build Coastguard Worker indices = [ 967*da0073e9SAndroid Build Coastguard Worker torch.tensor([0]), 968*da0073e9SAndroid Build Coastguard Worker torch.arange(3)[:, None], 969*da0073e9SAndroid Build Coastguard Worker torch.arange(2)[None, :], 970*da0073e9SAndroid Build Coastguard Worker ] 971*da0073e9SAndroid Build Coastguard Worker indices_dev = [i.to(device) for i in indices] 972*da0073e9SAndroid Build Coastguard Worker values1d = torch.tensor([-1.0, -2.0]) 973*da0073e9SAndroid Build Coastguard Worker values2d = torch.tensor([[-1.0, -2.0]]) 974*da0073e9SAndroid Build Coastguard Worker 975*da0073e9SAndroid Build Coastguard Worker out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True) 976*da0073e9SAndroid Build Coastguard Worker out_cpu = t.index_put_(indices, values1d, accumulate=True) 977*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cuda.cpu(), out_cpu) 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker out_cuda = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True) 980*da0073e9SAndroid Build Coastguard Worker out_cpu = t.index_put_(indices, values2d, accumulate=True) 981*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cuda.cpu(), out_cpu) 982*da0073e9SAndroid Build Coastguard Worker 983*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 984*da0073e9SAndroid Build Coastguard Worker def test_index_put_accumulate_non_contiguous(self, device): 985*da0073e9SAndroid Build Coastguard Worker t = torch.zeros((5, 2, 2)) 986*da0073e9SAndroid Build Coastguard Worker t_dev = t.to(device) 987*da0073e9SAndroid Build Coastguard Worker t1 = t_dev[:, 0, :] 988*da0073e9SAndroid Build Coastguard Worker t2 = t[:, 0, :] 989*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not t1.is_contiguous()) 990*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not t2.is_contiguous()) 991*da0073e9SAndroid Build Coastguard Worker 992*da0073e9SAndroid Build Coastguard Worker indices = [torch.tensor([0, 1])] 993*da0073e9SAndroid Build Coastguard Worker indices_dev = [i.to(device) for i in indices] 994*da0073e9SAndroid Build Coastguard Worker value = torch.randn(2, 2) 995*da0073e9SAndroid Build Coastguard Worker out_cuda = t1.index_put_(indices_dev, value.to(device), accumulate=True) 996*da0073e9SAndroid Build Coastguard Worker out_cpu = t2.index_put_(indices, value, accumulate=True) 997*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not t1.is_contiguous()) 998*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not t2.is_contiguous()) 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cuda.cpu(), out_cpu) 1001*da0073e9SAndroid Build Coastguard Worker 1002*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1003*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 1004*da0073e9SAndroid Build Coastguard Worker def test_index_put_accumulate_with_optional_tensors(self, device): 1005*da0073e9SAndroid Build Coastguard Worker # TODO: replace with a better solution. 1006*da0073e9SAndroid Build Coastguard Worker # Currently, here using torchscript to put None into indices. 1007*da0073e9SAndroid Build Coastguard Worker # on C++ it gives indices as a list of 2 optional tensors: first is null and 1008*da0073e9SAndroid Build Coastguard Worker # the second is a valid tensor. 1009*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1010*da0073e9SAndroid Build Coastguard Worker def func(x, i, v): 1011*da0073e9SAndroid Build Coastguard Worker idx = [None, i] 1012*da0073e9SAndroid Build Coastguard Worker x.index_put_(idx, v, accumulate=True) 1013*da0073e9SAndroid Build Coastguard Worker return x 1014*da0073e9SAndroid Build Coastguard Worker 1015*da0073e9SAndroid Build Coastguard Worker n = 4 1016*da0073e9SAndroid Build Coastguard Worker t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2) 1017*da0073e9SAndroid Build Coastguard Worker t_dev = t.to(device) 1018*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor([1, 0]) 1019*da0073e9SAndroid Build Coastguard Worker indices_dev = indices.to(device) 1020*da0073e9SAndroid Build Coastguard Worker value0d = torch.tensor(10.0) 1021*da0073e9SAndroid Build Coastguard Worker value1d = torch.tensor([1.0, 2.0]) 1022*da0073e9SAndroid Build Coastguard Worker 1023*da0073e9SAndroid Build Coastguard Worker out_cuda = func(t_dev, indices_dev, value0d.cuda()) 1024*da0073e9SAndroid Build Coastguard Worker out_cpu = func(t, indices, value0d) 1025*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cuda.cpu(), out_cpu) 1026*da0073e9SAndroid Build Coastguard Worker 1027*da0073e9SAndroid Build Coastguard Worker out_cuda = func(t_dev, indices_dev, value1d.cuda()) 1028*da0073e9SAndroid Build Coastguard Worker out_cpu = func(t, indices, value1d) 1029*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cuda.cpu(), out_cpu) 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1032*da0073e9SAndroid Build Coastguard Worker def test_index_put_accumulate_duplicate_indices(self, device): 1033*da0073e9SAndroid Build Coastguard Worker for i in range(1, 512): 1034*da0073e9SAndroid Build Coastguard Worker # generate indices by random walk, this will create indices with 1035*da0073e9SAndroid Build Coastguard Worker # lots of duplicates interleaved with each other 1036*da0073e9SAndroid Build Coastguard Worker delta = torch.empty(i, dtype=torch.double, device=device).uniform_(-1, 1) 1037*da0073e9SAndroid Build Coastguard Worker indices = delta.cumsum(0).long() 1038*da0073e9SAndroid Build Coastguard Worker 1039*da0073e9SAndroid Build Coastguard Worker input = torch.randn(indices.abs().max() + 1, device=device) 1040*da0073e9SAndroid Build Coastguard Worker values = torch.randn(indices.size(0), device=device) 1041*da0073e9SAndroid Build Coastguard Worker output = input.index_put((indices,), values, accumulate=True) 1042*da0073e9SAndroid Build Coastguard Worker 1043*da0073e9SAndroid Build Coastguard Worker input_list = input.tolist() 1044*da0073e9SAndroid Build Coastguard Worker indices_list = indices.tolist() 1045*da0073e9SAndroid Build Coastguard Worker values_list = values.tolist() 1046*da0073e9SAndroid Build Coastguard Worker for i, v in zip(indices_list, values_list): 1047*da0073e9SAndroid Build Coastguard Worker input_list[i] += v 1048*da0073e9SAndroid Build Coastguard Worker 1049*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, input_list) 1050*da0073e9SAndroid Build Coastguard Worker 1051*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1052*da0073e9SAndroid Build Coastguard Worker def test_index_ind_dtype(self, device): 1053*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, device=device) 1054*da0073e9SAndroid Build Coastguard Worker ind_long = torch.randint(4, (4,), dtype=torch.long, device=device) 1055*da0073e9SAndroid Build Coastguard Worker ind_int = ind_long.int() 1056*da0073e9SAndroid Build Coastguard Worker src = torch.randn(4, device=device) 1057*da0073e9SAndroid Build Coastguard Worker ref = x[ind_long, ind_long] 1058*da0073e9SAndroid Build Coastguard Worker res = x[ind_int, ind_int] 1059*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1060*da0073e9SAndroid Build Coastguard Worker ref = x[ind_long, :] 1061*da0073e9SAndroid Build Coastguard Worker res = x[ind_int, :] 1062*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1063*da0073e9SAndroid Build Coastguard Worker ref = x[:, ind_long] 1064*da0073e9SAndroid Build Coastguard Worker res = x[:, ind_int] 1065*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1066*da0073e9SAndroid Build Coastguard Worker # no repeating indices for index_put 1067*da0073e9SAndroid Build Coastguard Worker ind_long = torch.arange(4, dtype=torch.long, device=device) 1068*da0073e9SAndroid Build Coastguard Worker ind_int = ind_long.int() 1069*da0073e9SAndroid Build Coastguard Worker for accum in (True, False): 1070*da0073e9SAndroid Build Coastguard Worker inp_ref = x.clone() 1071*da0073e9SAndroid Build Coastguard Worker inp_res = x.clone() 1072*da0073e9SAndroid Build Coastguard Worker torch.index_put_(inp_ref, (ind_long, ind_long), src, accum) 1073*da0073e9SAndroid Build Coastguard Worker torch.index_put_(inp_res, (ind_int, ind_int), src, accum) 1074*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp_ref, inp_res) 1075*da0073e9SAndroid Build Coastguard Worker 1076*da0073e9SAndroid Build Coastguard Worker @skipXLA 1077*da0073e9SAndroid Build Coastguard Worker def test_index_put_accumulate_empty(self, device): 1078*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/94667 1079*da0073e9SAndroid Build Coastguard Worker input = torch.rand([], dtype=torch.float32, device=device) 1080*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 1081*da0073e9SAndroid Build Coastguard Worker input.index_put([], torch.tensor([1.0], device=device), True) 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker def test_multiple_byte_mask(self, device): 1084*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 1085*da0073e9SAndroid Build Coastguard Worker # note: these broadcast together and are transposed to the first dim 1086*da0073e9SAndroid Build Coastguard Worker mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) 1087*da0073e9SAndroid Build Coastguard Worker mask2 = torch.ByteTensor([1, 1, 1]).to(device) 1088*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1089*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 1090*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[mask1, :, mask2].shape, (3, 7)) 1091*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 2) 1092*da0073e9SAndroid Build Coastguard Worker 1093*da0073e9SAndroid Build Coastguard Worker def test_byte_mask2d(self, device): 1094*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 1095*da0073e9SAndroid Build Coastguard Worker c = torch.randn(5, 7, device=device) 1096*da0073e9SAndroid Build Coastguard Worker num_ones = (c > 0).sum() 1097*da0073e9SAndroid Build Coastguard Worker r = v[c > 0] 1098*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.shape, (num_ones, 3)) 1099*da0073e9SAndroid Build Coastguard Worker 1100*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 1101*da0073e9SAndroid Build Coastguard Worker def test_jit_indexing(self, device): 1102*da0073e9SAndroid Build Coastguard Worker def fn1(x): 1103*da0073e9SAndroid Build Coastguard Worker x[x < 50] = 1.0 1104*da0073e9SAndroid Build Coastguard Worker return x 1105*da0073e9SAndroid Build Coastguard Worker 1106*da0073e9SAndroid Build Coastguard Worker def fn2(x): 1107*da0073e9SAndroid Build Coastguard Worker x[0:50] = 1.0 1108*da0073e9SAndroid Build Coastguard Worker return x 1109*da0073e9SAndroid Build Coastguard Worker 1110*da0073e9SAndroid Build Coastguard Worker scripted_fn1 = torch.jit.script(fn1) 1111*da0073e9SAndroid Build Coastguard Worker scripted_fn2 = torch.jit.script(fn2) 1112*da0073e9SAndroid Build Coastguard Worker data = torch.arange(100, device=device, dtype=torch.float) 1113*da0073e9SAndroid Build Coastguard Worker out = scripted_fn1(data.detach().clone()) 1114*da0073e9SAndroid Build Coastguard Worker ref = torch.tensor( 1115*da0073e9SAndroid Build Coastguard Worker np.concatenate((np.ones(50), np.arange(50, 100))), 1116*da0073e9SAndroid Build Coastguard Worker device=device, 1117*da0073e9SAndroid Build Coastguard Worker dtype=torch.float, 1118*da0073e9SAndroid Build Coastguard Worker ) 1119*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref) 1120*da0073e9SAndroid Build Coastguard Worker out = scripted_fn2(data.detach().clone()) 1121*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref) 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Worker def test_int_indices(self, device): 1124*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 1125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3)) 1126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3)) 1127*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3)) 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Worker @dtypes( 1130*da0073e9SAndroid Build Coastguard Worker torch.cfloat, torch.cdouble, torch.float, torch.bfloat16, torch.long, torch.bool 1131*da0073e9SAndroid Build Coastguard Worker ) 1132*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU( 1133*da0073e9SAndroid Build Coastguard Worker torch.cfloat, torch.cdouble, torch.float, torch.long, torch.bool, torch.bfloat16 1134*da0073e9SAndroid Build Coastguard Worker ) 1135*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA( 1136*da0073e9SAndroid Build Coastguard Worker torch.cfloat, 1137*da0073e9SAndroid Build Coastguard Worker torch.cdouble, 1138*da0073e9SAndroid Build Coastguard Worker torch.half, 1139*da0073e9SAndroid Build Coastguard Worker torch.long, 1140*da0073e9SAndroid Build Coastguard Worker torch.bool, 1141*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 1142*da0073e9SAndroid Build Coastguard Worker torch.float8_e5m2, 1143*da0073e9SAndroid Build Coastguard Worker torch.float8_e4m3fn, 1144*da0073e9SAndroid Build Coastguard Worker ) 1145*da0073e9SAndroid Build Coastguard Worker def test_index_put_src_datatype(self, device, dtype): 1146*da0073e9SAndroid Build Coastguard Worker src = torch.ones(3, 2, 4, device=device, dtype=dtype) 1147*da0073e9SAndroid Build Coastguard Worker vals = torch.ones(3, 2, 4, device=device, dtype=dtype) 1148*da0073e9SAndroid Build Coastguard Worker indices = (torch.tensor([0, 2, 1]),) 1149*da0073e9SAndroid Build Coastguard Worker res = src.index_put_(indices, vals, accumulate=True) 1150*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, src.shape) 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.bfloat16, torch.long, torch.bool) 1153*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(torch.float, torch.long, torch.bfloat16, torch.bool) 1154*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.long, torch.bfloat16, torch.bool) 1155*da0073e9SAndroid Build Coastguard Worker def test_index_src_datatype(self, device, dtype): 1156*da0073e9SAndroid Build Coastguard Worker src = torch.ones(3, 2, 4, device=device, dtype=dtype) 1157*da0073e9SAndroid Build Coastguard Worker # test index 1158*da0073e9SAndroid Build Coastguard Worker res = src[[0, 2, 1], :, :] 1159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, src.shape) 1160*da0073e9SAndroid Build Coastguard Worker # test index_put, no accum 1161*da0073e9SAndroid Build Coastguard Worker src[[0, 2, 1], :, :] = res 1162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, src.shape) 1163*da0073e9SAndroid Build Coastguard Worker 1164*da0073e9SAndroid Build Coastguard Worker def test_int_indices2d(self, device): 1165*da0073e9SAndroid Build Coastguard Worker # From the NumPy indexing example 1166*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 12, device=device).view(4, 3) 1167*da0073e9SAndroid Build Coastguard Worker rows = torch.tensor([[0, 0], [3, 3]], device=device) 1168*da0073e9SAndroid Build Coastguard Worker columns = torch.tensor([[0, 2], [0, 2]], device=device) 1169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]]) 1170*da0073e9SAndroid Build Coastguard Worker 1171*da0073e9SAndroid Build Coastguard Worker def test_int_indices_broadcast(self, device): 1172*da0073e9SAndroid Build Coastguard Worker # From the NumPy indexing example 1173*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 12, device=device).view(4, 3) 1174*da0073e9SAndroid Build Coastguard Worker rows = torch.tensor([0, 3], device=device) 1175*da0073e9SAndroid Build Coastguard Worker columns = torch.tensor([0, 2], device=device) 1176*da0073e9SAndroid Build Coastguard Worker result = x[rows[:, None], columns] 1177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.tolist(), [[0, 2], [9, 11]]) 1178*da0073e9SAndroid Build Coastguard Worker 1179*da0073e9SAndroid Build Coastguard Worker def test_empty_index(self, device): 1180*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 12, device=device).view(4, 3) 1181*da0073e9SAndroid Build Coastguard Worker idx = torch.tensor([], dtype=torch.long, device=device) 1182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[idx].numel(), 0) 1183*da0073e9SAndroid Build Coastguard Worker 1184*da0073e9SAndroid Build Coastguard Worker # empty assignment should have no effect but not throw an exception 1185*da0073e9SAndroid Build Coastguard Worker y = x.clone() 1186*da0073e9SAndroid Build Coastguard Worker y[idx] = -1 1187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 1188*da0073e9SAndroid Build Coastguard Worker 1189*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros(4, 3, device=device).bool() 1190*da0073e9SAndroid Build Coastguard Worker y[mask] = -1 1191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 1192*da0073e9SAndroid Build Coastguard Worker 1193*da0073e9SAndroid Build Coastguard Worker def test_empty_ndim_index(self, device): 1194*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, device=device) 1195*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1196*da0073e9SAndroid Build Coastguard Worker torch.empty(0, 2, device=device), 1197*da0073e9SAndroid Build Coastguard Worker x[torch.empty(0, 2, dtype=torch.int64, device=device)], 1198*da0073e9SAndroid Build Coastguard Worker ) 1199*da0073e9SAndroid Build Coastguard Worker 1200*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 4, 5, device=device) 1201*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1202*da0073e9SAndroid Build Coastguard Worker torch.empty(2, 0, 6, 4, 5, device=device), 1203*da0073e9SAndroid Build Coastguard Worker x[:, torch.empty(0, 6, dtype=torch.int64, device=device)], 1204*da0073e9SAndroid Build Coastguard Worker ) 1205*da0073e9SAndroid Build Coastguard Worker 1206*da0073e9SAndroid Build Coastguard Worker x = torch.empty(10, 0, device=device) 1207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[[1, 2]].shape, (2, 0)) 1208*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[[], []].shape, (0,)) 1209*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, "for dimension with size 0"): 1210*da0073e9SAndroid Build Coastguard Worker x[:, [0, 1]] 1211*da0073e9SAndroid Build Coastguard Worker 1212*da0073e9SAndroid Build Coastguard Worker def test_empty_ndim_index_bool(self, device): 1213*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, device=device) 1214*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 1215*da0073e9SAndroid Build Coastguard Worker IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)] 1216*da0073e9SAndroid Build Coastguard Worker ) 1217*da0073e9SAndroid Build Coastguard Worker 1218*da0073e9SAndroid Build Coastguard Worker def test_empty_slice(self, device): 1219*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 4, 5, device=device) 1220*da0073e9SAndroid Build Coastguard Worker y = x[:, :, :, 1] 1221*da0073e9SAndroid Build Coastguard Worker z = y[:, 1:1, :] 1222*da0073e9SAndroid Build Coastguard Worker self.assertEqual((2, 0, 4), z.shape) 1223*da0073e9SAndroid Build Coastguard Worker # this isn't technically necessary, but matches NumPy stride calculations. 1224*da0073e9SAndroid Build Coastguard Worker self.assertEqual((60, 20, 5), z.stride()) 1225*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.is_contiguous()) 1226*da0073e9SAndroid Build Coastguard Worker 1227*da0073e9SAndroid Build Coastguard Worker def test_index_getitem_copy_bools_slices(self, device): 1228*da0073e9SAndroid Build Coastguard Worker true = torch.tensor(1, dtype=torch.uint8, device=device) 1229*da0073e9SAndroid Build Coastguard Worker false = torch.tensor(0, dtype=torch.uint8, device=device) 1230*da0073e9SAndroid Build Coastguard Worker 1231*da0073e9SAndroid Build Coastguard Worker tensors = [torch.randn(2, 3, device=device), torch.tensor(3.0, device=device)] 1232*da0073e9SAndroid Build Coastguard Worker 1233*da0073e9SAndroid Build Coastguard Worker for a in tensors: 1234*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(a.data_ptr(), a[True].data_ptr()) 1235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, *a.shape), a[False]) 1236*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(a.data_ptr(), a[true].data_ptr()) 1237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, *a.shape), a[false]) 1238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.data_ptr(), a[None].data_ptr()) 1239*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.data_ptr(), a[...].data_ptr()) 1240*da0073e9SAndroid Build Coastguard Worker 1241*da0073e9SAndroid Build Coastguard Worker def test_index_setitem_bools_slices(self, device): 1242*da0073e9SAndroid Build Coastguard Worker true = torch.tensor(1, dtype=torch.uint8, device=device) 1243*da0073e9SAndroid Build Coastguard Worker false = torch.tensor(0, dtype=torch.uint8, device=device) 1244*da0073e9SAndroid Build Coastguard Worker 1245*da0073e9SAndroid Build Coastguard Worker tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)] 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Worker for a in tensors: 1248*da0073e9SAndroid Build Coastguard Worker # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s 1249*da0073e9SAndroid Build Coastguard Worker # (some of these ops already prefix a 1 to the size) 1250*da0073e9SAndroid Build Coastguard Worker neg_ones = torch.ones_like(a) * -1 1251*da0073e9SAndroid Build Coastguard Worker neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0) 1252*da0073e9SAndroid Build Coastguard Worker a[True] = neg_ones_expanded 1253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, neg_ones) 1254*da0073e9SAndroid Build Coastguard Worker a[False] = 5 1255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, neg_ones) 1256*da0073e9SAndroid Build Coastguard Worker a[true] = neg_ones_expanded * 2 1257*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, neg_ones * 2) 1258*da0073e9SAndroid Build Coastguard Worker a[false] = 5 1259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, neg_ones * 2) 1260*da0073e9SAndroid Build Coastguard Worker a[None] = neg_ones_expanded * 3 1261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, neg_ones * 3) 1262*da0073e9SAndroid Build Coastguard Worker a[...] = neg_ones_expanded * 4 1263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, neg_ones * 4) 1264*da0073e9SAndroid Build Coastguard Worker if a.dim() == 0: 1265*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 1266*da0073e9SAndroid Build Coastguard Worker a[:] = neg_ones_expanded * 5 1267*da0073e9SAndroid Build Coastguard Worker 1268*da0073e9SAndroid Build Coastguard Worker def test_index_scalar_with_bool_mask(self, device): 1269*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1, device=device) 1270*da0073e9SAndroid Build Coastguard Worker uintMask = torch.tensor(True, dtype=torch.uint8, device=device) 1271*da0073e9SAndroid Build Coastguard Worker boolMask = torch.tensor(True, dtype=torch.bool, device=device) 1272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[uintMask], a[boolMask]) 1273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[uintMask].dtype, a[boolMask].dtype) 1274*da0073e9SAndroid Build Coastguard Worker 1275*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(True, dtype=torch.bool, device=device) 1276*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[uintMask], a[boolMask]) 1277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[uintMask].dtype, a[boolMask].dtype) 1278*da0073e9SAndroid Build Coastguard Worker 1279*da0073e9SAndroid Build Coastguard Worker def test_setitem_expansion_error(self, device): 1280*da0073e9SAndroid Build Coastguard Worker true = torch.tensor(True, device=device) 1281*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, device=device) 1282*da0073e9SAndroid Build Coastguard Worker # check prefix with non-1s doesn't work 1283*da0073e9SAndroid Build Coastguard Worker a_expanded = a.expand(torch.Size([5, 1]) + a.size()) 1284*da0073e9SAndroid Build Coastguard Worker # NumPy: ValueError 1285*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 1286*da0073e9SAndroid Build Coastguard Worker a[True] = a_expanded 1287*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 1288*da0073e9SAndroid Build Coastguard Worker a[true] = a_expanded 1289*da0073e9SAndroid Build Coastguard Worker 1290*da0073e9SAndroid Build Coastguard Worker def test_getitem_scalars(self, device): 1291*da0073e9SAndroid Build Coastguard Worker zero = torch.tensor(0, dtype=torch.int64, device=device) 1292*da0073e9SAndroid Build Coastguard Worker one = torch.tensor(1, dtype=torch.int64, device=device) 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker # non-scalar indexed with scalars 1295*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, device=device) 1296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0], a[zero]) 1297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0][1], a[zero][one]) 1298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0, 1], a[zero, one]) 1299*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0, one], a[zero, 1]) 1300*da0073e9SAndroid Build Coastguard Worker 1301*da0073e9SAndroid Build Coastguard Worker # indexing by a scalar should slice (not copy) 1302*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr()) 1303*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr()) 1304*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr()) 1305*da0073e9SAndroid Build Coastguard Worker 1306*da0073e9SAndroid Build Coastguard Worker # scalar indexed with scalar 1307*da0073e9SAndroid Build Coastguard Worker r = torch.randn((), device=device) 1308*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 1309*da0073e9SAndroid Build Coastguard Worker r[:] 1310*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 1311*da0073e9SAndroid Build Coastguard Worker r[zero] 1312*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, r[...]) 1313*da0073e9SAndroid Build Coastguard Worker 1314*da0073e9SAndroid Build Coastguard Worker def test_setitem_scalars(self, device): 1315*da0073e9SAndroid Build Coastguard Worker zero = torch.tensor(0, dtype=torch.int64) 1316*da0073e9SAndroid Build Coastguard Worker 1317*da0073e9SAndroid Build Coastguard Worker # non-scalar indexed with scalars 1318*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, device=device) 1319*da0073e9SAndroid Build Coastguard Worker a_set_with_number = a.clone() 1320*da0073e9SAndroid Build Coastguard Worker a_set_with_scalar = a.clone() 1321*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, device=device) 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker a_set_with_number[0] = b 1324*da0073e9SAndroid Build Coastguard Worker a_set_with_scalar[zero] = b 1325*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_set_with_number, a_set_with_scalar) 1326*da0073e9SAndroid Build Coastguard Worker a[1, zero] = 7.7 1327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(7.7, a[1, 0]) 1328*da0073e9SAndroid Build Coastguard Worker 1329*da0073e9SAndroid Build Coastguard Worker # scalar indexed with scalars 1330*da0073e9SAndroid Build Coastguard Worker r = torch.randn((), device=device) 1331*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 1332*da0073e9SAndroid Build Coastguard Worker r[:] = 8.8 1333*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 1334*da0073e9SAndroid Build Coastguard Worker r[zero] = 8.8 1335*da0073e9SAndroid Build Coastguard Worker r[...] = 9.9 1336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(9.9, r) 1337*da0073e9SAndroid Build Coastguard Worker 1338*da0073e9SAndroid Build Coastguard Worker def test_basic_advanced_combined(self, device): 1339*da0073e9SAndroid Build Coastguard Worker # From the NumPy indexing example 1340*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 12, device=device).view(4, 3) 1341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]]) 1342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]]) 1343*da0073e9SAndroid Build Coastguard Worker 1344*da0073e9SAndroid Build Coastguard Worker # Check that it is a copy 1345*da0073e9SAndroid Build Coastguard Worker unmodified = x.clone() 1346*da0073e9SAndroid Build Coastguard Worker x[1:2, [1, 2]].zero_() 1347*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, unmodified) 1348*da0073e9SAndroid Build Coastguard Worker 1349*da0073e9SAndroid Build Coastguard Worker # But assignment should modify the original 1350*da0073e9SAndroid Build Coastguard Worker unmodified = x.clone() 1351*da0073e9SAndroid Build Coastguard Worker x[1:2, [1, 2]] = 0 1352*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(x, unmodified) 1353*da0073e9SAndroid Build Coastguard Worker 1354*da0073e9SAndroid Build Coastguard Worker def test_int_assignment(self, device): 1355*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 4, device=device).view(2, 2) 1356*da0073e9SAndroid Build Coastguard Worker x[1] = 5 1357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.tolist(), [[0, 1], [5, 5]]) 1358*da0073e9SAndroid Build Coastguard Worker 1359*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 4, device=device).view(2, 2) 1360*da0073e9SAndroid Build Coastguard Worker x[1] = torch.arange(5, 7, device=device) 1361*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.tolist(), [[0, 1], [5, 6]]) 1362*da0073e9SAndroid Build Coastguard Worker 1363*da0073e9SAndroid Build Coastguard Worker def test_byte_tensor_assignment(self, device): 1364*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0.0, 16, device=device).view(4, 4) 1365*da0073e9SAndroid Build Coastguard Worker b = torch.ByteTensor([True, False, True, False]).to(device) 1366*da0073e9SAndroid Build Coastguard Worker value = torch.tensor([3.0, 4.0, 5.0, 6.0], device=device) 1367*da0073e9SAndroid Build Coastguard Worker 1368*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1369*da0073e9SAndroid Build Coastguard Worker x[b] = value 1370*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 1371*da0073e9SAndroid Build Coastguard Worker 1372*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[0], value) 1373*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[1], torch.arange(4.0, 8, device=device)) 1374*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[2], value) 1375*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[3], torch.arange(12.0, 16, device=device)) 1376*da0073e9SAndroid Build Coastguard Worker 1377*da0073e9SAndroid Build Coastguard Worker def test_variable_slicing(self, device): 1378*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 16, device=device).view(4, 4) 1379*da0073e9SAndroid Build Coastguard Worker indices = torch.IntTensor([0, 1]).to(device) 1380*da0073e9SAndroid Build Coastguard Worker i, j = indices 1381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[i:j], x[0:1]) 1382*da0073e9SAndroid Build Coastguard Worker 1383*da0073e9SAndroid Build Coastguard Worker def test_ellipsis_tensor(self, device): 1384*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 9, device=device).view(3, 3) 1385*da0073e9SAndroid Build Coastguard Worker idx = torch.tensor([0, 2], device=device) 1386*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[..., idx].tolist(), [[0, 2], [3, 5], [6, 8]]) 1387*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2], [6, 7, 8]]) 1388*da0073e9SAndroid Build Coastguard Worker 1389*da0073e9SAndroid Build Coastguard Worker def test_unravel_index_errors(self, device): 1390*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"): 1391*da0073e9SAndroid Build Coastguard Worker torch.unravel_index(torch.tensor(0.5, device=device), (2, 2)) 1392*da0073e9SAndroid Build Coastguard Worker 1393*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"): 1394*da0073e9SAndroid Build Coastguard Worker torch.unravel_index(torch.tensor([], device=device), (10, 3, 5)) 1395*da0073e9SAndroid Build Coastguard Worker 1396*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1397*da0073e9SAndroid Build Coastguard Worker TypeError, r"expected 'shape' to be int or sequence" 1398*da0073e9SAndroid Build Coastguard Worker ): 1399*da0073e9SAndroid Build Coastguard Worker torch.unravel_index( 1400*da0073e9SAndroid Build Coastguard Worker torch.tensor([1], device=device, dtype=torch.int64), 1401*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 2, 3]), 1402*da0073e9SAndroid Build Coastguard Worker ) 1403*da0073e9SAndroid Build Coastguard Worker 1404*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1405*da0073e9SAndroid Build Coastguard Worker TypeError, r"expected 'shape' sequence to only contain ints" 1406*da0073e9SAndroid Build Coastguard Worker ): 1407*da0073e9SAndroid Build Coastguard Worker torch.unravel_index( 1408*da0073e9SAndroid Build Coastguard Worker torch.tensor([1], device=device, dtype=torch.int64), (1, 2, 2.0) 1409*da0073e9SAndroid Build Coastguard Worker ) 1410*da0073e9SAndroid Build Coastguard Worker 1411*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1412*da0073e9SAndroid Build Coastguard Worker ValueError, r"'shape' cannot have negative values, but got \(2, -3\)" 1413*da0073e9SAndroid Build Coastguard Worker ): 1414*da0073e9SAndroid Build Coastguard Worker torch.unravel_index(torch.tensor(0, device=device), (2, -3)) 1415*da0073e9SAndroid Build Coastguard Worker 1416*da0073e9SAndroid Build Coastguard Worker def test_invalid_index(self, device): 1417*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 16, device=device).view(4, 4) 1418*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(TypeError, "slice indices", lambda: x["0":"1"]) 1419*da0073e9SAndroid Build Coastguard Worker 1420*da0073e9SAndroid Build Coastguard Worker def test_out_of_bound_index(self, device): 1421*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 100, device=device).view(2, 5, 10) 1422*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1423*da0073e9SAndroid Build Coastguard Worker IndexError, 1424*da0073e9SAndroid Build Coastguard Worker "index 5 is out of bounds for dimension 1 with size 5", 1425*da0073e9SAndroid Build Coastguard Worker lambda: x[0, 5], 1426*da0073e9SAndroid Build Coastguard Worker ) 1427*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1428*da0073e9SAndroid Build Coastguard Worker IndexError, 1429*da0073e9SAndroid Build Coastguard Worker "index 4 is out of bounds for dimension 0 with size 2", 1430*da0073e9SAndroid Build Coastguard Worker lambda: x[4, 5], 1431*da0073e9SAndroid Build Coastguard Worker ) 1432*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1433*da0073e9SAndroid Build Coastguard Worker IndexError, 1434*da0073e9SAndroid Build Coastguard Worker "index 15 is out of bounds for dimension 2 with size 10", 1435*da0073e9SAndroid Build Coastguard Worker lambda: x[0, 1, 15], 1436*da0073e9SAndroid Build Coastguard Worker ) 1437*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1438*da0073e9SAndroid Build Coastguard Worker IndexError, 1439*da0073e9SAndroid Build Coastguard Worker "index 12 is out of bounds for dimension 2 with size 10", 1440*da0073e9SAndroid Build Coastguard Worker lambda: x[:, :, 12], 1441*da0073e9SAndroid Build Coastguard Worker ) 1442*da0073e9SAndroid Build Coastguard Worker 1443*da0073e9SAndroid Build Coastguard Worker def test_zero_dim_index(self, device): 1444*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(10, device=device) 1445*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.item()) 1446*da0073e9SAndroid Build Coastguard Worker 1447*da0073e9SAndroid Build Coastguard Worker def runner(): 1448*da0073e9SAndroid Build Coastguard Worker print(x[0]) 1449*da0073e9SAndroid Build Coastguard Worker return x[0] 1450*da0073e9SAndroid Build Coastguard Worker 1451*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(IndexError, "invalid index", runner) 1452*da0073e9SAndroid Build Coastguard Worker 1453*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1454*da0073e9SAndroid Build Coastguard Worker def test_invalid_device(self, device): 1455*da0073e9SAndroid Build Coastguard Worker idx = torch.tensor([0, 1]) 1456*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(5, device=device) 1457*da0073e9SAndroid Build Coastguard Worker c = torch.tensor([1.0, 2.0], device="cpu") 1458*da0073e9SAndroid Build Coastguard Worker 1459*da0073e9SAndroid Build Coastguard Worker for accumulate in [True, False]: 1460*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 1461*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1462*da0073e9SAndroid Build Coastguard Worker lambda: torch.index_put_(b, (idx,), c, accumulate=accumulate), 1463*da0073e9SAndroid Build Coastguard Worker ) 1464*da0073e9SAndroid Build Coastguard Worker 1465*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1466*da0073e9SAndroid Build Coastguard Worker def test_cpu_indices(self, device): 1467*da0073e9SAndroid Build Coastguard Worker idx = torch.tensor([0, 1]) 1468*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(2, device=device) 1469*da0073e9SAndroid Build Coastguard Worker x = torch.ones(10, device=device) 1470*da0073e9SAndroid Build Coastguard Worker x[idx] = b # index_put_ 1471*da0073e9SAndroid Build Coastguard Worker ref = torch.ones(10, device=device) 1472*da0073e9SAndroid Build Coastguard Worker ref[:2] = 0 1473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, ref, atol=0, rtol=0) 1474*da0073e9SAndroid Build Coastguard Worker out = x[idx] # index 1475*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0) 1476*da0073e9SAndroid Build Coastguard Worker 1477*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.long, torch.float32) 1478*da0073e9SAndroid Build Coastguard Worker def test_take_along_dim(self, device, dtype): 1479*da0073e9SAndroid Build Coastguard Worker def _test_against_numpy(t, indices, dim): 1480*da0073e9SAndroid Build Coastguard Worker actual = torch.take_along_dim(t, indices, dim=dim) 1481*da0073e9SAndroid Build Coastguard Worker t_np = t.cpu().numpy() 1482*da0073e9SAndroid Build Coastguard Worker indices_np = indices.cpu().numpy() 1483*da0073e9SAndroid Build Coastguard Worker expected = np.take_along_axis(t_np, indices_np, axis=dim) 1484*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, atol=0, rtol=0) 1485*da0073e9SAndroid Build Coastguard Worker 1486*da0073e9SAndroid Build Coastguard Worker for shape in [(3, 2), (2, 3, 5), (2, 4, 0), (2, 3, 1, 4)]: 1487*da0073e9SAndroid Build Coastguard Worker for noncontiguous in [True, False]: 1488*da0073e9SAndroid Build Coastguard Worker t = make_tensor( 1489*da0073e9SAndroid Build Coastguard Worker shape, device=device, dtype=dtype, noncontiguous=noncontiguous 1490*da0073e9SAndroid Build Coastguard Worker ) 1491*da0073e9SAndroid Build Coastguard Worker for dim in list(range(t.ndim)) + [None]: 1492*da0073e9SAndroid Build Coastguard Worker if dim is None: 1493*da0073e9SAndroid Build Coastguard Worker indices = torch.argsort(t.view(-1)) 1494*da0073e9SAndroid Build Coastguard Worker else: 1495*da0073e9SAndroid Build Coastguard Worker indices = torch.argsort(t, dim=dim) 1496*da0073e9SAndroid Build Coastguard Worker 1497*da0073e9SAndroid Build Coastguard Worker _test_against_numpy(t, indices, dim) 1498*da0073e9SAndroid Build Coastguard Worker 1499*da0073e9SAndroid Build Coastguard Worker # test broadcasting 1500*da0073e9SAndroid Build Coastguard Worker t = torch.ones((3, 4, 1), device=device) 1501*da0073e9SAndroid Build Coastguard Worker indices = torch.ones((1, 2, 5), dtype=torch.long, device=device) 1502*da0073e9SAndroid Build Coastguard Worker 1503*da0073e9SAndroid Build Coastguard Worker _test_against_numpy(t, indices, 1) 1504*da0073e9SAndroid Build Coastguard Worker 1505*da0073e9SAndroid Build Coastguard Worker # test empty indices 1506*da0073e9SAndroid Build Coastguard Worker t = torch.ones((3, 4, 5), device=device) 1507*da0073e9SAndroid Build Coastguard Worker indices = torch.ones((3, 0, 5), dtype=torch.long, device=device) 1508*da0073e9SAndroid Build Coastguard Worker 1509*da0073e9SAndroid Build Coastguard Worker _test_against_numpy(t, indices, 1) 1510*da0073e9SAndroid Build Coastguard Worker 1511*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.long, torch.float) 1512*da0073e9SAndroid Build Coastguard Worker def test_take_along_dim_invalid(self, device, dtype): 1513*da0073e9SAndroid Build Coastguard Worker shape = (2, 3, 1, 4) 1514*da0073e9SAndroid Build Coastguard Worker dim = 0 1515*da0073e9SAndroid Build Coastguard Worker t = make_tensor(shape, device=device, dtype=dtype) 1516*da0073e9SAndroid Build Coastguard Worker indices = torch.argsort(t, dim=dim) 1517*da0073e9SAndroid Build Coastguard Worker 1518*da0073e9SAndroid Build Coastguard Worker # dim of `t` and `indices` does not match 1519*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1520*da0073e9SAndroid Build Coastguard Worker RuntimeError, "input and indices should have the same number of dimensions" 1521*da0073e9SAndroid Build Coastguard Worker ): 1522*da0073e9SAndroid Build Coastguard Worker torch.take_along_dim(t, indices[0], dim=0) 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker # invalid `indices` dtype 1525*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"): 1526*da0073e9SAndroid Build Coastguard Worker torch.take_along_dim(t, indices.to(torch.bool), dim=0) 1527*da0073e9SAndroid Build Coastguard Worker 1528*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"): 1529*da0073e9SAndroid Build Coastguard Worker torch.take_along_dim(t, indices.to(torch.float), dim=0) 1530*da0073e9SAndroid Build Coastguard Worker 1531*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"): 1532*da0073e9SAndroid Build Coastguard Worker torch.take_along_dim(t, indices.to(torch.int32), dim=0) 1533*da0073e9SAndroid Build Coastguard Worker 1534*da0073e9SAndroid Build Coastguard Worker # invalid axis 1535*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, "Dimension out of range"): 1536*da0073e9SAndroid Build Coastguard Worker torch.take_along_dim(t, indices, dim=-7) 1537*da0073e9SAndroid Build Coastguard Worker 1538*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, "Dimension out of range"): 1539*da0073e9SAndroid Build Coastguard Worker torch.take_along_dim(t, indices, dim=7) 1540*da0073e9SAndroid Build Coastguard Worker 1541*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1542*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1543*da0073e9SAndroid Build Coastguard Worker def test_gather_take_along_dim_cross_device(self, device, dtype): 1544*da0073e9SAndroid Build Coastguard Worker shape = (2, 3, 1, 4) 1545*da0073e9SAndroid Build Coastguard Worker dim = 0 1546*da0073e9SAndroid Build Coastguard Worker t = make_tensor(shape, device=device, dtype=dtype) 1547*da0073e9SAndroid Build Coastguard Worker indices = torch.argsort(t, dim=dim) 1548*da0073e9SAndroid Build Coastguard Worker 1549*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1550*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected all tensors to be on the same device" 1551*da0073e9SAndroid Build Coastguard Worker ): 1552*da0073e9SAndroid Build Coastguard Worker torch.gather(t, 0, indices.cpu()) 1553*da0073e9SAndroid Build Coastguard Worker 1554*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1555*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1556*da0073e9SAndroid Build Coastguard Worker r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()", 1557*da0073e9SAndroid Build Coastguard Worker ): 1558*da0073e9SAndroid Build Coastguard Worker torch.take_along_dim(t, indices.cpu(), dim=0) 1559*da0073e9SAndroid Build Coastguard Worker 1560*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1561*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected all tensors to be on the same device" 1562*da0073e9SAndroid Build Coastguard Worker ): 1563*da0073e9SAndroid Build Coastguard Worker torch.gather(t.cpu(), 0, indices) 1564*da0073e9SAndroid Build Coastguard Worker 1565*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1566*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1567*da0073e9SAndroid Build Coastguard Worker r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()", 1568*da0073e9SAndroid Build Coastguard Worker ): 1569*da0073e9SAndroid Build Coastguard Worker torch.take_along_dim(t.cpu(), indices, dim=0) 1570*da0073e9SAndroid Build Coastguard Worker 1571*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1572*da0073e9SAndroid Build Coastguard Worker def test_cuda_broadcast_index_use_deterministic_algorithms(self, device): 1573*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(True): 1574*da0073e9SAndroid Build Coastguard Worker idx1 = torch.tensor([0]) 1575*da0073e9SAndroid Build Coastguard Worker idx2 = torch.tensor([2, 6]) 1576*da0073e9SAndroid Build Coastguard Worker idx3 = torch.tensor([1, 5, 7]) 1577*da0073e9SAndroid Build Coastguard Worker 1578*da0073e9SAndroid Build Coastguard Worker tensor_a = torch.rand(13, 11, 12, 13, 12).cpu() 1579*da0073e9SAndroid Build Coastguard Worker tensor_b = tensor_a.to(device=device) 1580*da0073e9SAndroid Build Coastguard Worker tensor_a[idx1] = 1.0 1581*da0073e9SAndroid Build Coastguard Worker tensor_a[idx1, :, idx2, idx2, :] = 2.0 1582*da0073e9SAndroid Build Coastguard Worker tensor_a[:, idx1, idx3, :, idx3] = 3.0 1583*da0073e9SAndroid Build Coastguard Worker tensor_b[idx1] = 1.0 1584*da0073e9SAndroid Build Coastguard Worker tensor_b[idx1, :, idx2, idx2, :] = 2.0 1585*da0073e9SAndroid Build Coastguard Worker tensor_b[:, idx1, idx3, :, idx3] = 3.0 1586*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0) 1587*da0073e9SAndroid Build Coastguard Worker 1588*da0073e9SAndroid Build Coastguard Worker tensor_a = torch.rand(10, 11).cpu() 1589*da0073e9SAndroid Build Coastguard Worker tensor_b = tensor_a.to(device=device) 1590*da0073e9SAndroid Build Coastguard Worker tensor_a[idx3] = 1.0 1591*da0073e9SAndroid Build Coastguard Worker tensor_a[idx2, :] = 2.0 1592*da0073e9SAndroid Build Coastguard Worker tensor_a[:, idx2] = 3.0 1593*da0073e9SAndroid Build Coastguard Worker tensor_a[:, idx1] = 4.0 1594*da0073e9SAndroid Build Coastguard Worker tensor_b[idx3] = 1.0 1595*da0073e9SAndroid Build Coastguard Worker tensor_b[idx2, :] = 2.0 1596*da0073e9SAndroid Build Coastguard Worker tensor_b[:, idx2] = 3.0 1597*da0073e9SAndroid Build Coastguard Worker tensor_b[:, idx1] = 4.0 1598*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0) 1599*da0073e9SAndroid Build Coastguard Worker 1600*da0073e9SAndroid Build Coastguard Worker tensor_a = torch.rand(10, 10).cpu() 1601*da0073e9SAndroid Build Coastguard Worker tensor_b = tensor_a.to(device=device) 1602*da0073e9SAndroid Build Coastguard Worker tensor_a[[8]] = 1.0 1603*da0073e9SAndroid Build Coastguard Worker tensor_b[[8]] = 1.0 1604*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0) 1605*da0073e9SAndroid Build Coastguard Worker 1606*da0073e9SAndroid Build Coastguard Worker tensor_a = torch.rand(10).cpu() 1607*da0073e9SAndroid Build Coastguard Worker tensor_b = tensor_a.to(device=device) 1608*da0073e9SAndroid Build Coastguard Worker tensor_a[6] = 1.0 1609*da0073e9SAndroid Build Coastguard Worker tensor_b[6] = 1.0 1610*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0) 1611*da0073e9SAndroid Build Coastguard Worker 1612*da0073e9SAndroid Build Coastguard Worker def test_index_limits(self, device): 1613*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/115415 1614*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([], device=device) 1615*da0073e9SAndroid Build Coastguard Worker idx_min = torch.iinfo(torch.int64).min 1616*da0073e9SAndroid Build Coastguard Worker idx_max = torch.iinfo(torch.int64).max 1617*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: t[idx_min]) 1618*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: t[idx_max]) 1619*da0073e9SAndroid Build Coastguard Worker 1620*da0073e9SAndroid Build Coastguard Worker 1621*da0073e9SAndroid Build Coastguard Worker# The tests below are from NumPy test_indexing.py with some modifications to 1622*da0073e9SAndroid Build Coastguard Worker# make them compatible with PyTorch. It's licensed under the BDS license below: 1623*da0073e9SAndroid Build Coastguard Worker# 1624*da0073e9SAndroid Build Coastguard Worker# Copyright (c) 2005-2017, NumPy Developers. 1625*da0073e9SAndroid Build Coastguard Worker# All rights reserved. 1626*da0073e9SAndroid Build Coastguard Worker# 1627*da0073e9SAndroid Build Coastguard Worker# Redistribution and use in source and binary forms, with or without 1628*da0073e9SAndroid Build Coastguard Worker# modification, are permitted provided that the following conditions are 1629*da0073e9SAndroid Build Coastguard Worker# met: 1630*da0073e9SAndroid Build Coastguard Worker# 1631*da0073e9SAndroid Build Coastguard Worker# * Redistributions of source code must retain the above copyright 1632*da0073e9SAndroid Build Coastguard Worker# notice, this list of conditions and the following disclaimer. 1633*da0073e9SAndroid Build Coastguard Worker# 1634*da0073e9SAndroid Build Coastguard Worker# * Redistributions in binary form must reproduce the above 1635*da0073e9SAndroid Build Coastguard Worker# copyright notice, this list of conditions and the following 1636*da0073e9SAndroid Build Coastguard Worker# disclaimer in the documentation and/or other materials provided 1637*da0073e9SAndroid Build Coastguard Worker# with the distribution. 1638*da0073e9SAndroid Build Coastguard Worker# 1639*da0073e9SAndroid Build Coastguard Worker# * Neither the name of the NumPy Developers nor the names of any 1640*da0073e9SAndroid Build Coastguard Worker# contributors may be used to endorse or promote products derived 1641*da0073e9SAndroid Build Coastguard Worker# from this software without specific prior written permission. 1642*da0073e9SAndroid Build Coastguard Worker# 1643*da0073e9SAndroid Build Coastguard Worker# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 1644*da0073e9SAndroid Build Coastguard Worker# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 1645*da0073e9SAndroid Build Coastguard Worker# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 1646*da0073e9SAndroid Build Coastguard Worker# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 1647*da0073e9SAndroid Build Coastguard Worker# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 1648*da0073e9SAndroid Build Coastguard Worker# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 1649*da0073e9SAndroid Build Coastguard Worker# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 1650*da0073e9SAndroid Build Coastguard Worker# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 1651*da0073e9SAndroid Build Coastguard Worker# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 1652*da0073e9SAndroid Build Coastguard Worker# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 1653*da0073e9SAndroid Build Coastguard Worker# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 1654*da0073e9SAndroid Build Coastguard Worker 1655*da0073e9SAndroid Build Coastguard Worker 1656*da0073e9SAndroid Build Coastguard Workerclass NumpyTests(TestCase): 1657*da0073e9SAndroid Build Coastguard Worker def test_index_no_floats(self, device): 1658*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[[5.0]]], device=device) 1659*da0073e9SAndroid Build Coastguard Worker 1660*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[0.0]) 1661*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[0, 0.0]) 1662*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[0.0, 0]) 1663*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[0.0, :]) 1664*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[:, 0.0]) 1665*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[:, 0.0, :]) 1666*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[0.0, :, :]) 1667*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[0, 0, 0.0]) 1668*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[0.0, 0, 0]) 1669*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[0, 0.0, 0]) 1670*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[-1.4]) 1671*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[0, -1.4]) 1672*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[-1.4, 0]) 1673*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[-1.4, :]) 1674*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[:, -1.4]) 1675*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[:, -1.4, :]) 1676*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[-1.4, :, :]) 1677*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[0, 0, -1.4]) 1678*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[-1.4, 0, 0]) 1679*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[0, -1.4, 0]) 1680*da0073e9SAndroid Build Coastguard Worker # self.assertRaises(IndexError, lambda: a[0.0:, 0.0]) 1681*da0073e9SAndroid Build Coastguard Worker # self.assertRaises(IndexError, lambda: a[0.0:, 0.0,:]) 1682*da0073e9SAndroid Build Coastguard Worker 1683*da0073e9SAndroid Build Coastguard Worker def test_none_index(self, device): 1684*da0073e9SAndroid Build Coastguard Worker # `None` index adds newaxis 1685*da0073e9SAndroid Build Coastguard Worker a = tensor([1, 2, 3], device=device) 1686*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[None].dim(), a.dim() + 1) 1687*da0073e9SAndroid Build Coastguard Worker 1688*da0073e9SAndroid Build Coastguard Worker def test_empty_tuple_index(self, device): 1689*da0073e9SAndroid Build Coastguard Worker # Empty tuple index creates a view 1690*da0073e9SAndroid Build Coastguard Worker a = tensor([1, 2, 3], device=device) 1691*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[()], a) 1692*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[()].data_ptr(), a.data_ptr()) 1693*da0073e9SAndroid Build Coastguard Worker 1694*da0073e9SAndroid Build Coastguard Worker def test_empty_fancy_index(self, device): 1695*da0073e9SAndroid Build Coastguard Worker # Empty list index creates an empty array 1696*da0073e9SAndroid Build Coastguard Worker a = tensor([1, 2, 3], device=device) 1697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[[]], torch.tensor([], dtype=torch.long, device=device)) 1698*da0073e9SAndroid Build Coastguard Worker 1699*da0073e9SAndroid Build Coastguard Worker b = tensor([], device=device).long() 1700*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[[]], torch.tensor([], dtype=torch.long, device=device)) 1701*da0073e9SAndroid Build Coastguard Worker 1702*da0073e9SAndroid Build Coastguard Worker b = tensor([], device=device).float() 1703*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[b]) 1704*da0073e9SAndroid Build Coastguard Worker 1705*da0073e9SAndroid Build Coastguard Worker def test_ellipsis_index(self, device): 1706*da0073e9SAndroid Build Coastguard Worker a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device) 1707*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(a[...], a) 1708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[...], a) 1709*da0073e9SAndroid Build Coastguard Worker # `a[...]` was `a` in numpy <1.9. 1710*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[...].data_ptr(), a.data_ptr()) 1711*da0073e9SAndroid Build Coastguard Worker 1712*da0073e9SAndroid Build Coastguard Worker # Slicing with ellipsis can skip an 1713*da0073e9SAndroid Build Coastguard Worker # arbitrary number of dimensions 1714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0, ...], a[0]) 1715*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0, ...], a[0, :]) 1716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[..., 0], a[:, 0]) 1717*da0073e9SAndroid Build Coastguard Worker 1718*da0073e9SAndroid Build Coastguard Worker # In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch 1719*da0073e9SAndroid Build Coastguard Worker # we don't have separate 0-dim arrays and scalars. 1720*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0, ..., 1], torch.tensor(2, device=device)) 1721*da0073e9SAndroid Build Coastguard Worker 1722*da0073e9SAndroid Build Coastguard Worker # Assignment with `(Ellipsis,)` on 0-d arrays 1723*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(1) 1724*da0073e9SAndroid Build Coastguard Worker b[(Ellipsis,)] = 2 1725*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, 2) 1726*da0073e9SAndroid Build Coastguard Worker 1727*da0073e9SAndroid Build Coastguard Worker def test_single_int_index(self, device): 1728*da0073e9SAndroid Build Coastguard Worker # Single integer index selects one row 1729*da0073e9SAndroid Build Coastguard Worker a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device) 1730*da0073e9SAndroid Build Coastguard Worker 1731*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0], [1, 2, 3]) 1732*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[-1], [7, 8, 9]) 1733*da0073e9SAndroid Build Coastguard Worker 1734*da0073e9SAndroid Build Coastguard Worker # Index out of bounds produces IndexError 1735*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, a.__getitem__, 1 << 30) 1736*da0073e9SAndroid Build Coastguard Worker # Index overflow produces Exception NB: different exception type 1737*da0073e9SAndroid Build Coastguard Worker self.assertRaises(Exception, a.__getitem__, 1 << 64) 1738*da0073e9SAndroid Build Coastguard Worker 1739*da0073e9SAndroid Build Coastguard Worker def test_single_bool_index(self, device): 1740*da0073e9SAndroid Build Coastguard Worker # Single boolean index 1741*da0073e9SAndroid Build Coastguard Worker a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device) 1742*da0073e9SAndroid Build Coastguard Worker 1743*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[True], a[None]) 1744*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[False], a[None][0:0]) 1745*da0073e9SAndroid Build Coastguard Worker 1746*da0073e9SAndroid Build Coastguard Worker def test_boolean_shape_mismatch(self, device): 1747*da0073e9SAndroid Build Coastguard Worker arr = torch.ones((5, 4, 3), device=device) 1748*da0073e9SAndroid Build Coastguard Worker 1749*da0073e9SAndroid Build Coastguard Worker index = tensor([True], device=device) 1750*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(IndexError, "mask", lambda: arr[index]) 1751*da0073e9SAndroid Build Coastguard Worker 1752*da0073e9SAndroid Build Coastguard Worker index = tensor([False] * 6, device=device) 1753*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(IndexError, "mask", lambda: arr[index]) 1754*da0073e9SAndroid Build Coastguard Worker 1755*da0073e9SAndroid Build Coastguard Worker index = torch.ByteTensor(4, 4).to(device).zero_() 1756*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(IndexError, "mask", lambda: arr[index]) 1757*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(IndexError, "mask", lambda: arr[(slice(None), index)]) 1758*da0073e9SAndroid Build Coastguard Worker 1759*da0073e9SAndroid Build Coastguard Worker def test_boolean_indexing_onedim(self, device): 1760*da0073e9SAndroid Build Coastguard Worker # Indexing a 2-dimensional array with 1761*da0073e9SAndroid Build Coastguard Worker # boolean array of length one 1762*da0073e9SAndroid Build Coastguard Worker a = tensor([[0.0, 0.0, 0.0]], device=device) 1763*da0073e9SAndroid Build Coastguard Worker b = tensor([True], device=device) 1764*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[b], a) 1765*da0073e9SAndroid Build Coastguard Worker # boolean assignment 1766*da0073e9SAndroid Build Coastguard Worker a[b] = 1.0 1767*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, tensor([[1.0, 1.0, 1.0]], device=device)) 1768*da0073e9SAndroid Build Coastguard Worker 1769*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/127003 1770*da0073e9SAndroid Build Coastguard Worker @xfailIfTorchDynamo 1771*da0073e9SAndroid Build Coastguard Worker def test_boolean_assignment_value_mismatch(self, device): 1772*da0073e9SAndroid Build Coastguard Worker # A boolean assignment should fail when the shape of the values 1773*da0073e9SAndroid Build Coastguard Worker # cannot be broadcast to the subscription. (see also gh-3458) 1774*da0073e9SAndroid Build Coastguard Worker a = torch.arange(0, 4, device=device) 1775*da0073e9SAndroid Build Coastguard Worker 1776*da0073e9SAndroid Build Coastguard Worker def f(a, v): 1777*da0073e9SAndroid Build Coastguard Worker a[a > -1] = tensor(v).to(device) 1778*da0073e9SAndroid Build Coastguard Worker 1779*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(Exception, "shape mismatch", f, a, []) 1780*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(Exception, "shape mismatch", f, a, [1, 2, 3]) 1781*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(Exception, "shape mismatch", f, a[:1], [1, 2, 3]) 1782*da0073e9SAndroid Build Coastguard Worker 1783*da0073e9SAndroid Build Coastguard Worker def test_boolean_indexing_twodim(self, device): 1784*da0073e9SAndroid Build Coastguard Worker # Indexing a 2-dimensional array with 1785*da0073e9SAndroid Build Coastguard Worker # 2-dimensional boolean array 1786*da0073e9SAndroid Build Coastguard Worker a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device) 1787*da0073e9SAndroid Build Coastguard Worker b = tensor( 1788*da0073e9SAndroid Build Coastguard Worker [[True, False, True], [False, True, False], [True, False, True]], 1789*da0073e9SAndroid Build Coastguard Worker device=device, 1790*da0073e9SAndroid Build Coastguard Worker ) 1791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[b], tensor([1, 3, 5, 7, 9], device=device)) 1792*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[b[1]], tensor([[4, 5, 6]], device=device)) 1793*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[b[0]], a[b[2]]) 1794*da0073e9SAndroid Build Coastguard Worker 1795*da0073e9SAndroid Build Coastguard Worker # boolean assignment 1796*da0073e9SAndroid Build Coastguard Worker a[b] = 0 1797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, tensor([[0, 2, 0], [4, 0, 6], [0, 8, 0]], device=device)) 1798*da0073e9SAndroid Build Coastguard Worker 1799*da0073e9SAndroid Build Coastguard Worker def test_boolean_indexing_weirdness(self, device): 1800*da0073e9SAndroid Build Coastguard Worker # Weird boolean indexing things 1801*da0073e9SAndroid Build Coastguard Worker a = torch.ones((2, 3, 4), device=device) 1802*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape) 1803*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1804*da0073e9SAndroid Build Coastguard Worker torch.ones(1, 2, device=device), a[True, [0, 1], True, True, [1], [[2]]] 1805*da0073e9SAndroid Build Coastguard Worker ) 1806*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[False, [0, 1], ...]) 1807*da0073e9SAndroid Build Coastguard Worker 1808*da0073e9SAndroid Build Coastguard Worker def test_boolean_indexing_weirdness_tensors(self, device): 1809*da0073e9SAndroid Build Coastguard Worker # Weird boolean indexing things 1810*da0073e9SAndroid Build Coastguard Worker false = torch.tensor(False, device=device) 1811*da0073e9SAndroid Build Coastguard Worker true = torch.tensor(True, device=device) 1812*da0073e9SAndroid Build Coastguard Worker a = torch.ones((2, 3, 4), device=device) 1813*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape) 1814*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1815*da0073e9SAndroid Build Coastguard Worker torch.ones(1, 2, device=device), a[true, [0, 1], true, true, [1], [[2]]] 1816*da0073e9SAndroid Build Coastguard Worker ) 1817*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: a[false, [0, 1], ...]) 1818*da0073e9SAndroid Build Coastguard Worker 1819*da0073e9SAndroid Build Coastguard Worker def test_boolean_indexing_alldims(self, device): 1820*da0073e9SAndroid Build Coastguard Worker true = torch.tensor(True, device=device) 1821*da0073e9SAndroid Build Coastguard Worker a = torch.ones((2, 3), device=device) 1822*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1, 2, 3), a[True, True].shape) 1823*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1, 2, 3), a[true, true].shape) 1824*da0073e9SAndroid Build Coastguard Worker 1825*da0073e9SAndroid Build Coastguard Worker def test_boolean_list_indexing(self, device): 1826*da0073e9SAndroid Build Coastguard Worker # Indexing a 2-dimensional array with 1827*da0073e9SAndroid Build Coastguard Worker # boolean lists 1828*da0073e9SAndroid Build Coastguard Worker a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device) 1829*da0073e9SAndroid Build Coastguard Worker b = [True, False, False] 1830*da0073e9SAndroid Build Coastguard Worker c = [True, True, False] 1831*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[b], tensor([[1, 2, 3]], device=device)) 1832*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[b, b], tensor([1], device=device)) 1833*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[c], tensor([[1, 2, 3], [4, 5, 6]], device=device)) 1834*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[c, c], tensor([1, 5], device=device)) 1835*da0073e9SAndroid Build Coastguard Worker 1836*da0073e9SAndroid Build Coastguard Worker def test_everything_returns_views(self, device): 1837*da0073e9SAndroid Build Coastguard Worker # Before `...` would return a itself. 1838*da0073e9SAndroid Build Coastguard Worker a = tensor([5], device=device) 1839*da0073e9SAndroid Build Coastguard Worker 1840*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(a, a[()]) 1841*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(a, a[...]) 1842*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(a, a[:]) 1843*da0073e9SAndroid Build Coastguard Worker 1844*da0073e9SAndroid Build Coastguard Worker def test_broaderrors_indexing(self, device): 1845*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(5, 5, device=device) 1846*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1847*da0073e9SAndroid Build Coastguard Worker IndexError, "shape mismatch", a.__getitem__, ([0, 1], [0, 1, 2]) 1848*da0073e9SAndroid Build Coastguard Worker ) 1849*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1850*da0073e9SAndroid Build Coastguard Worker IndexError, "shape mismatch", a.__setitem__, ([0, 1], [0, 1, 2]), 0 1851*da0073e9SAndroid Build Coastguard Worker ) 1852*da0073e9SAndroid Build Coastguard Worker 1853*da0073e9SAndroid Build Coastguard Worker def test_trivial_fancy_out_of_bounds(self, device): 1854*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(5, device=device) 1855*da0073e9SAndroid Build Coastguard Worker ind = torch.ones(20, dtype=torch.int64, device=device) 1856*da0073e9SAndroid Build Coastguard Worker if a.is_cuda: 1857*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("CUDA asserts instead of raising an exception") 1858*da0073e9SAndroid Build Coastguard Worker ind[-1] = 10 1859*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, a.__getitem__, ind) 1860*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, a.__setitem__, ind, 0) 1861*da0073e9SAndroid Build Coastguard Worker ind = torch.ones(20, dtype=torch.int64, device=device) 1862*da0073e9SAndroid Build Coastguard Worker ind[0] = 11 1863*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, a.__getitem__, ind) 1864*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, a.__setitem__, ind, 0) 1865*da0073e9SAndroid Build Coastguard Worker 1866*da0073e9SAndroid Build Coastguard Worker def test_index_is_larger(self, device): 1867*da0073e9SAndroid Build Coastguard Worker # Simple case of fancy index broadcasting of the index. 1868*da0073e9SAndroid Build Coastguard Worker a = torch.zeros((5, 5), device=device) 1869*da0073e9SAndroid Build Coastguard Worker a[[[0], [1], [2]], [0, 1, 2]] = tensor([2.0, 3.0, 4.0], device=device) 1870*da0073e9SAndroid Build Coastguard Worker 1871*da0073e9SAndroid Build Coastguard Worker self.assertTrue((a[:3, :3] == tensor([2.0, 3.0, 4.0], device=device)).all()) 1872*da0073e9SAndroid Build Coastguard Worker 1873*da0073e9SAndroid Build Coastguard Worker def test_broadcast_subspace(self, device): 1874*da0073e9SAndroid Build Coastguard Worker a = torch.zeros((100, 100), device=device) 1875*da0073e9SAndroid Build Coastguard Worker v = torch.arange(0.0, 100, device=device)[:, None] 1876*da0073e9SAndroid Build Coastguard Worker b = torch.arange(99, -1, -1, device=device).long() 1877*da0073e9SAndroid Build Coastguard Worker a[b] = v 1878*da0073e9SAndroid Build Coastguard Worker expected = b.float().unsqueeze(1).expand(100, 100) 1879*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, expected) 1880*da0073e9SAndroid Build Coastguard Worker 1881*da0073e9SAndroid Build Coastguard Worker def test_truncate_leading_1s(self, device): 1882*da0073e9SAndroid Build Coastguard Worker col_max = torch.randn(1, 4) 1883*da0073e9SAndroid Build Coastguard Worker kernel = col_max.T * col_max # [4, 4] tensor 1884*da0073e9SAndroid Build Coastguard Worker kernel2 = kernel.clone() 1885*da0073e9SAndroid Build Coastguard Worker # Set the diagonal 1886*da0073e9SAndroid Build Coastguard Worker kernel[range(len(kernel)), range(len(kernel))] = torch.square(col_max) 1887*da0073e9SAndroid Build Coastguard Worker torch.diagonal(kernel2).copy_(torch.square(col_max.view(4))) 1888*da0073e9SAndroid Build Coastguard Worker self.assertEqual(kernel, kernel2) 1889*da0073e9SAndroid Build Coastguard Worker 1890*da0073e9SAndroid Build Coastguard Worker 1891*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestIndexing, globals(), except_for="meta") 1892*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(NumpyTests, globals(), except_for="meta") 1893*da0073e9SAndroid Build Coastguard Worker 1894*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1895*da0073e9SAndroid Build Coastguard Worker run_tests() 1896