xref: /aosp_15_r20/external/pytorch/test/test_tensor_creation_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: tensor creation"]
2
3import torch
4import numpy as np
5
6import sys
7import math
8import warnings
9import unittest
10from itertools import product, combinations, combinations_with_replacement, permutations
11import random
12import tempfile
13from typing import Any, Dict, List, Tuple
14
15from torch.testing import make_tensor
16from torch.testing._internal.common_utils import (
17    TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
18    torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest,
19    set_default_dtype, set_default_tensor_type,
20    TEST_SCIPY, IS_MACOS, IS_PPC, IS_JETSON, IS_WINDOWS, parametrize, skipIfTorchDynamo,
21    xfailIfTorchDynamo)
22from torch.testing._internal.common_device_type import (
23    expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes,
24    onlyCPU, largeTensorTest, precisionOverride, dtypes,
25    onlyCUDA, skipCPUIf, dtypesIfCUDA, dtypesIfCPU, skipMeta)
26from torch.testing._internal.common_dtype import (
27    all_types_and_complex, all_types_and_complex_and, all_types_and, floating_and_complex_types, complex_types,
28    floating_types, floating_and_complex_types_and, integral_types, integral_types_and, get_all_dtypes,
29    float_to_corresponding_complex_type_map
30)
31
32from torch.utils.dlpack import to_dlpack
33
34# TODO: replace with make_tensor
35def _generate_input(shape, dtype, device, with_extremal):
36    if shape == ():
37        x = torch.tensor((), dtype=dtype, device=device)
38    else:
39        if dtype.is_floating_point or dtype.is_complex:
40            # work around torch.randn not being implemented for bfloat16
41            if dtype == torch.bfloat16:
42                x = torch.randn(*shape, device=device) * random.randint(30, 100)
43                x = x.to(torch.bfloat16)
44            else:
45                x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100)
46            x[torch.randn(*shape) > 0.5] = 0
47            if with_extremal and dtype.is_floating_point:
48                # Use extremal values
49                x[torch.randn(*shape) > 0.5] = float('nan')
50                x[torch.randn(*shape) > 0.5] = float('inf')
51                x[torch.randn(*shape) > 0.5] = float('-inf')
52            elif with_extremal and dtype.is_complex:
53                x[torch.randn(*shape) > 0.5] = complex('nan')
54                x[torch.randn(*shape) > 0.5] = complex('inf')
55                x[torch.randn(*shape) > 0.5] = complex('-inf')
56        elif dtype == torch.bool:
57            x = torch.zeros(shape, dtype=dtype, device=device)
58            x[torch.randn(*shape) > 0.5] = True
59        else:
60            x = torch.randint(15, 100, shape, dtype=dtype, device=device)
61
62    return x
63
64
65# TODO: replace with make_tensor
66def _rand_shape(dim, min_size, max_size):
67    shape = []
68    for i in range(dim):
69        shape.append(random.randint(min_size, max_size))
70    return tuple(shape)
71
72# Test suite for tensor creation ops
73#
74# Includes creation functions like torch.eye, random creation functions like
75#   torch.rand, and *like functions like torch.ones_like.
76# DOES NOT INCLUDE view ops, which are tested in TestViewOps (currently in
77#   test_torch.py) OR numpy interop (which is also still tested in test_torch.py)
78#
79# See https://pytorch.org/docs/main/torch.html#creation-ops
80
81class TestTensorCreation(TestCase):
82    exact_dtype = True
83
84    @onlyCPU
85    @dtypes(torch.float)
86    def test_diag_embed(self, device, dtype):
87        x = torch.arange(3 * 4, dtype=dtype, device=device).view(3, 4)
88        result = torch.diag_embed(x)
89        expected = torch.stack([torch.diag(r) for r in x], 0)
90        self.assertEqual(result, expected)
91
92        result = torch.diag_embed(x, offset=1, dim1=0, dim2=2)
93        expected = torch.stack([torch.diag(r, 1) for r in x], 1)
94        self.assertEqual(result, expected)
95
96    def test_cat_mem_overlap(self, device):
97        x = torch.rand((1, 3), device=device).expand((6, 3))
98        y = torch.rand((3, 3), device=device)
99        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
100            torch.cat([y, y], out=x)
101
102    @onlyNativeDeviceTypes
103    def test_vander(self, device):
104        x = torch.tensor([1, 2, 3, 5], device=device)
105
106        self.assertEqual((0, 0), torch.vander(torch.tensor([]), 0).shape)
107
108        with self.assertRaisesRegex(RuntimeError, "N must be non-negative."):
109            torch.vander(x, N=-1)
110
111        with self.assertRaisesRegex(RuntimeError, "x must be a one-dimensional tensor."):
112            torch.vander(torch.stack((x, x)))
113
114    @onlyNativeDeviceTypes
115    @dtypes(torch.bool, torch.uint8, torch.int8, torch.short, torch.int, torch.long,
116            torch.float, torch.double,
117            torch.cfloat, torch.cdouble)
118    def test_vander_types(self, device, dtype):
119        if dtype is torch.uint8:
120            # Note: no negative uint8 values
121            X = [[1, 2, 3, 5], [0, 1 / 3, 1, math.pi, 3 / 7]]
122        elif dtype is torch.bool:
123            # Note: see https://github.com/pytorch/pytorch/issues/37398
124            # for why this is necessary.
125            X = [[True, True, True, True], [False, True, True, True, True]]
126        elif dtype in [torch.cfloat, torch.cdouble]:
127            X = [[1 + 1j, 1 + 0j, 0 + 1j, 0 + 0j],
128                 [2 + 2j, 3 + 2j, 4 + 3j, 5 + 4j]]
129        else:
130            X = [[1, 2, 3, 5], [-math.pi, 0, 1 / 3, 1, math.pi, 3 / 7]]
131
132        N = [None, 0, 1, 3]
133        increasing = [False, True]
134
135        for x, n, inc in product(X, N, increasing):
136            numpy_dtype = torch_to_numpy_dtype_dict[dtype]
137            pt_x = torch.tensor(x, device=device, dtype=dtype)
138            np_x = np.array(x, dtype=numpy_dtype)
139
140            pt_res = torch.vander(pt_x, increasing=inc) if n is None else torch.vander(pt_x, n, inc)
141            np_res = np.vander(np_x, n, inc)
142
143            self.assertEqual(
144                pt_res,
145                torch.from_numpy(np_res),
146                atol=1e-3,
147                rtol=0,
148                exact_dtype=False)
149
150    def test_cat_all_dtypes_and_devices(self, device):
151        for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.chalf):
152            x = torch.tensor([[1, 2], [3, 4]], dtype=dt, device=device)
153
154            expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dt, device=device)
155            self.assertEqual(torch.cat((x, x), 0), expected1)
156
157            expected2 = torch.tensor([[1, 2, 1, 2], [3, 4, 3, 4]], dtype=dt, device=device)
158            self.assertEqual(torch.cat((x, x), 1), expected2)
159
160    def test_fill_all_dtypes_and_devices(self, device):
161        for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.chalf):
162            for x in [torch.tensor((10, 10), dtype=dt, device=device),
163                      torch.empty(10000, dtype=dt, device=device)]:  # large tensor
164                numel = x.numel()
165                bound = 100 if dt in (torch.uint8, torch.int8) else 2000
166                for n in range(-bound, bound, bound // 10):
167                    x.fill_(n)
168                    self.assertEqual(x, torch.tensor([n] * numel, dtype=dt, device=device))
169                    self.assertEqual(dt, x.dtype)
170
171    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
172    def test_roll(self, device):
173        numbers = torch.arange(1, 9, device=device)
174
175        single_roll = numbers.roll(1, 0)
176        expected = torch.tensor([8, 1, 2, 3, 4, 5, 6, 7], device=device)
177        self.assertEqual(single_roll, expected, msg=f"{single_roll} did not equal expected result")
178
179        roll_backwards = numbers.roll(-2, 0)
180        expected = torch.tensor([3, 4, 5, 6, 7, 8, 1, 2], device=device)
181        self.assertEqual(roll_backwards, expected, msg=f"{roll_backwards} did not equal expected result")
182
183        data = numbers.view(2, 2, 2)
184        rolled = data.roll(1, 0)
185        expected = torch.tensor([5, 6, 7, 8, 1, 2, 3, 4], device=device).view(2, 2, 2)
186        self.assertEqual(expected, rolled, msg=f"{rolled} did not equal expected result: {expected}")
187
188        data = data.view(2, 4)
189        # roll a loop until back where started
190        loop_rolled = data.roll(2, 0).roll(4, 1)
191        self.assertEqual(data, loop_rolled, msg=f"{loop_rolled} did not equal the original: {data}")
192        # multiple inverse loops
193        self.assertEqual(data, data.roll(-20, 0).roll(-40, 1))
194        self.assertEqual(torch.tensor([8, 1, 2, 3, 4, 5, 6, 7], device=device), numbers.roll(1, 0))
195
196        # test non-contiguous
197        # strided equivalent to numbers.as_strided(size=(4, 2), stride=(1, 4))
198        strided = numbers.view(2, 4).transpose(0, 1)
199        self.assertFalse(strided.is_contiguous(), "this test needs a non-contiguous tensor")
200        expected = torch.tensor([4, 8, 1, 5, 2, 6, 3, 7]).view(4, 2)
201        rolled = strided.roll(1, 0)
202        self.assertEqual(expected, rolled,
203                         msg=f"non contiguous tensor rolled to {rolled} instead of {expected} ")
204
205        # test roll with no dimension specified
206        expected = numbers.roll(1, 0).view(2, 4)
207        self.assertEqual(expected, data.roll(1), msg="roll with no dims should flatten and roll.")
208        self.assertEqual(expected, data.roll(1, dims=None), msg="roll with no dims should flatten and roll.")
209
210        # test roll over multiple dimensions
211        expected = torch.tensor([[7, 8, 5, 6], [3, 4, 1, 2]], device=device)
212        double_rolled = data.roll(shifts=(2, -1), dims=(1, 0))
213        self.assertEqual(double_rolled, expected,
214                         msg=f"should be able to roll over two dimensions, got {double_rolled}")
215
216        self.assertRaisesRegex(RuntimeError, "required", lambda: data.roll(shifts=(), dims=()))
217        self.assertRaisesRegex(RuntimeError, "required", lambda: data.roll(shifts=(), dims=1))
218        # shifts/dims should align
219        self.assertRaisesRegex(RuntimeError, "align", lambda: data.roll(shifts=(1, 2), dims=(1,)))
220        self.assertRaisesRegex(RuntimeError, "align", lambda: data.roll(shifts=(1,), dims=(1, 2)))
221
222        # test bool tensor
223        t = torch.zeros(6, dtype=torch.bool, device=device)
224        t[0] = True
225        t[3] = True
226        self.assertEqual(torch.tensor([False, True, False, False, True, False]), t.roll(1, 0))
227
228        # test complex tensor
229        t = torch.tensor([1, 2 + 1j, 3.5, 4. + 2j, 5j, 6.], device=device)
230        t[0] = 1 + 0.5j
231        t[3] = 4.
232        expected = torch.tensor([6., 1 + 0.5j, 2 + 1j, 3.5, 4., 5j], device=device)
233        self.assertEqual(expected, t.roll(1, 0))
234
235    def test_diagflat(self, device):
236        dtype = torch.float32
237        # Basic sanity test
238        x = torch.randn((100,), dtype=dtype, device=device)
239        result = torch.diagflat(x)
240        expected = torch.diag(x)
241        self.assertEqual(result, expected)
242
243        # Test offset
244        x = torch.randn((100,), dtype=dtype, device=device)
245        result = torch.diagflat(x, 17)
246        expected = torch.diag(x, 17)
247        self.assertEqual(result, expected)
248
249        # Test where input has more than one dimension
250        x = torch.randn((2, 3, 4), dtype=dtype, device=device)
251        result = torch.diagflat(x)
252        expected = torch.diag(x.contiguous().view(-1))
253        self.assertEqual(result, expected)
254
255        # Noncontig input
256        x = torch.randn((2, 3, 4), dtype=dtype, device=device).transpose(2, 0)
257        self.assertFalse(x.is_contiguous())
258        result = torch.diagflat(x)
259        expected = torch.diag(x.contiguous().view(-1))
260        self.assertEqual(result, expected)
261
262        # Complex number support
263        result = torch.diagflat(torch.ones(4, dtype=torch.complex128))
264        expected = torch.eye(4, dtype=torch.complex128)
265        self.assertEqual(result, expected)
266
267    def test_block_diag(self, device):
268        def block_diag_workaround(*arrs):
269            arrs_expanded = []
270            for a in arrs:
271                if a.dim() == 2:
272                    arrs_expanded.append(a)
273                elif a.dim() == 1:
274                    arrs_expanded.append(a.expand(1, a.size(0)))
275                elif a.dim() == 0:
276                    arrs_expanded.append(a.expand(1, 1))
277            shapes = torch.tensor([a.shape for a in arrs_expanded], device=device)
278            out = torch.zeros(
279                torch.sum(shapes, dim=0).tolist(),
280                dtype=arrs_expanded[0].dtype,
281                device=device
282            )
283            r, c = 0, 0
284            for i, (rr, cc) in enumerate(shapes):
285                out[r:r + rr, c:c + cc] = arrs_expanded[i]
286                r += rr
287                c += cc
288            return out
289
290        tensors = [
291            torch.rand((2, 2), device=device),
292            torch.rand((2, 3), device=device),
293            torch.rand(10, device=device),
294            torch.rand((8, 1), device=device),
295            torch.rand(1, device=device)[0]
296        ]
297        result = torch.block_diag(*tensors)
298        result_check = block_diag_workaround(*tensors)
299        self.assertEqual(result, result_check)
300
301        tensor = torch.rand(1, device=device)[0]
302        result = torch.block_diag(tensor)
303        result_check = tensor.expand(1, 1)
304        self.assertEqual(result, result_check)
305
306        tensor = torch.rand(10, device=device)
307        result = torch.block_diag(tensor)
308        result_check = tensor.expand(1, tensor.size(0))
309        self.assertEqual(result, result_check)
310
311        result = torch.block_diag()
312        result_check = torch.empty(1, 0, device=device)
313        self.assertEqual(result, result_check)
314        self.assertEqual(result.device.type, 'cpu')
315
316        test_dtypes = [
317            torch.uint8,
318            torch.int8,
319            torch.int16,
320            torch.int32,
321            torch.int64,
322            torch.float32,
323            torch.float64,
324            torch.complex64,
325            torch.complex128
326        ]
327        # Test pairs of different dtypes
328        for dtype1 in test_dtypes:
329            for dtype2 in test_dtypes:
330                a = torch.tensor(1, device=device, dtype=dtype1)
331                b = torch.tensor(2, device=device, dtype=dtype2)
332                result = torch.block_diag(a, b)
333                result_dtype = torch.result_type(a, b)
334                result_check = torch.tensor([[1, 0], [0, 2]], device=device, dtype=result_dtype)
335                self.assertEqual(result, result_check)
336
337        with self.assertRaisesRegex(
338            RuntimeError,
339            "torch.block_diag: Input tensors must have 2 or fewer dimensions. Input 1 has 3 dimensions"
340        ):
341            torch.block_diag(torch.tensor(5), torch.tensor([[[6]]]))
342
343        with self.assertRaisesRegex(
344            RuntimeError,
345            "torch.block_diag: Input tensors must have 2 or fewer dimensions. Input 0 has 4 dimensions"
346        ):
347            torch.block_diag(torch.tensor([[[[6]]]]))
348
349        if device != 'cpu':
350            with self.assertRaisesRegex(
351                RuntimeError,
352                (
353                    "torch.block_diag: input tensors must all be on the same device."
354                    " Input 0 is on device cpu and input 1 is on device "
355                )
356            ):
357                torch.block_diag(torch.ones(2, 2).cpu(), torch.ones(2, 2, device=device))
358
359    @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
360    def test_block_diag_scipy(self, device):
361        import scipy.linalg
362        scipy_tensors_list = [
363            [
364                1,
365                [2],
366                [],
367                [3, 4, 5],
368                [[], []],
369                [[6], [7.3]]
370            ],
371            [
372                [[1, 2], [3, 4]],
373                [1]
374            ],
375            [
376                [[4, 9], [7, 10]],
377                [4.6, 9.12],
378                [1j + 3]
379            ],
380            []
381        ]
382
383        expected_torch_types = [
384            torch.float32,
385            torch.int64,
386            torch.complex64,
387            torch.float32
388        ]
389
390        expected_scipy_types = [
391            torch.float64,
392            # windows scipy block_diag returns int32 types
393            torch.int32 if IS_WINDOWS else torch.int64,
394            torch.complex128,
395            torch.float64
396        ]
397
398        for scipy_tensors, torch_type, scipy_type in zip(scipy_tensors_list, expected_torch_types, expected_scipy_types):
399            torch_tensors = [torch.tensor(t, device=device) for t in scipy_tensors]
400            torch_result = torch.block_diag(*torch_tensors)
401            self.assertEqual(torch_result.dtype, torch_type)
402
403            scipy_result = torch.tensor(
404                scipy.linalg.block_diag(*scipy_tensors),
405                device=device
406            )
407            self.assertEqual(scipy_result.dtype, scipy_type)
408            scipy_result = scipy_result.to(torch_type)
409
410            self.assertEqual(torch_result, scipy_result)
411
412    @onlyNativeDeviceTypes
413    @dtypes(torch.half, torch.float32, torch.float64)
414    def test_torch_complex(self, device, dtype):
415        real = torch.tensor([1, 2], device=device, dtype=dtype)
416        imag = torch.tensor([3, 4], device=device, dtype=dtype)
417        z = torch.complex(real, imag)
418        complex_dtype = float_to_corresponding_complex_type_map[dtype]
419        self.assertEqual(torch.tensor([1.0 + 3.0j, 2.0 + 4.0j], dtype=complex_dtype), z)
420
421    @onlyNativeDeviceTypes
422    @dtypes(torch.float32, torch.float64)
423    def test_torch_polar(self, device, dtype):
424        abs = torch.tensor([1, 2, -3, -4.5, 1, 1], device=device, dtype=dtype)
425        angle = torch.tensor([math.pi / 2, 5 * math.pi / 4, 0, -11 * math.pi / 6, math.pi, -math.pi],
426                             device=device, dtype=dtype)
427        z = torch.polar(abs, angle)
428        complex_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128
429        self.assertEqual(torch.tensor([1j, -1.41421356237 - 1.41421356237j, -3,
430                                       -3.89711431703 - 2.25j, -1, -1],
431                                      dtype=complex_dtype),
432                         z, atol=1e-5, rtol=1e-5)
433
434    @onlyNativeDeviceTypes
435    @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
436            torch.complex64, torch.complex128, torch.bool)
437    def test_torch_complex_floating_dtype_error(self, device, dtype):
438        for op in (torch.complex, torch.polar):
439            a = torch.tensor([1, 2], device=device, dtype=dtype)
440            b = torch.tensor([3, 4], device=device, dtype=dtype)
441            error = r"Expected both inputs to be Half, Float or Double tensors but " \
442                    r"got [A-Za-z]+ and [A-Za-z]+"
443        with self.assertRaisesRegex(RuntimeError, error):
444            op(a, b)
445
446    @onlyNativeDeviceTypes
447    @dtypes(torch.float32, torch.float64)
448    def test_torch_complex_same_dtype_error(self, device, dtype):
449
450        def dtype_name(dtype):
451            return 'Float' if dtype == torch.float32 else 'Double'
452
453        for op in (torch.complex, torch.polar):
454            other_dtype = torch.float64 if dtype == torch.float32 else torch.float32
455            a = torch.tensor([1, 2], device=device, dtype=dtype)
456            b = torch.tensor([3, 4], device=device, dtype=other_dtype)
457            error = f"Expected object of scalar type {dtype_name(dtype)} but got scalar type " \
458                    f"{dtype_name(other_dtype)} for second argument"
459            with self.assertRaisesRegex(RuntimeError, error):
460                op(a, b)
461
462    @onlyNativeDeviceTypes
463    @dtypes(torch.float32, torch.float64)
464    def test_torch_complex_out_dtype_error(self, device, dtype):
465
466        def dtype_name(dtype):
467            return 'Float' if dtype == torch.float32 else 'Double'
468
469        def complex_dtype_name(dtype):
470            return 'ComplexFloat' if dtype == torch.complex64 else 'ComplexDouble'
471
472        for op in (torch.complex, torch.polar):
473            a = torch.tensor([1, 2], device=device, dtype=dtype)
474            b = torch.tensor([3, 4], device=device, dtype=dtype)
475            out = torch.zeros(2, device=device, dtype=dtype)
476            expected_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128
477            error = f"Expected object of scalar type {complex_dtype_name(expected_dtype)} but got scalar type " \
478                    f"{dtype_name(dtype)} for argument 'out'"
479            with self.assertRaisesRegex(RuntimeError, error):
480                op(a, b, out=out)
481
482    def test_cat_empty_legacy(self, device):
483        # FIXME: this is legacy behavior and should be removed
484        # when we support empty tensors with arbitrary sizes
485        dtype = torch.float32
486
487        x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device)
488        empty = torch.randn((0,), dtype=dtype, device=device)
489
490        res1 = torch.cat([x, empty], dim=1)
491        res2 = torch.cat([empty, x], dim=1)
492        self.assertEqual(res1, res2)
493
494        res1 = torch.cat([empty, empty], dim=1)
495        self.assertEqual(res1, empty)
496
497    def test_cat_empty(self, device):
498        dtype = torch.float32
499
500        x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device)
501        empty = torch.randn((4, 0, 32, 32), dtype=dtype, device=device)
502
503        res1 = torch.cat([x, empty], dim=1)
504        res2 = torch.cat([empty, x], dim=1)
505        self.assertEqual(res1, res2)
506
507        res1 = torch.cat([empty, empty], dim=1)
508        self.assertEqual(res1, empty)
509
510    def test_cat_out(self, device):
511        x = torch.zeros((0), device=device)
512        y = torch.randn((4, 6), device=device)
513
514        w = y.view(-1).clone()
515        a = torch.cat([w[:2], w[4:6]])
516        b = torch.cat([w[:2], w[4:6]], out=w[6:10])
517        self.assertEqual(a, b)
518        self.assertEqual(a, w[6:10])
519        self.assertEqual(w[:6], y.view(-1)[:6])
520
521        # Case:
522        # Reference: https://github.com/pytorch/pytorch/issues/49878
523        for dim in [0, 1]:
524            x = torch.zeros((10, 5, 2), device=device)
525
526            random_length = random.randint(1, 4)
527            y = x.narrow(dim, 0, x.shape[dim] - random_length)
528            val = torch.full_like(y[0], 3., device=device)
529
530            if dim == 0:
531                self.assertTrue(y.is_contiguous())
532            else:
533                self.assertFalse(y.is_contiguous())
534
535            torch.cat((val[None],) * y.shape[0], dim=0, out=y)
536
537            expected_y = torch.cat((val[None],) * y.shape[0], dim=0)
538            expected_x = torch.zeros((10, 5, 2), device=device)
539            if dim == 0:
540                expected_x[:x.shape[dim] - random_length, :, :] = expected_y
541            elif dim == 1:
542                expected_x[:, :x.shape[dim] - random_length, :] = expected_y
543
544            self.assertEqual(y, expected_y)
545            self.assertEqual(x, expected_x)
546
547    @dtypes(*all_types_and_complex(), torch.uint16, torch.uint32, torch.uint64)
548    def test_cat_out_fast_path_dim0_dim1(self, device, dtype):
549        int_types = integral_types_and(torch.uint16, torch.uint32, torch.uint64)
550        x = torch.zeros((0), device=device, dtype=dtype)
551        if dtype in int_types:
552            y = torch.randint(low=0, high=100, size=(4, 6), device=device, dtype=dtype)
553        else:
554            y = torch.randn((4, 6), device=device, dtype=dtype)
555        # Test concat on dimension 0
556        w = y.view(-1).clone()
557        a = torch.cat([w[:2], w[4:6]])
558        b = torch.cat([w[:2], w[4:6]], out=w[6:10])
559        # Note that there is no guarantee that slicing here will result in
560        # contiguous tensors
561        self.assertEqual(a, b)
562        self.assertEqual(a, w[6:10])
563        self.assertEqual(w[:6], y.view(-1)[:6])
564        # If inputs are contiguous tensors, then fast concat paths will be invoked
565        a_fastcat = torch.cat([w[:2].contiguous(), w[4:6].contiguous()])
566        self.assertEqual(a_fastcat, a)
567        # Test concat on dimension 1
568        w = y.clone()
569        w_slices = torch.tensor_split(w, (2, 4), dim=1)
570        # Note that the tensor in w_slices[] here may not be a contiguous
571        # tensor and we need to make sure this is not broken by fast concat
572        b = torch.cat([w_slices[0], w_slices[1]], dim=1)
573        expected_b = torch.index_select(w, 1, torch.tensor([0, 1, 2, 3], device=device))
574        self.assertEqual(b, expected_b)
575        # If inputs are contiguous tensors, then fast concat paths will be invoked
576        b_fastcat = torch.cat([w_slices[0].contiguous(), w_slices[1].contiguous()], dim=1)
577        self.assertEqual(b_fastcat, expected_b)
578        # Finally, we need to make sure backward is not broken
579        # Integral types will not have grad
580        if dtype not in int_types:
581            a = torch.randn((4, 3), device=device, dtype=dtype, requires_grad=True)
582            b = torch.randn((2, 3), device=device, dtype=dtype, requires_grad=True)
583            c = torch.randn((5, 3), device=device, dtype=dtype, requires_grad=True)
584            d = torch.randn((5, 2), device=device, dtype=dtype, requires_grad=True)
585            expected_a_grad = torch.ones((4, 3), device=device, dtype=dtype)
586            expected_b_grad = torch.ones((2, 3), device=device, dtype=dtype)
587            expected_c_grad = torch.ones((5, 3), device=device, dtype=dtype)
588            expected_d_grad = torch.ones((5, 2), device=device, dtype=dtype)
589            # All the new tensors should be contiguous here. Let us make sure
590            # to explicitly set them contiguous to enforce fast cat
591            dim0_cat = torch.cat([a.contiguous(), b.contiguous()], dim=0)
592            if dtype in complex_types():
593                dim0_cat.sum().abs().backward()
594                self.assertEqual(a.grad.abs(), expected_a_grad.abs())
595                self.assertEqual(b.grad.abs(), expected_b_grad.abs())
596            else:
597                dim0_cat.sum().backward()
598                self.assertEqual(a.grad, expected_a_grad)
599                self.assertEqual(b.grad, expected_b_grad)
600            dim1_cat = torch.cat([c.contiguous(), d.contiguous()], dim=1)
601            if dtype in complex_types():
602                dim1_cat.sum().abs().backward()
603                self.assertEqual(c.grad.abs(), expected_c_grad.abs())
604                self.assertEqual(d.grad.abs(), expected_d_grad.abs())
605            else:
606                dim1_cat.sum().backward()
607                self.assertEqual(c.grad, expected_c_grad)
608                self.assertEqual(d.grad, expected_d_grad)
609
610    def test_cat_out_channels_last(self, device):
611        x = torch.randn((4, 3, 8, 8))
612        y = torch.randn(x.shape)
613        res1 = torch.cat((x, y))
614        z = res1.clone().contiguous(memory_format=torch.channels_last)
615        res2 = torch.cat((x, y), out=z)
616        self.assertEqual(res1, res2)
617
618    @onlyNativeDeviceTypes
619    def test_cat_in_channels_last(self, device):
620        for dim in range(4):
621            x = torch.randn((4, 15, 8, 8), device=device)
622            y = torch.randn(x.shape, device=device)
623            res1 = torch.cat((x, y), dim=dim)
624            x = x.clone().contiguous(memory_format=torch.channels_last)
625            y = y.clone().contiguous(memory_format=torch.channels_last)
626            res2 = torch.cat((x, y), dim=dim)
627            self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last))
628            self.assertEqual(res1, res2)
629
630            # Size larger than grain size.
631            x = torch.randn((4, 15, 256, 256), device=device)
632            y = torch.randn(x.shape, device=device)
633            res1 = torch.cat((x, y), dim=dim)
634            x = x.clone().contiguous(memory_format=torch.channels_last)
635            y = y.clone().contiguous(memory_format=torch.channels_last)
636            res2 = torch.cat((x, y), dim=dim)
637            self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last))
638            self.assertEqual(res1, res2)
639
640    @onlyNativeDeviceTypes
641    def test_cat_preserve_channels_last(self, device):
642        x = torch.randn((4, 3, 8, 8), device=device)
643        y = torch.randn(x.shape, device=device)
644        res1 = torch.cat((x, y))
645        res2 = torch.cat((x.contiguous(memory_format=torch.channels_last), y.contiguous(memory_format=torch.channels_last)))
646        self.assertEqual(res1, res2)
647        self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last))
648        # discontiguous channels-last inputs
649        x = torch.arange(24, dtype=torch.float, device=device).reshape(2, 2, 3, 2).to(memory_format=torch.channels_last)
650        x1 = x[:, :, :2]
651        x2 = x[:, :, 1:]
652        res1 = torch.cat((x1, x2), dim=-1)
653        res2 = torch.cat((x1.contiguous(), x2.contiguous()), dim=-1)
654        self.assertEqual(res1, res2)
655        self.assertTrue(res1.is_contiguous(memory_format=torch.channels_last))
656
657    @onlyCUDA
658    def test_cat_out_memory_format(self, device):
659        inp_size = (4, 4, 4, 4)
660        expected_size = (8, 4, 4, 4)
661        a_cuda = torch.randn(inp_size, device=device).contiguous(memory_format=torch.channels_last)
662        a_cpu = torch.randn(inp_size, device='cpu').contiguous(memory_format=torch.channels_last)
663        b_cuda = torch.randn(inp_size, device=device).contiguous(memory_format=torch.contiguous_format)
664        b_cpu = torch.randn(inp_size, device='cpu').contiguous(memory_format=torch.contiguous_format)
665        c_cuda = torch.randn(inp_size, device=device).contiguous(memory_format=torch.channels_last)
666
667        # Case 1: if out= is the correct shape then the memory format of out= is respected
668
669        out_cuda = torch.empty(expected_size, device=device).contiguous(memory_format=torch.contiguous_format)
670        res1_cuda = torch.cat((a_cuda, b_cuda), out=out_cuda)
671
672        out_cpu = torch.empty(expected_size, device='cpu').contiguous(memory_format=torch.contiguous_format)
673        res1_cpu = torch.cat((a_cpu, b_cpu), out=out_cpu)
674
675        self.assertTrue(res1_cuda.is_contiguous(memory_format=torch.contiguous_format))
676        self.assertTrue(res1_cpu.is_contiguous(memory_format=torch.contiguous_format))
677
678        # Case 2: if out= is not the correct shape then the output it is resized internally
679        # - For both CPU and CUDA variants, it only propagates memory format if all the tensors have
680        #   the same memory format, otherwise it just uses contiguous_format as a default
681
682        out_cuda = torch.empty((0), device=device).contiguous(memory_format=torch.contiguous_format)
683        # a_cuda and b_cuda have different memory_format
684        res2_cuda = torch.cat((a_cuda, b_cuda), out=out_cuda)
685
686        out_cpu = torch.empty((0), device='cpu').contiguous(memory_format=torch.contiguous_format)
687        res2_cpu = torch.cat((a_cpu, b_cpu), out=out_cpu)
688
689        self.assertTrue(res2_cuda.is_contiguous(memory_format=torch.contiguous_format))
690        self.assertTrue(res2_cpu.is_contiguous(memory_format=torch.contiguous_format))
691
692        out_cuda = torch.empty((0), device=device).contiguous(memory_format=torch.contiguous_format)
693        # a_cuda and c_cuda have same memory_format
694        res3_cuda = torch.cat((a_cuda, c_cuda), out=out_cuda)
695
696        self.assertTrue(res3_cuda.is_contiguous(memory_format=torch.channels_last))
697
698    @onlyCUDA
699    def test_cat_stack_cross_devices(self, device):
700        cuda = torch.randn((3, 3), device=device)
701        cpu = torch.randn((3, 3), device='cpu')
702
703        # Stack
704        with self.assertRaisesRegex(RuntimeError,
705                                    "Expected all tensors to be on the same device"):
706            torch.stack((cuda, cpu))
707        with self.assertRaisesRegex(RuntimeError,
708                                    "Expected all tensors to be on the same device"):
709            torch.stack((cpu, cuda))
710
711    # TODO: reconcile with other cat tests
712    # TODO: Compare with a NumPy reference instead of CPU
713    @onlyCUDA
714    def test_cat(self, device):
715        SIZE = 10
716        for dim in range(-3, 3):
717            pos_dim = dim if dim >= 0 else 3 + dim
718            x = torch.rand(13, SIZE, SIZE, device=device).transpose(0, pos_dim)
719            y = torch.rand(17, SIZE, SIZE, device=device).transpose(0, pos_dim)
720            z = torch.rand(19, SIZE, SIZE, device=device).transpose(0, pos_dim)
721
722            res1 = torch.cat((x, y, z), dim)
723            self.assertEqual(res1.narrow(pos_dim, 0, 13), x, atol=0, rtol=0)
724            self.assertEqual(res1.narrow(pos_dim, 13, 17), y, atol=0, rtol=0)
725            self.assertEqual(res1.narrow(pos_dim, 30, 19), z, atol=0, rtol=0)
726
727        x = torch.randn(20, SIZE, SIZE, device=device)
728        self.assertEqual(torch.cat(torch.split(x, 7)), x)
729        self.assertEqual(torch.cat(torch.chunk(x, 7)), x)
730
731        y = torch.randn(1, SIZE, SIZE, device=device)
732        z = torch.cat([x, y])
733        self.assertEqual(z.size(), (21, SIZE, SIZE))
734
735    # TODO: update this test to compare against NumPy instead of CPU
736    @onlyCUDA
737    @dtypesIfCUDA(torch.half, torch.float, torch.double)
738    @dtypes(torch.float, torch.double)
739    def test_device_rounding(self, device, dtype):
740        # test half-to-even
741        a = [-5.8, -3.5, -2.3, -1.5, -0.5, 0.5, 1.5, 2.3, 3.5, 5.8]
742        res = [-6., -4., -2., -2., 0., 0., 2., 2., 4., 6.]
743
744        a_tensor = torch.tensor(a, device=device).round()
745        res_tensor = torch.tensor(res, device='cpu')
746        self.assertEqual(a_tensor, res_tensor)
747
748    # Note: This test failed on XLA since its test cases are created by empty_strided which
749    #       doesn't support overlapping sizes/strides in XLA impl
750    @skipIfTorchDynamo("TorchDynamo fails on this test for unknown reasons")
751    @onlyNativeDeviceTypes
752    def test_like_fn_stride_proparation_vs_tensoriterator_unary_op(self, device):
753        # Test like functions against tensoriterator based unary operator (exp) to
754        # make sure the returned tensor from like function follows the same stride propergation
755        # rule as what tensoriterator does for unary operator. The like function's  output strides
756        # is computed on CPU side always, no need to test GPU here.
757
758        def compare_helper_(like_fn, t):
759            te = torch.exp(t)
760            tl = like_fn(t)
761            self.assertEqual(te.stride(), tl.stride())
762            self.assertEqual(te.size(), tl.size())
763
764        like_fns = [
765            lambda t, **kwargs: torch.zeros_like(t, **kwargs),
766            lambda t, **kwargs: torch.ones_like(t, **kwargs),
767            lambda t, **kwargs: torch.randint_like(t, 10, 100, **kwargs),
768            lambda t, **kwargs: torch.randint_like(t, 100, **kwargs),
769            lambda t, **kwargs: torch.randn_like(t, **kwargs),
770            lambda t, **kwargs: torch.rand_like(t, **kwargs),
771            lambda t, **kwargs: torch.full_like(t, 7, **kwargs),
772            lambda t, **kwargs: torch.empty_like(t, **kwargs)]
773
774        # dense non-overlapping tensor,
775        # non-dense non-overlapping sliced tensor
776        # non-dense non-overlapping gapped tensor
777        # non-dense non-overlapping 0 strided tensor
778        # non-dense overlapping general tensor
779        # non-dense overlapping sliced tensor
780        # non-dense overlapping gapped tensor
781        # non-dense overlapping 0 strided tensor
782        # non-dense overlapping equal strides
783        tset = (
784            torch.randn(4, 3, 2, device=device),
785            torch.randn(4, 3, 2, device=device)[:, :, ::2],
786            torch.empty_strided((4, 3, 2), (10, 3, 1), device=device).fill_(1.0),
787            torch.empty_strided((4, 3, 2), (10, 0, 3), device=device).fill_(1.0),
788            torch.empty_strided((4, 3, 2), (10, 1, 2), device=device).fill_(1.0),
789            torch.empty_strided((4, 3, 2), (4, 2, 1), device=device)[:, :, ::2].fill_(1.0),
790            torch.empty_strided((4, 3, 2), (10, 1, 1), device=device).fill_(1.0),
791            torch.empty_strided((4, 1, 1, 2), (10, 0, 0, 2), device=device).fill_(1.0),
792            torch.empty_strided((4, 2, 3), (10, 3, 3), device=device).fill_(1.0))
793
794        for like_fn in like_fns:
795            for t in tset:
796                for p in permutations(range(t.dim())):
797                    tp = t.permute(p)
798                    compare_helper_(like_fn, tp)
799
800    def _hvd_split_helper(self, torch_fn, np_fn, op_name, inputs, device, dtype, dim):
801        dimension_error_message = op_name + " requires a tensor with at least "
802        divisibiliy_error_message = op_name + " attempted to split along dimension "
803
804        for shape, arg in inputs:
805            direction = dim - (len(shape) == 1 and dim == 1)
806            bound = dim + 2 * (dim == 0) + (dim == 2)
807            error_expected = len(shape) < bound or (not isinstance(arg, list) and shape[direction] % arg != 0)
808
809            t = make_tensor(shape, dtype=dtype, device=device)
810            t_np = t.cpu().numpy()
811
812            if not error_expected:
813                self.assertEqual(torch_fn(t, arg), np_fn(t_np, arg))
814            else:
815                self.assertRaises(RuntimeError, lambda: torch_fn(t, arg))
816                self.assertRaises(ValueError, lambda: np_fn(t, arg))
817                expected_error_message = dimension_error_message if len(shape) < bound else divisibiliy_error_message
818                self.assertRaisesRegex(RuntimeError, expected_error_message, lambda: torch_fn(t, arg))
819
820    @onlyNativeDeviceTypes
821    @dtypes(torch.long, torch.float32, torch.complex64)
822    def test_hsplit(self, device, dtype):
823        inputs = (
824            ((), 3),
825            ((), [2, 4, 6]),
826            ((6,), 2),
827            ((6,), 4),
828            ((6,), [2, 5]),
829            ((6,), [7, 9]),
830            ((3, 8), 4),
831            ((3, 8), 5),
832            ((3, 8), [1, 5]),
833            ((3, 8), [3, 8]),
834            ((5, 5, 5), 2),
835            ((5, 5, 5), [1, 4]),
836            ((5, 0, 5), 3),
837            ((5, 5, 0), [2, 6]),
838        )
839        self._hvd_split_helper(torch.hsplit, np.hsplit, "torch.hsplit", inputs, device, dtype, 1)
840
841    @onlyNativeDeviceTypes
842    @dtypes(torch.long, torch.float32, torch.complex64)
843    def test_vsplit(self, device, dtype):
844        inputs = (
845            ((6,), 2),
846            ((6,), 4),
847            ((6, 5), 2),
848            ((6, 5), 4),
849            ((6, 5), [1, 2, 3]),
850            ((6, 5), [1, 5, 9]),
851            ((6, 5, 5), 2),
852            ((6, 0, 5), 2),
853            ((5, 0, 5), [1, 5]),
854        )
855        self._hvd_split_helper(torch.vsplit, np.vsplit, "torch.vsplit", inputs, device, dtype, 0)
856
857    @onlyNativeDeviceTypes
858    @dtypes(torch.long, torch.float32, torch.complex64)
859    def test_dsplit(self, device, dtype):
860        inputs = (
861            ((6,), 4),
862            ((6, 6), 3),
863            ((5, 5, 6), 2),
864            ((5, 5, 6), 4),
865            ((5, 5, 6), [1, 2, 3]),
866            ((5, 5, 6), [1, 5, 9]),
867            ((5, 5, 0), 2),
868            ((5, 0, 6), 4),
869            ((5, 0, 6), [1, 2, 3]),
870            ((5, 5, 6), [1, 5, 9]),
871        )
872        self._hvd_split_helper(torch.dsplit, np.dsplit, "torch.dsplit", inputs, device, dtype, 2)
873
874    def _test_special_stacks(self, dim, at_least_dim, torch_fn, np_fn, device, dtype):
875        # Test error for non-tuple argument
876        t = torch.randn(10)
877        with self.assertRaisesRegex(TypeError, "must be tuple of Tensors, not Tensor"):
878            torch_fn(t)
879        # Test error for a single array
880        with self.assertRaisesRegex(TypeError, "must be tuple of Tensors, not Tensor"):
881            torch_fn(t)
882
883        # Test 0-D
884        num_tensors = random.randint(1, 5)
885        input_t = [torch.tensor(random.uniform(0, 10), device=device, dtype=dtype) for i in range(num_tensors)]
886        actual = torch_fn(input_t)
887        expected = np_fn([input.cpu().numpy() for input in input_t])
888        self.assertEqual(actual, expected)
889
890        for ndims in range(1, 5):
891            base_shape = list(_rand_shape(ndims, min_size=1, max_size=5))
892            for i in range(ndims):
893                shape = list(base_shape)
894                num_tensors = random.randint(1, 5)
895                torch_input = []
896                # Create tensors with shape being different along one axis only
897                for param in range(num_tensors):
898                    shape[i] = random.randint(1, 5)
899                    torch_input.append(_generate_input(tuple(shape), dtype, device, with_extremal=False))
900
901                # Determine if input tensors have valid dimensions.
902                valid_dim = True
903                for k in range(len(torch_input) - 1):
904                    for tdim in range(ndims):
905                        # Test whether all tensors have the same shape except in concatenating dimension
906                        # Unless the number of dimensions is less than the corresponding at_least function dimension
907                        # Since the original concatenating dimension would shift after applying at_least and would no
908                        # longer be the concatenating dimension
909                        if (ndims < at_least_dim or tdim != dim) and torch_input[k].size()[tdim] != torch_input[k + 1].size()[tdim]:
910                            valid_dim = False
911
912                # Special case for hstack is needed since hstack works differently when ndims is 1
913                if valid_dim or (torch_fn is torch.hstack and ndims == 1):
914                    # Valid dimensions, test against numpy
915                    np_input = [input.cpu().numpy() for input in torch_input]
916                    actual = torch_fn(torch_input)
917                    expected = np_fn(np_input)
918                    self.assertEqual(actual, expected)
919                else:
920                    # Invalid dimensions, test for error
921                    with self.assertRaisesRegex(RuntimeError, "Sizes of tensors must match except in dimension"):
922                        torch_fn(torch_input)
923                    with self.assertRaises(ValueError):
924                        np_input = [input.cpu().numpy() for input in torch_input]
925                        np_fn(np_input)
926
927    @onlyNativeDeviceTypes
928    @dtypes(*all_types_and_complex_and(torch.half))
929    def test_hstack_column_stack(self, device, dtype):
930        ops = ((torch.hstack, np.hstack), (torch.column_stack, np.column_stack))
931        for torch_op, np_op in ops:
932            self._test_special_stacks(1, 1, torch_op, np_op, device, dtype)
933
934        # Test torch.column_stack with combinations of 1D and 2D tensors input
935        one_dim_tensor = torch.arange(0, 10).to(dtype=dtype, device=device)
936        two_dim_tensor = torch.arange(0, 100).to(dtype=dtype, device=device).reshape(10, 10)
937        inputs = two_dim_tensor, one_dim_tensor, two_dim_tensor, one_dim_tensor
938        torch_result = torch.column_stack(inputs)
939
940        np_inputs = [input.cpu().numpy() for input in inputs]
941        np_result = np.column_stack(np_inputs)
942
943        self.assertEqual(np_result,
944                         torch_result)
945
946    @onlyNativeDeviceTypes
947    @dtypes(*all_types_and_complex_and(torch.half))
948    def test_vstack_row_stack(self, device, dtype):
949        ops = ((torch.vstack, np.vstack), (torch.row_stack, np.vstack))
950        for torch_op, np_op in ops:
951            self._test_special_stacks(0, 2, torch_op, np_op, device, dtype)
952            for i in range(5):
953                # Test dimension change for 1D tensor of size (N) and 2D tensor of size (1, N)
954                n = random.randint(1, 10)
955                input_a = _generate_input((n,), dtype, device, with_extremal=False)
956                input_b = _generate_input((1, n), dtype, device, with_extremal=False)
957                torch_input = [input_a, input_b]
958                np_input = [input.cpu().numpy() for input in torch_input]
959                actual = torch_op(torch_input)
960                expected = np_op(np_input)
961                self.assertEqual(actual, expected)
962
963    @onlyNativeDeviceTypes
964    @dtypes(*all_types_and_complex_and(torch.half))
965    def test_dstack(self, device, dtype):
966        self._test_special_stacks(2, 3, torch.dstack, np.dstack, device, dtype)
967        for i in range(5):
968            # Test dimension change for 1D tensor of size (N), 2D tensor of size (1, N), and 3D tensor of size (1, N, 1)
969            n = random.randint(1, 10)
970            input_a = _generate_input((n,), dtype, device, with_extremal=False)
971            input_b = _generate_input((1, n), dtype, device, with_extremal=False)
972            input_c = _generate_input((1, n, 1), dtype, device, with_extremal=False)
973            torch_input = [input_a, input_b, input_c]
974            np_input = [input.cpu().numpy() for input in torch_input]
975            actual = torch.dstack(torch_input)
976            expected = np.dstack(np_input)
977            self.assertEqual(actual, expected)
978
979            # Test dimension change for 2D tensor of size (M, N) and 3D tensor of size (M, N, 1)
980            m = random.randint(1, 10)
981            n = random.randint(1, 10)
982            input_a = _generate_input((m, n), dtype, device, with_extremal=False)
983            input_b = _generate_input((m, n, 1), dtype, device, with_extremal=False)
984            torch_input = [input_a, input_b]
985            np_input = [input.cpu().numpy() for input in torch_input]
986            actual = torch.dstack(torch_input)
987            expected = np.dstack(np_input)
988            self.assertEqual(actual, expected)
989
990    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
991    @dtypes(torch.int32, torch.int64)
992    def test_large_linspace(self, device, dtype):
993        start = torch.iinfo(dtype).min
994        end = torch.iinfo(dtype).max & ~0xfff
995        steps = 15
996        x = torch.linspace(start, end, steps, dtype=dtype, device=device)
997        self.assertGreater(x[1] - x[0], (end - start) / steps)
998
999    @dtypes(torch.float32, torch.float64)
1000    def test_unpack_double(self, device, dtype):
1001        # Reference: https://github.com/pytorch/pytorch/issues/33111
1002        vals = (2 ** 24 + 1, 2 ** 53 + 1,
1003                np.iinfo(np.int64).max, np.iinfo(np.uint64).max, np.iinfo(np.uint64).max + 1,
1004                -1e500, 1e500)
1005        for val in vals:
1006            t = torch.tensor(val, dtype=dtype, device=device)
1007            a = np.array(val, dtype=torch_to_numpy_dtype_dict[dtype])
1008            self.assertEqual(t, torch.from_numpy(a))
1009
1010    def _float_to_int_conversion_helper(self, vals, device, dtype, refs=None):
1011        if refs is None:
1012            a = np.array(vals, dtype=np.float32).astype(torch_to_numpy_dtype_dict[dtype])
1013            refs = torch.from_numpy(a)
1014        t = torch.tensor(vals, device=device, dtype=torch.float).to(dtype)
1015        self.assertEqual(refs, t.cpu())
1016
1017    # Checks that float->integer casts don't produce undefined behavior errors.
1018    # Note: In C++, casting from a floating value to an integral dtype
1019    # is undefined if the floating point value is not within the integral
1020    # dtype's dynamic range. This can (and should) cause undefined behavior
1021    # errors with UBSAN. These casts are deliberate in PyTorch, however, and
1022    # NumPy may have the same behavior.
1023    @onlyNativeDeviceTypes
1024    @unittest.skipIf(IS_MACOS or IS_JETSON, "Test is broken on MacOS and Jetson, \
1025        see https://github.com/pytorch/pytorch/issues/38752")
1026    @unittest.skipIf(IS_PPC, "Test is broken on PowerPC, see https://github.com/pytorch/pytorch/issues/39671")
1027    @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
1028    def test_float_to_int_conversion_finite(self, device, dtype):
1029        min = torch.finfo(torch.float).min
1030        max = torch.finfo(torch.float).max
1031
1032        # Note: CUDA max float -> integer conversion is divergent on some dtypes
1033        vals = (min, -2, -1.5, -.5, 0, .5, 1.5, 2, max)
1034        refs = None
1035        if self.device_type == 'cuda':
1036            if torch.version.hip:
1037                # HIP min float -> int64 conversion is divergent
1038                vals = (-2, -1.5, -.5, 0, .5, 1.5, 2)
1039            else:
1040                vals = (min, -2, -1.5, -.5, 0, .5, 1.5, 2)
1041        elif dtype == torch.uint8:
1042            # Note: CPU max float -> uint8 conversion is divergent
1043            vals = (min, -2, -1.5, -.5, 0, .5, 1.5, 2)
1044            # Note: numpy -2.0 or -1.5 -> uint8 conversion is undefined
1045            #       see https://github.com/pytorch/pytorch/issues/97794
1046            refs = (0, 254, 255, 0, 0, 0, 1, 2)
1047
1048        self._float_to_int_conversion_helper(vals, device, dtype, refs)
1049
1050    # Note: CUDA will fail this test on most dtypes, often dramatically.
1051    # NB: torch.uint16, torch.uint32, torch.uint64 excluded as this
1052    # nondeterministically fails, warning "invalid value encountered in cast"
1053    @onlyCPU
1054    @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
1055    def test_float_to_int_conversion_nonfinite(self, device, dtype):
1056        vals = (float('-inf'), float('inf'), float('nan'))
1057
1058        self._float_to_int_conversion_helper(vals, device, dtype)
1059
1060    @onlyNativeDeviceTypes
1061    def test_complex_type_conversions(self, device):
1062        dtypes = [torch.float, torch.complex64, torch.complex128]
1063        for from_type in dtypes:
1064            for to_type in dtypes:
1065                from_tensor = torch.randn(4, dtype=from_type, device=device)
1066                to_tensor = from_tensor.to(to_type)
1067                if from_type.is_complex and not to_type.is_complex:
1068                    self.assertEqual(torch.real(from_tensor), to_tensor, exact_dtype=False)
1069                elif not from_type.is_complex and to_type.is_complex:
1070                    self.assertEqual(from_tensor, torch.real(to_tensor), exact_dtype=False)
1071                    self.assertEqual(torch.zeros_like(torch.imag(to_tensor)), torch.imag(to_tensor), exact_dtype=False)
1072                else:
1073                    self.assertEqual(from_tensor, to_tensor, exact_dtype=False)
1074
1075    @slowTest
1076    @onlyCPU
1077    def test_cat_big(self, device):
1078        SIZE1 = 6500
1079        SIZE2 = 4500
1080        concat_list = []
1081        concat_list.append(torch.ones((SIZE1, 1024 * 512), dtype=torch.uint8, device=device))
1082        concat_list.append(torch.ones((SIZE2, 1024 * 512), dtype=torch.uint8, device=device))
1083        result = torch.cat(concat_list)
1084        self.assertEqual(result.size(0), SIZE1 + SIZE2)
1085
1086    @onlyCPU
1087    @dtypes(torch.half, torch.double, torch.int)
1088    def test_cat2(self, device, dtype):
1089        SIZE = 10
1090        for dim in range(-3, 3):
1091            pos_dim = dim if dim >= 0 else 3 + dim
1092            x = torch.randint(low=-100, high=100, size=(13, SIZE, SIZE), device=device).to(dtype).transpose(0, pos_dim)
1093            y = torch.randint(low=-100, high=100, size=(17, SIZE, SIZE), device=device).to(dtype).transpose(0, pos_dim)
1094            z = torch.randint(low=-100, high=100, size=(19, SIZE, SIZE), device=device).to(dtype).transpose(0, pos_dim)
1095
1096            res1 = torch.cat((x, y, z), dim)
1097            self.assertEqual(res1.narrow(pos_dim, 0, 13), x, atol=0, rtol=0)
1098            self.assertEqual(res1.narrow(pos_dim, 13, 17), y, atol=0, rtol=0)
1099            self.assertEqual(res1.narrow(pos_dim, 30, 19), z, atol=0, rtol=0)
1100
1101        x = torch.randint(low=-100, high=100, size=(20, SIZE, SIZE), device=device).to(dtype)
1102        self.assertEqual(torch.cat(torch.split(x, 7)), x)
1103        self.assertEqual(torch.cat(torch.chunk(x, 7)), x)
1104
1105        y = torch.randint(low=-100, high=100, size=(1, SIZE, SIZE), device=device).to(dtype)
1106        z = torch.cat([x, y])
1107        self.assertEqual(z.size(), (21, SIZE, SIZE))
1108
1109    # FIXME: Create an OpInfo-based tensor creation method test that verifies this for all tensor
1110    #   creation methods and verify all dtypes and layouts
1111    @dtypes(torch.bool, torch.uint8, torch.int16, torch.int64, torch.float16, torch.float32, torch.complex64)
1112    def test_zeros_dtype_layout_device_match(self, device, dtype):
1113        layout = torch.strided
1114        t = torch.zeros((2, 3), device=device, dtype=dtype, layout=layout)
1115        self.assertIs(dtype, t.dtype)
1116        self.assertIs(layout, t.layout)
1117        self.assertEqual(torch.device(device), t.device)
1118
1119    # TODO: update to work on CUDA, too
1120    @onlyCPU
1121    def test_stack(self, device):
1122        for dtype in (torch.half, torch.double, torch.int):
1123            x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
1124            y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
1125            z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
1126            for dim in range(4):
1127                res = torch.stack((x, y, z), dim)
1128                res_neg = torch.stack((x, y, z), dim - 4)
1129                expected_size = x.size()[:dim] + (3,) + x.size()[dim:]
1130                self.assertEqual(res, res_neg)
1131                self.assertEqual(res.size(), expected_size)
1132                self.assertEqual(res.select(dim, 0), x, atol=0, rtol=0)
1133                self.assertEqual(res.select(dim, 1), y, atol=0, rtol=0)
1134                self.assertEqual(res.select(dim, 2), z, atol=0, rtol=0)
1135
1136    # TODO: update to work on CUDA, too
1137    @onlyCPU
1138    def test_stack_out(self, device):
1139        for dtype in (torch.half, torch.double, torch.int):
1140            x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
1141            y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
1142            z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
1143            for dim in range(4):
1144                expected_size = x.size()[:dim] + (3,) + x.size()[dim:]
1145                res_out = x.new(expected_size)
1146                res_neg_out = x.new(expected_size)
1147                res_out_dp = res_out.data_ptr()
1148                res_out_neg_dp = res_neg_out.data_ptr()
1149                torch.stack((x, y, z), dim, out=res_out)
1150                torch.stack((x, y, z), dim - 4, out=res_neg_out)
1151                self.assertEqual(res_out, res_neg_out)
1152                self.assertEqual(res_out.size(), expected_size)
1153                self.assertEqual(res_out_dp, res_out.data_ptr())
1154                self.assertEqual(res_out_neg_dp, res_neg_out.data_ptr())
1155                self.assertEqual(res_out.select(dim, 0), x, atol=0, rtol=0)
1156                self.assertEqual(res_out.select(dim, 1), y, atol=0, rtol=0)
1157                self.assertEqual(res_out.select(dim, 2), z, atol=0, rtol=0)
1158
1159    def test_repeat_interleave(self, device):
1160        x = torch.tensor([0, 1, 2, 3], device=device)
1161        expected = torch.tensor([1, 2, 2, 3, 3, 3], device=device)
1162        self.assertEqual(torch.repeat_interleave(x), expected)
1163
1164        with self.assertRaises(RuntimeError):
1165            torch.repeat_interleave(torch.arange(4, device=device).reshape(2, 2))
1166
1167        with self.assertRaises(RuntimeError):
1168            torch.repeat_interleave(torch.arange(4.0, device=device))
1169
1170        with self.assertRaises(RuntimeError):
1171            torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4], device=device))
1172
1173        y = torch.tensor([[1, 2], [3, 4]], device=device)
1174
1175        y1_v1 = torch.repeat_interleave(y, 2)
1176        y1_v2 = torch.repeat_interleave(y, torch.tensor(2, device=device))
1177        y1_v3 = torch.repeat_interleave(y, torch.tensor([2], device=device))
1178        y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4], device=device)
1179        self.assertEqual(y1_v1, y1_expect)
1180        self.assertEqual(y1_v2, y1_expect)
1181        self.assertEqual(y1_v3, y1_expect)
1182
1183        y2 = torch.repeat_interleave(y, 3, dim=1)
1184        y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2],
1185                                  [3, 3, 3, 4, 4, 4]], device=device)
1186        self.assertEqual(y2, y2_expect)
1187
1188        y3 = torch.repeat_interleave(y, torch.tensor([1, 2], device=device), dim=0)
1189        y3_expect = torch.tensor([[1, 2],
1190                                  [3, 4],
1191                                  [3, 4]], device=device)
1192        self.assertEqual(y3, y3_expect)
1193
1194        with self.assertRaises(RuntimeError):
1195            torch.repeat_interleave(y, torch.tensor([1, 2, 3], device=device), dim=0)
1196
1197        with self.assertRaises(RuntimeError):
1198            torch.repeat_interleave(y, torch.arange(9, device=device).reshape(3, 3), dim=0)
1199
1200        # test zero sized dimension
1201        x = torch.zeros((5, 0), device=device)
1202        y = torch.repeat_interleave(x, repeats=3, dim=1)
1203        self.assertEqual(y, x.new_zeros(5, 0, device=device))
1204
1205        x = torch.tensor([], dtype=torch.int64, device=device)
1206        y = torch.repeat_interleave(x, x)
1207        self.assertEqual(y, x)
1208
1209    # TODO: udpate to work on CUDA, too
1210    @onlyCPU
1211    def test_new_methods_requires_grad(self, device):
1212        size = (10,)
1213        test_cases = [
1214            # method name, args
1215            ('new_full', [size, 1]),
1216            ('new_empty', [size]),
1217            ('new_zeros', [size]),
1218            ('new_ones', [size]),
1219        ]
1220        for method_name, args in test_cases:
1221            x = torch.randn(size)
1222            for requires_grad in [True, False]:
1223                x_new = x.__getattribute__(method_name)(*args, requires_grad=requires_grad)
1224                self.assertEqual(x_new.requires_grad, requires_grad)
1225            x = torch.randint(10, size)
1226            with self.assertRaisesRegex(
1227                    RuntimeError,
1228                    r'Only Tensors of floating point and complex dtype can require gradients'):
1229                x_new = x.__getattribute__(method_name)(*args, requires_grad=True)
1230
1231    # TODO: update to work on CUDA, too?
1232    @onlyCPU
1233    def test_tensor_from_sequence(self, device):
1234        class MockSequence:
1235            def __init__(self, lst):
1236                self.lst = lst
1237
1238            def __len__(self):
1239                return len(self.lst)
1240
1241            def __getitem__(self, item):
1242                raise TypeError
1243
1244        class GoodMockSequence(MockSequence):
1245            def __getitem__(self, item):
1246                return self.lst[item]
1247
1248        bad_mock_seq = MockSequence([1.0, 2.0, 3.0])
1249        good_mock_seq = GoodMockSequence([1.0, 2.0, 3.0])
1250        with self.assertRaisesRegex(ValueError, 'could not determine the shape'):
1251            torch.tensor(bad_mock_seq)
1252        self.assertEqual(torch.tensor([1.0, 2.0, 3.0]), torch.tensor(good_mock_seq))
1253
1254    # TODO: update to work on CUDA, too?
1255    @onlyCPU
1256    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
1257    def test_simple_scalar_cast(self, device):
1258        ok = [torch.tensor([1.5]), torch.zeros(1, 1, 1, 1)]
1259        ok_values = [1.5, 0]
1260
1261        not_ok = map(torch.Tensor, [[], [1, 2], [[1, 2], [3, 4]]])
1262
1263        for tensor, value in zip(ok, ok_values):
1264            self.assertEqual(int(tensor), int(value))
1265            self.assertEqual(float(tensor), float(value))
1266            self.assertEqual(complex(tensor), complex(value))
1267
1268        self.assertEqual(complex(torch.tensor(1.5j)), 1.5j)
1269
1270        for tensor in not_ok:
1271            self.assertRaises(ValueError, lambda: int(tensor))
1272            self.assertRaises(ValueError, lambda: float(tensor))
1273            self.assertRaises(ValueError, lambda: complex(tensor))
1274
1275        self.assertRaises(RuntimeError, lambda: float(torch.tensor(1.5j)))
1276        self.assertRaises(RuntimeError, lambda: int(torch.tensor(1.5j)))
1277
1278    # TODO: update to work on CUDA, too?
1279    @onlyCPU
1280    def test_offset_scalar_cast(self, device):
1281        x = torch.tensor([1., 2., 3.])
1282        y = x[2:]
1283        self.assertEqual(int(y), 3)
1284
1285    def test_meshgrid_empty(self):
1286        with self.assertRaisesRegex(RuntimeError,
1287                                    'expects a non-empty TensorList'):
1288            torch.meshgrid()
1289
1290    def test_meshgrid_unsupported_indexing(self):
1291        with self.assertRaisesRegex(RuntimeError,
1292                                    'indexing must be one of "xy" or "ij"'):
1293            torch.meshgrid(torch.tensor([1, 2]), indexing='')
1294
1295    def test_meshgrid_non_1d_tensor(self):
1296        with self.assertRaisesRegex(RuntimeError,
1297                                    'Expected 0D or 1D tensor'):
1298            torch.meshgrid(torch.tensor([[1, 2], [3, 4]]))
1299
1300    def test_meshgrid_inconsistent_dtype(self):
1301        with self.assertRaisesRegex(
1302                RuntimeError, 'expects all tensors to have the same dtype'):
1303            torch.meshgrid(torch.tensor([1], dtype=torch.int),
1304                           torch.tensor([2], dtype=torch.float))
1305
1306    def test_meshgrid_inconsistent_device(self):
1307        with self.assertRaisesRegex(
1308                RuntimeError, 'expects all tensors to have the same device'):
1309            torch.meshgrid(torch.tensor([1], device='cpu'),
1310                           torch.tensor([2], device='meta'))
1311
1312    def test_meshgrid_warns_if_no_indexing(self):
1313        with self.assertWarnsOnceRegex(
1314                UserWarning, '.*will be required to pass the indexing arg.*'):
1315            torch.meshgrid(torch.tensor([1, 2]))
1316
1317    def test_meshgrid_default_indexing(self, device):
1318        a = torch.tensor(1, device=device)
1319        b = torch.tensor([1, 2, 3], device=device)
1320        c = torch.tensor([1, 2], device=device)
1321        grid_a, grid_b, grid_c = torch.meshgrid([a, b, c])
1322        self.assertEqual(grid_a.shape, torch.Size([1, 3, 2]))
1323        self.assertEqual(grid_b.shape, torch.Size([1, 3, 2]))
1324        self.assertEqual(grid_c.shape, torch.Size([1, 3, 2]))
1325        grid_a2, grid_b2, grid_c2 = torch.meshgrid(a, b, c)
1326        self.assertEqual(grid_a2.shape, torch.Size([1, 3, 2]))
1327        self.assertEqual(grid_b2.shape, torch.Size([1, 3, 2]))
1328        self.assertEqual(grid_c2.shape, torch.Size([1, 3, 2]))
1329        expected_grid_a = torch.ones(1, 3, 2, dtype=torch.int64, device=device)
1330        expected_grid_b = torch.tensor([[[1, 1],
1331                                         [2, 2],
1332                                         [3, 3]]], device=device)
1333        expected_grid_c = torch.tensor([[[1, 2],
1334                                         [1, 2],
1335                                         [1, 2]]], device=device)
1336        self.assertTrue(grid_a.equal(expected_grid_a))
1337        self.assertTrue(grid_b.equal(expected_grid_b))
1338        self.assertTrue(grid_c.equal(expected_grid_c))
1339        self.assertTrue(grid_a2.equal(expected_grid_a))
1340        self.assertTrue(grid_b2.equal(expected_grid_b))
1341        self.assertTrue(grid_c2.equal(expected_grid_c))
1342
1343    def test_meshgrid_xy_indexing(self, device):
1344        a = torch.tensor(1, device=device)
1345        b = torch.tensor([1, 2, 3], device=device)
1346        c = torch.tensor([1, 2], device=device)
1347        grid_a, grid_b, grid_c = torch.meshgrid([a, b, c], indexing='xy')
1348        self.assertEqual(grid_a.shape, torch.Size([3, 1, 2]))
1349        self.assertEqual(grid_b.shape, torch.Size([3, 1, 2]))
1350        self.assertEqual(grid_c.shape, torch.Size([3, 1, 2]))
1351        grid_a2, grid_b2, grid_c2 = torch.meshgrid(a, b, c, indexing='xy')
1352        self.assertEqual(grid_a2.shape, torch.Size([3, 1, 2]))
1353        self.assertEqual(grid_b2.shape, torch.Size([3, 1, 2]))
1354        self.assertEqual(grid_c2.shape, torch.Size([3, 1, 2]))
1355        expected_grid_a = torch.ones(3, 1, 2, dtype=torch.int64, device=device)
1356        expected_grid_b = torch.tensor([[[1, 1]],
1357                                        [[2, 2]],
1358                                        [[3, 3]]], device=device)
1359        expected_grid_c = torch.tensor([[[1, 2]],
1360                                        [[1, 2]],
1361                                        [[1, 2]]], device=device)
1362        self.assertTrue(grid_a.equal(expected_grid_a))
1363        self.assertTrue(grid_b.equal(expected_grid_b))
1364        self.assertTrue(grid_c.equal(expected_grid_c))
1365        self.assertTrue(grid_a2.equal(expected_grid_a))
1366        self.assertTrue(grid_b2.equal(expected_grid_b))
1367        self.assertTrue(grid_c2.equal(expected_grid_c))
1368
1369    def test_meshgrid_ij_indexing(self, device):
1370        a = torch.tensor(1, device=device)
1371        b = torch.tensor([1, 2, 3], device=device)
1372        c = torch.tensor([1, 2], device=device)
1373        grid_a, grid_b, grid_c = torch.meshgrid([a, b, c], indexing='ij')
1374        self.assertEqual(grid_a.shape, torch.Size([1, 3, 2]))
1375        self.assertEqual(grid_b.shape, torch.Size([1, 3, 2]))
1376        self.assertEqual(grid_c.shape, torch.Size([1, 3, 2]))
1377        grid_a2, grid_b2, grid_c2 = torch.meshgrid(a, b, c, indexing='ij')
1378        self.assertEqual(grid_a2.shape, torch.Size([1, 3, 2]))
1379        self.assertEqual(grid_b2.shape, torch.Size([1, 3, 2]))
1380        self.assertEqual(grid_c2.shape, torch.Size([1, 3, 2]))
1381        expected_grid_a = torch.ones(1, 3, 2, dtype=torch.int64, device=device)
1382        expected_grid_b = torch.tensor([[[1, 1],
1383                                         [2, 2],
1384                                         [3, 3]]], device=device)
1385        expected_grid_c = torch.tensor([[[1, 2],
1386                                         [1, 2],
1387                                         [1, 2]]], device=device)
1388        self.assertTrue(grid_a.equal(expected_grid_a))
1389        self.assertTrue(grid_b.equal(expected_grid_b))
1390        self.assertTrue(grid_c.equal(expected_grid_c))
1391        self.assertTrue(grid_a2.equal(expected_grid_a))
1392        self.assertTrue(grid_b2.equal(expected_grid_b))
1393        self.assertTrue(grid_c2.equal(expected_grid_c))
1394
1395    def test_meshgrid_ij_indexing_is_default(self, device):
1396        a = torch.tensor(1, device=device)
1397        b = torch.tensor([1, 2, 3], device=device)
1398        c = torch.tensor([1, 2], device=device)
1399        grid_a, grid_b, grid_c = torch.meshgrid(a, b, c, indexing='ij')
1400        grid_a2, grid_b2, grid_c2 = torch.meshgrid(a, b, c)
1401        self.assertTrue(grid_a.equal(grid_a2))
1402        self.assertTrue(grid_b.equal(grid_b2))
1403        self.assertTrue(grid_c.equal(grid_c2))
1404
1405    @skipMeta
1406    def test_meshgrid_vs_numpy(self, device):
1407        # Shapes to the random tensors. Each line is a test case, and
1408        # each list within that line is the shape of a single
1409        # tensor. The shapes are restricted to 0D (represented by [])
1410        # and 1D tensors.
1411        cases = [
1412            [[]],
1413            [[1], [1], [1]],
1414            [[], [], []],
1415            [[3], [5], [7]],
1416            [[3], [], [7]],
1417            [[11], [13]],
1418            [[15]],
1419        ]
1420
1421        # We also need to test the different indexing modes. We can't
1422        # just enumerate them because we don't presently support the
1423        # same modes as numpy.meshgrid, nor does our default
1424        # correspond to their default.
1425        #
1426        # TODO Eliminate this and replace it with a list of all
1427        # supported indexing modes when we have full compatibility.
1428        indexing_correspondence = [
1429            # No indexing in PyTorch corresponds to "ij" indexing in
1430            # NumPy.
1431            ({}, {'indexing': 'ij'}),
1432
1433            # No indexing in NumPy corresponds to "xy" indexing in
1434            # PyTorch.
1435            ({'indexing': 'xy'}, {}),
1436
1437            # "ij" and "xy" are implemented identically in both.
1438            ({'indexing': 'ij'}, {'indexing': 'ij'}),
1439            ({'indexing': 'xy'}, {'indexing': 'xy'}),
1440        ]
1441        for shapes, (torch_kwargs, numpy_kwargs) in product(cases, indexing_correspondence):
1442            with self.subTest(shapes=shapes, torch_kwargs=torch_kwargs, numpy_kwargs=numpy_kwargs):
1443                tensors = [make_tensor(shape, device=device, dtype=torch.int) for shape in shapes]
1444                torch_grids = torch.meshgrid(*tensors, **torch_kwargs)
1445                numpy_grids = np.meshgrid(*(tensor.cpu().numpy() for tensor in tensors), **numpy_kwargs)
1446                self.assertEqual(torch_grids, numpy_grids)
1447
1448
1449    def test_cartesian_prod(self, device):
1450        a = torch.tensor([1], device=device)
1451        b = torch.tensor([1, 2, 3], device=device)
1452        c = torch.tensor([1, 2], device=device)
1453        prod = torch.cartesian_prod(a, b, c)
1454        expected = torch.tensor(list(product([a], b, c)), device=device)
1455        self.assertEqual(expected, prod)
1456
1457        # test 0 size input
1458        d = torch.empty(0, dtype=b.dtype, device=device)
1459        prod = torch.cartesian_prod(a, b, c, d)
1460        expected = torch.empty(0, 4, dtype=b.dtype, device=device)
1461        self.assertEqual(expected, prod)
1462
1463        # test single input
1464        prod = torch.cartesian_prod(b)
1465        self.assertEqual(b, prod)
1466
1467    def test_combinations(self, device):
1468        a = torch.tensor([1, 2, 3], device=device)
1469
1470        c = torch.combinations(a, r=0)
1471        expected = torch.empty(0, dtype=a.dtype, device=device)
1472        self.assertEqual(c, expected)
1473
1474        c = torch.combinations(a, r=1)
1475        expected = torch.tensor(list(combinations(a, r=1)), device=device)
1476        self.assertEqual(c, expected)
1477
1478        c = torch.combinations(a, r=1, with_replacement=True)
1479        expected = torch.tensor(list(combinations_with_replacement(a, r=1)), device=device)
1480        self.assertEqual(c, expected)
1481
1482        c = torch.combinations(a)
1483        expected = torch.tensor(list(combinations(a, r=2)), device=device)
1484        self.assertEqual(c, expected)
1485
1486        c = torch.combinations(a, with_replacement=True)
1487        expected = torch.tensor(list(combinations_with_replacement(a, r=2)), device=device)
1488        self.assertEqual(c, expected)
1489
1490        c = torch.combinations(a, r=3)
1491        expected = torch.tensor(list(combinations(a, r=3)), device=device)
1492        self.assertEqual(c, expected)
1493
1494        c = torch.combinations(a, r=4)
1495        expected = torch.empty(0, 4, dtype=a.dtype, device=device)
1496        self.assertEqual(c, expected)
1497
1498        c = torch.combinations(a, r=5)
1499        expected = torch.empty(0, 5, dtype=a.dtype, device=device)
1500        self.assertEqual(c, expected)
1501
1502        # test empty imput
1503        a = torch.empty(0, device=device)
1504        c1 = torch.combinations(a)
1505        c2 = torch.combinations(a, with_replacement=True)
1506        expected = torch.empty(0, 2, dtype=a.dtype, device=device)
1507        self.assertEqual(c1, expected)
1508        self.assertEqual(c2, expected)
1509
1510    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
1511    @skipMeta
1512    def test_linlogspace_mem_overlap(self, device):
1513        x = torch.rand(1, device=device).expand(10)
1514        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
1515            torch.linspace(1, 10, 10, out=x)
1516
1517        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
1518            torch.logspace(1, 10, 10, out=x)
1519
1520    def test_ctor_with_numpy_array(self, device):
1521        correct_dtypes = [
1522            np.double,
1523            float,
1524            np.float16,
1525            np.int64,
1526            np.int32,
1527            np.int16,
1528            np.int8,
1529            np.uint8,
1530            bool,
1531        ]
1532
1533        incorrect_byteorder = '>' if sys.byteorder == 'little' else '<'
1534        incorrect_dtypes = [incorrect_byteorder + t for t in ['d', 'f']]
1535
1536        for dtype in correct_dtypes:
1537            array = np.array([1, 2, 3, 4], dtype=dtype)
1538
1539            # Upcast
1540            tensor = torch.DoubleTensor(array).to(device)
1541            for i in range(len(array)):
1542                self.assertEqual(tensor[i], array[i])
1543
1544            # Downcast (sometimes)
1545            tensor = torch.FloatTensor(array).to(device)
1546            for i in range(len(array)):
1547                self.assertEqual(tensor[i], array[i])
1548
1549            tensor = torch.HalfTensor(array).to(device)
1550            for i in range(len(array)):
1551                self.assertEqual(tensor[i], array[i])
1552
1553    @dtypes(torch.float, torch.double, torch.int8, torch.int16, torch.int32, torch.int64)
1554    def test_random(self, device, dtype):
1555        # This test is flaky with p<=(2/(ub-lb))^200=6e-36
1556        t = torch.empty(200, dtype=dtype, device=device)
1557        lb = 1
1558        ub = 4
1559
1560        t.fill_(-1)
1561        t.random_(lb, ub)
1562        self.assertEqual(t.min(), lb)
1563        self.assertEqual(t.max(), ub - 1)
1564
1565        t.fill_(-1)
1566        t.random_(ub)
1567        self.assertEqual(t.min(), 0)
1568        self.assertEqual(t.max(), ub - 1)
1569
1570    def test_random_bool(self, device):
1571        size = 2000
1572        t = torch.empty(size, dtype=torch.bool, device=device)
1573
1574        t.fill_(False)
1575        t.random_()
1576        self.assertEqual(t.min(), False)
1577        self.assertEqual(t.max(), True)
1578        self.assertTrue(0.4 < (t.eq(True)).to(torch.int).sum().item() / size < 0.6)
1579
1580        t.fill_(True)
1581        t.random_()
1582        self.assertEqual(t.min(), False)
1583        self.assertEqual(t.max(), True)
1584        self.assertTrue(0.4 < (t.eq(True)).to(torch.int).sum().item() / size < 0.6)
1585
1586    # https://github.com/pytorch/pytorch/issues/126834
1587    @xfailIfTorchDynamo
1588    def test_random_from_to_bool(self, device):
1589        size = 2000
1590
1591        int64_min_val = torch.iinfo(torch.int64).min
1592        int64_max_val = torch.iinfo(torch.int64).max
1593
1594        min_val = 0
1595        max_val = 1
1596
1597        froms = [int64_min_val, -42, min_val - 1, min_val, max_val, max_val + 1, 42]
1598        tos = [-42, min_val - 1, min_val, max_val, max_val + 1, 42, int64_max_val]
1599
1600        for from_ in froms:
1601            for to_ in tos:
1602                t = torch.empty(size, dtype=torch.bool, device=device)
1603                if to_ > from_:
1604                    if not (min_val <= from_ <= max_val):
1605                        self.assertRaisesRegex(
1606                            RuntimeError,
1607                            "from is out of bounds",
1608                            lambda: t.random_(from_, to_)
1609                        )
1610                    elif not (min_val <= (to_ - 1) <= max_val):
1611                        self.assertRaisesRegex(
1612                            RuntimeError,
1613                            "to - 1 is out of bounds",
1614                            lambda: t.random_(from_, to_)
1615                        )
1616                    else:
1617                        t.random_(from_, to_)
1618                        range_ = to_ - from_
1619                        delta = 1
1620                        self.assertTrue(from_ <= t.to(torch.int).min() < (from_ + delta))
1621                        self.assertTrue((to_ - delta) <= t.to(torch.int).max() < to_)
1622                else:
1623                    self.assertRaisesRegex(
1624                        RuntimeError,
1625                        "random_ expects 'from' to be less than 'to', but got from=" + str(from_) + " >= to=" + str(to_),
1626                        lambda: t.random_(from_, to_)
1627                    )
1628
1629    # NB: uint64 is broken because its max value is not representable in
1630    # int64_t, but this is what random expects
1631    @dtypes(*all_types_and(torch.bfloat16, torch.half, torch.uint16, torch.uint32))
1632    def test_random_full_range(self, device, dtype):
1633        size = 2000
1634        alpha = 0.1
1635
1636        int64_min_val = torch.iinfo(torch.int64).min
1637        int64_max_val = torch.iinfo(torch.int64).max
1638
1639        if dtype == torch.double:
1640            fp_limit = 2**53
1641        elif dtype == torch.float:
1642            fp_limit = 2**24
1643        elif dtype == torch.half:
1644            fp_limit = 2**11
1645        elif dtype == torch.bfloat16:
1646            fp_limit = 2**8
1647        else:
1648            fp_limit = 0
1649
1650        t = torch.empty(size, dtype=dtype, device=device)
1651
1652        if dtype in [torch.float, torch.double, torch.half, torch.bfloat16]:
1653            from_ = int(max(-fp_limit, int64_min_val))
1654            to_inc_ = int(min(fp_limit, int64_max_val))
1655        else:
1656            from_ = int(max(torch.iinfo(dtype).min, int64_min_val))
1657            to_inc_ = int(min(torch.iinfo(dtype).max, int64_max_val))
1658        range_ = to_inc_ - from_ + 1
1659
1660        t.random_(from_, None)
1661        delta = max(1, alpha * range_)
1662        self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta))
1663        self.assertTrue((to_inc_ - delta) < t.to(torch.double).max() <= to_inc_)
1664
1665    # NB: uint64 is broken because its max value is not representable in
1666    # int64_t, but this is what random expects
1667    # https://github.com/pytorch/pytorch/issues/126834
1668    @xfailIfTorchDynamo
1669    @dtypes(*all_types_and(torch.bfloat16, torch.half, torch .uint16, torch.uint32))
1670    def test_random_from_to(self, device, dtype):
1671        size = 2000
1672        alpha = 0.1
1673
1674        int64_min_val = torch.iinfo(torch.int64).min
1675        int64_max_val = torch.iinfo(torch.int64).max
1676
1677        if dtype in [torch.float, torch.double, torch.half]:
1678            min_val = int(max(torch.finfo(dtype).min, int64_min_val))
1679            max_val = int(min(torch.finfo(dtype).max, int64_max_val))
1680            froms = [min_val, -42, 0, 42]
1681            tos = [-42, 0, 42, max_val >> 1]
1682        elif dtype == torch.bfloat16:
1683            min_val = int64_min_val
1684            max_val = int64_max_val
1685            froms = [min_val, -42, 0, 42]
1686            tos = [-42, 0, 42, max_val >> 1]
1687        elif dtype == torch.uint8:
1688            min_val = torch.iinfo(dtype).min
1689            max_val = torch.iinfo(dtype).max
1690            froms = [int64_min_val, -42, min_val - 1, min_val, 42, max_val, max_val + 1]
1691            tos = [-42, min_val - 1, min_val, 42, max_val, max_val + 1, int64_max_val]
1692        elif dtype == torch.int64:
1693            min_val = int64_min_val
1694            max_val = int64_max_val
1695            froms = [min_val, -42, 0, 42]
1696            tos = [-42, 0, 42, max_val]
1697        else:
1698            min_val = torch.iinfo(dtype).min
1699            max_val = torch.iinfo(dtype).max
1700            froms = [int64_min_val, min_val - 1, min_val, -42, 0, 42, max_val, max_val + 1]
1701            tos = [min_val - 1, min_val, -42, 0, 42, max_val, max_val + 1, int64_max_val]
1702
1703        if dtype == torch.double:
1704            fp_limit = 2**53
1705        elif dtype == torch.float:
1706            fp_limit = 2**24
1707        elif dtype == torch.half:
1708            fp_limit = 2**11
1709        elif dtype == torch.bfloat16:
1710            fp_limit = 2**8
1711        else:
1712            fp_limit = 0
1713
1714        for from_ in froms:
1715            for to_ in tos:
1716                t = torch.empty(size, dtype=dtype, device=device)
1717                if to_ > from_:
1718                    if not (min_val <= from_ <= max_val):
1719                        self.assertRaisesRegex(
1720                            RuntimeError,
1721                            "from is out of bounds",
1722                            lambda: t.random_(from_, to_)
1723                        )
1724                    elif not (min_val <= (to_ - 1) <= max_val):
1725                        self.assertRaisesRegex(
1726                            RuntimeError,
1727                            "to - 1 is out of bounds",
1728                            lambda: t.random_(from_, to_)
1729                        )
1730                    else:
1731                        if dtype.is_floating_point and (
1732                                not (-fp_limit <= from_ <= fp_limit) or not (-fp_limit <= (to_ - 1) <= fp_limit)):
1733                            if not (-fp_limit <= from_ <= fp_limit):
1734                                self.assertWarnsRegex(UserWarning, "from is out of bounds",
1735                                                      lambda: t.random_(from_, to_))
1736                            if not (-fp_limit <= (to_ - 1) <= fp_limit):
1737                                self.assertWarnsRegex(UserWarning, "to - 1 is out of bounds",
1738                                                      lambda: t.random_(from_, to_))
1739                        else:
1740                            t.random_(from_, to_)
1741                            range_ = to_ - from_
1742                            delta = max(1, alpha * range_)
1743                            if dtype == torch.bfloat16:
1744                                # Less strict checks because of rounding errors
1745                                # TODO investigate rounding errors
1746                                self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta))
1747                                self.assertTrue((to_ - delta) < t.to(torch.double).max() <= to_)
1748                            else:
1749                                self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta))
1750                                self.assertTrue((to_ - delta) <= t.to(torch.double).max() < to_)
1751                else:
1752                    self.assertRaisesRegex(
1753                        RuntimeError,
1754                        "random_ expects 'from' to be less than 'to', but got from=" + str(from_) + " >= to=" + str(to_),
1755                        lambda: t.random_(from_, to_)
1756                    )
1757
1758    # https://github.com/pytorch/pytorch/issues/126834
1759    @xfailIfTorchDynamo
1760    @dtypes(*all_types_and(torch.bfloat16, torch.half, torch.uint16, torch.uint32))
1761    def test_random_to(self, device, dtype):
1762        size = 2000
1763        alpha = 0.1
1764
1765        int64_min_val = torch.iinfo(torch.int64).min
1766        int64_max_val = torch.iinfo(torch.int64).max
1767
1768        if dtype in [torch.float, torch.double, torch.half]:
1769            min_val = int(max(torch.finfo(dtype).min, int64_min_val))
1770            max_val = int(min(torch.finfo(dtype).max, int64_max_val))
1771            tos = [-42, 0, 42, max_val >> 1]
1772        elif dtype == torch.bfloat16:
1773            min_val = int64_min_val
1774            max_val = int64_max_val
1775            tos = [-42, 0, 42, max_val >> 1]
1776        elif dtype == torch.uint8:
1777            min_val = torch.iinfo(dtype).min
1778            max_val = torch.iinfo(dtype).max
1779            tos = [-42, min_val - 1, min_val, 42, max_val, max_val + 1, int64_max_val]
1780        elif dtype == torch.int64:
1781            min_val = int64_min_val
1782            max_val = int64_max_val
1783            tos = [-42, 0, 42, max_val]
1784        else:
1785            min_val = torch.iinfo(dtype).min
1786            max_val = torch.iinfo(dtype).max
1787            tos = [min_val - 1, min_val, -42, 0, 42, max_val, max_val + 1, int64_max_val]
1788
1789        from_ = 0
1790        for to_ in tos:
1791            t = torch.empty(size, dtype=dtype, device=device)
1792            if to_ > from_:
1793                if not (min_val <= (to_ - 1) <= max_val):
1794                    self.assertRaisesRegex(
1795                        RuntimeError,
1796                        "to - 1 is out of bounds",
1797                        lambda: t.random_(from_, to_)
1798                    )
1799                else:
1800                    t.random_(to_)
1801                    range_ = to_ - from_
1802                    delta = max(1, alpha * range_)
1803                    if dtype == torch.bfloat16:
1804                        # Less strict checks because of rounding errors
1805                        # TODO investigate rounding errors
1806                        self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta))
1807                        self.assertTrue((to_ - delta) < t.to(torch.double).max() <= to_)
1808                    else:
1809                        self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta))
1810                        self.assertTrue((to_ - delta) <= t.to(torch.double).max() < to_)
1811            else:
1812                self.assertRaisesRegex(
1813                    RuntimeError,
1814                    "random_ expects 'from' to be less than 'to', but got from=" + str(from_) + " >= to=" + str(to_),
1815                    lambda: t.random_(from_, to_)
1816                )
1817
1818    @dtypes(*all_types_and(torch.bfloat16, torch.half))
1819    def test_random_default(self, device, dtype):
1820        size = 2000
1821        alpha = 0.1
1822
1823        if dtype == torch.float:
1824            to_inc = 1 << 24
1825        elif dtype == torch.double:
1826            to_inc = 1 << 53
1827        elif dtype == torch.half:
1828            to_inc = 1 << 11
1829        elif dtype == torch.bfloat16:
1830            to_inc = 1 << 8
1831        else:
1832            to_inc = torch.iinfo(dtype).max
1833
1834        t = torch.empty(size, dtype=dtype, device=device)
1835        t.random_()
1836        self.assertTrue(0 <= t.to(torch.double).min() < alpha * to_inc)
1837        self.assertTrue((to_inc - alpha * to_inc) < t.to(torch.double).max() <= to_inc)
1838
1839    # TODO: this test should be updated
1840    @onlyNativeDeviceTypes
1841    def test_empty_full(self, device):
1842        torch_device = torch.device(device)
1843        device_type = torch_device.type
1844
1845        dtypes = get_all_dtypes(include_half=False, include_bfloat16=False, include_complex32=True)
1846        if device_type == 'cpu':
1847            do_test_empty_full(self, dtypes, torch.strided, torch_device)
1848        if device_type == 'cuda':
1849            do_test_empty_full(self, dtypes, torch.strided, None)
1850            do_test_empty_full(self, dtypes, torch.strided, torch_device)
1851
1852    # TODO: this test should be updated
1853    @suppress_warnings
1854    @onlyNativeDeviceTypes
1855    @deviceCountAtLeast(1)
1856    def test_tensor_device(self, devices):
1857        device_type = torch.device(devices[0]).type
1858        if device_type == 'cpu':
1859            self.assertEqual('cpu', torch.tensor(5).device.type)
1860            self.assertEqual('cpu',
1861                             torch.ones((2, 3), dtype=torch.float32, device='cpu').device.type)
1862            self.assertEqual('cpu',
1863                             torch.ones((2, 3), dtype=torch.float32, device='cpu:0').device.type)
1864            self.assertEqual('cpu',
1865                             torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cpu:0').device.type)
1866            self.assertEqual('cpu', torch.tensor(np.random.randn(2, 3), device='cpu').device.type)
1867        if device_type == 'cuda':
1868            self.assertEqual('cuda:0', str(torch.tensor(5).cuda(0).device))
1869            self.assertEqual('cuda:0', str(torch.tensor(5).cuda('cuda:0').device))
1870            self.assertEqual('cuda:0',
1871                             str(torch.tensor(5, dtype=torch.int64, device=0).device))
1872            self.assertEqual('cuda:0',
1873                             str(torch.tensor(5, dtype=torch.int64, device='cuda:0').device))
1874            self.assertEqual('cuda:0',
1875                             str(torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cuda:0').device))
1876
1877            self.assertEqual('cuda:0', str(torch.tensor(np.random.randn(2, 3), device='cuda:0').device))
1878
1879            for device in devices:
1880                with torch.cuda.device(device):
1881                    device_string = 'cuda:' + str(torch.cuda.current_device())
1882                    self.assertEqual(device_string,
1883                                     str(torch.tensor(5, dtype=torch.int64, device='cuda').device))
1884
1885            with self.assertRaises(RuntimeError):
1886                torch.tensor(5).cuda('cpu')
1887            with self.assertRaises(RuntimeError):
1888                torch.tensor(5).cuda('cpu:0')
1889
1890            if len(devices) > 1:
1891                self.assertEqual('cuda:1', str(torch.tensor(5).cuda(1).device))
1892                self.assertEqual('cuda:1', str(torch.tensor(5).cuda('cuda:1').device))
1893                self.assertEqual('cuda:1',
1894                                 str(torch.tensor(5, dtype=torch.int64, device=1).device))
1895                self.assertEqual('cuda:1',
1896                                 str(torch.tensor(5, dtype=torch.int64, device='cuda:1').device))
1897                self.assertEqual('cuda:1',
1898                                 str(torch.tensor(torch.ones((2, 3), dtype=torch.float32),
1899                                     device='cuda:1').device))
1900
1901                self.assertEqual('cuda:1',
1902                                 str(torch.tensor(np.random.randn(2, 3), device='cuda:1').device))
1903
1904    # TODO: this test should be updated
1905    @onlyNativeDeviceTypes
1906    def test_as_strided_neg(self, device):
1907        error = r'as_strided: Negative strides are not supported at the ' \
1908                r'moment, got strides: \[-?[0-9]+(, -?[0-9]+)*\]'
1909        with self.assertRaisesRegex(RuntimeError, error):
1910            torch.as_strided(torch.ones(3, 3, device=device), (1, 1), (2, -1))
1911        with self.assertRaisesRegex(RuntimeError, error):
1912            torch.as_strided(torch.ones(14, device=device), (2,), (-11,))
1913
1914    # TODO: this test should be updated
1915    def test_zeros(self, device):
1916        res1 = torch.zeros(100, 100, device=device)
1917        res2 = torch.tensor((), device=device)
1918        torch.zeros(100, 100, device=device, out=res2)
1919
1920        self.assertEqual(res1, res2)
1921
1922        boolTensor = torch.zeros(2, 2, device=device, dtype=torch.bool)
1923        expected = torch.tensor([[False, False], [False, False]],
1924                                device=device, dtype=torch.bool)
1925        self.assertEqual(boolTensor, expected)
1926
1927        halfTensor = torch.zeros(1, 1, device=device, dtype=torch.half)
1928        expected = torch.tensor([[0.]], device=device, dtype=torch.float16)
1929        self.assertEqual(halfTensor, expected)
1930
1931        bfloat16Tensor = torch.zeros(1, 1, device=device, dtype=torch.bfloat16)
1932        expected = torch.tensor([[0.]], device=device, dtype=torch.bfloat16)
1933        self.assertEqual(bfloat16Tensor, expected)
1934
1935        complexTensor = torch.zeros(2, 2, device=device, dtype=torch.complex64)
1936        expected = torch.tensor([[0., 0.], [0., 0.]], device=device, dtype=torch.complex64)
1937        self.assertEqual(complexTensor, expected)
1938
1939        complexHalfTensor = torch.zeros(2, 2, device=device, dtype=torch.complex32)
1940        expected = torch.tensor([[0., 0.], [0., 0.]], device=device, dtype=torch.complex32)
1941        self.assertEqual(complexHalfTensor, expected)
1942
1943    # TODO: this test should be updated
1944    def test_zeros_out(self, device):
1945        shape = (3, 4)
1946        out = torch.zeros(shape, device=device)
1947        torch.zeros(shape, device=device, out=out)
1948
1949        # change the dtype, layout, device
1950        with self.assertRaises(RuntimeError):
1951            torch.zeros(shape, device=device, dtype=torch.int64, out=out)
1952        with self.assertRaises(RuntimeError):
1953            torch.zeros(shape, device=device, layout=torch.sparse_coo, out=out)
1954
1955        # leave them the same
1956        self.assertEqual(torch.zeros(shape, device=device),
1957                         torch.zeros(shape, device=device, dtype=out.dtype, out=out))
1958        self.assertEqual(torch.zeros(shape, device=device),
1959                         torch.zeros(shape, device=device, layout=torch.strided, out=out))
1960        self.assertEqual(torch.zeros(shape, device=device),
1961                         torch.zeros(shape, device=device, out=out))
1962
1963    # TODO: this test should be updated
1964    def test_ones(self, device):
1965        res1 = torch.ones(100, 100, device=device)
1966        res2 = torch.tensor((), device=device)
1967        torch.ones(100, 100, device=device, out=res2)
1968        self.assertEqual(res1, res2)
1969
1970        # test boolean tensor
1971        res1 = torch.ones(1, 2, device=device, dtype=torch.bool)
1972        expected = torch.tensor([[True, True]], device=device, dtype=torch.bool)
1973        self.assertEqual(res1, expected)
1974
1975        # test chalf
1976        self.assertEqual(torch.ones(100, 100, device=device, dtype=torch.chalf),
1977                         torch.ones(100, 100, device=device, dtype=torch.cfloat), exact_dtype=False)
1978
1979    # TODO: this test should be updated
1980    @onlyCPU
1981    def test_constructor_dtypes(self, device):
1982        self.assertIs(torch.tensor([]).dtype, torch.get_default_dtype())
1983
1984        self.assertIs(torch.uint8, torch.ByteTensor.dtype)
1985        self.assertIs(torch.float32, torch.FloatTensor.dtype)
1986        self.assertIs(torch.float64, torch.DoubleTensor.dtype)
1987
1988        with set_default_tensor_type('torch.FloatTensor'):
1989            self.assertIs(torch.float32, torch.get_default_dtype())
1990            self.assertIs(torch.FloatStorage, torch.Storage)
1991
1992        # only floating-point types are supported as the default type
1993        self.assertRaises(TypeError, lambda: torch.set_default_tensor_type('torch.IntTensor'))
1994
1995        with set_default_dtype(torch.float64):
1996            self.assertIs(torch.float64, torch.get_default_dtype())
1997            self.assertIs(torch.DoubleStorage, torch.Storage)
1998
1999        with set_default_tensor_type(torch.FloatTensor):
2000            self.assertIs(torch.float32, torch.get_default_dtype())
2001            self.assertIs(torch.FloatStorage, torch.Storage)
2002
2003        if torch.cuda.is_available():
2004            with set_default_tensor_type(torch.cuda.FloatTensor):
2005                self.assertIs(torch.float32, torch.get_default_dtype())
2006                self.assertIs(torch.float32, torch.cuda.FloatTensor.dtype)
2007                self.assertIs(torch.cuda.FloatStorage, torch.Storage)
2008
2009                with set_default_dtype(torch.float64):
2010                    self.assertIs(torch.float64, torch.get_default_dtype())
2011                    self.assertIs(torch.cuda.DoubleStorage, torch.Storage)
2012
2013        # don't allow passing dtype to set_default_tensor_type
2014        self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.float32))
2015
2016        # don't allow passing dtype to set_default_dtype
2017        for t in all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.qint8):
2018            # only floating-point types are supported as the default type
2019            if t in (
2020                    torch.half,
2021                    torch.float,
2022                    torch.double,
2023                    torch.bfloat16):
2024                with set_default_dtype(t):
2025                    pass
2026            else:
2027                self.assertRaises(TypeError, lambda: torch.set_default_dtype(t))
2028
2029    # TODO: this test should be updated
2030    @onlyCPU
2031    def test_constructor_device_legacy(self, device):
2032        self.assertRaises(RuntimeError, lambda: torch.FloatTensor(device='cuda'))
2033        self.assertRaises(RuntimeError, lambda: torch.FloatTensor(torch.Size([2, 3, 4]), device='cuda'))
2034        self.assertRaises(RuntimeError, lambda: torch.FloatTensor((2.0, 3.0), device='cuda'))
2035
2036        self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cuda'))
2037        self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cuda'))
2038        self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cuda'))
2039
2040        # Tensor constructor/new with Tensor argument shouldn't work with device specified
2041        i = torch.tensor([1], device='cpu')
2042        self.assertRaises(RuntimeError, lambda: torch.Tensor(i, device='cpu'))
2043        self.assertRaises(RuntimeError, lambda: i.new(i, device='cpu'))
2044        self.assertRaises(RuntimeError, lambda: torch.Tensor(i, device='cuda'))
2045        self.assertRaises(RuntimeError, lambda: i.new(i, device='cuda'))
2046
2047        x = torch.randn((3,), device='cpu')
2048        self.assertRaises(RuntimeError, lambda: x.new(device='cuda'))
2049        self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cuda'))
2050        self.assertRaises(RuntimeError, lambda: x.new((2.0, 3.0), device='cuda'))
2051
2052        if torch.cuda.is_available():
2053            self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor(device='cpu'))
2054            self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor(torch.Size([2, 3, 4]), device='cpu'))
2055            self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor((2.0, 3.0), device='cpu'))
2056
2057            # Tensor constructor/new with Tensor argument shouldn't work with device specified
2058            i = torch.tensor([1], device='cuda')
2059            self.assertRaises(RuntimeError, lambda: torch.Tensor(i, device='cuda'))
2060            self.assertRaises(RuntimeError, lambda: i.new(i, device='cuda'))
2061            self.assertRaises(RuntimeError, lambda: torch.Tensor(i, device='cpu'))
2062            self.assertRaises(RuntimeError, lambda: i.new(i, device='cpu'))
2063
2064            with set_default_tensor_type(torch.cuda.FloatTensor):
2065                self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cpu'))
2066                self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cpu'))
2067                self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cpu'))
2068            x = torch.randn((3,), device='cuda')
2069            self.assertRaises(RuntimeError, lambda: x.new(device='cpu'))
2070            self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cpu'))
2071            self.assertRaises(RuntimeError, lambda: x.new((2.0, 3.0), device='cpu'))
2072
2073    # TODO: this test should be updated
2074    @suppress_warnings
2075    @onlyCPU
2076    def test_tensor_factory(self, device):
2077        # TODO: This test probably doesn't make too much sense now that
2078        # torch.tensor has been established for a while; it makes more
2079        # sense to test the legacy behavior in terms of the new behavior
2080        expected = torch.Tensor([1, 1])
2081        # test data
2082        res1 = torch.tensor([1, 1])
2083        self.assertEqual(res1, expected, exact_dtype=False)
2084
2085        res1 = torch.tensor([1, 1], dtype=torch.int)
2086        self.assertEqual(res1, expected, exact_dtype=False)
2087        self.assertIs(torch.int, res1.dtype)
2088
2089        # test copy
2090        res2 = torch.tensor(expected)
2091        self.assertEqual(res2, expected)
2092        res2[1] = 2
2093        self.assertEqual(expected, torch.ones_like(expected))
2094
2095        res2 = torch.tensor(expected, dtype=torch.int)
2096        self.assertEqual(res1, expected, exact_dtype=False)
2097        self.assertIs(torch.int, res1.dtype)
2098
2099        # test copy with numpy
2100        for dtype in [np.float64, np.int64, np.int8, np.uint8]:
2101            a = np.array([5.]).astype(dtype)
2102            res1 = torch.tensor(a)
2103            self.assertEqual(5., res1[0].item())
2104            a[0] = 7.
2105            self.assertEqual(5., res1[0].item())
2106
2107        # test boolean tensor
2108        a = torch.tensor([True, True, False, True, True], dtype=torch.bool)
2109        b = torch.tensor([-1, -1.1, 0, 1, 1.1], dtype=torch.bool)
2110        self.assertEqual(a, b)
2111        c = torch.tensor([-0.1, -1.1, 0, 1, 0.1], dtype=torch.bool)
2112        self.assertEqual(a, c)
2113        d = torch.tensor((-.3, 0, .3, 1, 3 / 7), dtype=torch.bool)
2114        e = torch.tensor((True, False, True, True, True), dtype=torch.bool)
2115        self.assertEqual(e, d)
2116        f = torch.tensor((-1, 0, -1.1, 1, 1.1), dtype=torch.bool)
2117        self.assertEqual(e, f)
2118
2119        int64_max = torch.iinfo(torch.int64).max
2120        int64_min = torch.iinfo(torch.int64).min
2121        float64_max = torch.finfo(torch.float64).max
2122        float64_min = torch.finfo(torch.float64).min
2123        g_1 = torch.tensor((float('nan'), 0, int64_min, int64_max, int64_min - 1), dtype=torch.bool)
2124        self.assertEqual(e, g_1)
2125        g_2 = torch.tensor((int64_max + 1, 0, (int64_max + 1) * 2, (int64_max + 1) * 2 + 1, float64_min), dtype=torch.bool)
2126        self.assertEqual(e, g_2)
2127        g_3 = torch.tensor((float64_max, 0, float64_max + 1, float64_min - 1, float64_max + 1e291), dtype=torch.bool)
2128        self.assertEqual(e, g_3)
2129
2130        h = torch.tensor([True, False, False, True, False, True, True], dtype=torch.bool)
2131        i = torch.tensor([1e-323, 1e-324, 0j, 1e-323j, 1e-324j, 1 + 2j, -1j], dtype=torch.bool)
2132        self.assertEqual(h, i)
2133        j = torch.tensor((True, True, True, True), dtype=torch.bool)
2134        k = torch.tensor((1e323, -1e323, float('inf'), -float('inf')), dtype=torch.bool)
2135        self.assertEqual(j, k)
2136
2137    # TODO: this test should be updated
2138    @suppress_warnings
2139    @onlyCPU
2140    def test_tensor_factory_copy_var(self, device):
2141        def check_copy(copy, is_leaf, requires_grad, data_ptr=None):
2142            if data_ptr is None:
2143                data_ptr = copy.data_ptr
2144            self.assertEqual(copy, source, exact_dtype=False)
2145            self.assertTrue(copy.is_leaf == is_leaf)
2146            self.assertTrue(copy.requires_grad == requires_grad)
2147            self.assertTrue(copy.data_ptr == data_ptr)
2148
2149        source = torch.randn(5, 5, dtype=torch.double, requires_grad=True)
2150        # test torch.tensor()
2151        check_copy(torch.tensor(source), True, False)
2152        check_copy(torch.tensor(source, requires_grad=False), True, False)
2153        check_copy(torch.tensor(source, requires_grad=True), True, True)
2154
2155        # test tensor.new_tensor()
2156        copy = torch.randn(1)
2157        check_copy(copy.new_tensor(source), True, False)
2158        check_copy(copy.new_tensor(source, requires_grad=False), True, False)
2159        check_copy(copy.new_tensor(source, requires_grad=True), True, True)
2160
2161        # test torch.as_tensor()
2162        check_copy(torch.as_tensor(source), source.is_leaf, source.requires_grad, source.data_ptr)  # not copy
2163        check_copy(torch.as_tensor(source, dtype=torch.float), False, True)  # copy and keep the graph
2164
2165    # TODO: this test should be updated
2166    @onlyCPU
2167    def test_tensor_factory_type_inference(self, device):
2168        def test_inference(default_dtype):
2169            default_complex_dtype = torch.complex64 if default_dtype == torch.float32 else torch.complex128
2170            self.assertIs(default_dtype, torch.tensor(()).dtype)
2171            self.assertIs(default_dtype, torch.tensor(5.).dtype)
2172            self.assertIs(torch.int64, torch.tensor(5).dtype)
2173            self.assertIs(torch.bool, torch.tensor(True).dtype)
2174            self.assertIs(torch.int32, torch.tensor(5, dtype=torch.int32).dtype)
2175            self.assertIs(default_dtype, torch.tensor(((7, 5), (9, 5.))).dtype)
2176            self.assertIs(default_dtype, torch.tensor(((5., 5), (3, 5))).dtype)
2177            self.assertIs(torch.int64, torch.tensor(((5, 3), (3, 5))).dtype)
2178            self.assertIs(default_complex_dtype, torch.tensor(((5, 3 + 2j), (3, 5 + 4j))).dtype)
2179
2180            self.assertIs(torch.float64, torch.tensor(np.array(())).dtype)
2181            self.assertIs(torch.float64, torch.tensor(np.array(5.)).dtype)
2182            if np.array(5).dtype == np.int64:  # np long, which can be 4 bytes (e.g. on windows)
2183                self.assertIs(torch.int64, torch.tensor(np.array(5)).dtype)
2184            else:
2185                self.assertIs(torch.int32, torch.tensor(np.array(5)).dtype)
2186            self.assertIs(torch.uint8, torch.tensor(np.array(3, dtype=np.uint8)).dtype)
2187            self.assertIs(default_dtype, torch.tensor(((7, np.array(5)), (np.array(9), 5.))).dtype)
2188            self.assertIs(torch.float64, torch.tensor(((7, 5), (9, np.array(5.)))).dtype)
2189            self.assertIs(torch.int64, torch.tensor(((5, np.array(3)), (np.array(3), 5))).dtype)
2190
2191        for dtype in [torch.float64, torch.float32]:
2192            with set_default_dtype(dtype):
2193                test_inference(dtype)
2194
2195    # TODO: this test should be updated
2196    @suppress_warnings
2197    @onlyCPU
2198    def test_new_tensor(self, device):
2199        expected = torch.autograd.Variable(torch.ByteTensor([1, 1]))
2200        # test data
2201        res1 = expected.new_tensor([1, 1])
2202        self.assertEqual(res1, expected)
2203        res1 = expected.new_tensor([1, 1], dtype=torch.int)
2204        self.assertEqual(res1, expected, exact_dtype=False)
2205        self.assertIs(torch.int, res1.dtype)
2206
2207        # test copy
2208        res2 = expected.new_tensor(expected)
2209        self.assertEqual(res2, expected)
2210        res2[1] = 2
2211        self.assertEqual(expected, torch.ones_like(expected))
2212        res2 = expected.new_tensor(expected, dtype=torch.int)
2213        self.assertEqual(res2, expected, exact_dtype=False)
2214        self.assertIs(torch.int, res2.dtype)
2215
2216        # test copy with numpy
2217        a = np.array([5.])
2218        res1 = torch.tensor(a)
2219        res1 = res1.new_tensor(a)
2220        self.assertEqual(5., res1[0].item())
2221        a[0] = 7.
2222        self.assertEqual(5., res1[0].item())
2223
2224        if torch.cuda.device_count() >= 2:
2225            expected = expected.cuda(1)
2226            res1 = expected.new_tensor([1, 1])
2227            self.assertEqual(res1.get_device(), expected.get_device())
2228            res1 = expected.new_tensor([1, 1], dtype=torch.int)
2229            self.assertIs(torch.int, res1.dtype)
2230            self.assertEqual(res1.get_device(), expected.get_device())
2231
2232            res2 = expected.new_tensor(expected)
2233            self.assertEqual(res2.get_device(), expected.get_device())
2234            res2 = expected.new_tensor(expected, dtype=torch.int)
2235            self.assertIs(torch.int, res1.dtype)
2236            self.assertEqual(res2.get_device(), expected.get_device())
2237            res2 = expected.new_tensor(expected, dtype=torch.int, device=0)
2238            self.assertIs(torch.int, res1.dtype)
2239            self.assertEqual(res2.get_device(), 0)
2240
2241            res1 = expected.new_tensor(1)
2242            self.assertEqual(res1.get_device(), expected.get_device())
2243            res1 = expected.new_tensor(1, dtype=torch.int)
2244            self.assertIs(torch.int, res1.dtype)
2245            self.assertEqual(res1.get_device(), expected.get_device())
2246
2247    # TODO: this test should be updated
2248    @onlyCPU
2249    def test_as_tensor(self, device):
2250        # from python data
2251        x = [[0, 1], [2, 3]]
2252        self.assertEqual(torch.tensor(x), torch.as_tensor(x))
2253        self.assertEqual(torch.tensor(x, dtype=torch.float32), torch.as_tensor(x, dtype=torch.float32))
2254
2255        # python data with heterogeneous types
2256        z = [0, 'torch']
2257        with self.assertRaisesRegex(TypeError, "invalid data type"):
2258            torch.tensor(z)
2259            torch.as_tensor(z)
2260
2261        # python data with self-referential lists
2262        z = [0]
2263        z += [z]
2264        with self.assertRaisesRegex(TypeError, "self-referential lists are incompatible"):
2265            torch.tensor(z)
2266            torch.as_tensor(z)
2267
2268        z = [[1, 2], z]
2269        with self.assertRaisesRegex(TypeError, "self-referential lists are incompatible"):
2270            torch.tensor(z)
2271            torch.as_tensor(z)
2272
2273        # from tensor (doesn't copy unless type is different)
2274        y = torch.tensor(x)
2275        self.assertIs(y, torch.as_tensor(y))
2276        self.assertIsNot(y, torch.as_tensor(y, dtype=torch.float32))
2277        if torch.cuda.is_available():
2278            self.assertIsNot(y, torch.as_tensor(y, device='cuda'))
2279            y_cuda = y.to('cuda')
2280            self.assertIs(y_cuda, torch.as_tensor(y_cuda))
2281            self.assertIs(y_cuda, torch.as_tensor(y_cuda, device='cuda'))
2282
2283        # doesn't copy
2284        for dtype in [np.float64, np.int64, np.int8, np.uint8]:
2285            n = np.random.rand(5, 6).astype(dtype)
2286            n_astensor = torch.as_tensor(n)
2287            self.assertEqual(torch.tensor(n), n_astensor)
2288            n_astensor[0][0] = 25.7
2289            self.assertEqual(torch.tensor(n), n_astensor)
2290
2291        # changing dtype causes copy
2292        n = np.random.rand(5, 6).astype(np.float32)
2293        n_astensor = torch.as_tensor(n, dtype=torch.float64)
2294        self.assertEqual(torch.tensor(n, dtype=torch.float64), n_astensor)
2295        n_astensor[0][1] = 250.8
2296        self.assertNotEqual(torch.tensor(n, dtype=torch.float64), n_astensor)
2297
2298        # changing device causes copy
2299        if torch.cuda.is_available():
2300            n = np.random.randn(5, 6)
2301            n_astensor = torch.as_tensor(n, device='cuda')
2302            self.assertEqual(torch.tensor(n, device='cuda'), n_astensor)
2303            n_astensor[0][2] = 250.9
2304            self.assertNotEqual(torch.tensor(n, device='cuda'), n_astensor)
2305
2306    # TODO: this test should be updated
2307    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
2308    @suppress_warnings
2309    @dtypesIfCPU(torch.float, torch.bfloat16, torch.float16)
2310    @dtypes(torch.float)
2311    def test_range(self, device, dtype):
2312        res1 = torch.range(0, 1, device=device, dtype=dtype)
2313        res2 = torch.tensor((), device=device, dtype=dtype)
2314        torch.range(0, 1, device=device, dtype=dtype, out=res2)
2315        self.assertEqual(res1, res2, atol=0, rtol=0)
2316
2317        # Check range for non-contiguous tensors.
2318        x = torch.zeros(2, 3, device=device, dtype=dtype)
2319        torch.range(0, 3, device=device, dtype=dtype, out=x.narrow(1, 1, 2))
2320        res2 = torch.tensor(((0, 0, 1), (0, 2, 3)), device=device, dtype=dtype)
2321        self.assertEqual(x, res2, atol=1e-16, rtol=0)
2322
2323        # Check negative
2324        res1 = torch.tensor((1, 0), device=device, dtype=dtype)
2325        res2 = torch.tensor((), device=device, dtype=dtype)
2326        torch.range(1, 0, -1, device=device, dtype=dtype, out=res2)
2327        self.assertEqual(res1, res2, atol=0, rtol=0)
2328
2329        # Equal bounds
2330        res1 = torch.ones(1, device=device, dtype=dtype)
2331        res2 = torch.tensor((), device=device, dtype=dtype)
2332        torch.range(1, 1, -1, device=device, dtype=dtype, out=res2)
2333        self.assertEqual(res1, res2, atol=0, rtol=0)
2334        torch.range(1, 1, 1, device=device, dtype=dtype, out=res2)
2335        self.assertEqual(res1, res2, atol=0, rtol=0)
2336
2337    # TODO: this test should be updated
2338    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
2339    def test_range_warning(self, device):
2340        with warnings.catch_warnings(record=True) as w:
2341            torch.range(0, 10, device=device)
2342            self.assertEqual(len(w), 1)
2343
2344    # TODO: this test should be updated
2345    def test_arange(self, device):
2346        res = torch.tensor(range(10000), device=device)
2347        res1 = torch.arange(0, 10000, device=device)  # Use a larger number so vectorized code can be triggered
2348        res2 = torch.tensor([], dtype=torch.int64, device=device)
2349        torch.arange(0, 10000, out=res2)
2350        self.assertEqual(res, res1, atol=0, rtol=0)
2351        self.assertEqual(res, res2, atol=0, rtol=0)
2352
2353        # Vectorization on non-contiguous tensors
2354        res = torch.rand(3, 3, 300000, device=device).to(torch.int64)
2355        res = res.permute(2, 0, 1)
2356        torch.arange(0, 300000 * 3 * 3, out=res)
2357        self.assertEqual(res.flatten(), torch.arange(0, 300000 * 3 * 3, device=device))
2358
2359        # Check arange with only one argument
2360        res1 = torch.arange(10, device=device)
2361        res2 = torch.arange(0, 10, device=device)
2362        self.assertEqual(res1, res2, atol=0, rtol=0)
2363
2364        # Check arange for non-contiguous tensors.
2365        x = torch.zeros(2, 3, device=device)
2366        torch.arange(0, 4, out=x.narrow(1, 1, 2))
2367        res2 = torch.tensor(((0., 0., 1.), (0., 2., 3.)), device=device)
2368        self.assertEqual(x, res2, atol=1e-16, rtol=0)
2369
2370        # Check negative
2371        res1 = torch.tensor((1., 0.), device=device)
2372        res2 = torch.tensor([], device=device)
2373        torch.arange(1, -1, -1, out=res2)
2374        self.assertEqual(res1, res2, atol=0, rtol=0)
2375
2376        # Equal bounds
2377        res1 = torch.ones(1, device=device)
2378        res2 = torch.tensor([], device=device)
2379        torch.arange(1, 0, -1, out=res2)
2380        self.assertEqual(res1, res2, atol=0, rtol=0)
2381        torch.arange(1, 2, 1, out=res2)
2382        self.assertEqual(res1, res2, atol=0, rtol=0)
2383
2384        # FloatTensor
2385        out = torch.tensor([], dtype=torch.float, device=device)
2386        res1 = torch.arange(0.6, 0.89, 0.1, out=out)
2387        self.assertEqual(res1, [0.6, 0.7, 0.8])
2388        out = torch.tensor([], dtype=torch.float, device=device)
2389        res1 = torch.arange(1, 10, 0.3, out=out)
2390        self.assertEqual(res1.size(0), 30)
2391        self.assertEqual(res1[0], 1)
2392        self.assertEqual(res1[29], 9.7)
2393
2394        # DoubleTensor
2395        out = torch.tensor([], dtype=torch.double, device=device)
2396        res1 = torch.arange(0.6, 0.89, 0.1, out=out)
2397        self.assertEqual(res1, [0.6, 0.7, 0.8])
2398        out = torch.tensor([], dtype=torch.double, device=device)
2399        res1 = torch.arange(1, 10, 0.3, out=out)
2400        self.assertEqual(res1.size(0), 30)
2401        self.assertEqual(res1[0], 1)
2402        self.assertEqual(res1[29], 9.7)
2403
2404        # Bool Input matching numpy semantics
2405        r = torch.arange(True, device=device)
2406        self.assertEqual(r[0], 0)
2407        r2 = torch.arange(False, device=device)
2408        self.assertEqual(len(r2), 0)
2409        self.assertEqual(r.dtype, torch.int64)
2410        self.assertEqual(r2.dtype, torch.int64)
2411
2412        # Check that it's exclusive
2413        r = torch.arange(0, 5, device=device)
2414        self.assertEqual(r.min(), 0)
2415        self.assertEqual(r.max(), 4)
2416        self.assertEqual(r.numel(), 5)
2417
2418        r = torch.arange(0, 6, 3, device=device)
2419        self.assertEqual(r.min(), 0)
2420        self.assertEqual(r.max(), 3)
2421        self.assertEqual(r.numel(), 2)
2422
2423        r = torch.arange(0, 5, 2, device=device)
2424        self.assertEqual(r.min(), 0)
2425        self.assertEqual(r.max(), 4)
2426        self.assertEqual(r.numel(), 3)
2427
2428        r = torch.arange(0, -5, -2, device=device)
2429        self.assertEqual(r.min(), -4)
2430        self.assertEqual(r.max(), 0)
2431        self.assertEqual(r.numel(), 3)
2432
2433        r1 = torch.arange(0, 5 + 1e-6, device=device)
2434        # NB: without the dtype, we'll infer output type to be int64
2435        r2 = torch.arange(0, 5, dtype=torch.float32, device=device)
2436        r3 = torch.arange(0, 5 - 1e-6, device=device)
2437        self.assertEqual(r1[:-1], r2, atol=0, rtol=0)
2438        self.assertEqual(r2, r3, atol=0, rtol=0)
2439
2440        r1 = torch.arange(10, -1 + 1e-6, -1, device=device)
2441        # NB: without the dtype, we'll infer output type to be int64
2442        r2 = torch.arange(10, -1, -1, dtype=torch.float32, device=device)
2443        r3 = torch.arange(10, -1 - 1e-6, -1, device=device)
2444        self.assertEqual(r1, r2, atol=0, rtol=0)
2445        self.assertEqual(r2, r3[:-1], atol=0, rtol=0)
2446
2447        w = 1449629115440469
2448        r = torch.arange(0, 100 * w, w, device=device)
2449        self.assertEqual(r.numel(), 100)
2450
2451        # Test Rounding Errors
2452        line = torch.zeros(size=(1, 49), device=device)
2453        self.assertWarnsRegex(UserWarning, 'The out tensor will be resized',
2454                              lambda: torch.arange(-1, 1, 2. / 49, dtype=torch.float32, out=line))
2455        self.assertEqual(line.shape, [50])
2456
2457        x = torch.empty(1).expand(10)
2458        self.assertRaises(RuntimeError, lambda: torch.arange(10, out=x))
2459
2460        msg = "unsupported range"
2461        self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(-5, float('nan'), device=device))
2462        # check with step size
2463        self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('-inf'), -1, device=device))
2464        self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('inf'), device=device))
2465        self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('-inf'), 10, device=device))
2466        self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('nan'), 10, device=device))
2467        self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('inf'), device=device))
2468        self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('nan'), device=device))
2469
2470        self.assertRaisesRegex(
2471            RuntimeError, "overflow",
2472            lambda: torch.arange(1.175494351e-38, 3.402823466e+38, device=device))
2473
2474        # check that it holds a consistent output shape on precision-cornered step sizes
2475        d = torch.arange(-4.0, 4.0, 0.01, dtype=torch.float32, device=device)
2476        self.assertEqual(d.shape[0], 800)
2477
2478    # TODO: this test should be updated
2479    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
2480    @onlyCPU
2481    def test_arange_inference(self, device):
2482        # end only
2483        self.assertIs(torch.float32, torch.arange(1.).dtype)
2484        self.assertIs(torch.float32, torch.arange(torch.tensor(1.)).dtype)
2485        self.assertIs(torch.float32, torch.arange(torch.tensor(1., dtype=torch.float64)).dtype)
2486
2487        self.assertIs(torch.int64, torch.arange(1).dtype)
2488        self.assertIs(torch.int64, torch.arange(torch.tensor(1)).dtype)
2489        self.assertIs(torch.int64, torch.arange(torch.tensor(1, dtype=torch.int16)).dtype)
2490
2491        # start, end, [step]
2492        self.assertIs(torch.float32, torch.arange(1., 3).dtype)
2493        self.assertIs(torch.float32, torch.arange(torch.tensor(1., dtype=torch.float64), 3).dtype)
2494        self.assertIs(torch.float32, torch.arange(1, 3.).dtype)
2495        self.assertIs(torch.float32, torch.arange(torch.tensor(1, dtype=torch.int16), torch.tensor(3.)).dtype)
2496        self.assertIs(torch.float32, torch.arange(1, 3, 1.).dtype)
2497        self.assertIs(torch.float32,
2498                      torch.arange(torch.tensor(1),
2499                                   torch.tensor(3, dtype=torch.int16),
2500                                   torch.tensor(1., dtype=torch.float64)).dtype)
2501
2502        self.assertIs(torch.int64, torch.arange(1, 3).dtype)
2503        self.assertIs(torch.int64, torch.arange(torch.tensor(1), 3).dtype)
2504        self.assertIs(torch.int64, torch.arange(torch.tensor(1), torch.tensor(3, dtype=torch.int16)).dtype)
2505        self.assertIs(torch.int64, torch.arange(1, 3, 1).dtype)
2506        self.assertIs(torch.int64,
2507                      torch.arange(torch.tensor(1),
2508                                   torch.tensor(3),
2509                                   torch.tensor(1, dtype=torch.int16)).dtype)
2510
2511    # cannot call storage() on meta tensor
2512    @skipMeta
2513    def test_empty_strided(self, device):
2514        for shape in [(2, 3, 4), (0, 2, 0)]:
2515            # some of these cases are pretty strange, just verifying that if as_strided
2516            # allows them then empty_strided can as well.
2517            for strides in [(12, 4, 1), (2, 4, 6), (0, 0, 0)]:
2518                empty_strided = torch.empty_strided(shape, strides, device=device)
2519                # as_strided checks the storage size is big enough to support such a strided tensor;
2520                # instead of repeating this calculation, we just use empty_strided which does the same
2521                # calculation when setting the storage size.
2522                as_strided = torch.empty(empty_strided.storage().size(),
2523                                         device=device).as_strided(shape, strides)
2524                self.assertEqual(empty_strided.shape, as_strided.shape)
2525                self.assertEqual(empty_strided.stride(), as_strided.stride())
2526
2527    def test_new_empty_strided(self, device):
2528        def _test(sizes, strides, dtype):
2529            x = torch.zeros(5, 5, dtype=dtype, device=device)
2530            result = x.new_empty_strided(sizes, strides)
2531            expected = torch.empty_strided(sizes, strides, dtype=x.dtype, device=x.device)
2532            self.assertEqual(result.shape, expected.shape)
2533            self.assertEqual(result.stride(), expected.stride())
2534            self.assertEqual(result.dtype, expected.dtype)
2535            self.assertEqual(result.device, expected.device)
2536
2537        _test([2, 3], [3, 1], torch.float)
2538        _test([5, 3], [0, 1], torch.int)
2539        _test([], [], torch.float)
2540
2541        # Some really weird cases
2542        for shape in [(2, 3, 4), (0, 2, 0)]:
2543            for strides in [(12, 4, 1), (2, 4, 6), (0, 0, 0)]:
2544                _test(shape, strides, torch.float)
2545
2546        # Make sure sizes and strides have the same length
2547        # https://github.com/pytorch/pytorch/issues/82416
2548        with self.assertRaisesRegex(
2549                RuntimeError,
2550                r"dimensionality of sizes \(1\) must match dimensionality of strides \(0\)"):
2551            dtype = torch.float64
2552            x = torch.tensor(-4.8270, dtype=dtype, device=device)
2553            size = (2,)
2554            stride = ()
2555            x.new_empty_strided(size, stride, dtype=dtype, device=device)
2556
2557    def test_strided_mismatched_stride_shape(self, device):
2558        for shape, strides in [((1, ), ()), ((1, 2), (1, ))]:
2559            with self.assertRaisesRegex(RuntimeError, "mismatch in length of strides and shape"):
2560                torch.tensor(0.42, device=device).as_strided(shape, strides)
2561
2562            with self.assertRaisesRegex(RuntimeError, "mismatch in length of strides and shape"):
2563                torch.tensor(0.42, device=device).as_strided_(shape, strides)
2564
2565    def test_empty_tensor_props(self, device):
2566        sizes = [(0,), (0, 3), (5, 0), (5, 0, 3, 0, 2), (0, 3, 0, 2), (0, 5, 0, 2, 0)]
2567        for size in sizes:
2568            x = torch.empty(tuple(size), device=device)
2569            self.assertEqual(size, x.shape)
2570            self.assertTrue(x.is_contiguous())
2571            size_ones_instead_of_zeros = (x if x != 0 else 1 for x in size)
2572            y = torch.empty(tuple(size_ones_instead_of_zeros), device=device)
2573            self.assertEqual(x.stride(), y.stride())
2574
2575    @onlyNativeDeviceTypes
2576    def test_empty_overflow(self, device):
2577        with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
2578            torch.empty([2, 4, 2**29, 2**29], dtype=torch.float64)
2579        with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
2580            torch.empty([8, 8, 2**29, 2**29], dtype=torch.float64)
2581        with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
2582            torch.empty_strided([8, 8], [2**61, 1], dtype=torch.float64)
2583        with self.assertRaisesRegex(RuntimeError, 'Stride calculation overflowed'):
2584            torch.empty([0, 4, 2305843009213693952], dtype=torch.float32)
2585
2586    def test_eye(self, device):
2587        for dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
2588            if dtype == torch.bfloat16:
2589                continue
2590            # Test the RuntimeError is raised when either m or n is a negative number
2591            for n, m in ((-1, 1), (1, -1), (-1, -1)):
2592                with self.assertRaisesRegex(RuntimeError, 'must be greater or equal to'):
2593                    torch.eye(n, m, device=device, dtype=dtype)
2594
2595            # Test when the `m` parameter is not provided
2596            for n in (3, 5, 7):
2597                res1 = torch.eye(n, device=device, dtype=dtype)
2598                naive_eye = torch.zeros(n, n, dtype=dtype, device=device)
2599                naive_eye.diagonal(dim1=-2, dim2=-1).fill_(1)
2600                self.assertEqual(naive_eye, res1)
2601
2602                # Check eye_out outputs
2603                res2 = torch.empty(0, device=device, dtype=dtype)
2604                torch.eye(n, out=res2)
2605                self.assertEqual(res1, res2)
2606
2607            for n, m in product([3, 5, 7], repeat=2):
2608                # Construct identity using diagonal and fill
2609                res1 = torch.eye(n, m, device=device, dtype=dtype)
2610                naive_eye = torch.zeros(n, m, dtype=dtype, device=device)
2611                naive_eye.diagonal(dim1=-2, dim2=-1).fill_(1)
2612                self.assertEqual(naive_eye, res1)
2613
2614                # Check eye_out outputs
2615                res2 = torch.empty(0, device=device, dtype=dtype)
2616                torch.eye(n, m, out=res2)
2617                self.assertEqual(res1, res2)
2618
2619    @precisionOverride({torch.float: 1e-8, torch.double: 1e-10})
2620    @dtypes(*floating_and_complex_types())
2621    def test_linspace_vs_numpy(self, device, dtype):
2622        start = -0.0316082797944545745849609375 + (0.8888888888j if dtype.is_complex else 0)
2623        end = .0315315723419189453125 + (0.444444444444j if dtype.is_complex else 0)
2624
2625        for steps in [1, 2, 3, 5, 11, 256, 257, 2**22]:
2626            t = torch.linspace(start, end, steps, device=device, dtype=dtype)
2627            a = np.linspace(start, end, steps, dtype=torch_to_numpy_dtype_dict[dtype])
2628            t = t.cpu()
2629            self.assertEqual(t, torch.from_numpy(a))
2630            self.assertTrue(t[0].item() == a[0])
2631            self.assertTrue(t[steps - 1].item() == a[steps - 1])
2632
2633    @dtypes(*integral_types())
2634    def test_linspace_vs_numpy_integral(self, device, dtype):
2635        start = 1
2636        end = 127
2637
2638        for steps in [25, 50]:
2639            t = torch.linspace(start, end, steps, device=device, dtype=dtype)
2640            a = np.linspace(start, end, steps, dtype=torch_to_numpy_dtype_dict[dtype])
2641            t = t.cpu()
2642            self.assertEqual(t, torch.from_numpy(a))
2643            self.assertTrue(t[0].item() == a[0])
2644            self.assertTrue(t[steps - 1].item() == a[steps - 1])
2645
2646    def _test_linspace_logspace_complex_helper(self, torch_fn, np_fn, device, dtype):
2647        start = torch.randn(1, dtype=dtype).item()
2648        end = (start + torch.randn(1, dtype=dtype) + random.randint(5, 15)).item()
2649
2650        def test_fn(torch_fn, numpy_fn, steps):
2651            t = torch_fn(start, end, steps, device=device)
2652            a = numpy_fn(start, end, steps, dtype=torch_to_numpy_dtype_dict[dtype])
2653            t = t.cpu()
2654            self.assertEqual(t, torch.from_numpy(a))
2655
2656        for steps in [1, 2, 3, 5, 11, 256, 257, 2**22]:
2657            test_fn(torch.linspace, np.linspace, steps)
2658
2659    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
2660    @dtypes(torch.complex64)
2661    def test_linspace_vs_numpy_complex(self, device, dtype):
2662        self._test_linspace_logspace_complex_helper(torch.linspace, np.linspace,
2663                                                    device, dtype)
2664
2665    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
2666    @dtypes(torch.complex64)
2667    def test_logspace_vs_numpy_complex(self, device, dtype):
2668        self._test_linspace_logspace_complex_helper(torch.logspace, np.logspace,
2669                                                    device, dtype)
2670
2671    @precisionOverride({torch.float: 1e-6, torch.double: 1e-10})
2672    @dtypes(*floating_types())
2673    def test_logspace_vs_numpy(self, device, dtype):
2674        start = -0.0316082797944545745849609375
2675        end = .0315315723419189453125
2676
2677        for steps in [1, 2, 3, 5, 11, 256, 257, 2**22]:
2678            t = torch.logspace(start, end, steps, device=device, dtype=dtype)
2679            a = np.logspace(start, end, steps, dtype=torch_to_numpy_dtype_dict[dtype])
2680            t = t.cpu()
2681            self.assertEqual(t, torch.from_numpy(a))
2682            self.assertEqual(t[0], a[0])
2683            self.assertEqual(t[steps - 1], a[steps - 1])
2684
2685    @onlyCUDA
2686    @largeTensorTest('16GB')
2687    def test_range_factories_64bit_indexing(self, device):
2688        bigint = 2 ** 31 + 1
2689        t = torch.arange(bigint, dtype=torch.long, device=device)
2690        self.assertEqual(t[-1].item(), bigint - 1)
2691        del t
2692        t = torch.linspace(0, 1, bigint, dtype=torch.float, device=device)
2693        self.assertEqual(t[-1].item(), 1)
2694        del t
2695        t = torch.logspace(0, 1, bigint, 2, dtype=torch.float, device=device)
2696        self.assertEqual(t[-1].item(), 2)
2697        del t
2698
2699    @expectedFailureMeta  # RuntimeError: The tensor has a non-zero number of elements
2700    @onlyNativeDeviceTypes
2701    def test_tensor_ctor_device_inference(self, device):
2702        torch_device = torch.device(device)
2703        values = torch.tensor((1, 2, 3), device=device)
2704
2705        # Tests tensor and as_tensor
2706        # Note: warnings are suppressed (suppresses warnings)
2707        for op in (torch.tensor, torch.as_tensor):
2708            with warnings.catch_warnings():
2709                warnings.simplefilter("ignore")
2710                self.assertEqual(op(values).device, torch_device)
2711                self.assertEqual(op(values, dtype=torch.float64).device, torch_device)
2712
2713                if self.device_type == 'cuda':
2714                    with torch.cuda.device(device):
2715                        self.assertEqual(op(values.cpu()).device, torch.device('cpu'))
2716
2717        # Tests sparse ctor
2718        indices = torch.tensor([[0, 1, 1],
2719                                [2, 0, 1],
2720                                [2, 1, 0]], device=device)
2721        sparse_size = (3, 3, 3)
2722
2723        sparse_default = torch.sparse_coo_tensor(indices, values, sparse_size)
2724        self.assertEqual(sparse_default.device, torch_device)
2725
2726        sparse_with_dtype = torch.sparse_coo_tensor(indices, values, sparse_size, dtype=torch.float64)
2727        self.assertEqual(sparse_with_dtype.device, torch_device)
2728
2729        if self.device_type == 'cuda':
2730            with torch.cuda.device(device):
2731                sparse_with_dtype = torch.sparse_coo_tensor(indices.cpu(), values.cpu(),
2732                                                            sparse_size, dtype=torch.float64)
2733                self.assertEqual(sparse_with_dtype.device, torch.device('cpu'))
2734
2735    def _test_signal_window_functions(self, name, dtype, device, **kwargs):
2736        import scipy.signal as signal
2737
2738        torch_method = getattr(torch, name + '_window')
2739        if not dtype.is_floating_point:
2740            with self.assertRaisesRegex(RuntimeError, r'floating point'):
2741                torch_method(3, dtype=dtype)
2742            return
2743        for size in [0, 1, 2, 5, 10, 50, 100, 1024, 2048]:
2744            for periodic in [True, False]:
2745                res = torch_method(
2746                    size,
2747                    periodic=periodic,
2748                    layout=torch.strided,
2749                    requires_grad=False,
2750                    **kwargs,
2751                    device=device,
2752                    dtype=dtype,
2753                )
2754                # NB: scipy always returns a float64 result
2755                ref = torch.from_numpy(
2756                    signal.get_window(
2757                        (name, *(kwargs.values())), size, fftbins=periodic
2758                    )
2759                )
2760                self.assertEqual(res, ref.to(dtype))
2761        with self.assertRaisesRegex(RuntimeError, r'not implemented for sparse types'):
2762            torch_method(3, layout=torch.sparse_coo)
2763        self.assertTrue(torch_method(3, requires_grad=True).requires_grad)
2764        self.assertFalse(torch_method(3).requires_grad)
2765
2766    @onlyNativeDeviceTypes
2767    @precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3})
2768    @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
2769    @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long)
2770    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
2771    @dtypes(torch.float, torch.double, torch.long)
2772    @parametrize("window", ['hann', 'hamming', 'bartlett', 'blackman'])
2773    def test_signal_window_functions(self, device, dtype, window):
2774        self._test_signal_window_functions(window, dtype, device)
2775
2776    @onlyNativeDeviceTypes
2777    @precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3})
2778    @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
2779    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
2780    @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long)
2781    @dtypes(torch.float, torch.double, torch.long, torch.bfloat16, torch.float16)
2782    def test_kaiser_window(self, device, dtype):
2783        for num_test in range(50):
2784            self._test_signal_window_functions('kaiser', dtype, device, beta=random.random() * 30)
2785
2786    def _test_signal_windows_functions(self, name, dtype, device, **kwargs):
2787        import scipy.signal as signal
2788
2789        torch_method = getattr(torch.signal.windows, name)
2790        if not dtype.is_floating_point:
2791            with self.assertRaisesRegex(RuntimeError, r'floating point'):
2792                torch_method(3, dtype=dtype)
2793            return
2794        for size in [0, 1, 2, 5, 10, 50, 100, 1024, 2048]:
2795            for periodic in [True, False]:
2796                res = torch_method(size, sym=not periodic, **kwargs, device=device, dtype=dtype)
2797                # NB: scipy always returns a float64 result
2798                ref = torch.from_numpy(signal.get_window((name, *(kwargs.values())), size, fftbins=periodic))
2799                self.assertEqual(res, ref, exact_dtype=False)
2800        self.assertTrue(torch_method(3, requires_grad=True).requires_grad)
2801        self.assertFalse(torch_method(3).requires_grad)
2802
2803    # torch.signal.windows functions (except any with extra parameters)
2804    @onlyNativeDeviceTypes
2805    @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
2806    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
2807    @dtypes(torch.float, torch.double)
2808    @parametrize("window", ['bartlett', 'blackman', 'cosine', 'hamming', 'hann', 'nuttall'])
2809    def test_signal_windows_functions(self, device, dtype, window):
2810        self._test_signal_windows_functions(window, dtype, device)
2811
2812    # torch.signal.windows.kaiser
2813    @onlyNativeDeviceTypes
2814    @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
2815    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
2816    @dtypes(torch.float, torch.double)
2817    def test_kaiser(self, device, dtype):
2818        for num_test in range(50):
2819            self._test_signal_windows_functions('kaiser', dtype, device, beta=random.random() * 30)
2820
2821    def test_tensor_factories_empty(self, device):
2822        # ensure we can create empty tensors from each factory function
2823        shapes = [(5, 0, 1), (0,), (0, 0, 1, 0, 2, 0, 0)]
2824
2825        for shape in shapes:
2826            for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.chalf):
2827
2828                self.assertEqual(shape, torch.zeros(shape, device=device, dtype=dt).shape)
2829                self.assertEqual(shape, torch.zeros_like(torch.zeros(shape, device=device, dtype=dt)).shape)
2830                self.assertEqual(shape, torch.full(shape, 3, device=device, dtype=dt).shape)
2831                self.assertEqual(shape, torch.full_like(torch.zeros(shape, device=device, dtype=dt), 3).shape)
2832                self.assertEqual(shape, torch.ones(shape, device=device, dtype=dt).shape)
2833                self.assertEqual(shape, torch.ones_like(torch.zeros(shape, device=device, dtype=dt)).shape)
2834                self.assertEqual(shape, torch.empty(shape, device=device, dtype=dt).shape)
2835                self.assertEqual(shape, torch.empty_like(torch.zeros(shape, device=device, dtype=dt)).shape)
2836                self.assertEqual(shape, torch.empty_strided(shape, (0,) * len(shape), device=device, dtype=dt).shape)
2837
2838                if dt == torch.bool:
2839                    self.assertEqual(shape, torch.randint(2, shape, device=device, dtype=dt).shape)
2840                    self.assertEqual(shape, torch.randint_like(torch.zeros(shape, device=device, dtype=dt), 2).shape)
2841                elif dt.is_complex:
2842                    self.assertRaises(RuntimeError, lambda: torch.randint(6, shape, device=device, dtype=dt).shape)
2843                else:
2844                    self.assertEqual(shape, torch.randint(6, shape, device=device, dtype=dt).shape)
2845                    self.assertEqual(shape, torch.randint_like(torch.zeros(shape, device=device, dtype=dt), 6).shape)
2846
2847                if dt not in {torch.double, torch.float, torch.half, torch.bfloat16,
2848                              torch.complex32, torch.complex64, torch.complex128}:
2849                    self.assertRaises(RuntimeError, lambda: torch.rand(shape, device=device, dtype=dt).shape)
2850
2851                if dt == torch.double or dt == torch.float or dt.is_complex:
2852                    self.assertEqual(shape, torch.randn(shape, device=device, dtype=dt).shape)
2853                    self.assertEqual(shape, torch.randn_like(torch.zeros(shape, device=device, dtype=dt)).shape)
2854
2855        self.assertEqual((0,), torch.arange(0, device=device).shape)
2856        self.assertEqual((0, 0), torch.eye(0, device=device).shape)
2857        self.assertEqual((0, 0), torch.eye(0, 0, device=device).shape)
2858        self.assertEqual((5, 0), torch.eye(5, 0, device=device).shape)
2859        self.assertEqual((0, 5), torch.eye(0, 5, device=device).shape)
2860        self.assertEqual((0,), torch.linspace(1, 1, 0, device=device).shape)
2861        self.assertEqual((0,), torch.logspace(1, 1, 0, device=device).shape)
2862        self.assertEqual((0,), torch.randperm(0, device=device).shape)
2863        self.assertEqual((0,), torch.bartlett_window(0, device=device).shape)
2864        self.assertEqual((0,), torch.bartlett_window(0, periodic=False, device=device).shape)
2865        self.assertEqual((0,), torch.hamming_window(0, device=device).shape)
2866        self.assertEqual((0,), torch.hann_window(0, device=device).shape)
2867        self.assertEqual((0,), torch.kaiser_window(0, device=device).shape)
2868        self.assertEqual((1, 1, 0), torch.tensor([[[]]], device=device).shape)
2869        self.assertEqual((1, 1, 0), torch.as_tensor([[[]]], device=device).shape)
2870
2871    @onlyCUDA
2872    def test_tensor_factory_gpu_type_inference(self, device):
2873        with set_default_tensor_type(torch.cuda.DoubleTensor):
2874            with set_default_dtype(torch.float32):
2875                self.assertIs(torch.float32, torch.tensor(0.).dtype)
2876                self.assertEqual(torch.device(device), torch.tensor(0.).device)
2877            with set_default_dtype(torch.float64):
2878                self.assertIs(torch.float64, torch.tensor(0.).dtype)
2879                self.assertEqual(torch.device(device), torch.tensor(0.).device)
2880
2881    @onlyCUDA
2882    def test_tensor_factory_gpu_type(self, device):
2883        with set_default_tensor_type(torch.cuda.FloatTensor):
2884            x = torch.zeros((5, 5))
2885            self.assertIs(torch.float32, x.dtype)
2886            self.assertTrue(x.is_cuda)
2887        with set_default_tensor_type(torch.cuda.DoubleTensor):
2888            x = torch.zeros((5, 5))
2889            self.assertIs(torch.float64, x.dtype)
2890            self.assertTrue(x.is_cuda)
2891
2892    @skipCPUIf(True, 'compares device with cpu')
2893    @dtypes(torch.int, torch.long, torch.float, torch.double)
2894    def test_arange_device_vs_cpu(self, device, dtype):
2895        cpu_tensor = torch.arange(0, 10, dtype=dtype, device='cpu')
2896        device_tensor = torch.arange(0, 10, dtype=dtype, device=device)
2897        self.assertEqual(cpu_tensor, device_tensor)
2898
2899    @dtypes(torch.bfloat16, torch.float16)
2900    def test_arange_lowp(self, device, dtype):
2901        ref_tensor = torch.tensor([0, 1, 2, 3], dtype=dtype, device=device)
2902        f16_tensor = torch.arange(0, 4, dtype=dtype, device=device)
2903        self.assertEqual(ref_tensor, f16_tensor)
2904
2905        # step=2
2906        ref_tensor = torch.tensor([0, 2, 4], dtype=dtype, device=device)
2907        f16_tensor = torch.arange(0, 6, step=2, dtype=dtype, device=device)
2908        self.assertEqual(ref_tensor, f16_tensor)
2909
2910    @dtypes(*all_types_and_complex_and(torch.bfloat16))
2911    @dtypesIfCUDA(*all_types_and_complex_and(torch.bfloat16))
2912    def test_linspace(self, device, dtype):
2913        _from = random.random()
2914        to = _from + random.random()
2915        res1 = torch.linspace(_from, to, 137, device=device, dtype=dtype)
2916        res2 = torch.tensor((), device=device, dtype=dtype)
2917        torch.linspace(_from, to, 137, dtype=dtype, out=res2)
2918        self.assertEqual(res1, res2, atol=0, rtol=0)
2919
2920        # small tensor
2921        self.assertEqual(torch.linspace(10, 20, 11, device=device, dtype=dtype),
2922                         torch.tensor(list(range(10, 21)), device=device, dtype=dtype))
2923        # large tensor
2924        if dtype not in (torch.int8, torch.uint8):
2925            self.assertEqual(torch.linspace(10, 2000, 1991, device=device, dtype=dtype),
2926                             torch.tensor(list(range(10, 2001)), device=device, dtype=dtype))
2927
2928        # Vectorization on non-contiguous tensors
2929        if dtype not in (torch.int8, torch.uint8):  # int8 and uint8 are too small for this test
2930            res = torch.rand(3, 3, 1000, device=device).to(dtype)
2931            res = res.permute(2, 0, 1)
2932            torch.linspace(0, 1000 * 3 * 3, 1000 * 3 * 3, out=res)
2933            self.assertEqual(res.flatten(), torch.linspace(0, 1000 * 3 * 3, 1000 * 3 * 3, device=device, dtype=dtype))
2934
2935        self.assertRaises(RuntimeError, lambda: torch.linspace(0, 1, -1, device=device, dtype=dtype))
2936        # steps = 1
2937        self.assertEqual(torch.linspace(0, 1, 1, device=device, dtype=dtype),
2938                         torch.zeros(1, device=device, dtype=dtype), atol=0, rtol=0)
2939        # steps = 0
2940        self.assertEqual(torch.linspace(0, 1, 0, device=device, dtype=dtype).numel(), 0, atol=0, rtol=0)
2941
2942        # steps not provided
2943        self.assertRaises(TypeError, lambda: torch.linspace(0, 1, device=device, dtype=dtype))
2944
2945        if dtype == torch.float:
2946            # passed dtype can't be safely casted to inferred dtype
2947            with self.assertRaisesRegex(RuntimeError, r"torch.linspace\(\): inferred dtype"):
2948                torch.linspace(0, 1j, 5, device=device, dtype=dtype)
2949            with self.assertRaisesRegex(RuntimeError, r"torch.linspace\(\): inferred dtype"):
2950                torch.linspace(0j, 1, 5, device=device, dtype=dtype)
2951            with self.assertRaisesRegex(RuntimeError, r"torch.linspace\(\): inferred dtype"):
2952                torch.linspace(0j, 1j, 5, device=device, dtype=dtype)
2953
2954        # Check linspace for generating the correct output for each dtype.
2955        start = 0 if dtype == torch.uint8 else -100
2956        expected_lin = torch.tensor([start + .5 * i for i in range(401)], device=device, dtype=torch.double)
2957        actual_lin = torch.linspace(start, start + 200, 401, device=device, dtype=dtype)
2958        # If on GPU, allow for minor error depending on dtype.
2959        tol = 0.
2960        if device != 'cpu':
2961            if dtype == torch.half:
2962                tol = 1e-1
2963            elif dtype == torch.float:
2964                tol = 1e-5
2965            elif dtype == torch.double:
2966                tol = 1e-10
2967
2968        self.assertEqual(expected_lin.to(dtype), actual_lin, atol=tol, rtol=0)
2969
2970        # Check linspace for generating with start > end.
2971        self.assertEqual(torch.linspace(2, 0, 3, device=device, dtype=dtype),
2972                         torch.tensor((2, 1, 0), device=device, dtype=dtype),
2973                         atol=0, rtol=0)
2974
2975        # Check for race condition (correctness when applied on a large tensor).
2976        if dtype not in (torch.int8, torch.uint8, torch.int16, torch.half, torch.bfloat16):
2977            y = torch.linspace(0, 999999 + (999999j if dtype.is_complex else 0),
2978                               1000000, device=device, dtype=dtype)
2979            if dtype.is_complex:
2980                cond = torch.logical_and(y[:-1].real < y[1:].real, y[:-1].imag < y[1:].imag)
2981            else:
2982                cond = y[:-1] < y[1:]
2983            correct = all(cond)
2984            self.assertTrue(correct)
2985
2986        # Check linspace for non-contiguous tensors.
2987        x = torch.zeros(2, 3, device=device, dtype=dtype)
2988        y = torch.linspace(0, 3, 4, out=x.narrow(1, 1, 2), dtype=dtype)
2989        self.assertEqual(x, torch.tensor(((0, 0, 1), (0, 2, 3)), device=device, dtype=dtype), atol=0, rtol=0)
2990
2991    def _test_linspace_logspace_deduction_helper(self, fn, device):
2992        for start, end in [(1, 2), (1., 2), (1., -2.), (1j, 2j), (0., 2j), (1j, 2)]:
2993            dtype = torch.float32
2994            if isinstance(start, complex) or isinstance(end, complex):
2995                dtype = torch.cfloat
2996
2997            self.assertEqual(fn(start, end, steps=100, device=device).dtype, dtype)
2998
2999    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
3000    def test_linspace_deduction(self, device):
3001        # Test deduction from input parameters.
3002        self._test_linspace_logspace_deduction_helper(torch.linspace, device)
3003
3004    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
3005    def test_logspace_deduction(self, device):
3006        # Test deduction from input parameters.
3007        self._test_linspace_logspace_deduction_helper(torch.logspace, device)
3008
3009    # The implementation of linspace+logspace goes through a different path
3010    # when the steps arg is equal to 0 or 1. For other values of `steps`
3011    # they call specialized linspace (or logspace) kernels.
3012    LINSPACE_LOGSPACE_SPECIAL_STEPS = [0, 1]
3013
3014    # NOTE [Linspace+Logspace precision override]
3015    # Our Linspace and logspace torch.half CUDA kernels are not very precise.
3016    # Since linspace/logspace are deterministic, we can compute an expected
3017    # amount of error (by testing without a precision override), adding a tiny
3018    # amount (EPS) to that, and using that value as the override.
3019    LINSPACE_LOGSPACE_EXTRA_EPS = 1e-5
3020
3021    # Compares linspace device vs. cpu
3022    def _test_linspace(self, device, dtype, steps):
3023        a = torch.linspace(0, 10, steps=steps, dtype=dtype, device=device)
3024        b = torch.linspace(0, 10, steps=steps)
3025        self.assertEqual(a, b, exact_dtype=False)
3026
3027    # See NOTE [Linspace+Logspace precision override]
3028    @skipCPUIf(True, "compares with CPU")
3029    @precisionOverride({torch.half: 0.0039 + LINSPACE_LOGSPACE_EXTRA_EPS})
3030    @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
3031    def test_linspace_device_vs_cpu(self, device, dtype):
3032        self._test_linspace(device, dtype, steps=10)
3033
3034    @skipCPUIf(True, "compares with CPU")
3035    @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
3036    def test_linspace_special_steps(self, device, dtype):
3037        for steps in self.LINSPACE_LOGSPACE_SPECIAL_STEPS:
3038            self._test_linspace(device, dtype, steps=steps)
3039
3040    # Compares logspace device vs cpu
3041    def _test_logspace(self, device, dtype, steps):
3042        a = torch.logspace(1, 1.1, steps=steps, dtype=dtype, device=device)
3043        b = torch.logspace(1, 1.1, steps=steps)
3044        self.assertEqual(a, b, exact_dtype=False)
3045
3046    # Compares logspace device vs cpu
3047    def _test_logspace_base2(self, device, dtype, steps):
3048        a = torch.logspace(1, 1.1, steps=steps, base=2, dtype=dtype, device=device)
3049        b = torch.logspace(1, 1.1, steps=steps, base=2)
3050        self.assertEqual(a, b, exact_dtype=False)
3051
3052    # See NOTE [Linspace+Logspace precision override]
3053    @skipCPUIf(True, "compares with CPU")
3054    @precisionOverride({torch.half: 0.025 + LINSPACE_LOGSPACE_EXTRA_EPS})
3055    @dtypesIfCUDA(torch.half, torch.float, torch.double)
3056    @dtypes(torch.float, torch.double)
3057    def test_logspace_device_vs_cpu(self, device, dtype):
3058        self._test_logspace(device, dtype, steps=10)
3059
3060    # See NOTE [Linspace+Logspace precision override]
3061    @skipCPUIf(True, "compares with CPU")
3062    @precisionOverride({torch.half: 0.0201 + LINSPACE_LOGSPACE_EXTRA_EPS})
3063    @dtypesIfCUDA(torch.half, torch.float, torch.double)
3064    @dtypes(torch.float, torch.double)
3065    def test_logspace_base2(self, device, dtype):
3066        self._test_logspace_base2(device, dtype, steps=10)
3067
3068    @skipCPUIf(True, "compares with CPU")
3069    @dtypesIfCUDA(torch.half, torch.float, torch.double)
3070    @dtypes(torch.float, torch.double)
3071    def test_logspace_special_steps(self, device, dtype):
3072        for steps in self.LINSPACE_LOGSPACE_SPECIAL_STEPS:
3073            self._test_logspace(device, dtype, steps=steps)
3074            self._test_logspace_base2(device, dtype, steps=steps)
3075
3076    @dtypes(*all_types_and(torch.bfloat16))
3077    @dtypesIfCUDA(*integral_types_and(torch.half, torch.bfloat16, torch.float32, torch.float64) if TEST_WITH_ROCM else
3078                  all_types_and(torch.half, torch.bfloat16))
3079    def test_logspace(self, device, dtype):
3080        _from = random.random()
3081        to = _from + random.random()
3082        res1 = torch.logspace(_from, to, 137, device=device, dtype=dtype)
3083        res2 = torch.tensor((), device=device, dtype=dtype)
3084        torch.logspace(_from, to, 137, device=device, dtype=dtype, out=res2)
3085        self.assertEqual(res1, res2, atol=0, rtol=0)
3086        self.assertRaises(RuntimeError, lambda: torch.logspace(0, 1, -1, device=device, dtype=dtype))
3087        # steps not provided
3088        self.assertRaises(TypeError, lambda: torch.logspace(0, 1, device=device, dtype=dtype))
3089        self.assertEqual(torch.logspace(0, 1, 1, device=device, dtype=dtype),
3090                         torch.ones(1, device=device, dtype=dtype), atol=0, rtol=0)
3091
3092        if dtype == torch.float:
3093            # passed dtype can't be safely casted to inferred dtype
3094            with self.assertRaisesRegex(RuntimeError, r"torch.logspace\(\): inferred dtype"):
3095                torch.logspace(0, 1j, 5, device=device, dtype=dtype)
3096            with self.assertRaisesRegex(RuntimeError, r"torch.logspace\(\): inferred dtype"):
3097                torch.logspace(0j, 1, 5, device=device, dtype=dtype)
3098            with self.assertRaisesRegex(RuntimeError, r"torch.logspace\(\): inferred dtype"):
3099                torch.logspace(0j, 1j, 5, device=device, dtype=dtype)
3100
3101        # Check precision - start, stop and base are chosen to avoid overflow
3102        # steps is chosen so that step size is not subject to rounding error
3103        # a tolerance is needed for gpu tests due to differences in computation
3104        atol = None
3105        rtol = None
3106        if self.device_type == 'cpu':
3107            atol = 0
3108            rtol = 0
3109        self.assertEqual(torch.tensor([2. ** (i / 8.) for i in range(49)], device=device, dtype=dtype),
3110                         torch.logspace(0, 6, steps=49, base=2, device=device, dtype=dtype),
3111                         atol=atol, rtol=rtol)
3112
3113        # Check non-default base=2
3114        self.assertEqual(torch.logspace(1, 1, 1, 2, device=device, dtype=dtype),
3115                         torch.ones(1, device=device, dtype=dtype) * 2)
3116        self.assertEqual(torch.logspace(0, 2, 3, 2, device=device, dtype=dtype),
3117                         torch.tensor((1, 2, 4), device=device, dtype=dtype))
3118
3119        # Check logspace_ for generating with start > end.
3120        self.assertEqual(torch.logspace(1, 0, 2, device=device, dtype=dtype),
3121                         torch.tensor((10, 1), device=device, dtype=dtype), atol=0, rtol=0)
3122
3123        # Check logspace_ for non-contiguous tensors.
3124        x = torch.zeros(2, 3, device=device, dtype=dtype)
3125        y = torch.logspace(0, 3, 4, base=2, device=device, dtype=dtype, out=x.narrow(1, 1, 2))
3126        self.assertEqual(x, torch.tensor(((0, 1, 2), (0, 4, 8)), device=device, dtype=dtype), atol=0, rtol=0)
3127
3128    @onlyNativeDeviceTypes
3129    @dtypes(torch.half, torch.float, torch.double)
3130    def test_full_inference(self, device, dtype):
3131        size = (2, 2)
3132
3133        with set_default_dtype(dtype):
3134            # Tests bool fill value inference
3135            t = torch.full(size, True)
3136            self.assertEqual(t.dtype, torch.bool)
3137
3138            # Tests integer fill value inference
3139            t = torch.full(size, 1)
3140            self.assertEqual(t.dtype, torch.long)
3141
3142            # Tests float fill value inference
3143            t = torch.full(size, 1.)
3144            self.assertEqual(t.dtype, dtype)
3145
3146            # Tests complex inference
3147            t = torch.full(size, (1 + 1j))
3148            ctype = torch.complex128 if dtype is torch.double else torch.complex64
3149            self.assertEqual(t.dtype, ctype)
3150
3151    def test_full_out(self, device):
3152        size = (5,)
3153        o = torch.empty(size, device=device, dtype=torch.long)
3154
3155        # verifies dtype/out conflict throws a RuntimeError
3156        with self.assertRaises(RuntimeError):
3157            torch.full(o.shape, 1., dtype=torch.float, out=o)
3158
3159        # verifies out dtype overrides inference
3160        self.assertEqual(torch.full(o.shape, 1., out=o).dtype, o.dtype)
3161        self.assertEqual(torch.full(size, 1, out=o).dtype, o.dtype)
3162
3163    # check that warning for numpy being not writable is suppressed
3164    # when a copy of it is being created.
3165    # see issue #47160
3166    def test_tensor_from_non_writable_numpy(self, device):
3167        with warnings.catch_warnings(record=True) as w:
3168            a = np.arange(5.)
3169            a.flags.writeable = False
3170            t = torch.tensor(a)
3171            self.assertEqual(len(w), 0)
3172
3173    @onlyCPU
3174    @parametrize('shared', [True, False])
3175    @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
3176    def test_from_file(self, device, shared):
3177        dtype = torch.float64
3178        t = torch.randn(2, 5, dtype=dtype, device=device)
3179        with tempfile.NamedTemporaryFile() as f:
3180            expected_filename = f.name if shared else None
3181            t.numpy().tofile(f)
3182            t_mapped = torch.from_file(f.name, shared=shared, size=t.numel(), dtype=dtype)
3183            self.assertTrue(t_mapped.untyped_storage().filename == expected_filename)
3184            self.assertEqual(torch.flatten(t), t_mapped)
3185
3186            s = torch.UntypedStorage.from_file(f.name, shared, t.numel() * dtype.itemsize)
3187            self.assertTrue(s.filename == expected_filename)
3188
3189    @onlyCPU
3190    def test_storage_filename(self, device):
3191        t = torch.randn(2, 5, device=device)
3192        self.assertIsNone(t.untyped_storage().filename)
3193
3194
3195# Class for testing random tensor creation ops, like torch.randint
3196class TestRandomTensorCreation(TestCase):
3197    exact_dtype = True
3198
3199    # TODO: add torch.complex64, torch.complex128
3200    @dtypes(torch.float, torch.double)
3201    def test_normal(self, device, dtype):
3202
3203        def helper(self, device, dtype, ptype, t_transform, std_transform):
3204            q = torch.empty(100, 100, dtype=dtype, device=device)
3205
3206            q.normal_()
3207            self.assertEqual(t_transform(q).mean(), 0, atol=0.2, rtol=0)
3208            self.assertEqual(t_transform(q).std(), std_transform(1), atol=0.2, rtol=0)
3209
3210            q.normal_(2, 3)
3211            self.assertEqual(t_transform(q).mean(), 2, atol=0.3, rtol=0)
3212            self.assertEqual(t_transform(q).std(), std_transform(3), atol=0.3, rtol=0)
3213
3214            q = torch.empty(100, 100, dtype=dtype, device=device)
3215            q_row1 = q[0:1].clone()
3216            q[99:100].normal_()
3217            self.assertEqual(t_transform(q[99:100]).mean(), 0, atol=0.2, rtol=0)
3218            self.assertEqual(t_transform(q[99:100]).std(), std_transform(1), atol=0.2, rtol=0)
3219            self.assertEqual(t_transform(q[0:1]).clone(), t_transform(q_row1))
3220
3221            mean = torch.empty(100, 100, dtype=dtype, device=device)
3222            mean[:50].fill_(ptype(0))
3223            mean[50:].fill_(ptype(1))
3224
3225            std = torch.empty(100, 100, dtype=torch.float, device=device)
3226            std[:, :50] = 4
3227            std[:, 50:] = 1
3228
3229            r = torch.normal(mean)
3230            self.assertEqual(r.dtype, dtype)
3231            self.assertEqual(str(r.device), device)
3232            self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0)
3233            self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0)
3234            self.assertEqual(t_transform(r).std(), std_transform(1), atol=0.2, rtol=0)
3235
3236            r.fill_(42)
3237            r = torch.normal(mean, 3)
3238            self.assertEqual(r.dtype, dtype)
3239            self.assertEqual(str(r.device), device)
3240            self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0)
3241            self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0)
3242            self.assertEqual(t_transform(r).std(), std_transform(3), atol=0.2, rtol=0)
3243
3244            r.fill_(42)
3245            torch.normal(mean, 3, out=r)
3246            self.assertEqual(r.dtype, dtype)
3247            self.assertEqual(str(r.device), device)
3248            self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0)
3249            self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0)
3250            self.assertEqual(t_transform(r).std(), std_transform(3), atol=0.2, rtol=0)
3251
3252            r.fill_(42)
3253            r = torch.normal(2, std)
3254            self.assertFalse(r.dtype.is_complex)
3255            self.assertEqual(str(r.device), device)
3256            self.assertEqual(r.mean(), 2, atol=0.2, rtol=0)
3257            self.assertEqual(r[:, :50].std(), 4, atol=0.3, rtol=0)
3258            self.assertEqual(r[:, 50:].std(), 1, atol=0.2, rtol=0)
3259
3260            r.fill_(42)
3261            torch.normal(2, std, out=r)
3262            self.assertFalse(r.dtype.is_complex)
3263            self.assertEqual(str(r.device), device)
3264            self.assertEqual(r.mean(), 2, atol=0.2, rtol=0)
3265            self.assertEqual(r[:, :50].std(), 4, atol=0.3, rtol=0)
3266            self.assertEqual(r[:, 50:].std(), 1, atol=0.2, rtol=0)
3267
3268            r.fill_(42)
3269            r = torch.normal(mean, std)
3270            self.assertEqual(r.dtype, dtype)
3271            self.assertEqual(str(r.device), device)
3272            self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0)
3273            self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0)
3274            self.assertEqual(t_transform(r[:, :50]).std(), std_transform(4), atol=0.3, rtol=0)
3275            self.assertEqual(t_transform(r[:, 50:]).std(), std_transform(1), atol=0.2, rtol=0)
3276
3277            r.fill_(42)
3278            torch.normal(mean, std, out=r)
3279            self.assertEqual(r.dtype, dtype)
3280            self.assertEqual(str(r.device), device)
3281            self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0)
3282            self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0)
3283            self.assertEqual(t_transform(r[:, :50]).std(), std_transform(4), atol=0.3, rtol=0)
3284            self.assertEqual(t_transform(r[:, 50:]).std(), std_transform(1), atol=0.2, rtol=0)
3285
3286            # test empty mean/std
3287            out = torch.normal(mean=torch.empty((0, 2)), std=torch.empty((0, 1)))
3288            self.assertEqual(out.size(), torch.Size([0, 2]))
3289
3290            r.fill_(42)
3291            r = torch.normal(2, 3, (100, 100), dtype=dtype, device=device)
3292            self.assertEqual(r.dtype, dtype)
3293            self.assertEqual(str(r.device), device)
3294            self.assertEqual(t_transform(r).mean(), 2, atol=0.3, rtol=0)
3295            self.assertEqual(t_transform(r).std(), std_transform(3), atol=0.3, rtol=0)
3296
3297            r.fill_(42)
3298            torch.normal(2, 3, (100, 100), dtype=dtype, device=device, out=r)
3299            self.assertEqual(r.dtype, dtype)
3300            self.assertEqual(str(r.device), device)
3301            self.assertEqual(t_transform(r).mean(), 2, atol=0.3, rtol=0)
3302            self.assertEqual(t_transform(r).std(), std_transform(3), atol=0.3, rtol=0)
3303
3304            # float std 0 with float mean
3305            r.fill_(42)
3306            torch.normal(2, 0, (10, 10), dtype=dtype, device=device, out=r)
3307            self.assertEqual(r.dtype, dtype)
3308            self.assertEqual(str(r.device), device)
3309            self.assertTrue(r.eq(2).all())
3310
3311            # float std 0 with tensor mean
3312            r.fill_(42)
3313            mean_rand = torch.randn(10, 10, dtype=dtype, device=device)
3314            torch.normal(mean_rand, 0, out=r)
3315            self.assertEqual(r.dtype, dtype)
3316            self.assertEqual(str(r.device), device)
3317            self.assertEqual(mean_rand, r, atol=0, rtol=0)
3318
3319            # tensor std 0 with float mean
3320            r.fill_(42)
3321            std_zeros = torch.zeros(10, 10, dtype=dtype, device=device)
3322            torch.normal(2, std_zeros, out=r)
3323            self.assertEqual(r.dtype, dtype)
3324            self.assertEqual(str(r.device), device)
3325            self.assertTrue(r.eq(2).all())
3326
3327            # tensor std 0 with tensor mean
3328            r.fill_(42)
3329            torch.normal(mean_rand, std_zeros, out=r)
3330            self.assertEqual(r.dtype, dtype)
3331            self.assertEqual(str(r.device), device)
3332            self.assertEqual(mean_rand, r, atol=0, rtol=0)
3333
3334        if dtype.is_complex:
3335            helper(self, device, dtype, lambda x: complex(x, x),
3336                   lambda t: torch.real(t).to(torch.float), lambda mean: mean / math.sqrt(2))
3337            helper(self, device, dtype, lambda x: complex(x, x),
3338                   lambda t: torch.imag(t).to(torch.float), lambda mean: mean / math.sqrt(2))
3339            self.assertRaisesRegex(
3340                RuntimeError, "normal expects standard deviation to be non-complex",
3341                lambda: torch.normal(0, torch.empty(100, 100, dtype=dtype, device=device)))
3342            out = torch.empty(100, 100, dtype=dtype, device=device)
3343            self.assertRaisesRegex(
3344                RuntimeError, "normal expects standard deviation to be non-complex",
3345                lambda: torch.normal(0, torch.empty(100, 100, dtype=dtype, device=device), out=out))
3346        else:
3347            helper(self, device, dtype, lambda x: x, lambda t: t, lambda mean: mean)
3348
3349    # Ensure that normal raises appropriate error when `std` < 0
3350    def test_normal_std_error(self, device):
3351        a = torch.tensor(0, dtype=torch.float32, device=device)
3352        std = torch.tensor(-1, dtype=torch.float32, device=device)
3353
3354        for input in [0, a]:
3355            with self.assertRaisesRegex(RuntimeError, r'normal expects std >= 0.0, but found std'):
3356                torch.normal(input, -1, (10,))
3357
3358            with self.assertRaisesRegex(RuntimeError, r'normal expects all elements of std >= 0.0'):
3359                torch.normal(input, std)
3360
3361    # https://github.com/pytorch/pytorch/issues/126834
3362    @xfailIfTorchDynamo
3363    @dtypes(torch.float, torch.double, torch.half)
3364    @dtypesIfCUDA(torch.float, torch.double, torch.half, torch.bfloat16)
3365    def test_uniform_from_to(self, device, dtype):
3366        size = 2000
3367        alpha = 0.1
3368
3369        float_min = torch.finfo(torch.float).min
3370        float_max = torch.finfo(torch.float).max
3371        double_min = torch.finfo(torch.double).min
3372        double_max = torch.finfo(torch.double).max
3373
3374        if dtype == torch.bfloat16:
3375            min_val = -3.389531389251535e+38
3376            max_val = 3.389531389251535e+38
3377        else:
3378            min_val = torch.finfo(dtype).min
3379            max_val = torch.finfo(dtype).max
3380
3381        values = [double_min, float_min, -42, 0, 42, float_max, double_max]
3382
3383        for from_ in values:
3384            for to_ in values:
3385                t = torch.empty(size, dtype=dtype, device=device)
3386                if not (min_val <= from_ <= max_val) or not (min_val <= to_ <= max_val):
3387                    pass
3388                elif to_ < from_:
3389                    self.assertRaisesRegex(
3390                        RuntimeError,
3391                        "uniform_ expects to return",
3392                        lambda: t.uniform_(from_, to_)
3393                    )
3394                elif to_ - from_ > max_val:
3395                    self.assertRaisesRegex(
3396                        RuntimeError,
3397                        "uniform_ expects to-from",
3398                        lambda: t.uniform_(from_, to_)
3399                    )
3400                else:
3401                    t.uniform_(from_, to_)
3402                    range_ = to_ - from_
3403                    if not (dtype == torch.bfloat16) and not (
3404                            dtype == torch.half and device == 'cpu') and not torch.isnan(t).all():
3405                        delta = alpha * range_
3406                        double_t = t.to(torch.double)
3407                        if range_ == 0:
3408                            self.assertTrue(double_t.min() == from_)
3409                            self.assertTrue(double_t.max() == to_)
3410                        elif dtype == torch.half:
3411                            self.assertTrue(from_ <= double_t.min() <= (from_ + delta))
3412                            self.assertTrue((to_ - delta) <= double_t.max() <= to_)
3413                        else:
3414                            self.assertTrue(from_ <= double_t.min() <= (from_ + delta))
3415                            self.assertTrue((to_ - delta) <= double_t.max() < to_)
3416
3417    def test_random_neg_values(self, device):
3418        SIZE = 10
3419        signed_dtypes = [torch.double, torch.float, torch.long, torch.int, torch.short]
3420        for dtype in signed_dtypes:
3421            res = torch.rand(SIZE, SIZE).to(device=device, dtype=dtype)
3422            res.random_(-10, -1)
3423            self.assertLessEqual(res.max().item(), 9)
3424            self.assertGreaterEqual(res.min().item(), -10)
3425
3426    # TODO: this test should be updated
3427    @onlyCPU
3428    def test_randint_inference(self, device):
3429        size = (2, 1)
3430        for args in [(3,), (1, 3)]:  # (low,) and (low, high)
3431            self.assertIs(torch.int64, torch.randint(*args, size=size).dtype)
3432            self.assertIs(torch.int64, torch.randint(*args, size=size, layout=torch.strided).dtype)
3433            self.assertIs(torch.int64, torch.randint(*args, size=size, generator=torch.default_generator).dtype)
3434            self.assertIs(torch.float32, torch.randint(*args, size=size, dtype=torch.float32).dtype)
3435            out = torch.empty(size, dtype=torch.float32)
3436            self.assertIs(torch.float32, torch.randint(*args, size=size, out=out).dtype)
3437            self.assertIs(torch.float32, torch.randint(*args, size=size, out=out, dtype=torch.float32).dtype)
3438            out = torch.empty(size, dtype=torch.int64)
3439            self.assertIs(torch.int64, torch.randint(*args, size=size, out=out).dtype)
3440            self.assertIs(torch.int64, torch.randint(*args, size=size, out=out, dtype=torch.int64).dtype)
3441
3442    # TODO: this test should be updated
3443    @onlyCPU
3444    def test_randint(self, device):
3445        SIZE = 100
3446
3447        def seed(generator):
3448            if generator is None:
3449                torch.manual_seed(123456)
3450            else:
3451                generator.manual_seed(123456)
3452            return generator
3453
3454        for generator in (None, torch.Generator()):
3455            generator = seed(generator)
3456            res1 = torch.randint(0, 6, (SIZE, SIZE), generator=generator)
3457            res2 = torch.empty((), dtype=torch.int64)
3458            generator = seed(generator)
3459            torch.randint(0, 6, (SIZE, SIZE), generator=generator, out=res2)
3460            generator = seed(generator)
3461            res3 = torch.randint(6, (SIZE, SIZE), generator=generator)
3462            res4 = torch.empty((), dtype=torch.int64)
3463            generator = seed(generator)
3464            torch.randint(6, (SIZE, SIZE), out=res4, generator=generator)
3465            self.assertEqual(res1, res2)
3466            self.assertEqual(res1, res3)
3467            self.assertEqual(res1, res4)
3468            self.assertEqual(res2, res3)
3469            self.assertEqual(res2, res4)
3470            self.assertEqual(res3, res4)
3471            self.assertTrue((res1 < 6).all().item())
3472            self.assertTrue((res1 >= 0).all().item())
3473
3474    @dtypes(torch.half, torch.float, torch.bfloat16, torch.double,
3475            torch.complex32, torch.complex64, torch.complex128)
3476    def test_randn(self, device, dtype):
3477        SIZE = 100
3478        for size in [0, SIZE]:
3479            torch.manual_seed(123456)
3480            res1 = torch.randn(size, size, dtype=dtype, device=device)
3481            res2 = torch.tensor([], dtype=dtype, device=device)
3482            torch.manual_seed(123456)
3483            torch.randn(size, size, out=res2)
3484            self.assertEqual(res1, res2)
3485
3486    @dtypes(torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128)
3487    def test_rand(self, device, dtype):
3488        SIZE = 100
3489        for size in [0, SIZE]:
3490            torch.manual_seed(123456)
3491            res1 = torch.rand(size, size, dtype=dtype, device=device)
3492            res2 = torch.tensor([], dtype=dtype, device=device)
3493            torch.manual_seed(123456)
3494            torch.rand(size, size, out=res2)
3495            self.assertEqual(res1, res2)
3496
3497    def test_randperm(self, device):
3498        if device == 'cpu' or device == 'meta':
3499            rng_device = None
3500        else:
3501            # TODO: This won't actually work for non-CUDA device
3502            # see https://github.com/pytorch/pytorch/issues/54282
3503            rng_device = [device]
3504
3505        # Test core functionality. On CUDA, different value of n has different
3506        # code path
3507        for n in (5, 100, 50000, 100000):
3508            # Ensure both integer and floating-point numbers are tested. Half follows an execution path that is
3509            # different from others on CUDA.
3510            for dtype in (torch.long, torch.half, torch.float, torch.bfloat16):
3511                if n > 2049 and dtype == torch.half:  # Large n for torch.half will raise an exception, do not test here.
3512                    continue
3513                if dtype == torch.bfloat16 and device != 'cpu':
3514                    continue
3515                if n > 256 and dtype == torch.bfloat16:
3516                    continue
3517                with torch.random.fork_rng(devices=rng_device):
3518                    res1 = torch.randperm(n, dtype=dtype, device=device)
3519                res2 = torch.empty(0, dtype=dtype, device=device)
3520                torch.randperm(n, out=res2, dtype=dtype, device=device)
3521                self.assertEqual(res1, res2, atol=0, rtol=0)
3522                self.assertEqual(res1.sort().values.long(), torch.arange(n, device=device))
3523
3524        # Default type is long
3525        for n in (100, 10000):
3526            self.assertEqual(torch.randperm(n, device=device).dtype, torch.long)
3527
3528        # randperm of 0 elements is an empty tensor
3529        res1 = torch.randperm(0)
3530        res2 = torch.tensor(5, dtype=dtype, device=device)
3531        torch.randperm(0, out=res2)
3532        self.assertEqual(res1.numel(), 0)
3533        self.assertEqual(res2.numel(), 0)
3534
3535        # Test exceptions when n is too large for a floating point type
3536        for dtype, small_n, large_n in ((torch.uint8, 2**8, 2**8 + 1),
3537                                        (torch.half, 2**11 + 1, 2**11 + 2),
3538                                        (torch.float, 2**24 + 1, 2**24 + 2),
3539                                        (torch.double, 2**25,  # 2**53 + 1 is too large to run
3540                                         2**53 + 2)):
3541            res = torch.empty(0, dtype=dtype, device=device)
3542            torch.randperm(small_n, out=res)  # No exception expected
3543            self.assertRaises(RuntimeError, lambda: torch.randperm(large_n, out=res, device=device))
3544
3545        # Test non-contiguous tensors
3546        for n in (4, 5, 6, 10, 20):
3547            non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t()
3548            self.assertFalse(non_contiguous_tensor.is_contiguous())
3549            with torch.random.fork_rng(devices=rng_device):
3550                res = torch.randperm(n, dtype=torch.long, device=device)
3551            torch.randperm(n, out=non_contiguous_tensor)
3552            self.assertEqual(non_contiguous_tensor, res)
3553            self.assertEqual(res.sort().values.long(), torch.arange(n, device=device))
3554
3555    # Test exceptions when device and generator types are incompatible
3556    @onlyCUDA
3557    def test_randperm_device_compatibility(self, device):
3558        cuda_gen = torch.Generator(device='cuda')
3559        cpu_gen = torch.Generator(device='cpu')
3560
3561        # n=0 is a special case that we don't need to use generator, thus no error even if
3562        # device and generator don't match
3563        torch.randperm(0, device='cuda:0', generator=torch.Generator(device='cuda:1'))
3564        if torch.cuda.device_count() > 1:
3565            torch.randperm(0, device='cuda:1', generator=torch.Generator(device='cuda:0'))
3566        torch.randperm(0, device='cuda', generator=torch.Generator(device='cpu'))
3567        torch.randperm(0, device='cpu', generator=torch.Generator(device='cuda'))
3568
3569        for n in (1, 3, 100, 30000):
3570            torch.randperm(n, device='cuda', generator=torch.Generator(device='cuda:0'))
3571            torch.randperm(n, device='cuda:0', generator=torch.Generator(device='cuda'))
3572            # For cuda:0 to match cuda:1, we are making consistent device type matching
3573            # behavior just like torch.randint. Longer term, generator should ignore
3574            # device ordinal, since it's not used anyway.
3575            torch.randint(low=0, high=n + 1, size=(1,), device="cuda:0", generator=torch.Generator(device='cuda:1'))
3576            torch.randperm(n, device='cuda:0', generator=torch.Generator(device='cuda:1'))
3577            if torch.cuda.device_count() > 1:
3578                torch.randint(low=0, high=n + 1, size=(1,), device="cuda:1", generator=torch.Generator(device='cuda:0'))
3579                torch.randperm(n, device='cuda:1', generator=torch.Generator(device='cuda:0'))
3580
3581            regex = 'Expected a .* device type for generator but found .*'
3582            cuda_t = torch.tensor(n, device='cuda')
3583            self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, device='cuda', generator=cpu_gen))
3584            self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, device='cuda', generator=cpu_gen, out=cuda_t))
3585            cpu_t = torch.tensor(n, device='cpu')
3586            self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, device='cpu', generator=cuda_gen))
3587            self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, device='cpu', generator=cuda_gen, out=cpu_t))
3588            self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, generator=cuda_gen))  # implicitly on CPU
3589
3590# Class for testing *like ops, like torch.ones_like
3591class TestLikeTensorCreation(TestCase):
3592    exact_dtype = True
3593
3594    # TODO: this test should be updated
3595    def test_ones_like(self, device):
3596        expected = torch.ones(100, 100, device=device)
3597
3598        res1 = torch.ones_like(expected)
3599        self.assertEqual(res1, expected)
3600
3601        # test boolean tensor
3602        expected = torch.tensor([True, True], device=device, dtype=torch.bool)
3603        res1 = torch.ones_like(expected)
3604        self.assertEqual(res1, expected)
3605
3606    # TODO: this test should be updated
3607    @onlyCPU
3608    def test_empty_like(self, device):
3609        x = torch.autograd.Variable(torch.tensor([]))
3610        y = torch.autograd.Variable(torch.randn(4, 4))
3611        z = torch.autograd.Variable(torch.IntTensor([1, 2, 3]))
3612        for a in (x, y, z):
3613            self.assertEqual(torch.empty_like(a).shape, a.shape)
3614            self.assertEqualTypeString(torch.empty_like(a), a)
3615
3616    def test_zeros_like(self, device):
3617        expected = torch.zeros((100, 100,), device=device)
3618
3619        res1 = torch.zeros_like(expected)
3620        self.assertEqual(res1, expected)
3621
3622    @deviceCountAtLeast(2)
3623    def test_zeros_like_multiple_device(self, devices):
3624        expected = torch.zeros(100, 100, device=devices[0])
3625        x = torch.randn(100, 100, device=devices[1], dtype=torch.float32)
3626        output = torch.zeros_like(x)
3627        self.assertEqual(output, expected)
3628
3629    @deviceCountAtLeast(2)
3630    def test_ones_like_multiple_device(self, devices):
3631        expected = torch.ones(100, 100, device=devices[0])
3632        x = torch.randn(100, 100, device=devices[1], dtype=torch.float32)
3633        output = torch.ones_like(x)
3634        self.assertEqual(output, expected)
3635
3636    # Full-like precedence is the explicit dtype then the dtype of the "like"
3637    # tensor.
3638    @onlyNativeDeviceTypes
3639    def test_full_like_inference(self, device):
3640        size = (2, 2)
3641        like = torch.empty((5,), device=device, dtype=torch.long)
3642
3643        self.assertEqual(torch.full_like(like, 1.).dtype, torch.long)
3644        self.assertEqual(torch.full_like(like, 1., dtype=torch.complex64).dtype,
3645                         torch.complex64)
3646
3647# Tests for the `frombuffer` function (only work on CPU):
3648#   Constructs tensors from Python objects that implement the buffer protocol,
3649#   without copying data.
3650SIZE = 5
3651SHAPE = (SIZE,)
3652
3653def may_require_grad(dtype):
3654    return dtype.is_floating_point or dtype.is_complex
3655
3656def get_dtype_size(dtype):
3657    return int(torch.empty((), dtype=dtype).element_size())
3658
3659class TestBufferProtocol(TestCase):
3660    def _run_test(self, shape, dtype, count=-1, first=0, offset=None, **kwargs):
3661        numpy_dtype = torch_to_numpy_dtype_dict[dtype]
3662
3663        if offset is None:
3664            offset = first * get_dtype_size(dtype)
3665
3666        numpy_original = make_tensor(shape, dtype=dtype, device="cpu").numpy()
3667        original = memoryview(numpy_original)
3668        # First call PyTorch's version in case of errors.
3669        # If this call exits successfully, the NumPy version must also do so.
3670        torch_frombuffer = torch.frombuffer(original, dtype=dtype, count=count, offset=offset, **kwargs)
3671        numpy_frombuffer = np.frombuffer(original, dtype=numpy_dtype, count=count, offset=offset)
3672
3673        self.assertEqual(numpy_frombuffer, torch_frombuffer)
3674        self.assertEqual(numpy_frombuffer.__array_interface__["data"][0], torch_frombuffer.data_ptr())
3675        return (numpy_original, torch_frombuffer)
3676
3677    @dtypes(*set(numpy_to_torch_dtype_dict.values()))
3678    def test_same_type(self, device, dtype):
3679        self._run_test((), dtype)
3680        self._run_test((4,), dtype)
3681        self._run_test((10, 10), dtype)
3682
3683    @dtypes(*set(numpy_to_torch_dtype_dict.values()))
3684    def test_requires_grad(self, device, dtype):
3685        def _run_test_and_check_grad(requires_grad, *args, **kwargs):
3686            kwargs["requires_grad"] = requires_grad
3687            _, tensor = self._run_test(*args, **kwargs)
3688            self.assertTrue(tensor.requires_grad == requires_grad)
3689
3690        requires_grad = may_require_grad(dtype)
3691        _run_test_and_check_grad(requires_grad, (), dtype)
3692        _run_test_and_check_grad(requires_grad, (4,), dtype)
3693        _run_test_and_check_grad(requires_grad, (10, 10), dtype)
3694        _run_test_and_check_grad(False, (), dtype)
3695        _run_test_and_check_grad(False, (4,), dtype)
3696        _run_test_and_check_grad(False, (10, 10), dtype)
3697
3698    @dtypes(*set(numpy_to_torch_dtype_dict.values()))
3699    def test_with_offset(self, device, dtype):
3700        # Offset should be valid whenever there is, at least,
3701        # one remaining element
3702        for i in range(SIZE):
3703            self._run_test(SHAPE, dtype, first=i)
3704
3705    @dtypes(*set(numpy_to_torch_dtype_dict.values()))
3706    def test_with_count(self, device, dtype):
3707        # Count should be valid for any valid in the interval
3708        # [-1, len(input)], except for 0
3709        for i in range(-1, SIZE + 1):
3710            if i != 0:
3711                self._run_test(SHAPE, dtype, count=i)
3712
3713    @dtypes(*set(numpy_to_torch_dtype_dict.values()))
3714    def test_with_count_and_offset(self, device, dtype):
3715        # Explicit default count [-1, 1, 2, ..., len]
3716        for i in range(-1, SIZE + 1):
3717            if i != 0:
3718                self._run_test(SHAPE, dtype, count=i)
3719        # Explicit default offset [0, 1, ..., len - 1]
3720        for i in range(SIZE):
3721            self._run_test(SHAPE, dtype, first=i)
3722        # All possible combinations of count and dtype aligned
3723        # offset for 'input'
3724        # count:[1, 2, ..., len - 1] x first:[0, 1, ..., len - count]
3725        for i in range(1, SIZE):
3726            for j in range(SIZE - i + 1):
3727                self._run_test(SHAPE, dtype, count=i, first=j)
3728
3729    @dtypes(*set(numpy_to_torch_dtype_dict.values()))
3730    def test_invalid_positional_args(self, device, dtype):
3731        bytes = get_dtype_size(dtype)
3732        in_bytes = SIZE * bytes
3733        # Empty array
3734        with self.assertRaisesRegex(ValueError,
3735                                    r"both buffer length \(0\) and count"):
3736            empty = np.array([])
3737            torch.frombuffer(empty, dtype=dtype)
3738        # Count equals 0
3739        with self.assertRaisesRegex(ValueError,
3740                                    r"both buffer length .* and count \(0\)"):
3741            self._run_test(SHAPE, dtype, count=0)
3742        # Offset negative and bigger than total length
3743        with self.assertRaisesRegex(ValueError,
3744                                    rf"offset \(-{bytes} bytes\) must be"):
3745            self._run_test(SHAPE, dtype, first=-1)
3746        with self.assertRaisesRegex(ValueError,
3747                                    rf"offset \({in_bytes} bytes\) must be .* "
3748                                    rf"buffer length \({in_bytes} bytes\)"):
3749            self._run_test(SHAPE, dtype, first=SIZE)
3750        # Non-multiple offset with all elements
3751        if bytes > 1:
3752            offset = bytes - 1
3753            with self.assertRaisesRegex(ValueError,
3754                                        rf"buffer length \({in_bytes - offset} bytes\) after "
3755                                        rf"offset \({offset} bytes\) must be"):
3756                self._run_test(SHAPE, dtype, offset=bytes - 1)
3757        # Count too big for each good first element
3758        for first in range(SIZE):
3759            count = SIZE - first + 1
3760            with self.assertRaisesRegex(ValueError,
3761                                        rf"requested buffer length \({count} \* {bytes} bytes\) "
3762                                        rf"after offset \({first * bytes} bytes\) must .*"
3763                                        rf"buffer length \({in_bytes} bytes\)"):
3764                self._run_test(SHAPE, dtype, count=count, first=first)
3765
3766    @dtypes(*set(numpy_to_torch_dtype_dict.values()))
3767    def test_shared_buffer(self, device, dtype):
3768        x = make_tensor((1,), dtype=dtype, device=device)
3769        # Modify the whole tensor
3770        arr, tensor = self._run_test(SHAPE, dtype)
3771        tensor[:] = x
3772        self.assertEqual(arr, tensor)
3773        self.assertTrue((tensor == x).all().item())
3774
3775        # Modify the whole tensor from all valid offsets, given
3776        # a count value
3777        for count in range(-1, SIZE + 1):
3778            if count == 0:
3779                continue
3780
3781            actual_count = count if count > 0 else SIZE
3782            for first in range(SIZE - actual_count):
3783                last = first + actual_count
3784                arr, tensor = self._run_test(SHAPE, dtype, first=first, count=count)
3785                tensor[:] = x
3786                self.assertEqual(arr[first:last], tensor)
3787                self.assertTrue((tensor == x).all().item())
3788
3789                # Modify the first value in the array
3790                arr[first] = x.item() - 1
3791                self.assertEqual(arr[first:last], tensor)
3792
3793    @dtypes(*set(numpy_to_torch_dtype_dict.values()))
3794    def test_not_a_buffer(self, device, dtype):
3795        with self.assertRaisesRegex(ValueError,
3796                                    r"object does not implement Python buffer protocol."):
3797            torch.frombuffer([1, 2, 3, 4], dtype=dtype)
3798
3799    @dtypes(*set(numpy_to_torch_dtype_dict.values()))
3800    def test_non_writable_buffer(self, device, dtype):
3801        numpy_arr = make_tensor((1,), dtype=dtype, device=device).numpy()
3802        byte_arr = numpy_arr.tobytes()
3803        with self.assertWarnsOnceRegex(UserWarning,
3804                                       r"The given buffer is not writable."):
3805            torch.frombuffer(byte_arr, dtype=dtype)
3806
3807    def test_byte_to_int(self):
3808        byte_array = np.array([-1, 0, 0, 0, -1, 0, 0, 0], dtype=np.byte) if sys.byteorder == 'little' \
3809            else np.array([0, 0, 0, -1, 0, 0, 0, -1], dtype=np.byte)
3810        tensor = torch.frombuffer(byte_array, dtype=torch.int32)
3811        self.assertEqual(tensor.numel(), 2)
3812        self.assertSequenceEqual(tensor, [255, 255])
3813
3814# Tests for the `asarray` function:
3815#   Constructs tensors from a Python object that has one of the following
3816#   characteristics:
3817#       1. is a Tensor
3818#       2. is a DLPack capsule
3819#       3. implements the Python Buffer protocol
3820#       4. is an arbitrary list
3821#   The implementation itself is based on the Python Array API:
3822#   https://data-apis.org/array-api/latest/API_specification/creation_functions.html
3823def get_another_device(device):
3824    return "cuda" if torch.device(device).type == "cpu" else "cpu"
3825
3826def identity(tensor):
3827    return tensor
3828def to_numpy(tensor):
3829    return tensor.numpy()
3830def to_memview(tensor):
3831    return memoryview(to_numpy(tensor))
3832
3833class TestAsArray(TestCase):
3834    def _check(self, original, cvt=lambda t: t, is_alias=True, same_dtype=True, same_device=True, **kwargs):
3835        """Check the output of 'asarray', given its input and assertion information.
3836
3837        Besides calling 'asarray' itself, this function does 4 different checks:
3838            1. Whether the result is aliased or not, depending on 'is_alias'
3839            2. Whether the result has the expected dtype and elements
3840            3. Whether the result lives in the expected device
3841            4. Whether the result has its 'requires_grad' set or not
3842        """
3843        result = torch.asarray(cvt(original), **kwargs)
3844        self.assertTrue(isinstance(result, torch.Tensor))
3845
3846        # 1. The storage pointers should be equal only if 'is_alias' is set
3847        if is_alias:
3848            self.assertEqual(result.data_ptr(), original.data_ptr())
3849        else:
3850            self.assertNotEqual(result.data_ptr(), original.data_ptr())
3851
3852        # 2. Comparison of the elements only takes place if the original
3853        # sequence and the resulting tensor have the same data type
3854        if same_dtype:
3855            self.assertEqual(original, result)
3856        else:
3857            dtype = kwargs.get("dtype", torch.get_default_dtype())
3858            self.assertEqual(original.shape, result.shape)
3859            self.assertEqual(dtype, result.dtype)
3860
3861        # 3. Given the specified target device, we first check whether
3862        # its type is the same, and then if its index is the same (if it
3863        # is not None)
3864        if same_device:
3865            device = original.device
3866        else:
3867            device = torch.device(kwargs.get("device", "cpu"))
3868
3869        # Compare the target device type, and its index
3870        self.assertEqual(device.type, result.device.type)
3871        if device.index is not None:
3872            self.assertEqual(device.index, result.device.index)
3873
3874        # 4. By default, 'requires_grad' is unset
3875        self.assertEqual(result.requires_grad, kwargs.get("requires_grad", False))
3876
3877    def _test_alias_with_cvt(self, cvt, device, dtype, shape=(5, 5), only_with_dtype=False):
3878        original = make_tensor(shape, dtype=dtype, device=device)
3879
3880        def check(**kwargs):
3881            self._check(original, cvt=cvt, **kwargs)
3882
3883        if not only_with_dtype:
3884            check(copy=False)
3885            check(device=device)
3886            check(device=device, copy=False)
3887
3888        check(dtype=dtype)
3889        check(dtype=dtype, copy=False)
3890        check(requires_grad=False, dtype=dtype)
3891        check(requires_grad=may_require_grad(dtype), dtype=dtype)
3892        check(device=device, dtype=dtype)
3893        check(device=device, dtype=dtype, copy=False)
3894
3895    # Skipping 'meta' devices, since there's no point in comparing their
3896    # data pointer (which is basically the point here), since they all
3897    # return 0.
3898    @skipMeta
3899    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3900    def test_alias_from_tensor(self, device, dtype):
3901        self._test_alias_with_cvt(identity, device, dtype)
3902
3903    @onlyCPU
3904    @dtypes(*set(numpy_to_torch_dtype_dict.values()))
3905    def test_alias_from_numpy(self, device, dtype):
3906        self._test_alias_with_cvt(to_numpy, device, dtype)
3907
3908    # Skipping 'meta', since 'to_dlpack' does not work for them.
3909    @skipMeta
3910    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
3911    def test_alias_from_dlpack(self, device, dtype):
3912        self._test_alias_with_cvt(to_dlpack, device, dtype)
3913
3914    @onlyCPU
3915    @dtypes(*set(numpy_to_torch_dtype_dict.values()))
3916    def test_alias_from_buffer(self, device, dtype):
3917        self._test_alias_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True)
3918
3919    def _test_copy_with_cvt(self, cvt, device, dtype, shape=(5, 5), only_with_dtype=False):
3920        original = make_tensor(shape, dtype=dtype, device=device)
3921
3922        def check(**kwargs):
3923            self._check(original, cvt=cvt, is_alias=False, **kwargs)
3924
3925        if not only_with_dtype:
3926            check(copy=True)
3927            check(device=device, copy=True)
3928
3929        check(requires_grad=False, dtype=dtype, copy=True)
3930        check(requires_grad=may_require_grad(dtype), dtype=dtype, copy=True)
3931        check(dtype=dtype, copy=True)
3932        check(device=device, dtype=dtype, copy=True)
3933
3934        # Copy is forced because of different device
3935        if torch.cuda.is_available():
3936            other = get_another_device(device)
3937            check(same_device=False, device=other, dtype=dtype)
3938            check(same_device=False, device=other, dtype=dtype, copy=True)
3939
3940        # Copy is forced because of different dtype
3941        if not only_with_dtype:
3942            for other in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
3943                if dtype != other:
3944                    check(same_dtype=False, dtype=other)
3945                    check(same_dtype=False, dtype=other, copy=True)
3946
3947    @skipMeta
3948    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3949    def test_copy_tensor(self, device, dtype):
3950        self._test_copy_with_cvt(identity, device, dtype)
3951
3952    @onlyCPU
3953    @dtypes(*set(numpy_to_torch_dtype_dict.values()))
3954    def test_copy_from_numpy(self, device, dtype):
3955        self._test_copy_with_cvt(to_numpy, device, dtype)
3956
3957    @skipMeta
3958    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
3959    def test_copy_from_dlpack(self, device, dtype):
3960        self._test_copy_with_cvt(to_dlpack, device, dtype)
3961
3962    @onlyCPU
3963    @dtypes(*set(numpy_to_torch_dtype_dict.values()))
3964    def test_copy_from_buffer(self, device, dtype):
3965        self._test_copy_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True)
3966
3967    def _test_copy_mult_devices(self, devices, dtype, cvt):
3968        cuda1 = devices[0]
3969        cuda2 = devices[1]
3970        original = make_tensor((5, 5), dtype=dtype, device=cuda1)
3971
3972        def check(**kwargs):
3973            self._check(original, cvt, is_alias=False, same_device=False, device=cuda2, **kwargs)
3974
3975        check()
3976        check(copy=True)
3977        check(dtype=dtype, copy=True)
3978
3979    @onlyCUDA
3980    @deviceCountAtLeast(2)
3981    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
3982    def test_copy_from_tensor_mult_devices(self, devices, dtype):
3983        self._test_copy_mult_devices(devices, dtype, identity)
3984
3985    @onlyCUDA
3986    @deviceCountAtLeast(2)
3987    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
3988    def test_copy_from_dlpack_mult_devices(self, devices, dtype):
3989        self._test_copy_mult_devices(devices, dtype, to_dlpack)
3990
3991    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3992    def test_copy_list(self, device, dtype):
3993        original = make_tensor((5, 5), dtype=dtype, device=torch.device("cpu"))
3994
3995        def check(**kwargs):
3996            self._check(original, torch.Tensor.tolist, is_alias=False, **kwargs)
3997
3998        same_device = torch.device("cpu") == device
3999        check(same_device=same_device, device=device, dtype=dtype)
4000        check(same_device=same_device, device=device, dtype=dtype, requires_grad=False)
4001        check(same_device=same_device, device=device, dtype=dtype, requires_grad=may_require_grad(dtype))
4002        check(same_device=same_device, device=device, dtype=dtype, copy=True)
4003
4004    @dtypes(torch.float32)
4005    def test_unsupported_alias(self, device, dtype):
4006        original = make_tensor((5, 5), dtype=dtype, device=device)
4007
4008        if torch.cuda.is_available():
4009            other_device = get_another_device(device)
4010            with self.assertRaisesRegex(ValueError,
4011                                        f"from device '{device}' to '{other_device}'"):
4012                torch.asarray(original, device=other_device, copy=False)
4013
4014        with self.assertRaisesRegex(ValueError,
4015                                    "with dtype '.*' into dtype '.*'"):
4016            torch.asarray(original, dtype=torch.float64, copy=False)
4017
4018        with self.assertRaisesRegex(ValueError,
4019                                    "can't alias arbitrary sequence"):
4020            torch.asarray(original.tolist(), copy=False)
4021
4022    @onlyCUDA
4023    @deviceCountAtLeast(2)
4024    @dtypes(torch.float32)
4025    def test_unsupported_alias_mult_devices(self, devices, dtype):
4026        dev1, dev2 = devices[:2]
4027        original = make_tensor((5, 5), dtype=dtype, device=dev1)
4028        with self.assertRaisesRegex(ValueError,
4029                                    f"from device '{dev1}' to '{dev2}'"):
4030            torch.asarray(original, device=dev2, copy=False)
4031
4032    @dtypes(torch.float32, torch.complex64)
4033    def test_retain_autograd_history(self, device, dtype):
4034        original = make_tensor((5, 5), dtype=dtype, device=device, requires_grad=True)
4035        # 'cloned' has 'grad_fn=<CloneBackwards>'
4036        cloned = original.clone()
4037
4038        def check(**kwargs):
4039            a = torch.asarray(cloned, **kwargs)
4040            requires_grad = kwargs.get("requires_grad", False)
4041            self.assertEqual(a.requires_grad, requires_grad)
4042            # Autograd history shouldn't be retained when requires_grad is False
4043            self.assertEqual(a.grad_fn is None, not requires_grad)
4044
4045        check()
4046        check(requires_grad=True)
4047        check(copy=True)
4048        check(requires_grad=True, copy=True)
4049        check(requires_grad=False)
4050        check(requires_grad=False, copy=True)
4051
4052    @onlyCPU
4053    def test_astensor_consistency(self, device):
4054        # See issue: https://github.com/pytorch/pytorch/pull/71757
4055
4056        examples = [
4057            # Scalars
4058            True,
4059            42,
4060            1.0,
4061            # Homogeneous Lists
4062            [True, True, False],
4063            [1, 2, 3, 42],
4064            [0.0, 1.0, 2.0, 3.0],
4065            # Mixed Lists
4066            [True, False, 0],
4067            [0.0, True, False],
4068            [0, 1.0, 42],
4069            [0.0, True, False, 42],
4070            # With Complex
4071            [0.0, True, False, 42, 5j],
4072            # With Range
4073            range(5),
4074        ]
4075
4076        for e in examples:
4077            original = torch.as_tensor(e)
4078            t = torch.asarray(e)
4079            self.assertEqual(t, original)
4080
4081    @skipIfTorchDynamo()
4082    @onlyCPU
4083    def test_numpy_scalars(self, device):
4084        scalar = np.float64(0.5)
4085
4086        with self.assertRaisesRegex(RuntimeError, "can't alias NumPy scalars."):
4087            torch.asarray(scalar, copy=False)
4088
4089        tensor = torch.asarray(scalar)
4090        self.assertEqual(tensor.dim(), 0)
4091        self.assertEqual(tensor.item(), scalar.item())
4092        self.assertEqual(tensor.dtype, torch.float64)
4093        # Regression test for https://github.com/pytorch/pytorch/issues/97021
4094        zerodim_arr = np.array(1.)
4095        tensor = torch.asarray(zerodim_arr, dtype=torch.int32)
4096        self.assertEqual(tensor.dim(), 0)
4097        self.assertEqual(tensor.item(), zerodim_arr.item())
4098        self.assertEqual(tensor.dtype, torch.int32)
4099
4100    def test_default_device(self, device):
4101        original = torch.arange(5)
4102
4103        examples: List[Tuple[Any, Dict]] = [
4104            (3, {}),
4105            (original, {}),
4106            (to_numpy(original), {}),
4107            (to_memview(original), {"dtype": original.dtype}),
4108        ]
4109
4110        for data, kwargs in examples:
4111            with torch.device(device):
4112                tensor = torch.asarray(data, **kwargs)
4113                self.assertEqual(tensor.device, torch.device(device))
4114
4115                # Check the contents of the tensor.
4116                if isinstance(data, int):
4117                    self.assertEqual(data, tensor.item())
4118                else:
4119                    self.assertEqual(data, tensor)
4120
4121    @onlyCUDA
4122    def test_device_without_index(self, device):
4123        original = torch.arange(5, device="cuda")
4124
4125        tensor = torch.asarray(original, device="cuda")
4126        # The storage pointers should be equal
4127        self.assertEqual(original.data_ptr(), tensor.data_ptr())
4128
4129        tensor = torch.asarray(original, copy=True, device="cuda")
4130        # The storage pointers should not be equal
4131        self.assertNotEqual(original.data_ptr(), tensor.data_ptr())
4132
4133
4134instantiate_device_type_tests(TestTensorCreation, globals())
4135instantiate_device_type_tests(TestRandomTensorCreation, globals())
4136instantiate_device_type_tests(TestLikeTensorCreation, globals())
4137instantiate_device_type_tests(TestBufferProtocol, globals(), only_for="cpu")
4138instantiate_device_type_tests(TestAsArray, globals())
4139
4140if __name__ == '__main__':
4141    TestCase._default_dtype_check_enabled = True
4142    run_tests()
4143