xref: /aosp_15_r20/external/pytorch/test/test_indexing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport operator
4*da0073e9SAndroid Build Coastguard Workerimport random
5*da0073e9SAndroid Build Coastguard Workerimport unittest
6*da0073e9SAndroid Build Coastguard Workerimport warnings
7*da0073e9SAndroid Build Coastguard Workerfrom functools import reduce
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport numpy as np
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerimport torch
12*da0073e9SAndroid Build Coastguard Workerfrom torch import tensor
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
15*da0073e9SAndroid Build Coastguard Worker    dtypes,
16*da0073e9SAndroid Build Coastguard Worker    dtypesIfCPU,
17*da0073e9SAndroid Build Coastguard Worker    dtypesIfCUDA,
18*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
19*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
20*da0073e9SAndroid Build Coastguard Worker    onlyNativeDeviceTypes,
21*da0073e9SAndroid Build Coastguard Worker    skipXLA,
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
24*da0073e9SAndroid Build Coastguard Worker    DeterministicGuard,
25*da0073e9SAndroid Build Coastguard Worker    run_tests,
26*da0073e9SAndroid Build Coastguard Worker    serialTest,
27*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
28*da0073e9SAndroid Build Coastguard Worker    TEST_CUDA,
29*da0073e9SAndroid Build Coastguard Worker    TestCase,
30*da0073e9SAndroid Build Coastguard Worker    xfailIfTorchDynamo,
31*da0073e9SAndroid Build Coastguard Worker)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Workerclass TestIndexing(TestCase):
35*da0073e9SAndroid Build Coastguard Worker    def test_index(self, device):
36*da0073e9SAndroid Build Coastguard Worker        def consec(size, start=1):
37*da0073e9SAndroid Build Coastguard Worker            sequence = torch.ones(torch.tensor(size).prod(0)).cumsum(0)
38*da0073e9SAndroid Build Coastguard Worker            sequence.add_(start - 1)
39*da0073e9SAndroid Build Coastguard Worker            return sequence.view(*size)
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker        reference = consec((3, 3, 3)).to(device)
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker        # empty tensor indexing
44*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
45*da0073e9SAndroid Build Coastguard Worker            reference[torch.LongTensor().to(device)], reference.new(0, 3, 3)
46*da0073e9SAndroid Build Coastguard Worker        )
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[0], consec((3, 3)), atol=0, rtol=0)
49*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[1], consec((3, 3), 10), atol=0, rtol=0)
50*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[2], consec((3, 3), 19), atol=0, rtol=0)
51*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[0, 1], consec((3,), 4), atol=0, rtol=0)
52*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[0:2], consec((2, 3, 3)), atol=0, rtol=0)
53*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[2, 2, 2], 27, atol=0, rtol=0)
54*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[:], consec((3, 3, 3)), atol=0, rtol=0)
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker        # indexing with Ellipsis
57*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
58*da0073e9SAndroid Build Coastguard Worker            reference[..., 2],
59*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[3.0, 6.0, 9.0], [12.0, 15.0, 18.0], [21.0, 24.0, 27.0]]),
60*da0073e9SAndroid Build Coastguard Worker            atol=0,
61*da0073e9SAndroid Build Coastguard Worker            rtol=0,
62*da0073e9SAndroid Build Coastguard Worker        )
63*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
64*da0073e9SAndroid Build Coastguard Worker            reference[0, ..., 2], torch.tensor([3.0, 6.0, 9.0]), atol=0, rtol=0
65*da0073e9SAndroid Build Coastguard Worker        )
66*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[..., 2], reference[:, :, 2], atol=0, rtol=0)
67*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[0, ..., 2], reference[0, :, 2], atol=0, rtol=0)
68*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[0, 2, ...], reference[0, 2], atol=0, rtol=0)
69*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[..., 2, 2, 2], 27, atol=0, rtol=0)
70*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[2, ..., 2, 2], 27, atol=0, rtol=0)
71*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[2, 2, ..., 2], 27, atol=0, rtol=0)
72*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[2, 2, 2, ...], 27, atol=0, rtol=0)
73*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[...], reference, atol=0, rtol=0)
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker        reference_5d = consec((3, 3, 3, 3, 3)).to(device)
76*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
77*da0073e9SAndroid Build Coastguard Worker            reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], atol=0, rtol=0
78*da0073e9SAndroid Build Coastguard Worker        )
79*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
80*da0073e9SAndroid Build Coastguard Worker            reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], atol=0, rtol=0
81*da0073e9SAndroid Build Coastguard Worker        )
82*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
83*da0073e9SAndroid Build Coastguard Worker            reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], atol=0, rtol=0
84*da0073e9SAndroid Build Coastguard Worker        )
85*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference_5d[...], reference_5d, atol=0, rtol=0)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker        # LongTensor indexing
88*da0073e9SAndroid Build Coastguard Worker        reference = consec((5, 5, 5)).to(device)
89*da0073e9SAndroid Build Coastguard Worker        idx = torch.LongTensor([2, 4]).to(device)
90*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]]))
91*da0073e9SAndroid Build Coastguard Worker        # TODO: enable one indexing is implemented like in numpy
92*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]]))
93*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1])
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        # None indexing
96*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[2, None], reference[2].unsqueeze(0))
97*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
98*da0073e9SAndroid Build Coastguard Worker            reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0)
99*da0073e9SAndroid Build Coastguard Worker        )
100*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[2:4, None], reference[2:4].unsqueeze(1))
101*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
102*da0073e9SAndroid Build Coastguard Worker            reference[None, 2, None, None],
103*da0073e9SAndroid Build Coastguard Worker            reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0),
104*da0073e9SAndroid Build Coastguard Worker        )
105*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
106*da0073e9SAndroid Build Coastguard Worker            reference[None, 2:5, None, None],
107*da0073e9SAndroid Build Coastguard Worker            reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2),
108*da0073e9SAndroid Build Coastguard Worker        )
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker        # indexing 0-length slice
111*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(0, 5, 5), reference[slice(0)])
112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(0, 5), reference[slice(0), 2])
113*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(0, 5), reference[2, slice(0)])
114*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([]), reference[2, 1:1, 2])
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker        # indexing with step
117*da0073e9SAndroid Build Coastguard Worker        reference = consec((10, 10, 10)).to(device)
118*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[1:5:2], torch.stack([reference[1], reference[3]], 0))
119*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
120*da0073e9SAndroid Build Coastguard Worker            reference[1:6:2], torch.stack([reference[1], reference[3], reference[5]], 0)
121*da0073e9SAndroid Build Coastguard Worker        )
122*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[1:9:4], torch.stack([reference[1], reference[5]], 0))
123*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
124*da0073e9SAndroid Build Coastguard Worker            reference[2:4, 1:5:2],
125*da0073e9SAndroid Build Coastguard Worker            torch.stack([reference[2:4, 1], reference[2:4, 3]], 1),
126*da0073e9SAndroid Build Coastguard Worker        )
127*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
128*da0073e9SAndroid Build Coastguard Worker            reference[3, 1:6:2],
129*da0073e9SAndroid Build Coastguard Worker            torch.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0),
130*da0073e9SAndroid Build Coastguard Worker        )
131*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
132*da0073e9SAndroid Build Coastguard Worker            reference[None, 2, 1:9:4],
133*da0073e9SAndroid Build Coastguard Worker            torch.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0),
134*da0073e9SAndroid Build Coastguard Worker        )
135*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
136*da0073e9SAndroid Build Coastguard Worker            reference[:, 2, 1:6:2],
137*da0073e9SAndroid Build Coastguard Worker            torch.stack(
138*da0073e9SAndroid Build Coastguard Worker                [reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1
139*da0073e9SAndroid Build Coastguard Worker            ),
140*da0073e9SAndroid Build Coastguard Worker        )
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        lst = [list(range(i, i + 10)) for i in range(0, 100, 10)]
143*da0073e9SAndroid Build Coastguard Worker        tensor = torch.DoubleTensor(lst).to(device)
144*da0073e9SAndroid Build Coastguard Worker        for _i in range(100):
145*da0073e9SAndroid Build Coastguard Worker            idx1_start = random.randrange(10)
146*da0073e9SAndroid Build Coastguard Worker            idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1)
147*da0073e9SAndroid Build Coastguard Worker            idx1_step = random.randrange(1, 8)
148*da0073e9SAndroid Build Coastguard Worker            idx1 = slice(idx1_start, idx1_end, idx1_step)
149*da0073e9SAndroid Build Coastguard Worker            if random.randrange(2) == 0:
150*da0073e9SAndroid Build Coastguard Worker                idx2_start = random.randrange(10)
151*da0073e9SAndroid Build Coastguard Worker                idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1)
152*da0073e9SAndroid Build Coastguard Worker                idx2_step = random.randrange(1, 8)
153*da0073e9SAndroid Build Coastguard Worker                idx2 = slice(idx2_start, idx2_end, idx2_step)
154*da0073e9SAndroid Build Coastguard Worker                lst_indexed = [l[idx2] for l in lst[idx1]]
155*da0073e9SAndroid Build Coastguard Worker                tensor_indexed = tensor[idx1, idx2]
156*da0073e9SAndroid Build Coastguard Worker            else:
157*da0073e9SAndroid Build Coastguard Worker                lst_indexed = lst[idx1]
158*da0073e9SAndroid Build Coastguard Worker                tensor_indexed = tensor[idx1]
159*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: reference[1:9:0])
162*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: reference[1:9:-1])
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])
165*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])
166*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: reference[0.0])
169*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: reference[0.0:2.0])
170*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])
171*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])
172*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])
173*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker        def delitem():
176*da0073e9SAndroid Build Coastguard Worker            del reference[0]
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, delitem)
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
181*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.double)
182*da0073e9SAndroid Build Coastguard Worker    def test_advancedindex(self, device, dtype):
183*da0073e9SAndroid Build Coastguard Worker        # Tests for Integer Array Indexing, Part I - Purely integer array
184*da0073e9SAndroid Build Coastguard Worker        # indexing
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker        def consec(size, start=1):
187*da0073e9SAndroid Build Coastguard Worker            # Creates the sequence in float since CPU half doesn't support the
188*da0073e9SAndroid Build Coastguard Worker            # needed operations. Converts to dtype before returning.
189*da0073e9SAndroid Build Coastguard Worker            numel = reduce(operator.mul, size, 1)
190*da0073e9SAndroid Build Coastguard Worker            sequence = torch.ones(numel, dtype=torch.float, device=device).cumsum(0)
191*da0073e9SAndroid Build Coastguard Worker            sequence.add_(start - 1)
192*da0073e9SAndroid Build Coastguard Worker            return sequence.view(*size).to(dtype=dtype)
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker        # pick a random valid indexer type
195*da0073e9SAndroid Build Coastguard Worker        def ri(indices):
196*da0073e9SAndroid Build Coastguard Worker            choice = random.randint(0, 2)
197*da0073e9SAndroid Build Coastguard Worker            if choice == 0:
198*da0073e9SAndroid Build Coastguard Worker                return torch.LongTensor(indices).to(device)
199*da0073e9SAndroid Build Coastguard Worker            elif choice == 1:
200*da0073e9SAndroid Build Coastguard Worker                return list(indices)
201*da0073e9SAndroid Build Coastguard Worker            else:
202*da0073e9SAndroid Build Coastguard Worker                return tuple(indices)
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker        def validate_indexing(x):
205*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x[[0]], consec((1,)))
206*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x[ri([0]),], consec((1,)))
207*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x[ri([3]),], consec((1,), 4))
208*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x[[2, 3, 4]], consec((3,), 3))
209*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x[ri([2, 3, 4]),], consec((3,), 3))
210*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
211*da0073e9SAndroid Build Coastguard Worker                x[ri([0, 2, 4]),], torch.tensor([1, 3, 5], dtype=dtype, device=device)
212*da0073e9SAndroid Build Coastguard Worker            )
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        def validate_setting(x):
215*da0073e9SAndroid Build Coastguard Worker            x[[0]] = -2
216*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x[[0]], torch.tensor([-2], dtype=dtype, device=device))
217*da0073e9SAndroid Build Coastguard Worker            x[[0]] = -1
218*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
219*da0073e9SAndroid Build Coastguard Worker                x[ri([0]),], torch.tensor([-1], dtype=dtype, device=device)
220*da0073e9SAndroid Build Coastguard Worker            )
221*da0073e9SAndroid Build Coastguard Worker            x[[2, 3, 4]] = 4
222*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
223*da0073e9SAndroid Build Coastguard Worker                x[[2, 3, 4]], torch.tensor([4, 4, 4], dtype=dtype, device=device)
224*da0073e9SAndroid Build Coastguard Worker            )
225*da0073e9SAndroid Build Coastguard Worker            x[ri([2, 3, 4]),] = 3
226*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
227*da0073e9SAndroid Build Coastguard Worker                x[ri([2, 3, 4]),], torch.tensor([3, 3, 3], dtype=dtype, device=device)
228*da0073e9SAndroid Build Coastguard Worker            )
229*da0073e9SAndroid Build Coastguard Worker            x[ri([0, 2, 4]),] = torch.tensor([5, 4, 3], dtype=dtype, device=device)
230*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
231*da0073e9SAndroid Build Coastguard Worker                x[ri([0, 2, 4]),], torch.tensor([5, 4, 3], dtype=dtype, device=device)
232*da0073e9SAndroid Build Coastguard Worker            )
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        # Only validates indexing and setting for halfs
235*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.half:
236*da0073e9SAndroid Build Coastguard Worker            reference = consec((10,))
237*da0073e9SAndroid Build Coastguard Worker            validate_indexing(reference)
238*da0073e9SAndroid Build Coastguard Worker            validate_setting(reference)
239*da0073e9SAndroid Build Coastguard Worker            return
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        # Case 1: Purely Integer Array Indexing
242*da0073e9SAndroid Build Coastguard Worker        reference = consec((10,))
243*da0073e9SAndroid Build Coastguard Worker        validate_indexing(reference)
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker        # setting values
246*da0073e9SAndroid Build Coastguard Worker        validate_setting(reference)
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        # Tensor with stride != 1
249*da0073e9SAndroid Build Coastguard Worker        # strided is [1, 3, 5, 7]
250*da0073e9SAndroid Build Coastguard Worker        reference = consec((10,))
251*da0073e9SAndroid Build Coastguard Worker        strided = torch.tensor((), dtype=dtype, device=device)
252*da0073e9SAndroid Build Coastguard Worker        strided.set_(
253*da0073e9SAndroid Build Coastguard Worker            reference.storage(), storage_offset=0, size=torch.Size([4]), stride=[2]
254*da0073e9SAndroid Build Coastguard Worker        )
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(strided[[0]], torch.tensor([1], dtype=dtype, device=device))
257*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
258*da0073e9SAndroid Build Coastguard Worker            strided[ri([0]),], torch.tensor([1], dtype=dtype, device=device)
259*da0073e9SAndroid Build Coastguard Worker        )
260*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
261*da0073e9SAndroid Build Coastguard Worker            strided[ri([3]),], torch.tensor([7], dtype=dtype, device=device)
262*da0073e9SAndroid Build Coastguard Worker        )
263*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
264*da0073e9SAndroid Build Coastguard Worker            strided[[1, 2]], torch.tensor([3, 5], dtype=dtype, device=device)
265*da0073e9SAndroid Build Coastguard Worker        )
266*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
267*da0073e9SAndroid Build Coastguard Worker            strided[ri([1, 2]),], torch.tensor([3, 5], dtype=dtype, device=device)
268*da0073e9SAndroid Build Coastguard Worker        )
269*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
270*da0073e9SAndroid Build Coastguard Worker            strided[ri([[2, 1], [0, 3]]),],
271*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[5, 3], [1, 7]], dtype=dtype, device=device),
272*da0073e9SAndroid Build Coastguard Worker        )
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker        # stride is [4, 8]
275*da0073e9SAndroid Build Coastguard Worker        strided = torch.tensor((), dtype=dtype, device=device)
276*da0073e9SAndroid Build Coastguard Worker        strided.set_(
277*da0073e9SAndroid Build Coastguard Worker            reference.storage(), storage_offset=4, size=torch.Size([2]), stride=[4]
278*da0073e9SAndroid Build Coastguard Worker        )
279*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(strided[[0]], torch.tensor([5], dtype=dtype, device=device))
280*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
281*da0073e9SAndroid Build Coastguard Worker            strided[ri([0]),], torch.tensor([5], dtype=dtype, device=device)
282*da0073e9SAndroid Build Coastguard Worker        )
283*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
284*da0073e9SAndroid Build Coastguard Worker            strided[ri([1]),], torch.tensor([9], dtype=dtype, device=device)
285*da0073e9SAndroid Build Coastguard Worker        )
286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
287*da0073e9SAndroid Build Coastguard Worker            strided[[0, 1]], torch.tensor([5, 9], dtype=dtype, device=device)
288*da0073e9SAndroid Build Coastguard Worker        )
289*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
290*da0073e9SAndroid Build Coastguard Worker            strided[ri([0, 1]),], torch.tensor([5, 9], dtype=dtype, device=device)
291*da0073e9SAndroid Build Coastguard Worker        )
292*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
293*da0073e9SAndroid Build Coastguard Worker            strided[ri([[0, 1], [1, 0]]),],
294*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[5, 9], [9, 5]], dtype=dtype, device=device),
295*da0073e9SAndroid Build Coastguard Worker        )
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker        # reference is 1 2
298*da0073e9SAndroid Build Coastguard Worker        #              3 4
299*da0073e9SAndroid Build Coastguard Worker        #              5 6
300*da0073e9SAndroid Build Coastguard Worker        reference = consec((3, 2))
301*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
302*da0073e9SAndroid Build Coastguard Worker            reference[ri([0, 1, 2]), ri([0])],
303*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 3, 5], dtype=dtype, device=device),
304*da0073e9SAndroid Build Coastguard Worker        )
305*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
306*da0073e9SAndroid Build Coastguard Worker            reference[ri([0, 1, 2]), ri([1])],
307*da0073e9SAndroid Build Coastguard Worker            torch.tensor([2, 4, 6], dtype=dtype, device=device),
308*da0073e9SAndroid Build Coastguard Worker        )
309*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[ri([0]), ri([0])], consec((1,)))
310*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6))
311*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
312*da0073e9SAndroid Build Coastguard Worker            reference[[ri([0, 0]), ri([0, 1])]],
313*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 2], dtype=dtype, device=device),
314*da0073e9SAndroid Build Coastguard Worker        )
315*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
316*da0073e9SAndroid Build Coastguard Worker            reference[[ri([0, 1, 1, 0, 2]), ri([1])]],
317*da0073e9SAndroid Build Coastguard Worker            torch.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device),
318*da0073e9SAndroid Build Coastguard Worker        )
319*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
320*da0073e9SAndroid Build Coastguard Worker            reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
321*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 2, 3, 3], dtype=dtype, device=device),
322*da0073e9SAndroid Build Coastguard Worker        )
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker        rows = ri([[0, 0], [1, 2]])
325*da0073e9SAndroid Build Coastguard Worker        columns = ([0],)
326*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
327*da0073e9SAndroid Build Coastguard Worker            reference[rows, columns],
328*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[1, 1], [3, 5]], dtype=dtype, device=device),
329*da0073e9SAndroid Build Coastguard Worker        )
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker        rows = ri([[0, 0], [1, 2]])
332*da0073e9SAndroid Build Coastguard Worker        columns = ri([1, 0])
333*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
334*da0073e9SAndroid Build Coastguard Worker            reference[rows, columns],
335*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[2, 1], [4, 5]], dtype=dtype, device=device),
336*da0073e9SAndroid Build Coastguard Worker        )
337*da0073e9SAndroid Build Coastguard Worker        rows = ri([[0, 0], [1, 2]])
338*da0073e9SAndroid Build Coastguard Worker        columns = ri([[0, 1], [1, 0]])
339*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
340*da0073e9SAndroid Build Coastguard Worker            reference[rows, columns],
341*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device),
342*da0073e9SAndroid Build Coastguard Worker        )
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker        # setting values
345*da0073e9SAndroid Build Coastguard Worker        reference[ri([0]), ri([1])] = -1
346*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
347*da0073e9SAndroid Build Coastguard Worker            reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device)
348*da0073e9SAndroid Build Coastguard Worker        )
349*da0073e9SAndroid Build Coastguard Worker        reference[ri([0, 1, 2]), ri([0])] = torch.tensor(
350*da0073e9SAndroid Build Coastguard Worker            [-1, 2, -4], dtype=dtype, device=device
351*da0073e9SAndroid Build Coastguard Worker        )
352*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
353*da0073e9SAndroid Build Coastguard Worker            reference[ri([0, 1, 2]), ri([0])],
354*da0073e9SAndroid Build Coastguard Worker            torch.tensor([-1, 2, -4], dtype=dtype, device=device),
355*da0073e9SAndroid Build Coastguard Worker        )
356*da0073e9SAndroid Build Coastguard Worker        reference[rows, columns] = torch.tensor(
357*da0073e9SAndroid Build Coastguard Worker            [[4, 6], [2, 3]], dtype=dtype, device=device
358*da0073e9SAndroid Build Coastguard Worker        )
359*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
360*da0073e9SAndroid Build Coastguard Worker            reference[rows, columns],
361*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device),
362*da0073e9SAndroid Build Coastguard Worker        )
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker        # Verify still works with Transposed (i.e. non-contiguous) Tensors
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        reference = torch.tensor(
367*da0073e9SAndroid Build Coastguard Worker            [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype, device=device
368*da0073e9SAndroid Build Coastguard Worker        ).t_()
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker        # Transposed: [[0, 4, 8],
371*da0073e9SAndroid Build Coastguard Worker        #              [1, 5, 9],
372*da0073e9SAndroid Build Coastguard Worker        #              [2, 6, 10],
373*da0073e9SAndroid Build Coastguard Worker        #              [3, 7, 11]]
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
376*da0073e9SAndroid Build Coastguard Worker            reference[ri([0, 1, 2]), ri([0])],
377*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 1, 2], dtype=dtype, device=device),
378*da0073e9SAndroid Build Coastguard Worker        )
379*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
380*da0073e9SAndroid Build Coastguard Worker            reference[ri([0, 1, 2]), ri([1])],
381*da0073e9SAndroid Build Coastguard Worker            torch.tensor([4, 5, 6], dtype=dtype, device=device),
382*da0073e9SAndroid Build Coastguard Worker        )
383*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
384*da0073e9SAndroid Build Coastguard Worker            reference[ri([0]), ri([0])], torch.tensor([0], dtype=dtype, device=device)
385*da0073e9SAndroid Build Coastguard Worker        )
386*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
387*da0073e9SAndroid Build Coastguard Worker            reference[ri([2]), ri([1])], torch.tensor([6], dtype=dtype, device=device)
388*da0073e9SAndroid Build Coastguard Worker        )
389*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
390*da0073e9SAndroid Build Coastguard Worker            reference[[ri([0, 0]), ri([0, 1])]],
391*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 4], dtype=dtype, device=device),
392*da0073e9SAndroid Build Coastguard Worker        )
393*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
394*da0073e9SAndroid Build Coastguard Worker            reference[[ri([0, 1, 1, 0, 3]), ri([1])]],
395*da0073e9SAndroid Build Coastguard Worker            torch.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device),
396*da0073e9SAndroid Build Coastguard Worker        )
397*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
398*da0073e9SAndroid Build Coastguard Worker            reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
399*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 4, 1, 1], dtype=dtype, device=device),
400*da0073e9SAndroid Build Coastguard Worker        )
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker        rows = ri([[0, 0], [1, 2]])
403*da0073e9SAndroid Build Coastguard Worker        columns = ([0],)
404*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
405*da0073e9SAndroid Build Coastguard Worker            reference[rows, columns],
406*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[0, 0], [1, 2]], dtype=dtype, device=device),
407*da0073e9SAndroid Build Coastguard Worker        )
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker        rows = ri([[0, 0], [1, 2]])
410*da0073e9SAndroid Build Coastguard Worker        columns = ri([1, 0])
411*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
412*da0073e9SAndroid Build Coastguard Worker            reference[rows, columns],
413*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[4, 0], [5, 2]], dtype=dtype, device=device),
414*da0073e9SAndroid Build Coastguard Worker        )
415*da0073e9SAndroid Build Coastguard Worker        rows = ri([[0, 0], [1, 3]])
416*da0073e9SAndroid Build Coastguard Worker        columns = ri([[0, 1], [1, 2]])
417*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
418*da0073e9SAndroid Build Coastguard Worker            reference[rows, columns],
419*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[0, 4], [5, 11]], dtype=dtype, device=device),
420*da0073e9SAndroid Build Coastguard Worker        )
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker        # setting values
423*da0073e9SAndroid Build Coastguard Worker        reference[ri([0]), ri([1])] = -1
424*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
425*da0073e9SAndroid Build Coastguard Worker            reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device)
426*da0073e9SAndroid Build Coastguard Worker        )
427*da0073e9SAndroid Build Coastguard Worker        reference[ri([0, 1, 2]), ri([0])] = torch.tensor(
428*da0073e9SAndroid Build Coastguard Worker            [-1, 2, -4], dtype=dtype, device=device
429*da0073e9SAndroid Build Coastguard Worker        )
430*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
431*da0073e9SAndroid Build Coastguard Worker            reference[ri([0, 1, 2]), ri([0])],
432*da0073e9SAndroid Build Coastguard Worker            torch.tensor([-1, 2, -4], dtype=dtype, device=device),
433*da0073e9SAndroid Build Coastguard Worker        )
434*da0073e9SAndroid Build Coastguard Worker        reference[rows, columns] = torch.tensor(
435*da0073e9SAndroid Build Coastguard Worker            [[4, 6], [2, 3]], dtype=dtype, device=device
436*da0073e9SAndroid Build Coastguard Worker        )
437*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
438*da0073e9SAndroid Build Coastguard Worker            reference[rows, columns],
439*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device),
440*da0073e9SAndroid Build Coastguard Worker        )
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker        # stride != 1
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker        # strided is [[1 3 5 7],
445*da0073e9SAndroid Build Coastguard Worker        #             [9 11 13 15]]
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker        reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
448*da0073e9SAndroid Build Coastguard Worker        strided = torch.tensor((), dtype=dtype, device=device)
449*da0073e9SAndroid Build Coastguard Worker        strided.set_(reference.storage(), 1, size=torch.Size([2, 4]), stride=[8, 2])
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
452*da0073e9SAndroid Build Coastguard Worker            strided[ri([0, 1]), ri([0])],
453*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 9], dtype=dtype, device=device),
454*da0073e9SAndroid Build Coastguard Worker        )
455*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
456*da0073e9SAndroid Build Coastguard Worker            strided[ri([0, 1]), ri([1])],
457*da0073e9SAndroid Build Coastguard Worker            torch.tensor([3, 11], dtype=dtype, device=device),
458*da0073e9SAndroid Build Coastguard Worker        )
459*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
460*da0073e9SAndroid Build Coastguard Worker            strided[ri([0]), ri([0])], torch.tensor([1], dtype=dtype, device=device)
461*da0073e9SAndroid Build Coastguard Worker        )
462*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
463*da0073e9SAndroid Build Coastguard Worker            strided[ri([1]), ri([3])], torch.tensor([15], dtype=dtype, device=device)
464*da0073e9SAndroid Build Coastguard Worker        )
465*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
466*da0073e9SAndroid Build Coastguard Worker            strided[[ri([0, 0]), ri([0, 3])]],
467*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 7], dtype=dtype, device=device),
468*da0073e9SAndroid Build Coastguard Worker        )
469*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
470*da0073e9SAndroid Build Coastguard Worker            strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
471*da0073e9SAndroid Build Coastguard Worker            torch.tensor([9, 11, 11, 9, 15], dtype=dtype, device=device),
472*da0073e9SAndroid Build Coastguard Worker        )
473*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
474*da0073e9SAndroid Build Coastguard Worker            strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
475*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 3, 9, 9], dtype=dtype, device=device),
476*da0073e9SAndroid Build Coastguard Worker        )
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker        rows = ri([[0, 0], [1, 1]])
479*da0073e9SAndroid Build Coastguard Worker        columns = ([0],)
480*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
481*da0073e9SAndroid Build Coastguard Worker            strided[rows, columns],
482*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[1, 1], [9, 9]], dtype=dtype, device=device),
483*da0073e9SAndroid Build Coastguard Worker        )
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker        rows = ri([[0, 1], [1, 0]])
486*da0073e9SAndroid Build Coastguard Worker        columns = ri([1, 2])
487*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
488*da0073e9SAndroid Build Coastguard Worker            strided[rows, columns],
489*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[3, 13], [11, 5]], dtype=dtype, device=device),
490*da0073e9SAndroid Build Coastguard Worker        )
491*da0073e9SAndroid Build Coastguard Worker        rows = ri([[0, 0], [1, 1]])
492*da0073e9SAndroid Build Coastguard Worker        columns = ri([[0, 1], [1, 2]])
493*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
494*da0073e9SAndroid Build Coastguard Worker            strided[rows, columns],
495*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[1, 3], [11, 13]], dtype=dtype, device=device),
496*da0073e9SAndroid Build Coastguard Worker        )
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker        # setting values
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker        # strided is [[10, 11],
501*da0073e9SAndroid Build Coastguard Worker        #             [17, 18]]
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker        reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
504*da0073e9SAndroid Build Coastguard Worker        strided = torch.tensor((), dtype=dtype, device=device)
505*da0073e9SAndroid Build Coastguard Worker        strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1])
506*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
507*da0073e9SAndroid Build Coastguard Worker            strided[ri([0]), ri([1])], torch.tensor([11], dtype=dtype, device=device)
508*da0073e9SAndroid Build Coastguard Worker        )
509*da0073e9SAndroid Build Coastguard Worker        strided[ri([0]), ri([1])] = -1
510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
511*da0073e9SAndroid Build Coastguard Worker            strided[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device)
512*da0073e9SAndroid Build Coastguard Worker        )
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker        reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
515*da0073e9SAndroid Build Coastguard Worker        strided = torch.tensor((), dtype=dtype, device=device)
516*da0073e9SAndroid Build Coastguard Worker        strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1])
517*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
518*da0073e9SAndroid Build Coastguard Worker            strided[ri([0, 1]), ri([1, 0])],
519*da0073e9SAndroid Build Coastguard Worker            torch.tensor([11, 17], dtype=dtype, device=device),
520*da0073e9SAndroid Build Coastguard Worker        )
521*da0073e9SAndroid Build Coastguard Worker        strided[ri([0, 1]), ri([1, 0])] = torch.tensor(
522*da0073e9SAndroid Build Coastguard Worker            [-1, 2], dtype=dtype, device=device
523*da0073e9SAndroid Build Coastguard Worker        )
524*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
525*da0073e9SAndroid Build Coastguard Worker            strided[ri([0, 1]), ri([1, 0])],
526*da0073e9SAndroid Build Coastguard Worker            torch.tensor([-1, 2], dtype=dtype, device=device),
527*da0073e9SAndroid Build Coastguard Worker        )
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker        reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
530*da0073e9SAndroid Build Coastguard Worker        strided = torch.tensor((), dtype=dtype, device=device)
531*da0073e9SAndroid Build Coastguard Worker        strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1])
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker        rows = ri([[0], [1]])
534*da0073e9SAndroid Build Coastguard Worker        columns = ri([[0, 1], [0, 1]])
535*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
536*da0073e9SAndroid Build Coastguard Worker            strided[rows, columns],
537*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[10, 11], [17, 18]], dtype=dtype, device=device),
538*da0073e9SAndroid Build Coastguard Worker        )
539*da0073e9SAndroid Build Coastguard Worker        strided[rows, columns] = torch.tensor(
540*da0073e9SAndroid Build Coastguard Worker            [[4, 6], [2, 3]], dtype=dtype, device=device
541*da0073e9SAndroid Build Coastguard Worker        )
542*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
543*da0073e9SAndroid Build Coastguard Worker            strided[rows, columns],
544*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device),
545*da0073e9SAndroid Build Coastguard Worker        )
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker        # Tests using less than the number of dims, and ellipsis
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Worker        # reference is 1 2
550*da0073e9SAndroid Build Coastguard Worker        #              3 4
551*da0073e9SAndroid Build Coastguard Worker        #              5 6
552*da0073e9SAndroid Build Coastguard Worker        reference = consec((3, 2))
553*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
554*da0073e9SAndroid Build Coastguard Worker            reference[ri([0, 2]),],
555*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device),
556*da0073e9SAndroid Build Coastguard Worker        )
557*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
558*da0073e9SAndroid Build Coastguard Worker            reference[ri([1]), ...], torch.tensor([[3, 4]], dtype=dtype, device=device)
559*da0073e9SAndroid Build Coastguard Worker        )
560*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
561*da0073e9SAndroid Build Coastguard Worker            reference[..., ri([1])],
562*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[2], [4], [6]], dtype=dtype, device=device),
563*da0073e9SAndroid Build Coastguard Worker        )
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker        # verify too many indices fails
566*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
567*da0073e9SAndroid Build Coastguard Worker            reference[ri([1]), ri([0, 2]), ri([3])]
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker        # test invalid index fails
570*da0073e9SAndroid Build Coastguard Worker        reference = torch.empty(10, dtype=dtype, device=device)
571*da0073e9SAndroid Build Coastguard Worker        # can't test cuda because it is a device assert
572*da0073e9SAndroid Build Coastguard Worker        if not reference.is_cuda:
573*da0073e9SAndroid Build Coastguard Worker            for err_idx in (10, -11):
574*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(IndexError, r"out of"):
575*da0073e9SAndroid Build Coastguard Worker                    reference[err_idx]
576*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(IndexError, r"out of"):
577*da0073e9SAndroid Build Coastguard Worker                    reference[torch.LongTensor([err_idx]).to(device)]
578*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(IndexError, r"out of"):
579*da0073e9SAndroid Build Coastguard Worker                    reference[[err_idx]]
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker        def tensor_indices_to_np(tensor, indices):
582*da0073e9SAndroid Build Coastguard Worker            # convert the Torch Tensor to a numpy array
583*da0073e9SAndroid Build Coastguard Worker            tensor = tensor.to(device="cpu")
584*da0073e9SAndroid Build Coastguard Worker            npt = tensor.numpy()
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker            # convert indices
587*da0073e9SAndroid Build Coastguard Worker            idxs = tuple(
588*da0073e9SAndroid Build Coastguard Worker                i.tolist() if isinstance(i, torch.LongTensor) else i for i in indices
589*da0073e9SAndroid Build Coastguard Worker            )
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker            return npt, idxs
592*da0073e9SAndroid Build Coastguard Worker
593*da0073e9SAndroid Build Coastguard Worker        def get_numpy(tensor, indices):
594*da0073e9SAndroid Build Coastguard Worker            npt, idxs = tensor_indices_to_np(tensor, indices)
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker            # index and return as a Torch Tensor
597*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(npt[idxs], dtype=dtype, device=device)
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker        def set_numpy(tensor, indices, value):
600*da0073e9SAndroid Build Coastguard Worker            if not isinstance(value, int):
601*da0073e9SAndroid Build Coastguard Worker                if self.device_type != "cpu":
602*da0073e9SAndroid Build Coastguard Worker                    value = value.cpu()
603*da0073e9SAndroid Build Coastguard Worker                value = value.numpy()
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker            npt, idxs = tensor_indices_to_np(tensor, indices)
606*da0073e9SAndroid Build Coastguard Worker            npt[idxs] = value
607*da0073e9SAndroid Build Coastguard Worker            return npt
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker        def assert_get_eq(tensor, indexer):
610*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor[indexer], get_numpy(tensor, indexer))
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker        def assert_set_eq(tensor, indexer, val):
613*da0073e9SAndroid Build Coastguard Worker            pyt = tensor.clone()
614*da0073e9SAndroid Build Coastguard Worker            numt = tensor.clone()
615*da0073e9SAndroid Build Coastguard Worker            pyt[indexer] = val
616*da0073e9SAndroid Build Coastguard Worker            numt = torch.tensor(
617*da0073e9SAndroid Build Coastguard Worker                set_numpy(numt, indexer, val), dtype=dtype, device=device
618*da0073e9SAndroid Build Coastguard Worker            )
619*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(pyt, numt)
620*da0073e9SAndroid Build Coastguard Worker
621*da0073e9SAndroid Build Coastguard Worker        def assert_backward_eq(tensor, indexer):
622*da0073e9SAndroid Build Coastguard Worker            cpu = tensor.float().clone().detach().requires_grad_(True)
623*da0073e9SAndroid Build Coastguard Worker            outcpu = cpu[indexer]
624*da0073e9SAndroid Build Coastguard Worker            gOcpu = torch.rand_like(outcpu)
625*da0073e9SAndroid Build Coastguard Worker            outcpu.backward(gOcpu)
626*da0073e9SAndroid Build Coastguard Worker            dev = cpu.to(device).detach().requires_grad_(True)
627*da0073e9SAndroid Build Coastguard Worker            outdev = dev[indexer]
628*da0073e9SAndroid Build Coastguard Worker            outdev.backward(gOcpu.to(device))
629*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu.grad, dev.grad)
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker        def get_set_tensor(indexed, indexer):
632*da0073e9SAndroid Build Coastguard Worker            set_size = indexed[indexer].size()
633*da0073e9SAndroid Build Coastguard Worker            set_count = indexed[indexer].numel()
634*da0073e9SAndroid Build Coastguard Worker            set_tensor = torch.randperm(set_count).view(set_size).double().to(device)
635*da0073e9SAndroid Build Coastguard Worker            return set_tensor
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker        # Tensor is  0  1  2  3  4
638*da0073e9SAndroid Build Coastguard Worker        #            5  6  7  8  9
639*da0073e9SAndroid Build Coastguard Worker        #           10 11 12 13 14
640*da0073e9SAndroid Build Coastguard Worker        #           15 16 17 18 19
641*da0073e9SAndroid Build Coastguard Worker        reference = torch.arange(0.0, 20, dtype=dtype, device=device).view(4, 5)
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker        indices_to_test = [
644*da0073e9SAndroid Build Coastguard Worker            # grab the second, fourth columns
645*da0073e9SAndroid Build Coastguard Worker            [slice(None), [1, 3]],
646*da0073e9SAndroid Build Coastguard Worker            # first, third rows,
647*da0073e9SAndroid Build Coastguard Worker            [[0, 2], slice(None)],
648*da0073e9SAndroid Build Coastguard Worker            # weird shape
649*da0073e9SAndroid Build Coastguard Worker            [slice(None), [[0, 1], [2, 3]]],
650*da0073e9SAndroid Build Coastguard Worker            # negatives
651*da0073e9SAndroid Build Coastguard Worker            [[-1], [0]],
652*da0073e9SAndroid Build Coastguard Worker            [[0, 2], [-1]],
653*da0073e9SAndroid Build Coastguard Worker            [slice(None), [-1]],
654*da0073e9SAndroid Build Coastguard Worker        ]
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker        # only test dupes on gets
657*da0073e9SAndroid Build Coastguard Worker        get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker        for indexer in get_indices_to_test:
660*da0073e9SAndroid Build Coastguard Worker            assert_get_eq(reference, indexer)
661*da0073e9SAndroid Build Coastguard Worker            if self.device_type != "cpu":
662*da0073e9SAndroid Build Coastguard Worker                assert_backward_eq(reference, indexer)
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker        for indexer in indices_to_test:
665*da0073e9SAndroid Build Coastguard Worker            assert_set_eq(reference, indexer, 44)
666*da0073e9SAndroid Build Coastguard Worker            assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Worker        reference = torch.arange(0.0, 160, dtype=dtype, device=device).view(4, 8, 5)
669*da0073e9SAndroid Build Coastguard Worker
670*da0073e9SAndroid Build Coastguard Worker        indices_to_test = [
671*da0073e9SAndroid Build Coastguard Worker            [slice(None), slice(None), [0, 3, 4]],
672*da0073e9SAndroid Build Coastguard Worker            [slice(None), [2, 4, 5, 7], slice(None)],
673*da0073e9SAndroid Build Coastguard Worker            [[2, 3], slice(None), slice(None)],
674*da0073e9SAndroid Build Coastguard Worker            [slice(None), [0, 2, 3], [1, 3, 4]],
675*da0073e9SAndroid Build Coastguard Worker            [slice(None), [0], [1, 2, 4]],
676*da0073e9SAndroid Build Coastguard Worker            [slice(None), [0, 1, 3], [4]],
677*da0073e9SAndroid Build Coastguard Worker            [slice(None), [[0, 1], [1, 0]], [[2, 3]]],
678*da0073e9SAndroid Build Coastguard Worker            [slice(None), [[0, 1], [2, 3]], [[0]]],
679*da0073e9SAndroid Build Coastguard Worker            [slice(None), [[5, 6]], [[0, 3], [4, 4]]],
680*da0073e9SAndroid Build Coastguard Worker            [[0, 2, 3], [1, 3, 4], slice(None)],
681*da0073e9SAndroid Build Coastguard Worker            [[0], [1, 2, 4], slice(None)],
682*da0073e9SAndroid Build Coastguard Worker            [[0, 1, 3], [4], slice(None)],
683*da0073e9SAndroid Build Coastguard Worker            [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
684*da0073e9SAndroid Build Coastguard Worker            [[[0, 1], [1, 0]], [[2, 3]], slice(None)],
685*da0073e9SAndroid Build Coastguard Worker            [[[0, 1], [2, 3]], [[0]], slice(None)],
686*da0073e9SAndroid Build Coastguard Worker            [[[2, 1]], [[0, 3], [4, 4]], slice(None)],
687*da0073e9SAndroid Build Coastguard Worker            [[[2]], [[0, 3], [4, 1]], slice(None)],
688*da0073e9SAndroid Build Coastguard Worker            # non-contiguous indexing subspace
689*da0073e9SAndroid Build Coastguard Worker            [[0, 2, 3], slice(None), [1, 3, 4]],
690*da0073e9SAndroid Build Coastguard Worker            # [...]
691*da0073e9SAndroid Build Coastguard Worker            # less dim, ellipsis
692*da0073e9SAndroid Build Coastguard Worker            [[0, 2]],
693*da0073e9SAndroid Build Coastguard Worker            [[0, 2], slice(None)],
694*da0073e9SAndroid Build Coastguard Worker            [[0, 2], Ellipsis],
695*da0073e9SAndroid Build Coastguard Worker            [[0, 2], slice(None), Ellipsis],
696*da0073e9SAndroid Build Coastguard Worker            [[0, 2], Ellipsis, slice(None)],
697*da0073e9SAndroid Build Coastguard Worker            [[0, 2], [1, 3]],
698*da0073e9SAndroid Build Coastguard Worker            [[0, 2], [1, 3], Ellipsis],
699*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, [1, 3], [2, 3]],
700*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, [2, 3, 4]],
701*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, slice(None), [2, 3, 4]],
702*da0073e9SAndroid Build Coastguard Worker            [slice(None), Ellipsis, [2, 3, 4]],
703*da0073e9SAndroid Build Coastguard Worker            # ellipsis counts for nothing
704*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, slice(None), slice(None), [0, 3, 4]],
705*da0073e9SAndroid Build Coastguard Worker            [slice(None), Ellipsis, slice(None), [0, 3, 4]],
706*da0073e9SAndroid Build Coastguard Worker            [slice(None), slice(None), Ellipsis, [0, 3, 4]],
707*da0073e9SAndroid Build Coastguard Worker            [slice(None), slice(None), [0, 3, 4], Ellipsis],
708*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
709*da0073e9SAndroid Build Coastguard Worker            [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
710*da0073e9SAndroid Build Coastguard Worker            [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
711*da0073e9SAndroid Build Coastguard Worker        ]
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Worker        for indexer in indices_to_test:
714*da0073e9SAndroid Build Coastguard Worker            assert_get_eq(reference, indexer)
715*da0073e9SAndroid Build Coastguard Worker            assert_set_eq(reference, indexer, 212)
716*da0073e9SAndroid Build Coastguard Worker            assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
717*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.is_available():
718*da0073e9SAndroid Build Coastguard Worker                assert_backward_eq(reference, indexer)
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker        reference = torch.arange(0.0, 1296, dtype=dtype, device=device).view(3, 9, 8, 6)
721*da0073e9SAndroid Build Coastguard Worker
722*da0073e9SAndroid Build Coastguard Worker        indices_to_test = [
723*da0073e9SAndroid Build Coastguard Worker            [slice(None), slice(None), slice(None), [0, 3, 4]],
724*da0073e9SAndroid Build Coastguard Worker            [slice(None), slice(None), [2, 4, 5, 7], slice(None)],
725*da0073e9SAndroid Build Coastguard Worker            [slice(None), [2, 3], slice(None), slice(None)],
726*da0073e9SAndroid Build Coastguard Worker            [[1, 2], slice(None), slice(None), slice(None)],
727*da0073e9SAndroid Build Coastguard Worker            [slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
728*da0073e9SAndroid Build Coastguard Worker            [slice(None), slice(None), [0], [1, 2, 4]],
729*da0073e9SAndroid Build Coastguard Worker            [slice(None), slice(None), [0, 1, 3], [4]],
730*da0073e9SAndroid Build Coastguard Worker            [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
731*da0073e9SAndroid Build Coastguard Worker            [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
732*da0073e9SAndroid Build Coastguard Worker            [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
733*da0073e9SAndroid Build Coastguard Worker            [slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
734*da0073e9SAndroid Build Coastguard Worker            [slice(None), [0], [1, 2, 4], slice(None)],
735*da0073e9SAndroid Build Coastguard Worker            [slice(None), [0, 1, 3], [4], slice(None)],
736*da0073e9SAndroid Build Coastguard Worker            [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
737*da0073e9SAndroid Build Coastguard Worker            [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
738*da0073e9SAndroid Build Coastguard Worker            [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
739*da0073e9SAndroid Build Coastguard Worker            [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
740*da0073e9SAndroid Build Coastguard Worker            [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
741*da0073e9SAndroid Build Coastguard Worker            [[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
742*da0073e9SAndroid Build Coastguard Worker            [[0], [1, 2, 4], slice(None), slice(None)],
743*da0073e9SAndroid Build Coastguard Worker            [[0, 1, 2], [4], slice(None), slice(None)],
744*da0073e9SAndroid Build Coastguard Worker            [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
745*da0073e9SAndroid Build Coastguard Worker            [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
746*da0073e9SAndroid Build Coastguard Worker            [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
747*da0073e9SAndroid Build Coastguard Worker            [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
748*da0073e9SAndroid Build Coastguard Worker            [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
749*da0073e9SAndroid Build Coastguard Worker            [slice(None), [2, 3, 4], [1, 3, 4], [4]],
750*da0073e9SAndroid Build Coastguard Worker            [slice(None), [0, 1, 3], [4], [1, 3, 4]],
751*da0073e9SAndroid Build Coastguard Worker            [slice(None), [6], [0, 2, 3], [1, 3, 4]],
752*da0073e9SAndroid Build Coastguard Worker            [slice(None), [2, 3, 5], [3], [4]],
753*da0073e9SAndroid Build Coastguard Worker            [slice(None), [0], [4], [1, 3, 4]],
754*da0073e9SAndroid Build Coastguard Worker            [slice(None), [6], [0, 2, 3], [1]],
755*da0073e9SAndroid Build Coastguard Worker            [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
756*da0073e9SAndroid Build Coastguard Worker            [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
757*da0073e9SAndroid Build Coastguard Worker            [[2, 0, 1], [1, 2, 3], [4], slice(None)],
758*da0073e9SAndroid Build Coastguard Worker            [[0, 1, 2], [4], [1, 3, 4], slice(None)],
759*da0073e9SAndroid Build Coastguard Worker            [[0], [0, 2, 3], [1, 3, 4], slice(None)],
760*da0073e9SAndroid Build Coastguard Worker            [[0, 2, 1], [3], [4], slice(None)],
761*da0073e9SAndroid Build Coastguard Worker            [[0], [4], [1, 3, 4], slice(None)],
762*da0073e9SAndroid Build Coastguard Worker            [[1], [0, 2, 3], [1], slice(None)],
763*da0073e9SAndroid Build Coastguard Worker            [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
764*da0073e9SAndroid Build Coastguard Worker            # less dim, ellipsis
765*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, [0, 3, 4]],
766*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, slice(None), [0, 3, 4]],
767*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, slice(None), slice(None), [0, 3, 4]],
768*da0073e9SAndroid Build Coastguard Worker            [slice(None), Ellipsis, [0, 3, 4]],
769*da0073e9SAndroid Build Coastguard Worker            [slice(None), slice(None), Ellipsis, [0, 3, 4]],
770*da0073e9SAndroid Build Coastguard Worker            [slice(None), [0, 2, 3], [1, 3, 4]],
771*da0073e9SAndroid Build Coastguard Worker            [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
772*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
773*da0073e9SAndroid Build Coastguard Worker            [[0], [1, 2, 4]],
774*da0073e9SAndroid Build Coastguard Worker            [[0], [1, 2, 4], slice(None)],
775*da0073e9SAndroid Build Coastguard Worker            [[0], [1, 2, 4], Ellipsis],
776*da0073e9SAndroid Build Coastguard Worker            [[0], [1, 2, 4], Ellipsis, slice(None)],
777*da0073e9SAndroid Build Coastguard Worker            [[1]],
778*da0073e9SAndroid Build Coastguard Worker            [[0, 2, 1], [3], [4]],
779*da0073e9SAndroid Build Coastguard Worker            [[0, 2, 1], [3], [4], slice(None)],
780*da0073e9SAndroid Build Coastguard Worker            [[0, 2, 1], [3], [4], Ellipsis],
781*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, [0, 2, 1], [3], [4]],
782*da0073e9SAndroid Build Coastguard Worker        ]
783*da0073e9SAndroid Build Coastguard Worker
784*da0073e9SAndroid Build Coastguard Worker        for indexer in indices_to_test:
785*da0073e9SAndroid Build Coastguard Worker            assert_get_eq(reference, indexer)
786*da0073e9SAndroid Build Coastguard Worker            assert_set_eq(reference, indexer, 1333)
787*da0073e9SAndroid Build Coastguard Worker            assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
788*da0073e9SAndroid Build Coastguard Worker        indices_to_test += [
789*da0073e9SAndroid Build Coastguard Worker            [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
790*da0073e9SAndroid Build Coastguard Worker            [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
791*da0073e9SAndroid Build Coastguard Worker        ]
792*da0073e9SAndroid Build Coastguard Worker        for indexer in indices_to_test:
793*da0073e9SAndroid Build Coastguard Worker            assert_get_eq(reference, indexer)
794*da0073e9SAndroid Build Coastguard Worker            assert_set_eq(reference, indexer, 1333)
795*da0073e9SAndroid Build Coastguard Worker            if self.device_type != "cpu":
796*da0073e9SAndroid Build Coastguard Worker                assert_backward_eq(reference, indexer)
797*da0073e9SAndroid Build Coastguard Worker
798*da0073e9SAndroid Build Coastguard Worker    def test_advancedindex_big(self, device):
799*da0073e9SAndroid Build Coastguard Worker        reference = torch.arange(0, 123344, dtype=torch.int, device=device)
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
802*da0073e9SAndroid Build Coastguard Worker            reference[[0, 123, 44488, 68807, 123343],],
803*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int),
804*da0073e9SAndroid Build Coastguard Worker        )
805*da0073e9SAndroid Build Coastguard Worker
806*da0073e9SAndroid Build Coastguard Worker    def test_set_item_to_scalar_tensor(self, device):
807*da0073e9SAndroid Build Coastguard Worker        m = random.randint(1, 10)
808*da0073e9SAndroid Build Coastguard Worker        n = random.randint(1, 10)
809*da0073e9SAndroid Build Coastguard Worker        z = torch.randn([m, n], device=device)
810*da0073e9SAndroid Build Coastguard Worker        a = 1.0
811*da0073e9SAndroid Build Coastguard Worker        w = torch.tensor(a, requires_grad=True, device=device)
812*da0073e9SAndroid Build Coastguard Worker        z[:, 0] = w
813*da0073e9SAndroid Build Coastguard Worker        z.sum().backward()
814*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(w.grad, m * a)
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker    def test_single_int(self, device):
817*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
818*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[4].shape, (7, 3))
819*da0073e9SAndroid Build Coastguard Worker
820*da0073e9SAndroid Build Coastguard Worker    def test_multiple_int(self, device):
821*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
822*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[4].shape, (7, 3))
823*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[4, :, 1].shape, (7,))
824*da0073e9SAndroid Build Coastguard Worker
825*da0073e9SAndroid Build Coastguard Worker    def test_none(self, device):
826*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
827*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[None].shape, (1, 5, 7, 3))
828*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
829*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
830*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker    def test_step(self, device):
833*da0073e9SAndroid Build Coastguard Worker        v = torch.arange(10, device=device)
834*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[::1], v)
835*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
836*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
837*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[::11].tolist(), [0])
838*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
839*da0073e9SAndroid Build Coastguard Worker
840*da0073e9SAndroid Build Coastguard Worker    def test_step_assignment(self, device):
841*da0073e9SAndroid Build Coastguard Worker        v = torch.zeros(4, 4, device=device)
842*da0073e9SAndroid Build Coastguard Worker        v[0, 1::2] = torch.tensor([3.0, 4.0], device=device)
843*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
844*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[1:].sum(), 0)
845*da0073e9SAndroid Build Coastguard Worker
846*da0073e9SAndroid Build Coastguard Worker    def test_bool_indices(self, device):
847*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
848*da0073e9SAndroid Build Coastguard Worker        boolIndices = torch.tensor(
849*da0073e9SAndroid Build Coastguard Worker            [True, False, True, True, False], dtype=torch.bool, device=device
850*da0073e9SAndroid Build Coastguard Worker        )
851*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[boolIndices].shape, (3, 7, 3))
852*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]]))
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([True, False, True], dtype=torch.bool, device=device)
855*da0073e9SAndroid Build Coastguard Worker        boolIndices = torch.tensor(
856*da0073e9SAndroid Build Coastguard Worker            [True, False, False], dtype=torch.bool, device=device
857*da0073e9SAndroid Build Coastguard Worker        )
858*da0073e9SAndroid Build Coastguard Worker        uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device)
859*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
860*da0073e9SAndroid Build Coastguard Worker            v1 = v[boolIndices]
861*da0073e9SAndroid Build Coastguard Worker            v2 = v[uint8Indices]
862*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v1.shape, v2.shape)
863*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v1, v2)
864*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
865*da0073e9SAndroid Build Coastguard Worker                v[boolIndices], tensor([True], dtype=torch.bool, device=device)
866*da0073e9SAndroid Build Coastguard Worker            )
867*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Worker    def test_bool_indices_accumulate(self, device):
870*da0073e9SAndroid Build Coastguard Worker        mask = torch.zeros(size=(10,), dtype=torch.bool, device=device)
871*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(size=(10, 10), device=device)
872*da0073e9SAndroid Build Coastguard Worker        y.index_put_((mask,), y[mask], accumulate=True)
873*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, torch.ones(size=(10, 10), device=device))
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker    def test_multiple_bool_indices(self, device):
876*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
877*da0073e9SAndroid Build Coastguard Worker        # note: these broadcast together and are transposed to the first dim
878*da0073e9SAndroid Build Coastguard Worker        mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device)
879*da0073e9SAndroid Build Coastguard Worker        mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
880*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker    def test_byte_mask(self, device):
883*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
884*da0073e9SAndroid Build Coastguard Worker        mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
885*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
886*da0073e9SAndroid Build Coastguard Worker            res = v[mask]
887*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.shape, (3, 7, 3))
888*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, torch.stack([v[0], v[2], v[3]]))
889*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
890*da0073e9SAndroid Build Coastguard Worker
891*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([1.0], device=device)
892*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[v == 0], torch.tensor([], device=device))
893*da0073e9SAndroid Build Coastguard Worker
894*da0073e9SAndroid Build Coastguard Worker    def test_byte_mask_accumulate(self, device):
895*da0073e9SAndroid Build Coastguard Worker        mask = torch.zeros(size=(10,), dtype=torch.uint8, device=device)
896*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(size=(10, 10), device=device)
897*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
898*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
899*da0073e9SAndroid Build Coastguard Worker            y.index_put_((mask,), y[mask], accumulate=True)
900*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, torch.ones(size=(10, 10), device=device))
901*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 2)
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo(
904*da0073e9SAndroid Build Coastguard Worker        "This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472"
905*da0073e9SAndroid Build Coastguard Worker    )
906*da0073e9SAndroid Build Coastguard Worker    @serialTest(TEST_CUDA)
907*da0073e9SAndroid Build Coastguard Worker    def test_index_put_accumulate_large_tensor(self, device):
908*da0073e9SAndroid Build Coastguard Worker        # This test is for tensors with number of elements >= INT_MAX (2^31 - 1).
909*da0073e9SAndroid Build Coastguard Worker        N = (1 << 31) + 5
910*da0073e9SAndroid Build Coastguard Worker        dt = torch.int8
911*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(N, dtype=dt, device=device)
912*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor(
913*da0073e9SAndroid Build Coastguard Worker            [-2, 0, -2, -1, 0, -1, 1], device=device, dtype=torch.long
914*da0073e9SAndroid Build Coastguard Worker        )
915*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([6, 5, 6, 6, 5, 7, 11], dtype=dt, device=device)
916*da0073e9SAndroid Build Coastguard Worker
917*da0073e9SAndroid Build Coastguard Worker        a.index_put_((indices,), values, accumulate=True)
918*da0073e9SAndroid Build Coastguard Worker
919*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0], 11)
920*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[1], 12)
921*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[2], 1)
922*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[-3], 1)
923*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[-2], 13)
924*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[-1], 14)
925*da0073e9SAndroid Build Coastguard Worker
926*da0073e9SAndroid Build Coastguard Worker        a = torch.ones((2, N), dtype=dt, device=device)
927*da0073e9SAndroid Build Coastguard Worker        indices0 = torch.tensor([0, -1, 0, 1], device=device, dtype=torch.long)
928*da0073e9SAndroid Build Coastguard Worker        indices1 = torch.tensor([-2, -1, 0, 1], device=device, dtype=torch.long)
929*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([12, 13, 10, 11], dtype=dt, device=device)
930*da0073e9SAndroid Build Coastguard Worker
931*da0073e9SAndroid Build Coastguard Worker        a.index_put_((indices0, indices1), values, accumulate=True)
932*da0073e9SAndroid Build Coastguard Worker
933*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0, 0], 11)
934*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0, 1], 1)
935*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[1, 0], 1)
936*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[1, 1], 12)
937*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[:, 2], torch.ones(2, dtype=torch.int8))
938*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[:, -3], torch.ones(2, dtype=torch.int8))
939*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0, -2], 13)
940*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[1, -2], 1)
941*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[-1, -1], 14)
942*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0, -1], 1)
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
945*da0073e9SAndroid Build Coastguard Worker    def test_index_put_accumulate_expanded_values(self, device):
946*da0073e9SAndroid Build Coastguard Worker        # checks the issue with cuda: https://github.com/pytorch/pytorch/issues/39227
947*da0073e9SAndroid Build Coastguard Worker        # and verifies consistency with CPU result
948*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros((5, 2))
949*da0073e9SAndroid Build Coastguard Worker        t_dev = t.to(device)
950*da0073e9SAndroid Build Coastguard Worker        indices = [torch.tensor([0, 1, 2, 3]), torch.tensor([1])]
951*da0073e9SAndroid Build Coastguard Worker        indices_dev = [i.to(device) for i in indices]
952*da0073e9SAndroid Build Coastguard Worker        values0d = torch.tensor(1.0)
953*da0073e9SAndroid Build Coastguard Worker        values1d = torch.tensor([1.0])
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker        out_cuda = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True)
956*da0073e9SAndroid Build Coastguard Worker        out_cpu = t.index_put_(indices, values0d, accumulate=True)
957*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_cuda.cpu(), out_cpu)
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Worker        out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
960*da0073e9SAndroid Build Coastguard Worker        out_cpu = t.index_put_(indices, values1d, accumulate=True)
961*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_cuda.cpu(), out_cpu)
962*da0073e9SAndroid Build Coastguard Worker
963*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros(4, 3, 2)
964*da0073e9SAndroid Build Coastguard Worker        t_dev = t.to(device)
965*da0073e9SAndroid Build Coastguard Worker
966*da0073e9SAndroid Build Coastguard Worker        indices = [
967*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0]),
968*da0073e9SAndroid Build Coastguard Worker            torch.arange(3)[:, None],
969*da0073e9SAndroid Build Coastguard Worker            torch.arange(2)[None, :],
970*da0073e9SAndroid Build Coastguard Worker        ]
971*da0073e9SAndroid Build Coastguard Worker        indices_dev = [i.to(device) for i in indices]
972*da0073e9SAndroid Build Coastguard Worker        values1d = torch.tensor([-1.0, -2.0])
973*da0073e9SAndroid Build Coastguard Worker        values2d = torch.tensor([[-1.0, -2.0]])
974*da0073e9SAndroid Build Coastguard Worker
975*da0073e9SAndroid Build Coastguard Worker        out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
976*da0073e9SAndroid Build Coastguard Worker        out_cpu = t.index_put_(indices, values1d, accumulate=True)
977*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_cuda.cpu(), out_cpu)
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker        out_cuda = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True)
980*da0073e9SAndroid Build Coastguard Worker        out_cpu = t.index_put_(indices, values2d, accumulate=True)
981*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_cuda.cpu(), out_cpu)
982*da0073e9SAndroid Build Coastguard Worker
983*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
984*da0073e9SAndroid Build Coastguard Worker    def test_index_put_accumulate_non_contiguous(self, device):
985*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros((5, 2, 2))
986*da0073e9SAndroid Build Coastguard Worker        t_dev = t.to(device)
987*da0073e9SAndroid Build Coastguard Worker        t1 = t_dev[:, 0, :]
988*da0073e9SAndroid Build Coastguard Worker        t2 = t[:, 0, :]
989*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not t1.is_contiguous())
990*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not t2.is_contiguous())
991*da0073e9SAndroid Build Coastguard Worker
992*da0073e9SAndroid Build Coastguard Worker        indices = [torch.tensor([0, 1])]
993*da0073e9SAndroid Build Coastguard Worker        indices_dev = [i.to(device) for i in indices]
994*da0073e9SAndroid Build Coastguard Worker        value = torch.randn(2, 2)
995*da0073e9SAndroid Build Coastguard Worker        out_cuda = t1.index_put_(indices_dev, value.to(device), accumulate=True)
996*da0073e9SAndroid Build Coastguard Worker        out_cpu = t2.index_put_(indices, value, accumulate=True)
997*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not t1.is_contiguous())
998*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not t2.is_contiguous())
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_cuda.cpu(), out_cpu)
1001*da0073e9SAndroid Build Coastguard Worker
1002*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1003*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1004*da0073e9SAndroid Build Coastguard Worker    def test_index_put_accumulate_with_optional_tensors(self, device):
1005*da0073e9SAndroid Build Coastguard Worker        # TODO: replace with a better solution.
1006*da0073e9SAndroid Build Coastguard Worker        # Currently, here using torchscript to put None into indices.
1007*da0073e9SAndroid Build Coastguard Worker        # on C++ it gives indices as a list of 2 optional tensors: first is null and
1008*da0073e9SAndroid Build Coastguard Worker        # the second is a valid tensor.
1009*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1010*da0073e9SAndroid Build Coastguard Worker        def func(x, i, v):
1011*da0073e9SAndroid Build Coastguard Worker            idx = [None, i]
1012*da0073e9SAndroid Build Coastguard Worker            x.index_put_(idx, v, accumulate=True)
1013*da0073e9SAndroid Build Coastguard Worker            return x
1014*da0073e9SAndroid Build Coastguard Worker
1015*da0073e9SAndroid Build Coastguard Worker        n = 4
1016*da0073e9SAndroid Build Coastguard Worker        t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
1017*da0073e9SAndroid Build Coastguard Worker        t_dev = t.to(device)
1018*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor([1, 0])
1019*da0073e9SAndroid Build Coastguard Worker        indices_dev = indices.to(device)
1020*da0073e9SAndroid Build Coastguard Worker        value0d = torch.tensor(10.0)
1021*da0073e9SAndroid Build Coastguard Worker        value1d = torch.tensor([1.0, 2.0])
1022*da0073e9SAndroid Build Coastguard Worker
1023*da0073e9SAndroid Build Coastguard Worker        out_cuda = func(t_dev, indices_dev, value0d.cuda())
1024*da0073e9SAndroid Build Coastguard Worker        out_cpu = func(t, indices, value0d)
1025*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_cuda.cpu(), out_cpu)
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Worker        out_cuda = func(t_dev, indices_dev, value1d.cuda())
1028*da0073e9SAndroid Build Coastguard Worker        out_cpu = func(t, indices, value1d)
1029*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_cuda.cpu(), out_cpu)
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1032*da0073e9SAndroid Build Coastguard Worker    def test_index_put_accumulate_duplicate_indices(self, device):
1033*da0073e9SAndroid Build Coastguard Worker        for i in range(1, 512):
1034*da0073e9SAndroid Build Coastguard Worker            # generate indices by random walk, this will create indices with
1035*da0073e9SAndroid Build Coastguard Worker            # lots of duplicates interleaved with each other
1036*da0073e9SAndroid Build Coastguard Worker            delta = torch.empty(i, dtype=torch.double, device=device).uniform_(-1, 1)
1037*da0073e9SAndroid Build Coastguard Worker            indices = delta.cumsum(0).long()
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(indices.abs().max() + 1, device=device)
1040*da0073e9SAndroid Build Coastguard Worker            values = torch.randn(indices.size(0), device=device)
1041*da0073e9SAndroid Build Coastguard Worker            output = input.index_put((indices,), values, accumulate=True)
1042*da0073e9SAndroid Build Coastguard Worker
1043*da0073e9SAndroid Build Coastguard Worker            input_list = input.tolist()
1044*da0073e9SAndroid Build Coastguard Worker            indices_list = indices.tolist()
1045*da0073e9SAndroid Build Coastguard Worker            values_list = values.tolist()
1046*da0073e9SAndroid Build Coastguard Worker            for i, v in zip(indices_list, values_list):
1047*da0073e9SAndroid Build Coastguard Worker                input_list[i] += v
1048*da0073e9SAndroid Build Coastguard Worker
1049*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output, input_list)
1050*da0073e9SAndroid Build Coastguard Worker
1051*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1052*da0073e9SAndroid Build Coastguard Worker    def test_index_ind_dtype(self, device):
1053*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, device=device)
1054*da0073e9SAndroid Build Coastguard Worker        ind_long = torch.randint(4, (4,), dtype=torch.long, device=device)
1055*da0073e9SAndroid Build Coastguard Worker        ind_int = ind_long.int()
1056*da0073e9SAndroid Build Coastguard Worker        src = torch.randn(4, device=device)
1057*da0073e9SAndroid Build Coastguard Worker        ref = x[ind_long, ind_long]
1058*da0073e9SAndroid Build Coastguard Worker        res = x[ind_int, ind_int]
1059*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1060*da0073e9SAndroid Build Coastguard Worker        ref = x[ind_long, :]
1061*da0073e9SAndroid Build Coastguard Worker        res = x[ind_int, :]
1062*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1063*da0073e9SAndroid Build Coastguard Worker        ref = x[:, ind_long]
1064*da0073e9SAndroid Build Coastguard Worker        res = x[:, ind_int]
1065*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1066*da0073e9SAndroid Build Coastguard Worker        # no repeating indices for index_put
1067*da0073e9SAndroid Build Coastguard Worker        ind_long = torch.arange(4, dtype=torch.long, device=device)
1068*da0073e9SAndroid Build Coastguard Worker        ind_int = ind_long.int()
1069*da0073e9SAndroid Build Coastguard Worker        for accum in (True, False):
1070*da0073e9SAndroid Build Coastguard Worker            inp_ref = x.clone()
1071*da0073e9SAndroid Build Coastguard Worker            inp_res = x.clone()
1072*da0073e9SAndroid Build Coastguard Worker            torch.index_put_(inp_ref, (ind_long, ind_long), src, accum)
1073*da0073e9SAndroid Build Coastguard Worker            torch.index_put_(inp_res, (ind_int, ind_int), src, accum)
1074*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(inp_ref, inp_res)
1075*da0073e9SAndroid Build Coastguard Worker
1076*da0073e9SAndroid Build Coastguard Worker    @skipXLA
1077*da0073e9SAndroid Build Coastguard Worker    def test_index_put_accumulate_empty(self, device):
1078*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/94667
1079*da0073e9SAndroid Build Coastguard Worker        input = torch.rand([], dtype=torch.float32, device=device)
1080*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
1081*da0073e9SAndroid Build Coastguard Worker            input.index_put([], torch.tensor([1.0], device=device), True)
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker    def test_multiple_byte_mask(self, device):
1084*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
1085*da0073e9SAndroid Build Coastguard Worker        # note: these broadcast together and are transposed to the first dim
1086*da0073e9SAndroid Build Coastguard Worker        mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
1087*da0073e9SAndroid Build Coastguard Worker        mask2 = torch.ByteTensor([1, 1, 1]).to(device)
1088*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
1089*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
1090*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
1091*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 2)
1092*da0073e9SAndroid Build Coastguard Worker
1093*da0073e9SAndroid Build Coastguard Worker    def test_byte_mask2d(self, device):
1094*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
1095*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(5, 7, device=device)
1096*da0073e9SAndroid Build Coastguard Worker        num_ones = (c > 0).sum()
1097*da0073e9SAndroid Build Coastguard Worker        r = v[c > 0]
1098*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.shape, (num_ones, 3))
1099*da0073e9SAndroid Build Coastguard Worker
1100*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1101*da0073e9SAndroid Build Coastguard Worker    def test_jit_indexing(self, device):
1102*da0073e9SAndroid Build Coastguard Worker        def fn1(x):
1103*da0073e9SAndroid Build Coastguard Worker            x[x < 50] = 1.0
1104*da0073e9SAndroid Build Coastguard Worker            return x
1105*da0073e9SAndroid Build Coastguard Worker
1106*da0073e9SAndroid Build Coastguard Worker        def fn2(x):
1107*da0073e9SAndroid Build Coastguard Worker            x[0:50] = 1.0
1108*da0073e9SAndroid Build Coastguard Worker            return x
1109*da0073e9SAndroid Build Coastguard Worker
1110*da0073e9SAndroid Build Coastguard Worker        scripted_fn1 = torch.jit.script(fn1)
1111*da0073e9SAndroid Build Coastguard Worker        scripted_fn2 = torch.jit.script(fn2)
1112*da0073e9SAndroid Build Coastguard Worker        data = torch.arange(100, device=device, dtype=torch.float)
1113*da0073e9SAndroid Build Coastguard Worker        out = scripted_fn1(data.detach().clone())
1114*da0073e9SAndroid Build Coastguard Worker        ref = torch.tensor(
1115*da0073e9SAndroid Build Coastguard Worker            np.concatenate((np.ones(50), np.arange(50, 100))),
1116*da0073e9SAndroid Build Coastguard Worker            device=device,
1117*da0073e9SAndroid Build Coastguard Worker            dtype=torch.float,
1118*da0073e9SAndroid Build Coastguard Worker        )
1119*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, ref)
1120*da0073e9SAndroid Build Coastguard Worker        out = scripted_fn2(data.detach().clone())
1121*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, ref)
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Worker    def test_int_indices(self, device):
1124*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
1125*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
1126*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
1127*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Worker    @dtypes(
1130*da0073e9SAndroid Build Coastguard Worker        torch.cfloat, torch.cdouble, torch.float, torch.bfloat16, torch.long, torch.bool
1131*da0073e9SAndroid Build Coastguard Worker    )
1132*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(
1133*da0073e9SAndroid Build Coastguard Worker        torch.cfloat, torch.cdouble, torch.float, torch.long, torch.bool, torch.bfloat16
1134*da0073e9SAndroid Build Coastguard Worker    )
1135*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(
1136*da0073e9SAndroid Build Coastguard Worker        torch.cfloat,
1137*da0073e9SAndroid Build Coastguard Worker        torch.cdouble,
1138*da0073e9SAndroid Build Coastguard Worker        torch.half,
1139*da0073e9SAndroid Build Coastguard Worker        torch.long,
1140*da0073e9SAndroid Build Coastguard Worker        torch.bool,
1141*da0073e9SAndroid Build Coastguard Worker        torch.bfloat16,
1142*da0073e9SAndroid Build Coastguard Worker        torch.float8_e5m2,
1143*da0073e9SAndroid Build Coastguard Worker        torch.float8_e4m3fn,
1144*da0073e9SAndroid Build Coastguard Worker    )
1145*da0073e9SAndroid Build Coastguard Worker    def test_index_put_src_datatype(self, device, dtype):
1146*da0073e9SAndroid Build Coastguard Worker        src = torch.ones(3, 2, 4, device=device, dtype=dtype)
1147*da0073e9SAndroid Build Coastguard Worker        vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
1148*da0073e9SAndroid Build Coastguard Worker        indices = (torch.tensor([0, 2, 1]),)
1149*da0073e9SAndroid Build Coastguard Worker        res = src.index_put_(indices, vals, accumulate=True)
1150*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res.shape, src.shape)
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.bfloat16, torch.long, torch.bool)
1153*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(torch.float, torch.long, torch.bfloat16, torch.bool)
1154*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.long, torch.bfloat16, torch.bool)
1155*da0073e9SAndroid Build Coastguard Worker    def test_index_src_datatype(self, device, dtype):
1156*da0073e9SAndroid Build Coastguard Worker        src = torch.ones(3, 2, 4, device=device, dtype=dtype)
1157*da0073e9SAndroid Build Coastguard Worker        # test index
1158*da0073e9SAndroid Build Coastguard Worker        res = src[[0, 2, 1], :, :]
1159*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res.shape, src.shape)
1160*da0073e9SAndroid Build Coastguard Worker        # test index_put, no accum
1161*da0073e9SAndroid Build Coastguard Worker        src[[0, 2, 1], :, :] = res
1162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res.shape, src.shape)
1163*da0073e9SAndroid Build Coastguard Worker
1164*da0073e9SAndroid Build Coastguard Worker    def test_int_indices2d(self, device):
1165*da0073e9SAndroid Build Coastguard Worker        # From the NumPy indexing example
1166*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 12, device=device).view(4, 3)
1167*da0073e9SAndroid Build Coastguard Worker        rows = torch.tensor([[0, 0], [3, 3]], device=device)
1168*da0073e9SAndroid Build Coastguard Worker        columns = torch.tensor([[0, 2], [0, 2]], device=device)
1169*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
1170*da0073e9SAndroid Build Coastguard Worker
1171*da0073e9SAndroid Build Coastguard Worker    def test_int_indices_broadcast(self, device):
1172*da0073e9SAndroid Build Coastguard Worker        # From the NumPy indexing example
1173*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 12, device=device).view(4, 3)
1174*da0073e9SAndroid Build Coastguard Worker        rows = torch.tensor([0, 3], device=device)
1175*da0073e9SAndroid Build Coastguard Worker        columns = torch.tensor([0, 2], device=device)
1176*da0073e9SAndroid Build Coastguard Worker        result = x[rows[:, None], columns]
1177*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
1178*da0073e9SAndroid Build Coastguard Worker
1179*da0073e9SAndroid Build Coastguard Worker    def test_empty_index(self, device):
1180*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 12, device=device).view(4, 3)
1181*da0073e9SAndroid Build Coastguard Worker        idx = torch.tensor([], dtype=torch.long, device=device)
1182*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[idx].numel(), 0)
1183*da0073e9SAndroid Build Coastguard Worker
1184*da0073e9SAndroid Build Coastguard Worker        # empty assignment should have no effect but not throw an exception
1185*da0073e9SAndroid Build Coastguard Worker        y = x.clone()
1186*da0073e9SAndroid Build Coastguard Worker        y[idx] = -1
1187*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, y)
1188*da0073e9SAndroid Build Coastguard Worker
1189*da0073e9SAndroid Build Coastguard Worker        mask = torch.zeros(4, 3, device=device).bool()
1190*da0073e9SAndroid Build Coastguard Worker        y[mask] = -1
1191*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, y)
1192*da0073e9SAndroid Build Coastguard Worker
1193*da0073e9SAndroid Build Coastguard Worker    def test_empty_ndim_index(self, device):
1194*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, device=device)
1195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1196*da0073e9SAndroid Build Coastguard Worker            torch.empty(0, 2, device=device),
1197*da0073e9SAndroid Build Coastguard Worker            x[torch.empty(0, 2, dtype=torch.int64, device=device)],
1198*da0073e9SAndroid Build Coastguard Worker        )
1199*da0073e9SAndroid Build Coastguard Worker
1200*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 4, 5, device=device)
1201*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1202*da0073e9SAndroid Build Coastguard Worker            torch.empty(2, 0, 6, 4, 5, device=device),
1203*da0073e9SAndroid Build Coastguard Worker            x[:, torch.empty(0, 6, dtype=torch.int64, device=device)],
1204*da0073e9SAndroid Build Coastguard Worker        )
1205*da0073e9SAndroid Build Coastguard Worker
1206*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(10, 0, device=device)
1207*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[[1, 2]].shape, (2, 0))
1208*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[[], []].shape, (0,))
1209*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IndexError, "for dimension with size 0"):
1210*da0073e9SAndroid Build Coastguard Worker            x[:, [0, 1]]
1211*da0073e9SAndroid Build Coastguard Worker
1212*da0073e9SAndroid Build Coastguard Worker    def test_empty_ndim_index_bool(self, device):
1213*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, device=device)
1214*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
1215*da0073e9SAndroid Build Coastguard Worker            IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)]
1216*da0073e9SAndroid Build Coastguard Worker        )
1217*da0073e9SAndroid Build Coastguard Worker
1218*da0073e9SAndroid Build Coastguard Worker    def test_empty_slice(self, device):
1219*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 4, 5, device=device)
1220*da0073e9SAndroid Build Coastguard Worker        y = x[:, :, :, 1]
1221*da0073e9SAndroid Build Coastguard Worker        z = y[:, 1:1, :]
1222*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((2, 0, 4), z.shape)
1223*da0073e9SAndroid Build Coastguard Worker        # this isn't technically necessary, but matches NumPy stride calculations.
1224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((60, 20, 5), z.stride())
1225*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.is_contiguous())
1226*da0073e9SAndroid Build Coastguard Worker
1227*da0073e9SAndroid Build Coastguard Worker    def test_index_getitem_copy_bools_slices(self, device):
1228*da0073e9SAndroid Build Coastguard Worker        true = torch.tensor(1, dtype=torch.uint8, device=device)
1229*da0073e9SAndroid Build Coastguard Worker        false = torch.tensor(0, dtype=torch.uint8, device=device)
1230*da0073e9SAndroid Build Coastguard Worker
1231*da0073e9SAndroid Build Coastguard Worker        tensors = [torch.randn(2, 3, device=device), torch.tensor(3.0, device=device)]
1232*da0073e9SAndroid Build Coastguard Worker
1233*da0073e9SAndroid Build Coastguard Worker        for a in tensors:
1234*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
1235*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty(0, *a.shape), a[False])
1236*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(a.data_ptr(), a[true].data_ptr())
1237*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty(0, *a.shape), a[false])
1238*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.data_ptr(), a[None].data_ptr())
1239*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.data_ptr(), a[...].data_ptr())
1240*da0073e9SAndroid Build Coastguard Worker
1241*da0073e9SAndroid Build Coastguard Worker    def test_index_setitem_bools_slices(self, device):
1242*da0073e9SAndroid Build Coastguard Worker        true = torch.tensor(1, dtype=torch.uint8, device=device)
1243*da0073e9SAndroid Build Coastguard Worker        false = torch.tensor(0, dtype=torch.uint8, device=device)
1244*da0073e9SAndroid Build Coastguard Worker
1245*da0073e9SAndroid Build Coastguard Worker        tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)]
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Worker        for a in tensors:
1248*da0073e9SAndroid Build Coastguard Worker            # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
1249*da0073e9SAndroid Build Coastguard Worker            # (some of these ops already prefix a 1 to the size)
1250*da0073e9SAndroid Build Coastguard Worker            neg_ones = torch.ones_like(a) * -1
1251*da0073e9SAndroid Build Coastguard Worker            neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
1252*da0073e9SAndroid Build Coastguard Worker            a[True] = neg_ones_expanded
1253*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, neg_ones)
1254*da0073e9SAndroid Build Coastguard Worker            a[False] = 5
1255*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, neg_ones)
1256*da0073e9SAndroid Build Coastguard Worker            a[true] = neg_ones_expanded * 2
1257*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, neg_ones * 2)
1258*da0073e9SAndroid Build Coastguard Worker            a[false] = 5
1259*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, neg_ones * 2)
1260*da0073e9SAndroid Build Coastguard Worker            a[None] = neg_ones_expanded * 3
1261*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, neg_ones * 3)
1262*da0073e9SAndroid Build Coastguard Worker            a[...] = neg_ones_expanded * 4
1263*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, neg_ones * 4)
1264*da0073e9SAndroid Build Coastguard Worker            if a.dim() == 0:
1265*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(IndexError):
1266*da0073e9SAndroid Build Coastguard Worker                    a[:] = neg_ones_expanded * 5
1267*da0073e9SAndroid Build Coastguard Worker
1268*da0073e9SAndroid Build Coastguard Worker    def test_index_scalar_with_bool_mask(self, device):
1269*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1, device=device)
1270*da0073e9SAndroid Build Coastguard Worker        uintMask = torch.tensor(True, dtype=torch.uint8, device=device)
1271*da0073e9SAndroid Build Coastguard Worker        boolMask = torch.tensor(True, dtype=torch.bool, device=device)
1272*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[uintMask], a[boolMask])
1273*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
1274*da0073e9SAndroid Build Coastguard Worker
1275*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(True, dtype=torch.bool, device=device)
1276*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[uintMask], a[boolMask])
1277*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
1278*da0073e9SAndroid Build Coastguard Worker
1279*da0073e9SAndroid Build Coastguard Worker    def test_setitem_expansion_error(self, device):
1280*da0073e9SAndroid Build Coastguard Worker        true = torch.tensor(True, device=device)
1281*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, device=device)
1282*da0073e9SAndroid Build Coastguard Worker        # check prefix with  non-1s doesn't work
1283*da0073e9SAndroid Build Coastguard Worker        a_expanded = a.expand(torch.Size([5, 1]) + a.size())
1284*da0073e9SAndroid Build Coastguard Worker        # NumPy: ValueError
1285*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
1286*da0073e9SAndroid Build Coastguard Worker            a[True] = a_expanded
1287*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
1288*da0073e9SAndroid Build Coastguard Worker            a[true] = a_expanded
1289*da0073e9SAndroid Build Coastguard Worker
1290*da0073e9SAndroid Build Coastguard Worker    def test_getitem_scalars(self, device):
1291*da0073e9SAndroid Build Coastguard Worker        zero = torch.tensor(0, dtype=torch.int64, device=device)
1292*da0073e9SAndroid Build Coastguard Worker        one = torch.tensor(1, dtype=torch.int64, device=device)
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker        # non-scalar indexed with scalars
1295*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, device=device)
1296*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0], a[zero])
1297*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0][1], a[zero][one])
1298*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0, 1], a[zero, one])
1299*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0, one], a[zero, 1])
1300*da0073e9SAndroid Build Coastguard Worker
1301*da0073e9SAndroid Build Coastguard Worker        # indexing by a scalar should slice (not copy)
1302*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr())
1303*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr())
1304*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr())
1305*da0073e9SAndroid Build Coastguard Worker
1306*da0073e9SAndroid Build Coastguard Worker        # scalar indexed with scalar
1307*da0073e9SAndroid Build Coastguard Worker        r = torch.randn((), device=device)
1308*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
1309*da0073e9SAndroid Build Coastguard Worker            r[:]
1310*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
1311*da0073e9SAndroid Build Coastguard Worker            r[zero]
1312*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, r[...])
1313*da0073e9SAndroid Build Coastguard Worker
1314*da0073e9SAndroid Build Coastguard Worker    def test_setitem_scalars(self, device):
1315*da0073e9SAndroid Build Coastguard Worker        zero = torch.tensor(0, dtype=torch.int64)
1316*da0073e9SAndroid Build Coastguard Worker
1317*da0073e9SAndroid Build Coastguard Worker        # non-scalar indexed with scalars
1318*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, device=device)
1319*da0073e9SAndroid Build Coastguard Worker        a_set_with_number = a.clone()
1320*da0073e9SAndroid Build Coastguard Worker        a_set_with_scalar = a.clone()
1321*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, device=device)
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker        a_set_with_number[0] = b
1324*da0073e9SAndroid Build Coastguard Worker        a_set_with_scalar[zero] = b
1325*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_set_with_number, a_set_with_scalar)
1326*da0073e9SAndroid Build Coastguard Worker        a[1, zero] = 7.7
1327*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(7.7, a[1, 0])
1328*da0073e9SAndroid Build Coastguard Worker
1329*da0073e9SAndroid Build Coastguard Worker        # scalar indexed with scalars
1330*da0073e9SAndroid Build Coastguard Worker        r = torch.randn((), device=device)
1331*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
1332*da0073e9SAndroid Build Coastguard Worker            r[:] = 8.8
1333*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
1334*da0073e9SAndroid Build Coastguard Worker            r[zero] = 8.8
1335*da0073e9SAndroid Build Coastguard Worker        r[...] = 9.9
1336*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(9.9, r)
1337*da0073e9SAndroid Build Coastguard Worker
1338*da0073e9SAndroid Build Coastguard Worker    def test_basic_advanced_combined(self, device):
1339*da0073e9SAndroid Build Coastguard Worker        # From the NumPy indexing example
1340*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 12, device=device).view(4, 3)
1341*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
1342*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
1343*da0073e9SAndroid Build Coastguard Worker
1344*da0073e9SAndroid Build Coastguard Worker        # Check that it is a copy
1345*da0073e9SAndroid Build Coastguard Worker        unmodified = x.clone()
1346*da0073e9SAndroid Build Coastguard Worker        x[1:2, [1, 2]].zero_()
1347*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, unmodified)
1348*da0073e9SAndroid Build Coastguard Worker
1349*da0073e9SAndroid Build Coastguard Worker        # But assignment should modify the original
1350*da0073e9SAndroid Build Coastguard Worker        unmodified = x.clone()
1351*da0073e9SAndroid Build Coastguard Worker        x[1:2, [1, 2]] = 0
1352*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(x, unmodified)
1353*da0073e9SAndroid Build Coastguard Worker
1354*da0073e9SAndroid Build Coastguard Worker    def test_int_assignment(self, device):
1355*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 4, device=device).view(2, 2)
1356*da0073e9SAndroid Build Coastguard Worker        x[1] = 5
1357*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
1358*da0073e9SAndroid Build Coastguard Worker
1359*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 4, device=device).view(2, 2)
1360*da0073e9SAndroid Build Coastguard Worker        x[1] = torch.arange(5, 7, device=device)
1361*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
1362*da0073e9SAndroid Build Coastguard Worker
1363*da0073e9SAndroid Build Coastguard Worker    def test_byte_tensor_assignment(self, device):
1364*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0.0, 16, device=device).view(4, 4)
1365*da0073e9SAndroid Build Coastguard Worker        b = torch.ByteTensor([True, False, True, False]).to(device)
1366*da0073e9SAndroid Build Coastguard Worker        value = torch.tensor([3.0, 4.0, 5.0, 6.0], device=device)
1367*da0073e9SAndroid Build Coastguard Worker
1368*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
1369*da0073e9SAndroid Build Coastguard Worker            x[b] = value
1370*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
1371*da0073e9SAndroid Build Coastguard Worker
1372*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[0], value)
1373*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[1], torch.arange(4.0, 8, device=device))
1374*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[2], value)
1375*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[3], torch.arange(12.0, 16, device=device))
1376*da0073e9SAndroid Build Coastguard Worker
1377*da0073e9SAndroid Build Coastguard Worker    def test_variable_slicing(self, device):
1378*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 16, device=device).view(4, 4)
1379*da0073e9SAndroid Build Coastguard Worker        indices = torch.IntTensor([0, 1]).to(device)
1380*da0073e9SAndroid Build Coastguard Worker        i, j = indices
1381*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[i:j], x[0:1])
1382*da0073e9SAndroid Build Coastguard Worker
1383*da0073e9SAndroid Build Coastguard Worker    def test_ellipsis_tensor(self, device):
1384*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 9, device=device).view(3, 3)
1385*da0073e9SAndroid Build Coastguard Worker        idx = torch.tensor([0, 2], device=device)
1386*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[..., idx].tolist(), [[0, 2], [3, 5], [6, 8]])
1387*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2], [6, 7, 8]])
1388*da0073e9SAndroid Build Coastguard Worker
1389*da0073e9SAndroid Build Coastguard Worker    def test_unravel_index_errors(self, device):
1390*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"):
1391*da0073e9SAndroid Build Coastguard Worker            torch.unravel_index(torch.tensor(0.5, device=device), (2, 2))
1392*da0073e9SAndroid Build Coastguard Worker
1393*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"):
1394*da0073e9SAndroid Build Coastguard Worker            torch.unravel_index(torch.tensor([], device=device), (10, 3, 5))
1395*da0073e9SAndroid Build Coastguard Worker
1396*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1397*da0073e9SAndroid Build Coastguard Worker            TypeError, r"expected 'shape' to be int or sequence"
1398*da0073e9SAndroid Build Coastguard Worker        ):
1399*da0073e9SAndroid Build Coastguard Worker            torch.unravel_index(
1400*da0073e9SAndroid Build Coastguard Worker                torch.tensor([1], device=device, dtype=torch.int64),
1401*da0073e9SAndroid Build Coastguard Worker                torch.tensor([1, 2, 3]),
1402*da0073e9SAndroid Build Coastguard Worker            )
1403*da0073e9SAndroid Build Coastguard Worker
1404*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1405*da0073e9SAndroid Build Coastguard Worker            TypeError, r"expected 'shape' sequence to only contain ints"
1406*da0073e9SAndroid Build Coastguard Worker        ):
1407*da0073e9SAndroid Build Coastguard Worker            torch.unravel_index(
1408*da0073e9SAndroid Build Coastguard Worker                torch.tensor([1], device=device, dtype=torch.int64), (1, 2, 2.0)
1409*da0073e9SAndroid Build Coastguard Worker            )
1410*da0073e9SAndroid Build Coastguard Worker
1411*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1412*da0073e9SAndroid Build Coastguard Worker            ValueError, r"'shape' cannot have negative values, but got \(2, -3\)"
1413*da0073e9SAndroid Build Coastguard Worker        ):
1414*da0073e9SAndroid Build Coastguard Worker            torch.unravel_index(torch.tensor(0, device=device), (2, -3))
1415*da0073e9SAndroid Build Coastguard Worker
1416*da0073e9SAndroid Build Coastguard Worker    def test_invalid_index(self, device):
1417*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 16, device=device).view(4, 4)
1418*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(TypeError, "slice indices", lambda: x["0":"1"])
1419*da0073e9SAndroid Build Coastguard Worker
1420*da0073e9SAndroid Build Coastguard Worker    def test_out_of_bound_index(self, device):
1421*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 100, device=device).view(2, 5, 10)
1422*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1423*da0073e9SAndroid Build Coastguard Worker            IndexError,
1424*da0073e9SAndroid Build Coastguard Worker            "index 5 is out of bounds for dimension 1 with size 5",
1425*da0073e9SAndroid Build Coastguard Worker            lambda: x[0, 5],
1426*da0073e9SAndroid Build Coastguard Worker        )
1427*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1428*da0073e9SAndroid Build Coastguard Worker            IndexError,
1429*da0073e9SAndroid Build Coastguard Worker            "index 4 is out of bounds for dimension 0 with size 2",
1430*da0073e9SAndroid Build Coastguard Worker            lambda: x[4, 5],
1431*da0073e9SAndroid Build Coastguard Worker        )
1432*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1433*da0073e9SAndroid Build Coastguard Worker            IndexError,
1434*da0073e9SAndroid Build Coastguard Worker            "index 15 is out of bounds for dimension 2 with size 10",
1435*da0073e9SAndroid Build Coastguard Worker            lambda: x[0, 1, 15],
1436*da0073e9SAndroid Build Coastguard Worker        )
1437*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1438*da0073e9SAndroid Build Coastguard Worker            IndexError,
1439*da0073e9SAndroid Build Coastguard Worker            "index 12 is out of bounds for dimension 2 with size 10",
1440*da0073e9SAndroid Build Coastguard Worker            lambda: x[:, :, 12],
1441*da0073e9SAndroid Build Coastguard Worker        )
1442*da0073e9SAndroid Build Coastguard Worker
1443*da0073e9SAndroid Build Coastguard Worker    def test_zero_dim_index(self, device):
1444*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(10, device=device)
1445*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.item())
1446*da0073e9SAndroid Build Coastguard Worker
1447*da0073e9SAndroid Build Coastguard Worker        def runner():
1448*da0073e9SAndroid Build Coastguard Worker            print(x[0])
1449*da0073e9SAndroid Build Coastguard Worker            return x[0]
1450*da0073e9SAndroid Build Coastguard Worker
1451*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(IndexError, "invalid index", runner)
1452*da0073e9SAndroid Build Coastguard Worker
1453*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1454*da0073e9SAndroid Build Coastguard Worker    def test_invalid_device(self, device):
1455*da0073e9SAndroid Build Coastguard Worker        idx = torch.tensor([0, 1])
1456*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(5, device=device)
1457*da0073e9SAndroid Build Coastguard Worker        c = torch.tensor([1.0, 2.0], device="cpu")
1458*da0073e9SAndroid Build Coastguard Worker
1459*da0073e9SAndroid Build Coastguard Worker        for accumulate in [True, False]:
1460*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(
1461*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
1462*da0073e9SAndroid Build Coastguard Worker                lambda: torch.index_put_(b, (idx,), c, accumulate=accumulate),
1463*da0073e9SAndroid Build Coastguard Worker            )
1464*da0073e9SAndroid Build Coastguard Worker
1465*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1466*da0073e9SAndroid Build Coastguard Worker    def test_cpu_indices(self, device):
1467*da0073e9SAndroid Build Coastguard Worker        idx = torch.tensor([0, 1])
1468*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(2, device=device)
1469*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(10, device=device)
1470*da0073e9SAndroid Build Coastguard Worker        x[idx] = b  # index_put_
1471*da0073e9SAndroid Build Coastguard Worker        ref = torch.ones(10, device=device)
1472*da0073e9SAndroid Build Coastguard Worker        ref[:2] = 0
1473*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, ref, atol=0, rtol=0)
1474*da0073e9SAndroid Build Coastguard Worker        out = x[idx]  # index
1475*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
1476*da0073e9SAndroid Build Coastguard Worker
1477*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.long, torch.float32)
1478*da0073e9SAndroid Build Coastguard Worker    def test_take_along_dim(self, device, dtype):
1479*da0073e9SAndroid Build Coastguard Worker        def _test_against_numpy(t, indices, dim):
1480*da0073e9SAndroid Build Coastguard Worker            actual = torch.take_along_dim(t, indices, dim=dim)
1481*da0073e9SAndroid Build Coastguard Worker            t_np = t.cpu().numpy()
1482*da0073e9SAndroid Build Coastguard Worker            indices_np = indices.cpu().numpy()
1483*da0073e9SAndroid Build Coastguard Worker            expected = np.take_along_axis(t_np, indices_np, axis=dim)
1484*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, atol=0, rtol=0)
1485*da0073e9SAndroid Build Coastguard Worker
1486*da0073e9SAndroid Build Coastguard Worker        for shape in [(3, 2), (2, 3, 5), (2, 4, 0), (2, 3, 1, 4)]:
1487*da0073e9SAndroid Build Coastguard Worker            for noncontiguous in [True, False]:
1488*da0073e9SAndroid Build Coastguard Worker                t = make_tensor(
1489*da0073e9SAndroid Build Coastguard Worker                    shape, device=device, dtype=dtype, noncontiguous=noncontiguous
1490*da0073e9SAndroid Build Coastguard Worker                )
1491*da0073e9SAndroid Build Coastguard Worker                for dim in list(range(t.ndim)) + [None]:
1492*da0073e9SAndroid Build Coastguard Worker                    if dim is None:
1493*da0073e9SAndroid Build Coastguard Worker                        indices = torch.argsort(t.view(-1))
1494*da0073e9SAndroid Build Coastguard Worker                    else:
1495*da0073e9SAndroid Build Coastguard Worker                        indices = torch.argsort(t, dim=dim)
1496*da0073e9SAndroid Build Coastguard Worker
1497*da0073e9SAndroid Build Coastguard Worker                _test_against_numpy(t, indices, dim)
1498*da0073e9SAndroid Build Coastguard Worker
1499*da0073e9SAndroid Build Coastguard Worker        # test broadcasting
1500*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((3, 4, 1), device=device)
1501*da0073e9SAndroid Build Coastguard Worker        indices = torch.ones((1, 2, 5), dtype=torch.long, device=device)
1502*da0073e9SAndroid Build Coastguard Worker
1503*da0073e9SAndroid Build Coastguard Worker        _test_against_numpy(t, indices, 1)
1504*da0073e9SAndroid Build Coastguard Worker
1505*da0073e9SAndroid Build Coastguard Worker        # test empty indices
1506*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((3, 4, 5), device=device)
1507*da0073e9SAndroid Build Coastguard Worker        indices = torch.ones((3, 0, 5), dtype=torch.long, device=device)
1508*da0073e9SAndroid Build Coastguard Worker
1509*da0073e9SAndroid Build Coastguard Worker        _test_against_numpy(t, indices, 1)
1510*da0073e9SAndroid Build Coastguard Worker
1511*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.long, torch.float)
1512*da0073e9SAndroid Build Coastguard Worker    def test_take_along_dim_invalid(self, device, dtype):
1513*da0073e9SAndroid Build Coastguard Worker        shape = (2, 3, 1, 4)
1514*da0073e9SAndroid Build Coastguard Worker        dim = 0
1515*da0073e9SAndroid Build Coastguard Worker        t = make_tensor(shape, device=device, dtype=dtype)
1516*da0073e9SAndroid Build Coastguard Worker        indices = torch.argsort(t, dim=dim)
1517*da0073e9SAndroid Build Coastguard Worker
1518*da0073e9SAndroid Build Coastguard Worker        # dim of `t` and `indices` does not match
1519*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1520*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "input and indices should have the same number of dimensions"
1521*da0073e9SAndroid Build Coastguard Worker        ):
1522*da0073e9SAndroid Build Coastguard Worker            torch.take_along_dim(t, indices[0], dim=0)
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker        # invalid `indices` dtype
1525*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"):
1526*da0073e9SAndroid Build Coastguard Worker            torch.take_along_dim(t, indices.to(torch.bool), dim=0)
1527*da0073e9SAndroid Build Coastguard Worker
1528*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"):
1529*da0073e9SAndroid Build Coastguard Worker            torch.take_along_dim(t, indices.to(torch.float), dim=0)
1530*da0073e9SAndroid Build Coastguard Worker
1531*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"):
1532*da0073e9SAndroid Build Coastguard Worker            torch.take_along_dim(t, indices.to(torch.int32), dim=0)
1533*da0073e9SAndroid Build Coastguard Worker
1534*da0073e9SAndroid Build Coastguard Worker        # invalid axis
1535*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1536*da0073e9SAndroid Build Coastguard Worker            torch.take_along_dim(t, indices, dim=-7)
1537*da0073e9SAndroid Build Coastguard Worker
1538*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1539*da0073e9SAndroid Build Coastguard Worker            torch.take_along_dim(t, indices, dim=7)
1540*da0073e9SAndroid Build Coastguard Worker
1541*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1542*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1543*da0073e9SAndroid Build Coastguard Worker    def test_gather_take_along_dim_cross_device(self, device, dtype):
1544*da0073e9SAndroid Build Coastguard Worker        shape = (2, 3, 1, 4)
1545*da0073e9SAndroid Build Coastguard Worker        dim = 0
1546*da0073e9SAndroid Build Coastguard Worker        t = make_tensor(shape, device=device, dtype=dtype)
1547*da0073e9SAndroid Build Coastguard Worker        indices = torch.argsort(t, dim=dim)
1548*da0073e9SAndroid Build Coastguard Worker
1549*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1550*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Expected all tensors to be on the same device"
1551*da0073e9SAndroid Build Coastguard Worker        ):
1552*da0073e9SAndroid Build Coastguard Worker            torch.gather(t, 0, indices.cpu())
1553*da0073e9SAndroid Build Coastguard Worker
1554*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1555*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1556*da0073e9SAndroid Build Coastguard Worker            r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()",
1557*da0073e9SAndroid Build Coastguard Worker        ):
1558*da0073e9SAndroid Build Coastguard Worker            torch.take_along_dim(t, indices.cpu(), dim=0)
1559*da0073e9SAndroid Build Coastguard Worker
1560*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1561*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Expected all tensors to be on the same device"
1562*da0073e9SAndroid Build Coastguard Worker        ):
1563*da0073e9SAndroid Build Coastguard Worker            torch.gather(t.cpu(), 0, indices)
1564*da0073e9SAndroid Build Coastguard Worker
1565*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1566*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1567*da0073e9SAndroid Build Coastguard Worker            r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()",
1568*da0073e9SAndroid Build Coastguard Worker        ):
1569*da0073e9SAndroid Build Coastguard Worker            torch.take_along_dim(t.cpu(), indices, dim=0)
1570*da0073e9SAndroid Build Coastguard Worker
1571*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1572*da0073e9SAndroid Build Coastguard Worker    def test_cuda_broadcast_index_use_deterministic_algorithms(self, device):
1573*da0073e9SAndroid Build Coastguard Worker        with DeterministicGuard(True):
1574*da0073e9SAndroid Build Coastguard Worker            idx1 = torch.tensor([0])
1575*da0073e9SAndroid Build Coastguard Worker            idx2 = torch.tensor([2, 6])
1576*da0073e9SAndroid Build Coastguard Worker            idx3 = torch.tensor([1, 5, 7])
1577*da0073e9SAndroid Build Coastguard Worker
1578*da0073e9SAndroid Build Coastguard Worker            tensor_a = torch.rand(13, 11, 12, 13, 12).cpu()
1579*da0073e9SAndroid Build Coastguard Worker            tensor_b = tensor_a.to(device=device)
1580*da0073e9SAndroid Build Coastguard Worker            tensor_a[idx1] = 1.0
1581*da0073e9SAndroid Build Coastguard Worker            tensor_a[idx1, :, idx2, idx2, :] = 2.0
1582*da0073e9SAndroid Build Coastguard Worker            tensor_a[:, idx1, idx3, :, idx3] = 3.0
1583*da0073e9SAndroid Build Coastguard Worker            tensor_b[idx1] = 1.0
1584*da0073e9SAndroid Build Coastguard Worker            tensor_b[idx1, :, idx2, idx2, :] = 2.0
1585*da0073e9SAndroid Build Coastguard Worker            tensor_b[:, idx1, idx3, :, idx3] = 3.0
1586*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
1587*da0073e9SAndroid Build Coastguard Worker
1588*da0073e9SAndroid Build Coastguard Worker            tensor_a = torch.rand(10, 11).cpu()
1589*da0073e9SAndroid Build Coastguard Worker            tensor_b = tensor_a.to(device=device)
1590*da0073e9SAndroid Build Coastguard Worker            tensor_a[idx3] = 1.0
1591*da0073e9SAndroid Build Coastguard Worker            tensor_a[idx2, :] = 2.0
1592*da0073e9SAndroid Build Coastguard Worker            tensor_a[:, idx2] = 3.0
1593*da0073e9SAndroid Build Coastguard Worker            tensor_a[:, idx1] = 4.0
1594*da0073e9SAndroid Build Coastguard Worker            tensor_b[idx3] = 1.0
1595*da0073e9SAndroid Build Coastguard Worker            tensor_b[idx2, :] = 2.0
1596*da0073e9SAndroid Build Coastguard Worker            tensor_b[:, idx2] = 3.0
1597*da0073e9SAndroid Build Coastguard Worker            tensor_b[:, idx1] = 4.0
1598*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
1599*da0073e9SAndroid Build Coastguard Worker
1600*da0073e9SAndroid Build Coastguard Worker            tensor_a = torch.rand(10, 10).cpu()
1601*da0073e9SAndroid Build Coastguard Worker            tensor_b = tensor_a.to(device=device)
1602*da0073e9SAndroid Build Coastguard Worker            tensor_a[[8]] = 1.0
1603*da0073e9SAndroid Build Coastguard Worker            tensor_b[[8]] = 1.0
1604*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
1605*da0073e9SAndroid Build Coastguard Worker
1606*da0073e9SAndroid Build Coastguard Worker            tensor_a = torch.rand(10).cpu()
1607*da0073e9SAndroid Build Coastguard Worker            tensor_b = tensor_a.to(device=device)
1608*da0073e9SAndroid Build Coastguard Worker            tensor_a[6] = 1.0
1609*da0073e9SAndroid Build Coastguard Worker            tensor_b[6] = 1.0
1610*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
1611*da0073e9SAndroid Build Coastguard Worker
1612*da0073e9SAndroid Build Coastguard Worker    def test_index_limits(self, device):
1613*da0073e9SAndroid Build Coastguard Worker        #  Regression test for https://github.com/pytorch/pytorch/issues/115415
1614*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([], device=device)
1615*da0073e9SAndroid Build Coastguard Worker        idx_min = torch.iinfo(torch.int64).min
1616*da0073e9SAndroid Build Coastguard Worker        idx_max = torch.iinfo(torch.int64).max
1617*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: t[idx_min])
1618*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: t[idx_max])
1619*da0073e9SAndroid Build Coastguard Worker
1620*da0073e9SAndroid Build Coastguard Worker
1621*da0073e9SAndroid Build Coastguard Worker# The tests below are from NumPy test_indexing.py with some modifications to
1622*da0073e9SAndroid Build Coastguard Worker# make them compatible with PyTorch. It's licensed under the BDS license below:
1623*da0073e9SAndroid Build Coastguard Worker#
1624*da0073e9SAndroid Build Coastguard Worker# Copyright (c) 2005-2017, NumPy Developers.
1625*da0073e9SAndroid Build Coastguard Worker# All rights reserved.
1626*da0073e9SAndroid Build Coastguard Worker#
1627*da0073e9SAndroid Build Coastguard Worker# Redistribution and use in source and binary forms, with or without
1628*da0073e9SAndroid Build Coastguard Worker# modification, are permitted provided that the following conditions are
1629*da0073e9SAndroid Build Coastguard Worker# met:
1630*da0073e9SAndroid Build Coastguard Worker#
1631*da0073e9SAndroid Build Coastguard Worker#     * Redistributions of source code must retain the above copyright
1632*da0073e9SAndroid Build Coastguard Worker#        notice, this list of conditions and the following disclaimer.
1633*da0073e9SAndroid Build Coastguard Worker#
1634*da0073e9SAndroid Build Coastguard Worker#     * Redistributions in binary form must reproduce the above
1635*da0073e9SAndroid Build Coastguard Worker#        copyright notice, this list of conditions and the following
1636*da0073e9SAndroid Build Coastguard Worker#        disclaimer in the documentation and/or other materials provided
1637*da0073e9SAndroid Build Coastguard Worker#        with the distribution.
1638*da0073e9SAndroid Build Coastguard Worker#
1639*da0073e9SAndroid Build Coastguard Worker#     * Neither the name of the NumPy Developers nor the names of any
1640*da0073e9SAndroid Build Coastguard Worker#        contributors may be used to endorse or promote products derived
1641*da0073e9SAndroid Build Coastguard Worker#        from this software without specific prior written permission.
1642*da0073e9SAndroid Build Coastguard Worker#
1643*da0073e9SAndroid Build Coastguard Worker# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
1644*da0073e9SAndroid Build Coastguard Worker# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
1645*da0073e9SAndroid Build Coastguard Worker# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
1646*da0073e9SAndroid Build Coastguard Worker# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
1647*da0073e9SAndroid Build Coastguard Worker# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
1648*da0073e9SAndroid Build Coastguard Worker# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
1649*da0073e9SAndroid Build Coastguard Worker# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
1650*da0073e9SAndroid Build Coastguard Worker# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
1651*da0073e9SAndroid Build Coastguard Worker# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
1652*da0073e9SAndroid Build Coastguard Worker# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
1653*da0073e9SAndroid Build Coastguard Worker# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
1654*da0073e9SAndroid Build Coastguard Worker
1655*da0073e9SAndroid Build Coastguard Worker
1656*da0073e9SAndroid Build Coastguard Workerclass NumpyTests(TestCase):
1657*da0073e9SAndroid Build Coastguard Worker    def test_index_no_floats(self, device):
1658*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([[[5.0]]], device=device)
1659*da0073e9SAndroid Build Coastguard Worker
1660*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[0.0])
1661*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[0, 0.0])
1662*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[0.0, 0])
1663*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[0.0, :])
1664*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[:, 0.0])
1665*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[:, 0.0, :])
1666*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[0.0, :, :])
1667*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[0, 0, 0.0])
1668*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[0.0, 0, 0])
1669*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[0, 0.0, 0])
1670*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[-1.4])
1671*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[0, -1.4])
1672*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[-1.4, 0])
1673*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[-1.4, :])
1674*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[:, -1.4])
1675*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[:, -1.4, :])
1676*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[-1.4, :, :])
1677*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[0, 0, -1.4])
1678*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[-1.4, 0, 0])
1679*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[0, -1.4, 0])
1680*da0073e9SAndroid Build Coastguard Worker        # self.assertRaises(IndexError, lambda: a[0.0:, 0.0])
1681*da0073e9SAndroid Build Coastguard Worker        # self.assertRaises(IndexError, lambda: a[0.0:, 0.0,:])
1682*da0073e9SAndroid Build Coastguard Worker
1683*da0073e9SAndroid Build Coastguard Worker    def test_none_index(self, device):
1684*da0073e9SAndroid Build Coastguard Worker        # `None` index adds newaxis
1685*da0073e9SAndroid Build Coastguard Worker        a = tensor([1, 2, 3], device=device)
1686*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[None].dim(), a.dim() + 1)
1687*da0073e9SAndroid Build Coastguard Worker
1688*da0073e9SAndroid Build Coastguard Worker    def test_empty_tuple_index(self, device):
1689*da0073e9SAndroid Build Coastguard Worker        # Empty tuple index creates a view
1690*da0073e9SAndroid Build Coastguard Worker        a = tensor([1, 2, 3], device=device)
1691*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[()], a)
1692*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[()].data_ptr(), a.data_ptr())
1693*da0073e9SAndroid Build Coastguard Worker
1694*da0073e9SAndroid Build Coastguard Worker    def test_empty_fancy_index(self, device):
1695*da0073e9SAndroid Build Coastguard Worker        # Empty list index creates an empty array
1696*da0073e9SAndroid Build Coastguard Worker        a = tensor([1, 2, 3], device=device)
1697*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[[]], torch.tensor([], dtype=torch.long, device=device))
1698*da0073e9SAndroid Build Coastguard Worker
1699*da0073e9SAndroid Build Coastguard Worker        b = tensor([], device=device).long()
1700*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[[]], torch.tensor([], dtype=torch.long, device=device))
1701*da0073e9SAndroid Build Coastguard Worker
1702*da0073e9SAndroid Build Coastguard Worker        b = tensor([], device=device).float()
1703*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[b])
1704*da0073e9SAndroid Build Coastguard Worker
1705*da0073e9SAndroid Build Coastguard Worker    def test_ellipsis_index(self, device):
1706*da0073e9SAndroid Build Coastguard Worker        a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1707*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(a[...], a)
1708*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[...], a)
1709*da0073e9SAndroid Build Coastguard Worker        # `a[...]` was `a` in numpy <1.9.
1710*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[...].data_ptr(), a.data_ptr())
1711*da0073e9SAndroid Build Coastguard Worker
1712*da0073e9SAndroid Build Coastguard Worker        # Slicing with ellipsis can skip an
1713*da0073e9SAndroid Build Coastguard Worker        # arbitrary number of dimensions
1714*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0, ...], a[0])
1715*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0, ...], a[0, :])
1716*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[..., 0], a[:, 0])
1717*da0073e9SAndroid Build Coastguard Worker
1718*da0073e9SAndroid Build Coastguard Worker        # In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch
1719*da0073e9SAndroid Build Coastguard Worker        # we don't have separate 0-dim arrays and scalars.
1720*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0, ..., 1], torch.tensor(2, device=device))
1721*da0073e9SAndroid Build Coastguard Worker
1722*da0073e9SAndroid Build Coastguard Worker        # Assignment with `(Ellipsis,)` on 0-d arrays
1723*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(1)
1724*da0073e9SAndroid Build Coastguard Worker        b[(Ellipsis,)] = 2
1725*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b, 2)
1726*da0073e9SAndroid Build Coastguard Worker
1727*da0073e9SAndroid Build Coastguard Worker    def test_single_int_index(self, device):
1728*da0073e9SAndroid Build Coastguard Worker        # Single integer index selects one row
1729*da0073e9SAndroid Build Coastguard Worker        a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1730*da0073e9SAndroid Build Coastguard Worker
1731*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0], [1, 2, 3])
1732*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[-1], [7, 8, 9])
1733*da0073e9SAndroid Build Coastguard Worker
1734*da0073e9SAndroid Build Coastguard Worker        # Index out of bounds produces IndexError
1735*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, a.__getitem__, 1 << 30)
1736*da0073e9SAndroid Build Coastguard Worker        # Index overflow produces Exception  NB: different exception type
1737*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(Exception, a.__getitem__, 1 << 64)
1738*da0073e9SAndroid Build Coastguard Worker
1739*da0073e9SAndroid Build Coastguard Worker    def test_single_bool_index(self, device):
1740*da0073e9SAndroid Build Coastguard Worker        # Single boolean index
1741*da0073e9SAndroid Build Coastguard Worker        a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1742*da0073e9SAndroid Build Coastguard Worker
1743*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[True], a[None])
1744*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[False], a[None][0:0])
1745*da0073e9SAndroid Build Coastguard Worker
1746*da0073e9SAndroid Build Coastguard Worker    def test_boolean_shape_mismatch(self, device):
1747*da0073e9SAndroid Build Coastguard Worker        arr = torch.ones((5, 4, 3), device=device)
1748*da0073e9SAndroid Build Coastguard Worker
1749*da0073e9SAndroid Build Coastguard Worker        index = tensor([True], device=device)
1750*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(IndexError, "mask", lambda: arr[index])
1751*da0073e9SAndroid Build Coastguard Worker
1752*da0073e9SAndroid Build Coastguard Worker        index = tensor([False] * 6, device=device)
1753*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(IndexError, "mask", lambda: arr[index])
1754*da0073e9SAndroid Build Coastguard Worker
1755*da0073e9SAndroid Build Coastguard Worker        index = torch.ByteTensor(4, 4).to(device).zero_()
1756*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(IndexError, "mask", lambda: arr[index])
1757*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(IndexError, "mask", lambda: arr[(slice(None), index)])
1758*da0073e9SAndroid Build Coastguard Worker
1759*da0073e9SAndroid Build Coastguard Worker    def test_boolean_indexing_onedim(self, device):
1760*da0073e9SAndroid Build Coastguard Worker        # Indexing a 2-dimensional array with
1761*da0073e9SAndroid Build Coastguard Worker        # boolean array of length one
1762*da0073e9SAndroid Build Coastguard Worker        a = tensor([[0.0, 0.0, 0.0]], device=device)
1763*da0073e9SAndroid Build Coastguard Worker        b = tensor([True], device=device)
1764*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[b], a)
1765*da0073e9SAndroid Build Coastguard Worker        # boolean assignment
1766*da0073e9SAndroid Build Coastguard Worker        a[b] = 1.0
1767*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, tensor([[1.0, 1.0, 1.0]], device=device))
1768*da0073e9SAndroid Build Coastguard Worker
1769*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/127003
1770*da0073e9SAndroid Build Coastguard Worker    @xfailIfTorchDynamo
1771*da0073e9SAndroid Build Coastguard Worker    def test_boolean_assignment_value_mismatch(self, device):
1772*da0073e9SAndroid Build Coastguard Worker        # A boolean assignment should fail when the shape of the values
1773*da0073e9SAndroid Build Coastguard Worker        # cannot be broadcast to the subscription. (see also gh-3458)
1774*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(0, 4, device=device)
1775*da0073e9SAndroid Build Coastguard Worker
1776*da0073e9SAndroid Build Coastguard Worker        def f(a, v):
1777*da0073e9SAndroid Build Coastguard Worker            a[a > -1] = tensor(v).to(device)
1778*da0073e9SAndroid Build Coastguard Worker
1779*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(Exception, "shape mismatch", f, a, [])
1780*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(Exception, "shape mismatch", f, a, [1, 2, 3])
1781*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(Exception, "shape mismatch", f, a[:1], [1, 2, 3])
1782*da0073e9SAndroid Build Coastguard Worker
1783*da0073e9SAndroid Build Coastguard Worker    def test_boolean_indexing_twodim(self, device):
1784*da0073e9SAndroid Build Coastguard Worker        # Indexing a 2-dimensional array with
1785*da0073e9SAndroid Build Coastguard Worker        # 2-dimensional boolean array
1786*da0073e9SAndroid Build Coastguard Worker        a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1787*da0073e9SAndroid Build Coastguard Worker        b = tensor(
1788*da0073e9SAndroid Build Coastguard Worker            [[True, False, True], [False, True, False], [True, False, True]],
1789*da0073e9SAndroid Build Coastguard Worker            device=device,
1790*da0073e9SAndroid Build Coastguard Worker        )
1791*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[b], tensor([1, 3, 5, 7, 9], device=device))
1792*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[b[1]], tensor([[4, 5, 6]], device=device))
1793*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[b[0]], a[b[2]])
1794*da0073e9SAndroid Build Coastguard Worker
1795*da0073e9SAndroid Build Coastguard Worker        # boolean assignment
1796*da0073e9SAndroid Build Coastguard Worker        a[b] = 0
1797*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, tensor([[0, 2, 0], [4, 0, 6], [0, 8, 0]], device=device))
1798*da0073e9SAndroid Build Coastguard Worker
1799*da0073e9SAndroid Build Coastguard Worker    def test_boolean_indexing_weirdness(self, device):
1800*da0073e9SAndroid Build Coastguard Worker        # Weird boolean indexing things
1801*da0073e9SAndroid Build Coastguard Worker        a = torch.ones((2, 3, 4), device=device)
1802*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape)
1803*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1804*da0073e9SAndroid Build Coastguard Worker            torch.ones(1, 2, device=device), a[True, [0, 1], True, True, [1], [[2]]]
1805*da0073e9SAndroid Build Coastguard Worker        )
1806*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[False, [0, 1], ...])
1807*da0073e9SAndroid Build Coastguard Worker
1808*da0073e9SAndroid Build Coastguard Worker    def test_boolean_indexing_weirdness_tensors(self, device):
1809*da0073e9SAndroid Build Coastguard Worker        # Weird boolean indexing things
1810*da0073e9SAndroid Build Coastguard Worker        false = torch.tensor(False, device=device)
1811*da0073e9SAndroid Build Coastguard Worker        true = torch.tensor(True, device=device)
1812*da0073e9SAndroid Build Coastguard Worker        a = torch.ones((2, 3, 4), device=device)
1813*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape)
1814*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1815*da0073e9SAndroid Build Coastguard Worker            torch.ones(1, 2, device=device), a[true, [0, 1], true, true, [1], [[2]]]
1816*da0073e9SAndroid Build Coastguard Worker        )
1817*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: a[false, [0, 1], ...])
1818*da0073e9SAndroid Build Coastguard Worker
1819*da0073e9SAndroid Build Coastguard Worker    def test_boolean_indexing_alldims(self, device):
1820*da0073e9SAndroid Build Coastguard Worker        true = torch.tensor(True, device=device)
1821*da0073e9SAndroid Build Coastguard Worker        a = torch.ones((2, 3), device=device)
1822*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1, 2, 3), a[True, True].shape)
1823*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1, 2, 3), a[true, true].shape)
1824*da0073e9SAndroid Build Coastguard Worker
1825*da0073e9SAndroid Build Coastguard Worker    def test_boolean_list_indexing(self, device):
1826*da0073e9SAndroid Build Coastguard Worker        # Indexing a 2-dimensional array with
1827*da0073e9SAndroid Build Coastguard Worker        # boolean lists
1828*da0073e9SAndroid Build Coastguard Worker        a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1829*da0073e9SAndroid Build Coastguard Worker        b = [True, False, False]
1830*da0073e9SAndroid Build Coastguard Worker        c = [True, True, False]
1831*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[b], tensor([[1, 2, 3]], device=device))
1832*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[b, b], tensor([1], device=device))
1833*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[c], tensor([[1, 2, 3], [4, 5, 6]], device=device))
1834*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[c, c], tensor([1, 5], device=device))
1835*da0073e9SAndroid Build Coastguard Worker
1836*da0073e9SAndroid Build Coastguard Worker    def test_everything_returns_views(self, device):
1837*da0073e9SAndroid Build Coastguard Worker        # Before `...` would return a itself.
1838*da0073e9SAndroid Build Coastguard Worker        a = tensor([5], device=device)
1839*da0073e9SAndroid Build Coastguard Worker
1840*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(a, a[()])
1841*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(a, a[...])
1842*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(a, a[:])
1843*da0073e9SAndroid Build Coastguard Worker
1844*da0073e9SAndroid Build Coastguard Worker    def test_broaderrors_indexing(self, device):
1845*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(5, 5, device=device)
1846*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1847*da0073e9SAndroid Build Coastguard Worker            IndexError, "shape mismatch", a.__getitem__, ([0, 1], [0, 1, 2])
1848*da0073e9SAndroid Build Coastguard Worker        )
1849*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1850*da0073e9SAndroid Build Coastguard Worker            IndexError, "shape mismatch", a.__setitem__, ([0, 1], [0, 1, 2]), 0
1851*da0073e9SAndroid Build Coastguard Worker        )
1852*da0073e9SAndroid Build Coastguard Worker
1853*da0073e9SAndroid Build Coastguard Worker    def test_trivial_fancy_out_of_bounds(self, device):
1854*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(5, device=device)
1855*da0073e9SAndroid Build Coastguard Worker        ind = torch.ones(20, dtype=torch.int64, device=device)
1856*da0073e9SAndroid Build Coastguard Worker        if a.is_cuda:
1857*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("CUDA asserts instead of raising an exception")
1858*da0073e9SAndroid Build Coastguard Worker        ind[-1] = 10
1859*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, a.__getitem__, ind)
1860*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, a.__setitem__, ind, 0)
1861*da0073e9SAndroid Build Coastguard Worker        ind = torch.ones(20, dtype=torch.int64, device=device)
1862*da0073e9SAndroid Build Coastguard Worker        ind[0] = 11
1863*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, a.__getitem__, ind)
1864*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, a.__setitem__, ind, 0)
1865*da0073e9SAndroid Build Coastguard Worker
1866*da0073e9SAndroid Build Coastguard Worker    def test_index_is_larger(self, device):
1867*da0073e9SAndroid Build Coastguard Worker        # Simple case of fancy index broadcasting of the index.
1868*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros((5, 5), device=device)
1869*da0073e9SAndroid Build Coastguard Worker        a[[[0], [1], [2]], [0, 1, 2]] = tensor([2.0, 3.0, 4.0], device=device)
1870*da0073e9SAndroid Build Coastguard Worker
1871*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((a[:3, :3] == tensor([2.0, 3.0, 4.0], device=device)).all())
1872*da0073e9SAndroid Build Coastguard Worker
1873*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_subspace(self, device):
1874*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros((100, 100), device=device)
1875*da0073e9SAndroid Build Coastguard Worker        v = torch.arange(0.0, 100, device=device)[:, None]
1876*da0073e9SAndroid Build Coastguard Worker        b = torch.arange(99, -1, -1, device=device).long()
1877*da0073e9SAndroid Build Coastguard Worker        a[b] = v
1878*da0073e9SAndroid Build Coastguard Worker        expected = b.float().unsqueeze(1).expand(100, 100)
1879*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, expected)
1880*da0073e9SAndroid Build Coastguard Worker
1881*da0073e9SAndroid Build Coastguard Worker    def test_truncate_leading_1s(self, device):
1882*da0073e9SAndroid Build Coastguard Worker        col_max = torch.randn(1, 4)
1883*da0073e9SAndroid Build Coastguard Worker        kernel = col_max.T * col_max  # [4, 4] tensor
1884*da0073e9SAndroid Build Coastguard Worker        kernel2 = kernel.clone()
1885*da0073e9SAndroid Build Coastguard Worker        # Set the diagonal
1886*da0073e9SAndroid Build Coastguard Worker        kernel[range(len(kernel)), range(len(kernel))] = torch.square(col_max)
1887*da0073e9SAndroid Build Coastguard Worker        torch.diagonal(kernel2).copy_(torch.square(col_max.view(4)))
1888*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(kernel, kernel2)
1889*da0073e9SAndroid Build Coastguard Worker
1890*da0073e9SAndroid Build Coastguard Worker
1891*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestIndexing, globals(), except_for="meta")
1892*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(NumpyTests, globals(), except_for="meta")
1893*da0073e9SAndroid Build Coastguard Worker
1894*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1895*da0073e9SAndroid Build Coastguard Worker    run_tests()
1896