xref: /aosp_15_r20/external/pytorch/test/test_indexing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: tests"]
2
3import operator
4import random
5import unittest
6import warnings
7from functools import reduce
8
9import numpy as np
10
11import torch
12from torch import tensor
13from torch.testing import make_tensor
14from torch.testing._internal.common_device_type import (
15    dtypes,
16    dtypesIfCPU,
17    dtypesIfCUDA,
18    instantiate_device_type_tests,
19    onlyCUDA,
20    onlyNativeDeviceTypes,
21    skipXLA,
22)
23from torch.testing._internal.common_utils import (
24    DeterministicGuard,
25    run_tests,
26    serialTest,
27    skipIfTorchDynamo,
28    TEST_CUDA,
29    TestCase,
30    xfailIfTorchDynamo,
31)
32
33
34class TestIndexing(TestCase):
35    def test_index(self, device):
36        def consec(size, start=1):
37            sequence = torch.ones(torch.tensor(size).prod(0)).cumsum(0)
38            sequence.add_(start - 1)
39            return sequence.view(*size)
40
41        reference = consec((3, 3, 3)).to(device)
42
43        # empty tensor indexing
44        self.assertEqual(
45            reference[torch.LongTensor().to(device)], reference.new(0, 3, 3)
46        )
47
48        self.assertEqual(reference[0], consec((3, 3)), atol=0, rtol=0)
49        self.assertEqual(reference[1], consec((3, 3), 10), atol=0, rtol=0)
50        self.assertEqual(reference[2], consec((3, 3), 19), atol=0, rtol=0)
51        self.assertEqual(reference[0, 1], consec((3,), 4), atol=0, rtol=0)
52        self.assertEqual(reference[0:2], consec((2, 3, 3)), atol=0, rtol=0)
53        self.assertEqual(reference[2, 2, 2], 27, atol=0, rtol=0)
54        self.assertEqual(reference[:], consec((3, 3, 3)), atol=0, rtol=0)
55
56        # indexing with Ellipsis
57        self.assertEqual(
58            reference[..., 2],
59            torch.tensor([[3.0, 6.0, 9.0], [12.0, 15.0, 18.0], [21.0, 24.0, 27.0]]),
60            atol=0,
61            rtol=0,
62        )
63        self.assertEqual(
64            reference[0, ..., 2], torch.tensor([3.0, 6.0, 9.0]), atol=0, rtol=0
65        )
66        self.assertEqual(reference[..., 2], reference[:, :, 2], atol=0, rtol=0)
67        self.assertEqual(reference[0, ..., 2], reference[0, :, 2], atol=0, rtol=0)
68        self.assertEqual(reference[0, 2, ...], reference[0, 2], atol=0, rtol=0)
69        self.assertEqual(reference[..., 2, 2, 2], 27, atol=0, rtol=0)
70        self.assertEqual(reference[2, ..., 2, 2], 27, atol=0, rtol=0)
71        self.assertEqual(reference[2, 2, ..., 2], 27, atol=0, rtol=0)
72        self.assertEqual(reference[2, 2, 2, ...], 27, atol=0, rtol=0)
73        self.assertEqual(reference[...], reference, atol=0, rtol=0)
74
75        reference_5d = consec((3, 3, 3, 3, 3)).to(device)
76        self.assertEqual(
77            reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], atol=0, rtol=0
78        )
79        self.assertEqual(
80            reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], atol=0, rtol=0
81        )
82        self.assertEqual(
83            reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], atol=0, rtol=0
84        )
85        self.assertEqual(reference_5d[...], reference_5d, atol=0, rtol=0)
86
87        # LongTensor indexing
88        reference = consec((5, 5, 5)).to(device)
89        idx = torch.LongTensor([2, 4]).to(device)
90        self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]]))
91        # TODO: enable one indexing is implemented like in numpy
92        # self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]]))
93        # self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1])
94
95        # None indexing
96        self.assertEqual(reference[2, None], reference[2].unsqueeze(0))
97        self.assertEqual(
98            reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0)
99        )
100        self.assertEqual(reference[2:4, None], reference[2:4].unsqueeze(1))
101        self.assertEqual(
102            reference[None, 2, None, None],
103            reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0),
104        )
105        self.assertEqual(
106            reference[None, 2:5, None, None],
107            reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2),
108        )
109
110        # indexing 0-length slice
111        self.assertEqual(torch.empty(0, 5, 5), reference[slice(0)])
112        self.assertEqual(torch.empty(0, 5), reference[slice(0), 2])
113        self.assertEqual(torch.empty(0, 5), reference[2, slice(0)])
114        self.assertEqual(torch.tensor([]), reference[2, 1:1, 2])
115
116        # indexing with step
117        reference = consec((10, 10, 10)).to(device)
118        self.assertEqual(reference[1:5:2], torch.stack([reference[1], reference[3]], 0))
119        self.assertEqual(
120            reference[1:6:2], torch.stack([reference[1], reference[3], reference[5]], 0)
121        )
122        self.assertEqual(reference[1:9:4], torch.stack([reference[1], reference[5]], 0))
123        self.assertEqual(
124            reference[2:4, 1:5:2],
125            torch.stack([reference[2:4, 1], reference[2:4, 3]], 1),
126        )
127        self.assertEqual(
128            reference[3, 1:6:2],
129            torch.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0),
130        )
131        self.assertEqual(
132            reference[None, 2, 1:9:4],
133            torch.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0),
134        )
135        self.assertEqual(
136            reference[:, 2, 1:6:2],
137            torch.stack(
138                [reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1
139            ),
140        )
141
142        lst = [list(range(i, i + 10)) for i in range(0, 100, 10)]
143        tensor = torch.DoubleTensor(lst).to(device)
144        for _i in range(100):
145            idx1_start = random.randrange(10)
146            idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1)
147            idx1_step = random.randrange(1, 8)
148            idx1 = slice(idx1_start, idx1_end, idx1_step)
149            if random.randrange(2) == 0:
150                idx2_start = random.randrange(10)
151                idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1)
152                idx2_step = random.randrange(1, 8)
153                idx2 = slice(idx2_start, idx2_end, idx2_step)
154                lst_indexed = [l[idx2] for l in lst[idx1]]
155                tensor_indexed = tensor[idx1, idx2]
156            else:
157                lst_indexed = lst[idx1]
158                tensor_indexed = tensor[idx1]
159            self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed)
160
161        self.assertRaises(ValueError, lambda: reference[1:9:0])
162        self.assertRaises(ValueError, lambda: reference[1:9:-1])
163
164        self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])
165        self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])
166        self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])
167
168        self.assertRaises(IndexError, lambda: reference[0.0])
169        self.assertRaises(TypeError, lambda: reference[0.0:2.0])
170        self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])
171        self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])
172        self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])
173        self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])
174
175        def delitem():
176            del reference[0]
177
178        self.assertRaises(TypeError, delitem)
179
180    @onlyNativeDeviceTypes
181    @dtypes(torch.half, torch.double)
182    def test_advancedindex(self, device, dtype):
183        # Tests for Integer Array Indexing, Part I - Purely integer array
184        # indexing
185
186        def consec(size, start=1):
187            # Creates the sequence in float since CPU half doesn't support the
188            # needed operations. Converts to dtype before returning.
189            numel = reduce(operator.mul, size, 1)
190            sequence = torch.ones(numel, dtype=torch.float, device=device).cumsum(0)
191            sequence.add_(start - 1)
192            return sequence.view(*size).to(dtype=dtype)
193
194        # pick a random valid indexer type
195        def ri(indices):
196            choice = random.randint(0, 2)
197            if choice == 0:
198                return torch.LongTensor(indices).to(device)
199            elif choice == 1:
200                return list(indices)
201            else:
202                return tuple(indices)
203
204        def validate_indexing(x):
205            self.assertEqual(x[[0]], consec((1,)))
206            self.assertEqual(x[ri([0]),], consec((1,)))
207            self.assertEqual(x[ri([3]),], consec((1,), 4))
208            self.assertEqual(x[[2, 3, 4]], consec((3,), 3))
209            self.assertEqual(x[ri([2, 3, 4]),], consec((3,), 3))
210            self.assertEqual(
211                x[ri([0, 2, 4]),], torch.tensor([1, 3, 5], dtype=dtype, device=device)
212            )
213
214        def validate_setting(x):
215            x[[0]] = -2
216            self.assertEqual(x[[0]], torch.tensor([-2], dtype=dtype, device=device))
217            x[[0]] = -1
218            self.assertEqual(
219                x[ri([0]),], torch.tensor([-1], dtype=dtype, device=device)
220            )
221            x[[2, 3, 4]] = 4
222            self.assertEqual(
223                x[[2, 3, 4]], torch.tensor([4, 4, 4], dtype=dtype, device=device)
224            )
225            x[ri([2, 3, 4]),] = 3
226            self.assertEqual(
227                x[ri([2, 3, 4]),], torch.tensor([3, 3, 3], dtype=dtype, device=device)
228            )
229            x[ri([0, 2, 4]),] = torch.tensor([5, 4, 3], dtype=dtype, device=device)
230            self.assertEqual(
231                x[ri([0, 2, 4]),], torch.tensor([5, 4, 3], dtype=dtype, device=device)
232            )
233
234        # Only validates indexing and setting for halfs
235        if dtype == torch.half:
236            reference = consec((10,))
237            validate_indexing(reference)
238            validate_setting(reference)
239            return
240
241        # Case 1: Purely Integer Array Indexing
242        reference = consec((10,))
243        validate_indexing(reference)
244
245        # setting values
246        validate_setting(reference)
247
248        # Tensor with stride != 1
249        # strided is [1, 3, 5, 7]
250        reference = consec((10,))
251        strided = torch.tensor((), dtype=dtype, device=device)
252        strided.set_(
253            reference.storage(), storage_offset=0, size=torch.Size([4]), stride=[2]
254        )
255
256        self.assertEqual(strided[[0]], torch.tensor([1], dtype=dtype, device=device))
257        self.assertEqual(
258            strided[ri([0]),], torch.tensor([1], dtype=dtype, device=device)
259        )
260        self.assertEqual(
261            strided[ri([3]),], torch.tensor([7], dtype=dtype, device=device)
262        )
263        self.assertEqual(
264            strided[[1, 2]], torch.tensor([3, 5], dtype=dtype, device=device)
265        )
266        self.assertEqual(
267            strided[ri([1, 2]),], torch.tensor([3, 5], dtype=dtype, device=device)
268        )
269        self.assertEqual(
270            strided[ri([[2, 1], [0, 3]]),],
271            torch.tensor([[5, 3], [1, 7]], dtype=dtype, device=device),
272        )
273
274        # stride is [4, 8]
275        strided = torch.tensor((), dtype=dtype, device=device)
276        strided.set_(
277            reference.storage(), storage_offset=4, size=torch.Size([2]), stride=[4]
278        )
279        self.assertEqual(strided[[0]], torch.tensor([5], dtype=dtype, device=device))
280        self.assertEqual(
281            strided[ri([0]),], torch.tensor([5], dtype=dtype, device=device)
282        )
283        self.assertEqual(
284            strided[ri([1]),], torch.tensor([9], dtype=dtype, device=device)
285        )
286        self.assertEqual(
287            strided[[0, 1]], torch.tensor([5, 9], dtype=dtype, device=device)
288        )
289        self.assertEqual(
290            strided[ri([0, 1]),], torch.tensor([5, 9], dtype=dtype, device=device)
291        )
292        self.assertEqual(
293            strided[ri([[0, 1], [1, 0]]),],
294            torch.tensor([[5, 9], [9, 5]], dtype=dtype, device=device),
295        )
296
297        # reference is 1 2
298        #              3 4
299        #              5 6
300        reference = consec((3, 2))
301        self.assertEqual(
302            reference[ri([0, 1, 2]), ri([0])],
303            torch.tensor([1, 3, 5], dtype=dtype, device=device),
304        )
305        self.assertEqual(
306            reference[ri([0, 1, 2]), ri([1])],
307            torch.tensor([2, 4, 6], dtype=dtype, device=device),
308        )
309        self.assertEqual(reference[ri([0]), ri([0])], consec((1,)))
310        self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6))
311        self.assertEqual(
312            reference[[ri([0, 0]), ri([0, 1])]],
313            torch.tensor([1, 2], dtype=dtype, device=device),
314        )
315        self.assertEqual(
316            reference[[ri([0, 1, 1, 0, 2]), ri([1])]],
317            torch.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device),
318        )
319        self.assertEqual(
320            reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
321            torch.tensor([1, 2, 3, 3], dtype=dtype, device=device),
322        )
323
324        rows = ri([[0, 0], [1, 2]])
325        columns = ([0],)
326        self.assertEqual(
327            reference[rows, columns],
328            torch.tensor([[1, 1], [3, 5]], dtype=dtype, device=device),
329        )
330
331        rows = ri([[0, 0], [1, 2]])
332        columns = ri([1, 0])
333        self.assertEqual(
334            reference[rows, columns],
335            torch.tensor([[2, 1], [4, 5]], dtype=dtype, device=device),
336        )
337        rows = ri([[0, 0], [1, 2]])
338        columns = ri([[0, 1], [1, 0]])
339        self.assertEqual(
340            reference[rows, columns],
341            torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device),
342        )
343
344        # setting values
345        reference[ri([0]), ri([1])] = -1
346        self.assertEqual(
347            reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device)
348        )
349        reference[ri([0, 1, 2]), ri([0])] = torch.tensor(
350            [-1, 2, -4], dtype=dtype, device=device
351        )
352        self.assertEqual(
353            reference[ri([0, 1, 2]), ri([0])],
354            torch.tensor([-1, 2, -4], dtype=dtype, device=device),
355        )
356        reference[rows, columns] = torch.tensor(
357            [[4, 6], [2, 3]], dtype=dtype, device=device
358        )
359        self.assertEqual(
360            reference[rows, columns],
361            torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device),
362        )
363
364        # Verify still works with Transposed (i.e. non-contiguous) Tensors
365
366        reference = torch.tensor(
367            [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype, device=device
368        ).t_()
369
370        # Transposed: [[0, 4, 8],
371        #              [1, 5, 9],
372        #              [2, 6, 10],
373        #              [3, 7, 11]]
374
375        self.assertEqual(
376            reference[ri([0, 1, 2]), ri([0])],
377            torch.tensor([0, 1, 2], dtype=dtype, device=device),
378        )
379        self.assertEqual(
380            reference[ri([0, 1, 2]), ri([1])],
381            torch.tensor([4, 5, 6], dtype=dtype, device=device),
382        )
383        self.assertEqual(
384            reference[ri([0]), ri([0])], torch.tensor([0], dtype=dtype, device=device)
385        )
386        self.assertEqual(
387            reference[ri([2]), ri([1])], torch.tensor([6], dtype=dtype, device=device)
388        )
389        self.assertEqual(
390            reference[[ri([0, 0]), ri([0, 1])]],
391            torch.tensor([0, 4], dtype=dtype, device=device),
392        )
393        self.assertEqual(
394            reference[[ri([0, 1, 1, 0, 3]), ri([1])]],
395            torch.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device),
396        )
397        self.assertEqual(
398            reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
399            torch.tensor([0, 4, 1, 1], dtype=dtype, device=device),
400        )
401
402        rows = ri([[0, 0], [1, 2]])
403        columns = ([0],)
404        self.assertEqual(
405            reference[rows, columns],
406            torch.tensor([[0, 0], [1, 2]], dtype=dtype, device=device),
407        )
408
409        rows = ri([[0, 0], [1, 2]])
410        columns = ri([1, 0])
411        self.assertEqual(
412            reference[rows, columns],
413            torch.tensor([[4, 0], [5, 2]], dtype=dtype, device=device),
414        )
415        rows = ri([[0, 0], [1, 3]])
416        columns = ri([[0, 1], [1, 2]])
417        self.assertEqual(
418            reference[rows, columns],
419            torch.tensor([[0, 4], [5, 11]], dtype=dtype, device=device),
420        )
421
422        # setting values
423        reference[ri([0]), ri([1])] = -1
424        self.assertEqual(
425            reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device)
426        )
427        reference[ri([0, 1, 2]), ri([0])] = torch.tensor(
428            [-1, 2, -4], dtype=dtype, device=device
429        )
430        self.assertEqual(
431            reference[ri([0, 1, 2]), ri([0])],
432            torch.tensor([-1, 2, -4], dtype=dtype, device=device),
433        )
434        reference[rows, columns] = torch.tensor(
435            [[4, 6], [2, 3]], dtype=dtype, device=device
436        )
437        self.assertEqual(
438            reference[rows, columns],
439            torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device),
440        )
441
442        # stride != 1
443
444        # strided is [[1 3 5 7],
445        #             [9 11 13 15]]
446
447        reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
448        strided = torch.tensor((), dtype=dtype, device=device)
449        strided.set_(reference.storage(), 1, size=torch.Size([2, 4]), stride=[8, 2])
450
451        self.assertEqual(
452            strided[ri([0, 1]), ri([0])],
453            torch.tensor([1, 9], dtype=dtype, device=device),
454        )
455        self.assertEqual(
456            strided[ri([0, 1]), ri([1])],
457            torch.tensor([3, 11], dtype=dtype, device=device),
458        )
459        self.assertEqual(
460            strided[ri([0]), ri([0])], torch.tensor([1], dtype=dtype, device=device)
461        )
462        self.assertEqual(
463            strided[ri([1]), ri([3])], torch.tensor([15], dtype=dtype, device=device)
464        )
465        self.assertEqual(
466            strided[[ri([0, 0]), ri([0, 3])]],
467            torch.tensor([1, 7], dtype=dtype, device=device),
468        )
469        self.assertEqual(
470            strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
471            torch.tensor([9, 11, 11, 9, 15], dtype=dtype, device=device),
472        )
473        self.assertEqual(
474            strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
475            torch.tensor([1, 3, 9, 9], dtype=dtype, device=device),
476        )
477
478        rows = ri([[0, 0], [1, 1]])
479        columns = ([0],)
480        self.assertEqual(
481            strided[rows, columns],
482            torch.tensor([[1, 1], [9, 9]], dtype=dtype, device=device),
483        )
484
485        rows = ri([[0, 1], [1, 0]])
486        columns = ri([1, 2])
487        self.assertEqual(
488            strided[rows, columns],
489            torch.tensor([[3, 13], [11, 5]], dtype=dtype, device=device),
490        )
491        rows = ri([[0, 0], [1, 1]])
492        columns = ri([[0, 1], [1, 2]])
493        self.assertEqual(
494            strided[rows, columns],
495            torch.tensor([[1, 3], [11, 13]], dtype=dtype, device=device),
496        )
497
498        # setting values
499
500        # strided is [[10, 11],
501        #             [17, 18]]
502
503        reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
504        strided = torch.tensor((), dtype=dtype, device=device)
505        strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1])
506        self.assertEqual(
507            strided[ri([0]), ri([1])], torch.tensor([11], dtype=dtype, device=device)
508        )
509        strided[ri([0]), ri([1])] = -1
510        self.assertEqual(
511            strided[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device)
512        )
513
514        reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
515        strided = torch.tensor((), dtype=dtype, device=device)
516        strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1])
517        self.assertEqual(
518            strided[ri([0, 1]), ri([1, 0])],
519            torch.tensor([11, 17], dtype=dtype, device=device),
520        )
521        strided[ri([0, 1]), ri([1, 0])] = torch.tensor(
522            [-1, 2], dtype=dtype, device=device
523        )
524        self.assertEqual(
525            strided[ri([0, 1]), ri([1, 0])],
526            torch.tensor([-1, 2], dtype=dtype, device=device),
527        )
528
529        reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
530        strided = torch.tensor((), dtype=dtype, device=device)
531        strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1])
532
533        rows = ri([[0], [1]])
534        columns = ri([[0, 1], [0, 1]])
535        self.assertEqual(
536            strided[rows, columns],
537            torch.tensor([[10, 11], [17, 18]], dtype=dtype, device=device),
538        )
539        strided[rows, columns] = torch.tensor(
540            [[4, 6], [2, 3]], dtype=dtype, device=device
541        )
542        self.assertEqual(
543            strided[rows, columns],
544            torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device),
545        )
546
547        # Tests using less than the number of dims, and ellipsis
548
549        # reference is 1 2
550        #              3 4
551        #              5 6
552        reference = consec((3, 2))
553        self.assertEqual(
554            reference[ri([0, 2]),],
555            torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device),
556        )
557        self.assertEqual(
558            reference[ri([1]), ...], torch.tensor([[3, 4]], dtype=dtype, device=device)
559        )
560        self.assertEqual(
561            reference[..., ri([1])],
562            torch.tensor([[2], [4], [6]], dtype=dtype, device=device),
563        )
564
565        # verify too many indices fails
566        with self.assertRaises(IndexError):
567            reference[ri([1]), ri([0, 2]), ri([3])]
568
569        # test invalid index fails
570        reference = torch.empty(10, dtype=dtype, device=device)
571        # can't test cuda because it is a device assert
572        if not reference.is_cuda:
573            for err_idx in (10, -11):
574                with self.assertRaisesRegex(IndexError, r"out of"):
575                    reference[err_idx]
576                with self.assertRaisesRegex(IndexError, r"out of"):
577                    reference[torch.LongTensor([err_idx]).to(device)]
578                with self.assertRaisesRegex(IndexError, r"out of"):
579                    reference[[err_idx]]
580
581        def tensor_indices_to_np(tensor, indices):
582            # convert the Torch Tensor to a numpy array
583            tensor = tensor.to(device="cpu")
584            npt = tensor.numpy()
585
586            # convert indices
587            idxs = tuple(
588                i.tolist() if isinstance(i, torch.LongTensor) else i for i in indices
589            )
590
591            return npt, idxs
592
593        def get_numpy(tensor, indices):
594            npt, idxs = tensor_indices_to_np(tensor, indices)
595
596            # index and return as a Torch Tensor
597            return torch.tensor(npt[idxs], dtype=dtype, device=device)
598
599        def set_numpy(tensor, indices, value):
600            if not isinstance(value, int):
601                if self.device_type != "cpu":
602                    value = value.cpu()
603                value = value.numpy()
604
605            npt, idxs = tensor_indices_to_np(tensor, indices)
606            npt[idxs] = value
607            return npt
608
609        def assert_get_eq(tensor, indexer):
610            self.assertEqual(tensor[indexer], get_numpy(tensor, indexer))
611
612        def assert_set_eq(tensor, indexer, val):
613            pyt = tensor.clone()
614            numt = tensor.clone()
615            pyt[indexer] = val
616            numt = torch.tensor(
617                set_numpy(numt, indexer, val), dtype=dtype, device=device
618            )
619            self.assertEqual(pyt, numt)
620
621        def assert_backward_eq(tensor, indexer):
622            cpu = tensor.float().clone().detach().requires_grad_(True)
623            outcpu = cpu[indexer]
624            gOcpu = torch.rand_like(outcpu)
625            outcpu.backward(gOcpu)
626            dev = cpu.to(device).detach().requires_grad_(True)
627            outdev = dev[indexer]
628            outdev.backward(gOcpu.to(device))
629            self.assertEqual(cpu.grad, dev.grad)
630
631        def get_set_tensor(indexed, indexer):
632            set_size = indexed[indexer].size()
633            set_count = indexed[indexer].numel()
634            set_tensor = torch.randperm(set_count).view(set_size).double().to(device)
635            return set_tensor
636
637        # Tensor is  0  1  2  3  4
638        #            5  6  7  8  9
639        #           10 11 12 13 14
640        #           15 16 17 18 19
641        reference = torch.arange(0.0, 20, dtype=dtype, device=device).view(4, 5)
642
643        indices_to_test = [
644            # grab the second, fourth columns
645            [slice(None), [1, 3]],
646            # first, third rows,
647            [[0, 2], slice(None)],
648            # weird shape
649            [slice(None), [[0, 1], [2, 3]]],
650            # negatives
651            [[-1], [0]],
652            [[0, 2], [-1]],
653            [slice(None), [-1]],
654        ]
655
656        # only test dupes on gets
657        get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]
658
659        for indexer in get_indices_to_test:
660            assert_get_eq(reference, indexer)
661            if self.device_type != "cpu":
662                assert_backward_eq(reference, indexer)
663
664        for indexer in indices_to_test:
665            assert_set_eq(reference, indexer, 44)
666            assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
667
668        reference = torch.arange(0.0, 160, dtype=dtype, device=device).view(4, 8, 5)
669
670        indices_to_test = [
671            [slice(None), slice(None), [0, 3, 4]],
672            [slice(None), [2, 4, 5, 7], slice(None)],
673            [[2, 3], slice(None), slice(None)],
674            [slice(None), [0, 2, 3], [1, 3, 4]],
675            [slice(None), [0], [1, 2, 4]],
676            [slice(None), [0, 1, 3], [4]],
677            [slice(None), [[0, 1], [1, 0]], [[2, 3]]],
678            [slice(None), [[0, 1], [2, 3]], [[0]]],
679            [slice(None), [[5, 6]], [[0, 3], [4, 4]]],
680            [[0, 2, 3], [1, 3, 4], slice(None)],
681            [[0], [1, 2, 4], slice(None)],
682            [[0, 1, 3], [4], slice(None)],
683            [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
684            [[[0, 1], [1, 0]], [[2, 3]], slice(None)],
685            [[[0, 1], [2, 3]], [[0]], slice(None)],
686            [[[2, 1]], [[0, 3], [4, 4]], slice(None)],
687            [[[2]], [[0, 3], [4, 1]], slice(None)],
688            # non-contiguous indexing subspace
689            [[0, 2, 3], slice(None), [1, 3, 4]],
690            # [...]
691            # less dim, ellipsis
692            [[0, 2]],
693            [[0, 2], slice(None)],
694            [[0, 2], Ellipsis],
695            [[0, 2], slice(None), Ellipsis],
696            [[0, 2], Ellipsis, slice(None)],
697            [[0, 2], [1, 3]],
698            [[0, 2], [1, 3], Ellipsis],
699            [Ellipsis, [1, 3], [2, 3]],
700            [Ellipsis, [2, 3, 4]],
701            [Ellipsis, slice(None), [2, 3, 4]],
702            [slice(None), Ellipsis, [2, 3, 4]],
703            # ellipsis counts for nothing
704            [Ellipsis, slice(None), slice(None), [0, 3, 4]],
705            [slice(None), Ellipsis, slice(None), [0, 3, 4]],
706            [slice(None), slice(None), Ellipsis, [0, 3, 4]],
707            [slice(None), slice(None), [0, 3, 4], Ellipsis],
708            [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
709            [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
710            [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
711        ]
712
713        for indexer in indices_to_test:
714            assert_get_eq(reference, indexer)
715            assert_set_eq(reference, indexer, 212)
716            assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
717            if torch.cuda.is_available():
718                assert_backward_eq(reference, indexer)
719
720        reference = torch.arange(0.0, 1296, dtype=dtype, device=device).view(3, 9, 8, 6)
721
722        indices_to_test = [
723            [slice(None), slice(None), slice(None), [0, 3, 4]],
724            [slice(None), slice(None), [2, 4, 5, 7], slice(None)],
725            [slice(None), [2, 3], slice(None), slice(None)],
726            [[1, 2], slice(None), slice(None), slice(None)],
727            [slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
728            [slice(None), slice(None), [0], [1, 2, 4]],
729            [slice(None), slice(None), [0, 1, 3], [4]],
730            [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
731            [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
732            [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
733            [slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
734            [slice(None), [0], [1, 2, 4], slice(None)],
735            [slice(None), [0, 1, 3], [4], slice(None)],
736            [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
737            [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
738            [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
739            [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
740            [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
741            [[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
742            [[0], [1, 2, 4], slice(None), slice(None)],
743            [[0, 1, 2], [4], slice(None), slice(None)],
744            [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
745            [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
746            [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
747            [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
748            [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
749            [slice(None), [2, 3, 4], [1, 3, 4], [4]],
750            [slice(None), [0, 1, 3], [4], [1, 3, 4]],
751            [slice(None), [6], [0, 2, 3], [1, 3, 4]],
752            [slice(None), [2, 3, 5], [3], [4]],
753            [slice(None), [0], [4], [1, 3, 4]],
754            [slice(None), [6], [0, 2, 3], [1]],
755            [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
756            [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
757            [[2, 0, 1], [1, 2, 3], [4], slice(None)],
758            [[0, 1, 2], [4], [1, 3, 4], slice(None)],
759            [[0], [0, 2, 3], [1, 3, 4], slice(None)],
760            [[0, 2, 1], [3], [4], slice(None)],
761            [[0], [4], [1, 3, 4], slice(None)],
762            [[1], [0, 2, 3], [1], slice(None)],
763            [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
764            # less dim, ellipsis
765            [Ellipsis, [0, 3, 4]],
766            [Ellipsis, slice(None), [0, 3, 4]],
767            [Ellipsis, slice(None), slice(None), [0, 3, 4]],
768            [slice(None), Ellipsis, [0, 3, 4]],
769            [slice(None), slice(None), Ellipsis, [0, 3, 4]],
770            [slice(None), [0, 2, 3], [1, 3, 4]],
771            [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
772            [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
773            [[0], [1, 2, 4]],
774            [[0], [1, 2, 4], slice(None)],
775            [[0], [1, 2, 4], Ellipsis],
776            [[0], [1, 2, 4], Ellipsis, slice(None)],
777            [[1]],
778            [[0, 2, 1], [3], [4]],
779            [[0, 2, 1], [3], [4], slice(None)],
780            [[0, 2, 1], [3], [4], Ellipsis],
781            [Ellipsis, [0, 2, 1], [3], [4]],
782        ]
783
784        for indexer in indices_to_test:
785            assert_get_eq(reference, indexer)
786            assert_set_eq(reference, indexer, 1333)
787            assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
788        indices_to_test += [
789            [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
790            [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
791        ]
792        for indexer in indices_to_test:
793            assert_get_eq(reference, indexer)
794            assert_set_eq(reference, indexer, 1333)
795            if self.device_type != "cpu":
796                assert_backward_eq(reference, indexer)
797
798    def test_advancedindex_big(self, device):
799        reference = torch.arange(0, 123344, dtype=torch.int, device=device)
800
801        self.assertEqual(
802            reference[[0, 123, 44488, 68807, 123343],],
803            torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int),
804        )
805
806    def test_set_item_to_scalar_tensor(self, device):
807        m = random.randint(1, 10)
808        n = random.randint(1, 10)
809        z = torch.randn([m, n], device=device)
810        a = 1.0
811        w = torch.tensor(a, requires_grad=True, device=device)
812        z[:, 0] = w
813        z.sum().backward()
814        self.assertEqual(w.grad, m * a)
815
816    def test_single_int(self, device):
817        v = torch.randn(5, 7, 3, device=device)
818        self.assertEqual(v[4].shape, (7, 3))
819
820    def test_multiple_int(self, device):
821        v = torch.randn(5, 7, 3, device=device)
822        self.assertEqual(v[4].shape, (7, 3))
823        self.assertEqual(v[4, :, 1].shape, (7,))
824
825    def test_none(self, device):
826        v = torch.randn(5, 7, 3, device=device)
827        self.assertEqual(v[None].shape, (1, 5, 7, 3))
828        self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
829        self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
830        self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
831
832    def test_step(self, device):
833        v = torch.arange(10, device=device)
834        self.assertEqual(v[::1], v)
835        self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
836        self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
837        self.assertEqual(v[::11].tolist(), [0])
838        self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
839
840    def test_step_assignment(self, device):
841        v = torch.zeros(4, 4, device=device)
842        v[0, 1::2] = torch.tensor([3.0, 4.0], device=device)
843        self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
844        self.assertEqual(v[1:].sum(), 0)
845
846    def test_bool_indices(self, device):
847        v = torch.randn(5, 7, 3, device=device)
848        boolIndices = torch.tensor(
849            [True, False, True, True, False], dtype=torch.bool, device=device
850        )
851        self.assertEqual(v[boolIndices].shape, (3, 7, 3))
852        self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]]))
853
854        v = torch.tensor([True, False, True], dtype=torch.bool, device=device)
855        boolIndices = torch.tensor(
856            [True, False, False], dtype=torch.bool, device=device
857        )
858        uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device)
859        with warnings.catch_warnings(record=True) as w:
860            v1 = v[boolIndices]
861            v2 = v[uint8Indices]
862            self.assertEqual(v1.shape, v2.shape)
863            self.assertEqual(v1, v2)
864            self.assertEqual(
865                v[boolIndices], tensor([True], dtype=torch.bool, device=device)
866            )
867            self.assertEqual(len(w), 1)
868
869    def test_bool_indices_accumulate(self, device):
870        mask = torch.zeros(size=(10,), dtype=torch.bool, device=device)
871        y = torch.ones(size=(10, 10), device=device)
872        y.index_put_((mask,), y[mask], accumulate=True)
873        self.assertEqual(y, torch.ones(size=(10, 10), device=device))
874
875    def test_multiple_bool_indices(self, device):
876        v = torch.randn(5, 7, 3, device=device)
877        # note: these broadcast together and are transposed to the first dim
878        mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device)
879        mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
880        self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
881
882    def test_byte_mask(self, device):
883        v = torch.randn(5, 7, 3, device=device)
884        mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
885        with warnings.catch_warnings(record=True) as w:
886            res = v[mask]
887            self.assertEqual(res.shape, (3, 7, 3))
888            self.assertEqual(res, torch.stack([v[0], v[2], v[3]]))
889            self.assertEqual(len(w), 1)
890
891        v = torch.tensor([1.0], device=device)
892        self.assertEqual(v[v == 0], torch.tensor([], device=device))
893
894    def test_byte_mask_accumulate(self, device):
895        mask = torch.zeros(size=(10,), dtype=torch.uint8, device=device)
896        y = torch.ones(size=(10, 10), device=device)
897        with warnings.catch_warnings(record=True) as w:
898            warnings.simplefilter("always")
899            y.index_put_((mask,), y[mask], accumulate=True)
900            self.assertEqual(y, torch.ones(size=(10, 10), device=device))
901            self.assertEqual(len(w), 2)
902
903    @skipIfTorchDynamo(
904        "This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472"
905    )
906    @serialTest(TEST_CUDA)
907    def test_index_put_accumulate_large_tensor(self, device):
908        # This test is for tensors with number of elements >= INT_MAX (2^31 - 1).
909        N = (1 << 31) + 5
910        dt = torch.int8
911        a = torch.ones(N, dtype=dt, device=device)
912        indices = torch.tensor(
913            [-2, 0, -2, -1, 0, -1, 1], device=device, dtype=torch.long
914        )
915        values = torch.tensor([6, 5, 6, 6, 5, 7, 11], dtype=dt, device=device)
916
917        a.index_put_((indices,), values, accumulate=True)
918
919        self.assertEqual(a[0], 11)
920        self.assertEqual(a[1], 12)
921        self.assertEqual(a[2], 1)
922        self.assertEqual(a[-3], 1)
923        self.assertEqual(a[-2], 13)
924        self.assertEqual(a[-1], 14)
925
926        a = torch.ones((2, N), dtype=dt, device=device)
927        indices0 = torch.tensor([0, -1, 0, 1], device=device, dtype=torch.long)
928        indices1 = torch.tensor([-2, -1, 0, 1], device=device, dtype=torch.long)
929        values = torch.tensor([12, 13, 10, 11], dtype=dt, device=device)
930
931        a.index_put_((indices0, indices1), values, accumulate=True)
932
933        self.assertEqual(a[0, 0], 11)
934        self.assertEqual(a[0, 1], 1)
935        self.assertEqual(a[1, 0], 1)
936        self.assertEqual(a[1, 1], 12)
937        self.assertEqual(a[:, 2], torch.ones(2, dtype=torch.int8))
938        self.assertEqual(a[:, -3], torch.ones(2, dtype=torch.int8))
939        self.assertEqual(a[0, -2], 13)
940        self.assertEqual(a[1, -2], 1)
941        self.assertEqual(a[-1, -1], 14)
942        self.assertEqual(a[0, -1], 1)
943
944    @onlyNativeDeviceTypes
945    def test_index_put_accumulate_expanded_values(self, device):
946        # checks the issue with cuda: https://github.com/pytorch/pytorch/issues/39227
947        # and verifies consistency with CPU result
948        t = torch.zeros((5, 2))
949        t_dev = t.to(device)
950        indices = [torch.tensor([0, 1, 2, 3]), torch.tensor([1])]
951        indices_dev = [i.to(device) for i in indices]
952        values0d = torch.tensor(1.0)
953        values1d = torch.tensor([1.0])
954
955        out_cuda = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True)
956        out_cpu = t.index_put_(indices, values0d, accumulate=True)
957        self.assertEqual(out_cuda.cpu(), out_cpu)
958
959        out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
960        out_cpu = t.index_put_(indices, values1d, accumulate=True)
961        self.assertEqual(out_cuda.cpu(), out_cpu)
962
963        t = torch.zeros(4, 3, 2)
964        t_dev = t.to(device)
965
966        indices = [
967            torch.tensor([0]),
968            torch.arange(3)[:, None],
969            torch.arange(2)[None, :],
970        ]
971        indices_dev = [i.to(device) for i in indices]
972        values1d = torch.tensor([-1.0, -2.0])
973        values2d = torch.tensor([[-1.0, -2.0]])
974
975        out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
976        out_cpu = t.index_put_(indices, values1d, accumulate=True)
977        self.assertEqual(out_cuda.cpu(), out_cpu)
978
979        out_cuda = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True)
980        out_cpu = t.index_put_(indices, values2d, accumulate=True)
981        self.assertEqual(out_cuda.cpu(), out_cpu)
982
983    @onlyCUDA
984    def test_index_put_accumulate_non_contiguous(self, device):
985        t = torch.zeros((5, 2, 2))
986        t_dev = t.to(device)
987        t1 = t_dev[:, 0, :]
988        t2 = t[:, 0, :]
989        self.assertTrue(not t1.is_contiguous())
990        self.assertTrue(not t2.is_contiguous())
991
992        indices = [torch.tensor([0, 1])]
993        indices_dev = [i.to(device) for i in indices]
994        value = torch.randn(2, 2)
995        out_cuda = t1.index_put_(indices_dev, value.to(device), accumulate=True)
996        out_cpu = t2.index_put_(indices, value, accumulate=True)
997        self.assertTrue(not t1.is_contiguous())
998        self.assertTrue(not t2.is_contiguous())
999
1000        self.assertEqual(out_cuda.cpu(), out_cpu)
1001
1002    @onlyCUDA
1003    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1004    def test_index_put_accumulate_with_optional_tensors(self, device):
1005        # TODO: replace with a better solution.
1006        # Currently, here using torchscript to put None into indices.
1007        # on C++ it gives indices as a list of 2 optional tensors: first is null and
1008        # the second is a valid tensor.
1009        @torch.jit.script
1010        def func(x, i, v):
1011            idx = [None, i]
1012            x.index_put_(idx, v, accumulate=True)
1013            return x
1014
1015        n = 4
1016        t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
1017        t_dev = t.to(device)
1018        indices = torch.tensor([1, 0])
1019        indices_dev = indices.to(device)
1020        value0d = torch.tensor(10.0)
1021        value1d = torch.tensor([1.0, 2.0])
1022
1023        out_cuda = func(t_dev, indices_dev, value0d.cuda())
1024        out_cpu = func(t, indices, value0d)
1025        self.assertEqual(out_cuda.cpu(), out_cpu)
1026
1027        out_cuda = func(t_dev, indices_dev, value1d.cuda())
1028        out_cpu = func(t, indices, value1d)
1029        self.assertEqual(out_cuda.cpu(), out_cpu)
1030
1031    @onlyNativeDeviceTypes
1032    def test_index_put_accumulate_duplicate_indices(self, device):
1033        for i in range(1, 512):
1034            # generate indices by random walk, this will create indices with
1035            # lots of duplicates interleaved with each other
1036            delta = torch.empty(i, dtype=torch.double, device=device).uniform_(-1, 1)
1037            indices = delta.cumsum(0).long()
1038
1039            input = torch.randn(indices.abs().max() + 1, device=device)
1040            values = torch.randn(indices.size(0), device=device)
1041            output = input.index_put((indices,), values, accumulate=True)
1042
1043            input_list = input.tolist()
1044            indices_list = indices.tolist()
1045            values_list = values.tolist()
1046            for i, v in zip(indices_list, values_list):
1047                input_list[i] += v
1048
1049            self.assertEqual(output, input_list)
1050
1051    @onlyNativeDeviceTypes
1052    def test_index_ind_dtype(self, device):
1053        x = torch.randn(4, 4, device=device)
1054        ind_long = torch.randint(4, (4,), dtype=torch.long, device=device)
1055        ind_int = ind_long.int()
1056        src = torch.randn(4, device=device)
1057        ref = x[ind_long, ind_long]
1058        res = x[ind_int, ind_int]
1059        self.assertEqual(ref, res)
1060        ref = x[ind_long, :]
1061        res = x[ind_int, :]
1062        self.assertEqual(ref, res)
1063        ref = x[:, ind_long]
1064        res = x[:, ind_int]
1065        self.assertEqual(ref, res)
1066        # no repeating indices for index_put
1067        ind_long = torch.arange(4, dtype=torch.long, device=device)
1068        ind_int = ind_long.int()
1069        for accum in (True, False):
1070            inp_ref = x.clone()
1071            inp_res = x.clone()
1072            torch.index_put_(inp_ref, (ind_long, ind_long), src, accum)
1073            torch.index_put_(inp_res, (ind_int, ind_int), src, accum)
1074            self.assertEqual(inp_ref, inp_res)
1075
1076    @skipXLA
1077    def test_index_put_accumulate_empty(self, device):
1078        # Regression test for https://github.com/pytorch/pytorch/issues/94667
1079        input = torch.rand([], dtype=torch.float32, device=device)
1080        with self.assertRaises(RuntimeError):
1081            input.index_put([], torch.tensor([1.0], device=device), True)
1082
1083    def test_multiple_byte_mask(self, device):
1084        v = torch.randn(5, 7, 3, device=device)
1085        # note: these broadcast together and are transposed to the first dim
1086        mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
1087        mask2 = torch.ByteTensor([1, 1, 1]).to(device)
1088        with warnings.catch_warnings(record=True) as w:
1089            warnings.simplefilter("always")
1090            self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
1091            self.assertEqual(len(w), 2)
1092
1093    def test_byte_mask2d(self, device):
1094        v = torch.randn(5, 7, 3, device=device)
1095        c = torch.randn(5, 7, device=device)
1096        num_ones = (c > 0).sum()
1097        r = v[c > 0]
1098        self.assertEqual(r.shape, (num_ones, 3))
1099
1100    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1101    def test_jit_indexing(self, device):
1102        def fn1(x):
1103            x[x < 50] = 1.0
1104            return x
1105
1106        def fn2(x):
1107            x[0:50] = 1.0
1108            return x
1109
1110        scripted_fn1 = torch.jit.script(fn1)
1111        scripted_fn2 = torch.jit.script(fn2)
1112        data = torch.arange(100, device=device, dtype=torch.float)
1113        out = scripted_fn1(data.detach().clone())
1114        ref = torch.tensor(
1115            np.concatenate((np.ones(50), np.arange(50, 100))),
1116            device=device,
1117            dtype=torch.float,
1118        )
1119        self.assertEqual(out, ref)
1120        out = scripted_fn2(data.detach().clone())
1121        self.assertEqual(out, ref)
1122
1123    def test_int_indices(self, device):
1124        v = torch.randn(5, 7, 3, device=device)
1125        self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
1126        self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
1127        self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
1128
1129    @dtypes(
1130        torch.cfloat, torch.cdouble, torch.float, torch.bfloat16, torch.long, torch.bool
1131    )
1132    @dtypesIfCPU(
1133        torch.cfloat, torch.cdouble, torch.float, torch.long, torch.bool, torch.bfloat16
1134    )
1135    @dtypesIfCUDA(
1136        torch.cfloat,
1137        torch.cdouble,
1138        torch.half,
1139        torch.long,
1140        torch.bool,
1141        torch.bfloat16,
1142        torch.float8_e5m2,
1143        torch.float8_e4m3fn,
1144    )
1145    def test_index_put_src_datatype(self, device, dtype):
1146        src = torch.ones(3, 2, 4, device=device, dtype=dtype)
1147        vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
1148        indices = (torch.tensor([0, 2, 1]),)
1149        res = src.index_put_(indices, vals, accumulate=True)
1150        self.assertEqual(res.shape, src.shape)
1151
1152    @dtypes(torch.float, torch.bfloat16, torch.long, torch.bool)
1153    @dtypesIfCPU(torch.float, torch.long, torch.bfloat16, torch.bool)
1154    @dtypesIfCUDA(torch.half, torch.long, torch.bfloat16, torch.bool)
1155    def test_index_src_datatype(self, device, dtype):
1156        src = torch.ones(3, 2, 4, device=device, dtype=dtype)
1157        # test index
1158        res = src[[0, 2, 1], :, :]
1159        self.assertEqual(res.shape, src.shape)
1160        # test index_put, no accum
1161        src[[0, 2, 1], :, :] = res
1162        self.assertEqual(res.shape, src.shape)
1163
1164    def test_int_indices2d(self, device):
1165        # From the NumPy indexing example
1166        x = torch.arange(0, 12, device=device).view(4, 3)
1167        rows = torch.tensor([[0, 0], [3, 3]], device=device)
1168        columns = torch.tensor([[0, 2], [0, 2]], device=device)
1169        self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
1170
1171    def test_int_indices_broadcast(self, device):
1172        # From the NumPy indexing example
1173        x = torch.arange(0, 12, device=device).view(4, 3)
1174        rows = torch.tensor([0, 3], device=device)
1175        columns = torch.tensor([0, 2], device=device)
1176        result = x[rows[:, None], columns]
1177        self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
1178
1179    def test_empty_index(self, device):
1180        x = torch.arange(0, 12, device=device).view(4, 3)
1181        idx = torch.tensor([], dtype=torch.long, device=device)
1182        self.assertEqual(x[idx].numel(), 0)
1183
1184        # empty assignment should have no effect but not throw an exception
1185        y = x.clone()
1186        y[idx] = -1
1187        self.assertEqual(x, y)
1188
1189        mask = torch.zeros(4, 3, device=device).bool()
1190        y[mask] = -1
1191        self.assertEqual(x, y)
1192
1193    def test_empty_ndim_index(self, device):
1194        x = torch.randn(5, device=device)
1195        self.assertEqual(
1196            torch.empty(0, 2, device=device),
1197            x[torch.empty(0, 2, dtype=torch.int64, device=device)],
1198        )
1199
1200        x = torch.randn(2, 3, 4, 5, device=device)
1201        self.assertEqual(
1202            torch.empty(2, 0, 6, 4, 5, device=device),
1203            x[:, torch.empty(0, 6, dtype=torch.int64, device=device)],
1204        )
1205
1206        x = torch.empty(10, 0, device=device)
1207        self.assertEqual(x[[1, 2]].shape, (2, 0))
1208        self.assertEqual(x[[], []].shape, (0,))
1209        with self.assertRaisesRegex(IndexError, "for dimension with size 0"):
1210            x[:, [0, 1]]
1211
1212    def test_empty_ndim_index_bool(self, device):
1213        x = torch.randn(5, device=device)
1214        self.assertRaises(
1215            IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)]
1216        )
1217
1218    def test_empty_slice(self, device):
1219        x = torch.randn(2, 3, 4, 5, device=device)
1220        y = x[:, :, :, 1]
1221        z = y[:, 1:1, :]
1222        self.assertEqual((2, 0, 4), z.shape)
1223        # this isn't technically necessary, but matches NumPy stride calculations.
1224        self.assertEqual((60, 20, 5), z.stride())
1225        self.assertTrue(z.is_contiguous())
1226
1227    def test_index_getitem_copy_bools_slices(self, device):
1228        true = torch.tensor(1, dtype=torch.uint8, device=device)
1229        false = torch.tensor(0, dtype=torch.uint8, device=device)
1230
1231        tensors = [torch.randn(2, 3, device=device), torch.tensor(3.0, device=device)]
1232
1233        for a in tensors:
1234            self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
1235            self.assertEqual(torch.empty(0, *a.shape), a[False])
1236            self.assertNotEqual(a.data_ptr(), a[true].data_ptr())
1237            self.assertEqual(torch.empty(0, *a.shape), a[false])
1238            self.assertEqual(a.data_ptr(), a[None].data_ptr())
1239            self.assertEqual(a.data_ptr(), a[...].data_ptr())
1240
1241    def test_index_setitem_bools_slices(self, device):
1242        true = torch.tensor(1, dtype=torch.uint8, device=device)
1243        false = torch.tensor(0, dtype=torch.uint8, device=device)
1244
1245        tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)]
1246
1247        for a in tensors:
1248            # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
1249            # (some of these ops already prefix a 1 to the size)
1250            neg_ones = torch.ones_like(a) * -1
1251            neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
1252            a[True] = neg_ones_expanded
1253            self.assertEqual(a, neg_ones)
1254            a[False] = 5
1255            self.assertEqual(a, neg_ones)
1256            a[true] = neg_ones_expanded * 2
1257            self.assertEqual(a, neg_ones * 2)
1258            a[false] = 5
1259            self.assertEqual(a, neg_ones * 2)
1260            a[None] = neg_ones_expanded * 3
1261            self.assertEqual(a, neg_ones * 3)
1262            a[...] = neg_ones_expanded * 4
1263            self.assertEqual(a, neg_ones * 4)
1264            if a.dim() == 0:
1265                with self.assertRaises(IndexError):
1266                    a[:] = neg_ones_expanded * 5
1267
1268    def test_index_scalar_with_bool_mask(self, device):
1269        a = torch.tensor(1, device=device)
1270        uintMask = torch.tensor(True, dtype=torch.uint8, device=device)
1271        boolMask = torch.tensor(True, dtype=torch.bool, device=device)
1272        self.assertEqual(a[uintMask], a[boolMask])
1273        self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
1274
1275        a = torch.tensor(True, dtype=torch.bool, device=device)
1276        self.assertEqual(a[uintMask], a[boolMask])
1277        self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
1278
1279    def test_setitem_expansion_error(self, device):
1280        true = torch.tensor(True, device=device)
1281        a = torch.randn(2, 3, device=device)
1282        # check prefix with  non-1s doesn't work
1283        a_expanded = a.expand(torch.Size([5, 1]) + a.size())
1284        # NumPy: ValueError
1285        with self.assertRaises(RuntimeError):
1286            a[True] = a_expanded
1287        with self.assertRaises(RuntimeError):
1288            a[true] = a_expanded
1289
1290    def test_getitem_scalars(self, device):
1291        zero = torch.tensor(0, dtype=torch.int64, device=device)
1292        one = torch.tensor(1, dtype=torch.int64, device=device)
1293
1294        # non-scalar indexed with scalars
1295        a = torch.randn(2, 3, device=device)
1296        self.assertEqual(a[0], a[zero])
1297        self.assertEqual(a[0][1], a[zero][one])
1298        self.assertEqual(a[0, 1], a[zero, one])
1299        self.assertEqual(a[0, one], a[zero, 1])
1300
1301        # indexing by a scalar should slice (not copy)
1302        self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr())
1303        self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr())
1304        self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr())
1305
1306        # scalar indexed with scalar
1307        r = torch.randn((), device=device)
1308        with self.assertRaises(IndexError):
1309            r[:]
1310        with self.assertRaises(IndexError):
1311            r[zero]
1312        self.assertEqual(r, r[...])
1313
1314    def test_setitem_scalars(self, device):
1315        zero = torch.tensor(0, dtype=torch.int64)
1316
1317        # non-scalar indexed with scalars
1318        a = torch.randn(2, 3, device=device)
1319        a_set_with_number = a.clone()
1320        a_set_with_scalar = a.clone()
1321        b = torch.randn(3, device=device)
1322
1323        a_set_with_number[0] = b
1324        a_set_with_scalar[zero] = b
1325        self.assertEqual(a_set_with_number, a_set_with_scalar)
1326        a[1, zero] = 7.7
1327        self.assertEqual(7.7, a[1, 0])
1328
1329        # scalar indexed with scalars
1330        r = torch.randn((), device=device)
1331        with self.assertRaises(IndexError):
1332            r[:] = 8.8
1333        with self.assertRaises(IndexError):
1334            r[zero] = 8.8
1335        r[...] = 9.9
1336        self.assertEqual(9.9, r)
1337
1338    def test_basic_advanced_combined(self, device):
1339        # From the NumPy indexing example
1340        x = torch.arange(0, 12, device=device).view(4, 3)
1341        self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
1342        self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
1343
1344        # Check that it is a copy
1345        unmodified = x.clone()
1346        x[1:2, [1, 2]].zero_()
1347        self.assertEqual(x, unmodified)
1348
1349        # But assignment should modify the original
1350        unmodified = x.clone()
1351        x[1:2, [1, 2]] = 0
1352        self.assertNotEqual(x, unmodified)
1353
1354    def test_int_assignment(self, device):
1355        x = torch.arange(0, 4, device=device).view(2, 2)
1356        x[1] = 5
1357        self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
1358
1359        x = torch.arange(0, 4, device=device).view(2, 2)
1360        x[1] = torch.arange(5, 7, device=device)
1361        self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
1362
1363    def test_byte_tensor_assignment(self, device):
1364        x = torch.arange(0.0, 16, device=device).view(4, 4)
1365        b = torch.ByteTensor([True, False, True, False]).to(device)
1366        value = torch.tensor([3.0, 4.0, 5.0, 6.0], device=device)
1367
1368        with warnings.catch_warnings(record=True) as w:
1369            x[b] = value
1370            self.assertEqual(len(w), 1)
1371
1372        self.assertEqual(x[0], value)
1373        self.assertEqual(x[1], torch.arange(4.0, 8, device=device))
1374        self.assertEqual(x[2], value)
1375        self.assertEqual(x[3], torch.arange(12.0, 16, device=device))
1376
1377    def test_variable_slicing(self, device):
1378        x = torch.arange(0, 16, device=device).view(4, 4)
1379        indices = torch.IntTensor([0, 1]).to(device)
1380        i, j = indices
1381        self.assertEqual(x[i:j], x[0:1])
1382
1383    def test_ellipsis_tensor(self, device):
1384        x = torch.arange(0, 9, device=device).view(3, 3)
1385        idx = torch.tensor([0, 2], device=device)
1386        self.assertEqual(x[..., idx].tolist(), [[0, 2], [3, 5], [6, 8]])
1387        self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2], [6, 7, 8]])
1388
1389    def test_unravel_index_errors(self, device):
1390        with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"):
1391            torch.unravel_index(torch.tensor(0.5, device=device), (2, 2))
1392
1393        with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"):
1394            torch.unravel_index(torch.tensor([], device=device), (10, 3, 5))
1395
1396        with self.assertRaisesRegex(
1397            TypeError, r"expected 'shape' to be int or sequence"
1398        ):
1399            torch.unravel_index(
1400                torch.tensor([1], device=device, dtype=torch.int64),
1401                torch.tensor([1, 2, 3]),
1402            )
1403
1404        with self.assertRaisesRegex(
1405            TypeError, r"expected 'shape' sequence to only contain ints"
1406        ):
1407            torch.unravel_index(
1408                torch.tensor([1], device=device, dtype=torch.int64), (1, 2, 2.0)
1409            )
1410
1411        with self.assertRaisesRegex(
1412            ValueError, r"'shape' cannot have negative values, but got \(2, -3\)"
1413        ):
1414            torch.unravel_index(torch.tensor(0, device=device), (2, -3))
1415
1416    def test_invalid_index(self, device):
1417        x = torch.arange(0, 16, device=device).view(4, 4)
1418        self.assertRaisesRegex(TypeError, "slice indices", lambda: x["0":"1"])
1419
1420    def test_out_of_bound_index(self, device):
1421        x = torch.arange(0, 100, device=device).view(2, 5, 10)
1422        self.assertRaisesRegex(
1423            IndexError,
1424            "index 5 is out of bounds for dimension 1 with size 5",
1425            lambda: x[0, 5],
1426        )
1427        self.assertRaisesRegex(
1428            IndexError,
1429            "index 4 is out of bounds for dimension 0 with size 2",
1430            lambda: x[4, 5],
1431        )
1432        self.assertRaisesRegex(
1433            IndexError,
1434            "index 15 is out of bounds for dimension 2 with size 10",
1435            lambda: x[0, 1, 15],
1436        )
1437        self.assertRaisesRegex(
1438            IndexError,
1439            "index 12 is out of bounds for dimension 2 with size 10",
1440            lambda: x[:, :, 12],
1441        )
1442
1443    def test_zero_dim_index(self, device):
1444        x = torch.tensor(10, device=device)
1445        self.assertEqual(x, x.item())
1446
1447        def runner():
1448            print(x[0])
1449            return x[0]
1450
1451        self.assertRaisesRegex(IndexError, "invalid index", runner)
1452
1453    @onlyCUDA
1454    def test_invalid_device(self, device):
1455        idx = torch.tensor([0, 1])
1456        b = torch.zeros(5, device=device)
1457        c = torch.tensor([1.0, 2.0], device="cpu")
1458
1459        for accumulate in [True, False]:
1460            self.assertRaises(
1461                RuntimeError,
1462                lambda: torch.index_put_(b, (idx,), c, accumulate=accumulate),
1463            )
1464
1465    @onlyCUDA
1466    def test_cpu_indices(self, device):
1467        idx = torch.tensor([0, 1])
1468        b = torch.zeros(2, device=device)
1469        x = torch.ones(10, device=device)
1470        x[idx] = b  # index_put_
1471        ref = torch.ones(10, device=device)
1472        ref[:2] = 0
1473        self.assertEqual(x, ref, atol=0, rtol=0)
1474        out = x[idx]  # index
1475        self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
1476
1477    @dtypes(torch.long, torch.float32)
1478    def test_take_along_dim(self, device, dtype):
1479        def _test_against_numpy(t, indices, dim):
1480            actual = torch.take_along_dim(t, indices, dim=dim)
1481            t_np = t.cpu().numpy()
1482            indices_np = indices.cpu().numpy()
1483            expected = np.take_along_axis(t_np, indices_np, axis=dim)
1484            self.assertEqual(actual, expected, atol=0, rtol=0)
1485
1486        for shape in [(3, 2), (2, 3, 5), (2, 4, 0), (2, 3, 1, 4)]:
1487            for noncontiguous in [True, False]:
1488                t = make_tensor(
1489                    shape, device=device, dtype=dtype, noncontiguous=noncontiguous
1490                )
1491                for dim in list(range(t.ndim)) + [None]:
1492                    if dim is None:
1493                        indices = torch.argsort(t.view(-1))
1494                    else:
1495                        indices = torch.argsort(t, dim=dim)
1496
1497                _test_against_numpy(t, indices, dim)
1498
1499        # test broadcasting
1500        t = torch.ones((3, 4, 1), device=device)
1501        indices = torch.ones((1, 2, 5), dtype=torch.long, device=device)
1502
1503        _test_against_numpy(t, indices, 1)
1504
1505        # test empty indices
1506        t = torch.ones((3, 4, 5), device=device)
1507        indices = torch.ones((3, 0, 5), dtype=torch.long, device=device)
1508
1509        _test_against_numpy(t, indices, 1)
1510
1511    @dtypes(torch.long, torch.float)
1512    def test_take_along_dim_invalid(self, device, dtype):
1513        shape = (2, 3, 1, 4)
1514        dim = 0
1515        t = make_tensor(shape, device=device, dtype=dtype)
1516        indices = torch.argsort(t, dim=dim)
1517
1518        # dim of `t` and `indices` does not match
1519        with self.assertRaisesRegex(
1520            RuntimeError, "input and indices should have the same number of dimensions"
1521        ):
1522            torch.take_along_dim(t, indices[0], dim=0)
1523
1524        # invalid `indices` dtype
1525        with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"):
1526            torch.take_along_dim(t, indices.to(torch.bool), dim=0)
1527
1528        with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"):
1529            torch.take_along_dim(t, indices.to(torch.float), dim=0)
1530
1531        with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"):
1532            torch.take_along_dim(t, indices.to(torch.int32), dim=0)
1533
1534        # invalid axis
1535        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1536            torch.take_along_dim(t, indices, dim=-7)
1537
1538        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1539            torch.take_along_dim(t, indices, dim=7)
1540
1541    @onlyCUDA
1542    @dtypes(torch.float)
1543    def test_gather_take_along_dim_cross_device(self, device, dtype):
1544        shape = (2, 3, 1, 4)
1545        dim = 0
1546        t = make_tensor(shape, device=device, dtype=dtype)
1547        indices = torch.argsort(t, dim=dim)
1548
1549        with self.assertRaisesRegex(
1550            RuntimeError, "Expected all tensors to be on the same device"
1551        ):
1552            torch.gather(t, 0, indices.cpu())
1553
1554        with self.assertRaisesRegex(
1555            RuntimeError,
1556            r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()",
1557        ):
1558            torch.take_along_dim(t, indices.cpu(), dim=0)
1559
1560        with self.assertRaisesRegex(
1561            RuntimeError, "Expected all tensors to be on the same device"
1562        ):
1563            torch.gather(t.cpu(), 0, indices)
1564
1565        with self.assertRaisesRegex(
1566            RuntimeError,
1567            r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()",
1568        ):
1569            torch.take_along_dim(t.cpu(), indices, dim=0)
1570
1571    @onlyCUDA
1572    def test_cuda_broadcast_index_use_deterministic_algorithms(self, device):
1573        with DeterministicGuard(True):
1574            idx1 = torch.tensor([0])
1575            idx2 = torch.tensor([2, 6])
1576            idx3 = torch.tensor([1, 5, 7])
1577
1578            tensor_a = torch.rand(13, 11, 12, 13, 12).cpu()
1579            tensor_b = tensor_a.to(device=device)
1580            tensor_a[idx1] = 1.0
1581            tensor_a[idx1, :, idx2, idx2, :] = 2.0
1582            tensor_a[:, idx1, idx3, :, idx3] = 3.0
1583            tensor_b[idx1] = 1.0
1584            tensor_b[idx1, :, idx2, idx2, :] = 2.0
1585            tensor_b[:, idx1, idx3, :, idx3] = 3.0
1586            self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
1587
1588            tensor_a = torch.rand(10, 11).cpu()
1589            tensor_b = tensor_a.to(device=device)
1590            tensor_a[idx3] = 1.0
1591            tensor_a[idx2, :] = 2.0
1592            tensor_a[:, idx2] = 3.0
1593            tensor_a[:, idx1] = 4.0
1594            tensor_b[idx3] = 1.0
1595            tensor_b[idx2, :] = 2.0
1596            tensor_b[:, idx2] = 3.0
1597            tensor_b[:, idx1] = 4.0
1598            self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
1599
1600            tensor_a = torch.rand(10, 10).cpu()
1601            tensor_b = tensor_a.to(device=device)
1602            tensor_a[[8]] = 1.0
1603            tensor_b[[8]] = 1.0
1604            self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
1605
1606            tensor_a = torch.rand(10).cpu()
1607            tensor_b = tensor_a.to(device=device)
1608            tensor_a[6] = 1.0
1609            tensor_b[6] = 1.0
1610            self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
1611
1612    def test_index_limits(self, device):
1613        #  Regression test for https://github.com/pytorch/pytorch/issues/115415
1614        t = torch.tensor([], device=device)
1615        idx_min = torch.iinfo(torch.int64).min
1616        idx_max = torch.iinfo(torch.int64).max
1617        self.assertRaises(IndexError, lambda: t[idx_min])
1618        self.assertRaises(IndexError, lambda: t[idx_max])
1619
1620
1621# The tests below are from NumPy test_indexing.py with some modifications to
1622# make them compatible with PyTorch. It's licensed under the BDS license below:
1623#
1624# Copyright (c) 2005-2017, NumPy Developers.
1625# All rights reserved.
1626#
1627# Redistribution and use in source and binary forms, with or without
1628# modification, are permitted provided that the following conditions are
1629# met:
1630#
1631#     * Redistributions of source code must retain the above copyright
1632#        notice, this list of conditions and the following disclaimer.
1633#
1634#     * Redistributions in binary form must reproduce the above
1635#        copyright notice, this list of conditions and the following
1636#        disclaimer in the documentation and/or other materials provided
1637#        with the distribution.
1638#
1639#     * Neither the name of the NumPy Developers nor the names of any
1640#        contributors may be used to endorse or promote products derived
1641#        from this software without specific prior written permission.
1642#
1643# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
1644# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
1645# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
1646# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
1647# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
1648# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
1649# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
1650# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
1651# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
1652# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
1653# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
1654
1655
1656class NumpyTests(TestCase):
1657    def test_index_no_floats(self, device):
1658        a = torch.tensor([[[5.0]]], device=device)
1659
1660        self.assertRaises(IndexError, lambda: a[0.0])
1661        self.assertRaises(IndexError, lambda: a[0, 0.0])
1662        self.assertRaises(IndexError, lambda: a[0.0, 0])
1663        self.assertRaises(IndexError, lambda: a[0.0, :])
1664        self.assertRaises(IndexError, lambda: a[:, 0.0])
1665        self.assertRaises(IndexError, lambda: a[:, 0.0, :])
1666        self.assertRaises(IndexError, lambda: a[0.0, :, :])
1667        self.assertRaises(IndexError, lambda: a[0, 0, 0.0])
1668        self.assertRaises(IndexError, lambda: a[0.0, 0, 0])
1669        self.assertRaises(IndexError, lambda: a[0, 0.0, 0])
1670        self.assertRaises(IndexError, lambda: a[-1.4])
1671        self.assertRaises(IndexError, lambda: a[0, -1.4])
1672        self.assertRaises(IndexError, lambda: a[-1.4, 0])
1673        self.assertRaises(IndexError, lambda: a[-1.4, :])
1674        self.assertRaises(IndexError, lambda: a[:, -1.4])
1675        self.assertRaises(IndexError, lambda: a[:, -1.4, :])
1676        self.assertRaises(IndexError, lambda: a[-1.4, :, :])
1677        self.assertRaises(IndexError, lambda: a[0, 0, -1.4])
1678        self.assertRaises(IndexError, lambda: a[-1.4, 0, 0])
1679        self.assertRaises(IndexError, lambda: a[0, -1.4, 0])
1680        # self.assertRaises(IndexError, lambda: a[0.0:, 0.0])
1681        # self.assertRaises(IndexError, lambda: a[0.0:, 0.0,:])
1682
1683    def test_none_index(self, device):
1684        # `None` index adds newaxis
1685        a = tensor([1, 2, 3], device=device)
1686        self.assertEqual(a[None].dim(), a.dim() + 1)
1687
1688    def test_empty_tuple_index(self, device):
1689        # Empty tuple index creates a view
1690        a = tensor([1, 2, 3], device=device)
1691        self.assertEqual(a[()], a)
1692        self.assertEqual(a[()].data_ptr(), a.data_ptr())
1693
1694    def test_empty_fancy_index(self, device):
1695        # Empty list index creates an empty array
1696        a = tensor([1, 2, 3], device=device)
1697        self.assertEqual(a[[]], torch.tensor([], dtype=torch.long, device=device))
1698
1699        b = tensor([], device=device).long()
1700        self.assertEqual(a[[]], torch.tensor([], dtype=torch.long, device=device))
1701
1702        b = tensor([], device=device).float()
1703        self.assertRaises(IndexError, lambda: a[b])
1704
1705    def test_ellipsis_index(self, device):
1706        a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1707        self.assertIsNot(a[...], a)
1708        self.assertEqual(a[...], a)
1709        # `a[...]` was `a` in numpy <1.9.
1710        self.assertEqual(a[...].data_ptr(), a.data_ptr())
1711
1712        # Slicing with ellipsis can skip an
1713        # arbitrary number of dimensions
1714        self.assertEqual(a[0, ...], a[0])
1715        self.assertEqual(a[0, ...], a[0, :])
1716        self.assertEqual(a[..., 0], a[:, 0])
1717
1718        # In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch
1719        # we don't have separate 0-dim arrays and scalars.
1720        self.assertEqual(a[0, ..., 1], torch.tensor(2, device=device))
1721
1722        # Assignment with `(Ellipsis,)` on 0-d arrays
1723        b = torch.tensor(1)
1724        b[(Ellipsis,)] = 2
1725        self.assertEqual(b, 2)
1726
1727    def test_single_int_index(self, device):
1728        # Single integer index selects one row
1729        a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1730
1731        self.assertEqual(a[0], [1, 2, 3])
1732        self.assertEqual(a[-1], [7, 8, 9])
1733
1734        # Index out of bounds produces IndexError
1735        self.assertRaises(IndexError, a.__getitem__, 1 << 30)
1736        # Index overflow produces Exception  NB: different exception type
1737        self.assertRaises(Exception, a.__getitem__, 1 << 64)
1738
1739    def test_single_bool_index(self, device):
1740        # Single boolean index
1741        a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1742
1743        self.assertEqual(a[True], a[None])
1744        self.assertEqual(a[False], a[None][0:0])
1745
1746    def test_boolean_shape_mismatch(self, device):
1747        arr = torch.ones((5, 4, 3), device=device)
1748
1749        index = tensor([True], device=device)
1750        self.assertRaisesRegex(IndexError, "mask", lambda: arr[index])
1751
1752        index = tensor([False] * 6, device=device)
1753        self.assertRaisesRegex(IndexError, "mask", lambda: arr[index])
1754
1755        index = torch.ByteTensor(4, 4).to(device).zero_()
1756        self.assertRaisesRegex(IndexError, "mask", lambda: arr[index])
1757        self.assertRaisesRegex(IndexError, "mask", lambda: arr[(slice(None), index)])
1758
1759    def test_boolean_indexing_onedim(self, device):
1760        # Indexing a 2-dimensional array with
1761        # boolean array of length one
1762        a = tensor([[0.0, 0.0, 0.0]], device=device)
1763        b = tensor([True], device=device)
1764        self.assertEqual(a[b], a)
1765        # boolean assignment
1766        a[b] = 1.0
1767        self.assertEqual(a, tensor([[1.0, 1.0, 1.0]], device=device))
1768
1769    # https://github.com/pytorch/pytorch/issues/127003
1770    @xfailIfTorchDynamo
1771    def test_boolean_assignment_value_mismatch(self, device):
1772        # A boolean assignment should fail when the shape of the values
1773        # cannot be broadcast to the subscription. (see also gh-3458)
1774        a = torch.arange(0, 4, device=device)
1775
1776        def f(a, v):
1777            a[a > -1] = tensor(v).to(device)
1778
1779        self.assertRaisesRegex(Exception, "shape mismatch", f, a, [])
1780        self.assertRaisesRegex(Exception, "shape mismatch", f, a, [1, 2, 3])
1781        self.assertRaisesRegex(Exception, "shape mismatch", f, a[:1], [1, 2, 3])
1782
1783    def test_boolean_indexing_twodim(self, device):
1784        # Indexing a 2-dimensional array with
1785        # 2-dimensional boolean array
1786        a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1787        b = tensor(
1788            [[True, False, True], [False, True, False], [True, False, True]],
1789            device=device,
1790        )
1791        self.assertEqual(a[b], tensor([1, 3, 5, 7, 9], device=device))
1792        self.assertEqual(a[b[1]], tensor([[4, 5, 6]], device=device))
1793        self.assertEqual(a[b[0]], a[b[2]])
1794
1795        # boolean assignment
1796        a[b] = 0
1797        self.assertEqual(a, tensor([[0, 2, 0], [4, 0, 6], [0, 8, 0]], device=device))
1798
1799    def test_boolean_indexing_weirdness(self, device):
1800        # Weird boolean indexing things
1801        a = torch.ones((2, 3, 4), device=device)
1802        self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape)
1803        self.assertEqual(
1804            torch.ones(1, 2, device=device), a[True, [0, 1], True, True, [1], [[2]]]
1805        )
1806        self.assertRaises(IndexError, lambda: a[False, [0, 1], ...])
1807
1808    def test_boolean_indexing_weirdness_tensors(self, device):
1809        # Weird boolean indexing things
1810        false = torch.tensor(False, device=device)
1811        true = torch.tensor(True, device=device)
1812        a = torch.ones((2, 3, 4), device=device)
1813        self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape)
1814        self.assertEqual(
1815            torch.ones(1, 2, device=device), a[true, [0, 1], true, true, [1], [[2]]]
1816        )
1817        self.assertRaises(IndexError, lambda: a[false, [0, 1], ...])
1818
1819    def test_boolean_indexing_alldims(self, device):
1820        true = torch.tensor(True, device=device)
1821        a = torch.ones((2, 3), device=device)
1822        self.assertEqual((1, 2, 3), a[True, True].shape)
1823        self.assertEqual((1, 2, 3), a[true, true].shape)
1824
1825    def test_boolean_list_indexing(self, device):
1826        # Indexing a 2-dimensional array with
1827        # boolean lists
1828        a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1829        b = [True, False, False]
1830        c = [True, True, False]
1831        self.assertEqual(a[b], tensor([[1, 2, 3]], device=device))
1832        self.assertEqual(a[b, b], tensor([1], device=device))
1833        self.assertEqual(a[c], tensor([[1, 2, 3], [4, 5, 6]], device=device))
1834        self.assertEqual(a[c, c], tensor([1, 5], device=device))
1835
1836    def test_everything_returns_views(self, device):
1837        # Before `...` would return a itself.
1838        a = tensor([5], device=device)
1839
1840        self.assertIsNot(a, a[()])
1841        self.assertIsNot(a, a[...])
1842        self.assertIsNot(a, a[:])
1843
1844    def test_broaderrors_indexing(self, device):
1845        a = torch.zeros(5, 5, device=device)
1846        self.assertRaisesRegex(
1847            IndexError, "shape mismatch", a.__getitem__, ([0, 1], [0, 1, 2])
1848        )
1849        self.assertRaisesRegex(
1850            IndexError, "shape mismatch", a.__setitem__, ([0, 1], [0, 1, 2]), 0
1851        )
1852
1853    def test_trivial_fancy_out_of_bounds(self, device):
1854        a = torch.zeros(5, device=device)
1855        ind = torch.ones(20, dtype=torch.int64, device=device)
1856        if a.is_cuda:
1857            raise unittest.SkipTest("CUDA asserts instead of raising an exception")
1858        ind[-1] = 10
1859        self.assertRaises(IndexError, a.__getitem__, ind)
1860        self.assertRaises(IndexError, a.__setitem__, ind, 0)
1861        ind = torch.ones(20, dtype=torch.int64, device=device)
1862        ind[0] = 11
1863        self.assertRaises(IndexError, a.__getitem__, ind)
1864        self.assertRaises(IndexError, a.__setitem__, ind, 0)
1865
1866    def test_index_is_larger(self, device):
1867        # Simple case of fancy index broadcasting of the index.
1868        a = torch.zeros((5, 5), device=device)
1869        a[[[0], [1], [2]], [0, 1, 2]] = tensor([2.0, 3.0, 4.0], device=device)
1870
1871        self.assertTrue((a[:3, :3] == tensor([2.0, 3.0, 4.0], device=device)).all())
1872
1873    def test_broadcast_subspace(self, device):
1874        a = torch.zeros((100, 100), device=device)
1875        v = torch.arange(0.0, 100, device=device)[:, None]
1876        b = torch.arange(99, -1, -1, device=device).long()
1877        a[b] = v
1878        expected = b.float().unsqueeze(1).expand(100, 100)
1879        self.assertEqual(a, expected)
1880
1881    def test_truncate_leading_1s(self, device):
1882        col_max = torch.randn(1, 4)
1883        kernel = col_max.T * col_max  # [4, 4] tensor
1884        kernel2 = kernel.clone()
1885        # Set the diagonal
1886        kernel[range(len(kernel)), range(len(kernel))] = torch.square(col_max)
1887        torch.diagonal(kernel2).copy_(torch.square(col_max.view(4)))
1888        self.assertEqual(kernel, kernel2)
1889
1890
1891instantiate_device_type_tests(TestIndexing, globals(), except_for="meta")
1892instantiate_device_type_tests(NumpyTests, globals(), except_for="meta")
1893
1894if __name__ == "__main__":
1895    run_tests()
1896