xref: /aosp_15_r20/external/pytorch/test/test_torch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3# Owner(s): ["module: tests"]
4
5import torch
6import torch.utils.data
7import numpy as np
8
9import contextlib
10import gc
11import io
12import inspect
13import itertools
14import math
15import random
16import re
17import copy
18import os
19import tempfile
20import unittest
21import warnings
22import types
23import pickle
24import textwrap
25import subprocess
26import weakref
27import sys
28import copyreg
29from torch import inf, nan
30from itertools import product, combinations, permutations, chain
31from functools import partial
32from torch import multiprocessing as mp
33from torch.testing import make_tensor
34from torch.testing._internal.common_optimizers import (
35    optim_db, optims, _get_optim_inputs_including_global_cliquey_kwargs)
36
37from torch.testing._internal.common_utils import (  # type: ignore[attr-defined]
38    TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, run_tests, IS_JETSON,
39    IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
40    IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest, slowTestIf,
41    TEST_WITH_CROSSREF, skipIfTorchDynamo, skipRocmIfTorchInductor, set_default_dtype,
42    skipCUDAMemoryLeakCheckIf, BytesIOContext,
43    skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
44    wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard,
45    bytes_to_scalar, parametrize, skipIfMps, noncontiguous_like,
46    AlwaysWarnTypedStorageRemoval, TEST_WITH_TORCHDYNAMO, xfailIfTorchDynamo)
47from multiprocessing.reduction import ForkingPickler
48from torch.testing._internal.common_device_type import (
49    expectedFailureMeta,
50    expectedFailureXLA,
51    instantiate_device_type_tests,
52    onlyCUDA, onlyCPU,
53    dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast,
54    skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes,
55    get_all_device_types, skipXLA)
56from typing import Tuple
57import torch.backends.quantized
58import torch.testing._internal.data
59from torch.testing._internal.common_cuda import (
60    tf32_on_and_off, tf32_is_not_fp32, TEST_CUDNN, TEST_MULTIGPU,
61    _create_scaling_case, _create_scaling_models_optimizers)
62from torch.testing._internal.common_mkldnn import bf32_on_and_off
63from torch.testing._internal.common_dtype import (
64    floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types,
65    all_types_and, floating_types, floating_and_complex_types, integral_types_and,
66    get_all_qint_dtypes,
67)
68from torch.testing._internal.two_tensor import TwoTensor
69
70if TEST_WITH_TORCHINDUCTOR:
71    from torch._inductor.test_case import TestCase
72else:
73    from torch.testing._internal.common_utils import TestCase  # type: ignore[assignment]
74
75
76# Protects against includes accidentally setting the default dtype
77assert torch.get_default_dtype() is torch.float32
78
79# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
80# sharding on sandcastle. This line silences flake warnings
81load_tests = load_tests
82
83AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
84
85@contextlib.contextmanager
86def torch_vital_set(value):
87    stash = None
88    if 'TORCH_VITAL' in os.environ:
89        stash = os.environ['TORCH_VITAL']
90    os.environ['TORCH_VITAL'] = value
91    try:
92        yield
93    finally:
94        if stash:
95            os.environ['TORCH_VITAL'] = stash
96        else:
97            del os.environ['TORCH_VITAL']
98
99# Tests Vital Signs for Torch
100# FIXME: document or deprecate whatever this is
101class TestBasicVitalSigns(TestCase):
102    def test_basic_vitals(self):
103        with torch_vital_set(''):
104            self.assertFalse(torch.vitals_enabled())
105        with torch_vital_set('ON'):
106            self.assertTrue(torch.vitals_enabled())
107
108    def test_basic_vitals_read_write(self):
109        with torch_vital_set('ON'):
110            self.assertTrue(torch.vitals_enabled())
111            # This tests the code path of setting a vital
112            self.assertTrue(torch.set_vital('Dataloader', 'basic_unit_test', 'TEST_VALUE_STRING'))
113            self.assertIn('TEST_VALUE_STRING', torch.read_vitals())
114            self.assertIn('CUDA.used', torch.read_vitals())
115
116    def test_dataloader_vitals(self):
117        with torch_vital_set('ON'):
118            inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
119            tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
120            dataset = torch.utils.data.TensorDataset(inps, tgts)
121            loader = torch.utils.data.DataLoader(dataset, batch_size=2)
122            self.assertIn('Dataloader.enabled\t\t True', torch.read_vitals())
123
124# FIXME: document or deprecate whatever this is
125class TestVitalSignsCuda(TestCase):
126    @onlyCUDA
127    def test_cuda_vitals_gpu_only(self, device):
128        with torch_vital_set('ON'):
129            self.assertIn('CUDA.used\t\t true', torch.read_vitals())
130
131
132is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability(0) == (8, 6)
133
134class TestTorchDeviceType(TestCase):
135    exact_dtype = True
136
137    # TODO: move all tensor creation to common ops
138    def _rand_shape(self, dim, min_size, max_size):
139        shape = []
140        for i in range(dim):
141            shape.append(random.randint(min_size, max_size))
142        return tuple(shape)
143
144    # Validates that mathematical constants are defined properly, as required by
145    # the Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html)
146    @onlyCPU
147    def test_constants(self, device):
148        self.assertIsInstance(torch.e, float)
149        self.assertEqual(torch.e, math.e, atol=0, rtol=0)
150
151        self.assertIsInstance(torch.pi, float)
152        self.assertEqual(torch.pi, math.pi, atol=0, rtol=0)
153
154        self.assertIsInstance(torch.nan, float)
155        self.assertEqual(torch.nan, math.nan, equal_nan=True)
156
157        self.assertIsInstance(torch.inf, float)
158        self.assertEqual(torch.inf, math.inf)
159
160    @onlyNativeDeviceTypes
161    @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
162            torch.bool, torch.float32, torch.complex64, torch.float64,
163            torch.complex128, torch.uint16, torch.uint32, torch.uint64)
164    def test_bytes_to_scalar(self, device, dtype):
165        def rand_byte():
166            if dtype == torch.bool:
167                return torch.randint(0, 2, ()).item()
168            else:
169                return torch.randint(0, 256, ()).item()
170
171        element_size = torch._utils._element_size(dtype)
172
173        for i in range(10):
174            bytes_list = [rand_byte() for _ in range(element_size)]
175            scalar = bytes_to_scalar(bytes_list, dtype, device)
176            self.assertEqual(scalar.storage().untyped().tolist(), bytes_list)
177
178    @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
179            torch.bool, torch.float32, torch.complex64, torch.float64,
180            torch.complex128, torch.uint16, torch.uint32, torch.uint64)
181    def test_storage(self, device, dtype):
182        v = make_tensor((3, 5), dtype=dtype, device=device, low=-9, high=9)
183        self.assertEqual(v.storage()[0], v[0][0])
184        self.assertEqual(v.storage()[14], v[2][4])
185        v_s = v.storage()
186
187        for el_num in range(v.numel()):
188            dim0 = el_num // v.size(1)
189            dim1 = el_num % v.size(1)
190            self.assertEqual(
191                v_s[el_num],
192                v[dim0][dim1])
193
194        v_s_byte = v.storage().untyped()
195        el_size = v.element_size()
196
197        for el_num in range(v.numel()):
198            start = el_num * el_size
199            end = start + el_size
200            dim0 = el_num // v.size(1)
201            dim1 = el_num % v.size(1)
202            self.assertEqual(
203                bytes_to_scalar(v_s_byte[start:end], dtype, device),
204                v[dim0][dim1])
205
206    @onlyNativeDeviceTypes
207    @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
208            torch.bool, torch.float32, torch.complex64, torch.float64,
209            torch.complex128, torch.quint8, torch.qint8, torch.qint32,
210            torch.quint4x2)
211    def test_storage_setitem(self, device, dtype):
212        # Skip quantized dtypes for CUDA, since they're not supported
213        if torch.device(device).type == 'cuda':
214            if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.quint4x2]:
215                return
216
217        storage_type_name = torch.storage._dtype_to_storage_type_map()[dtype]
218        if torch.device(device).type == 'cuda':
219            storage_type = eval('torch.cuda.' + storage_type_name)
220        else:
221            storage_type = eval('torch.' + storage_type_name)
222
223        N = 10
224
225        s = storage_type(N)
226        s[:] = 0
227        l = [0] * N
228        self.assertEqual(s, storage_type(l))
229
230        for i in range(N):
231            s[i] = i
232            l[i] = i
233
234        self.assertEqual(s, storage_type(l))
235
236        l[2:7] = [1] * 5
237        s[2:7] = 1
238        self.assertEqual(s, storage_type(l))
239
240    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
241    @onlyNativeDeviceTypes
242    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
243    def test_tensor_storage_type(self, device, dtype):
244        a = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9)
245
246        module = torch.cuda if (torch.device(device).type == 'cuda') else torch
247        expected_storage_type = getattr(module, torch.storage._dtype_to_storage_type_map()[dtype])
248
249        self.assertEqual(a.storage_type(), expected_storage_type)
250
251    @onlyNativeDeviceTypes
252    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64))
253    def test_tensor_from_storage(self, device, dtype):
254        a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9)
255        a_s = a.storage()
256        b = torch.tensor(a_s, device=device, dtype=dtype).reshape(a.size())
257        self.assertEqual(a, b)
258        c = torch.tensor(a_s.untyped(), device=device, dtype=dtype).reshape(a.size())
259        self.assertEqual(a, c)
260
261        for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
262            if error_dtype == dtype:
263                continue
264            with self.assertRaisesRegex(RuntimeError, r'Expected a Storage of type'):
265                error_storage = a.to(error_dtype).storage()
266                torch.tensor(error_storage, device=device, dtype=dtype)
267
268    @onlyNativeDeviceTypes
269    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
270    def test_set_storage(self, device, dtype):
271        a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9)
272        a_s = a.storage()
273        b = torch.tensor([], device=device, dtype=dtype).set_(a_s).reshape(a.size())
274        self.assertEqual(a, b)
275        c = torch.tensor([], device=device, dtype=dtype).set_(a_s.untyped()).reshape(a.size())
276        self.assertEqual(a, c)
277
278        for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
279            if error_dtype == dtype:
280                continue
281            with self.assertRaisesRegex(RuntimeError, r'Expected a Storage of type'):
282                error_storage = a.to(error_dtype).storage()
283                b = torch.tensor([], device=device, dtype=dtype).set_(error_storage)
284
285    def _check_storage_meta(self, s, s_check):
286        self.assertTrue(
287            isinstance(s, (torch.UntypedStorage, torch.TypedStorage)) and
288            isinstance(s_check, type(s)),
289            (
290                's and s_check must both be one of UntypedStorage or '
291                'TypedStorage, but got'
292                f' {type(s).__name__} and {type(s_check).__name__}'))
293
294        self.assertEqual(s.device.type, 'meta')
295        self.assertEqual(s.nbytes(), s_check.nbytes())
296        self.assertEqual(s.size(), s_check.size())
297        self.assertEqual(s.data_ptr(), 0)
298
299        with self.assertRaisesRegex(NotImplementedError, r'Not available'):
300            s[0]
301
302        if isinstance(s, torch.TypedStorage):
303            self.assertEqual(s.dtype, s_check.dtype)
304            self._check_storage_meta(s.untyped(), s_check.untyped())
305
306    @onlyNativeDeviceTypes
307    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
308    def test_typed_storage_meta(self, device, dtype):
309        args_list = [
310            [],
311            [0],
312            [100],
313            [[1, 2, 3, 4, 5, 6]],
314        ]
315        for args in args_list:
316            s_check = torch.TypedStorage(*args, dtype=dtype, device=device)
317            s = torch.TypedStorage(*args, dtype=dtype, device='meta')
318            self._check_storage_meta(s, s_check)
319
320    @onlyNativeDeviceTypes
321    def test_untyped_storage_meta(self, device):
322        args_list = [
323            [],
324            [0],
325            [100],
326            [[1, 2, 3, 4, 5, 6]],
327        ]
328        for args in args_list:
329            s_check = torch.UntypedStorage(*args, device=device)
330            s = torch.UntypedStorage(*args, device='meta')
331            self._check_storage_meta(s, s_check)
332
333    @onlyNativeDeviceTypes
334    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
335    def test_storage_meta_from_tensor(self, device, dtype):
336        t_check = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9)
337        t = t_check.to('meta')
338
339        s_check = t_check.storage()
340        s = t.storage()
341        self._check_storage_meta(s, s_check)
342
343    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
344    def test_storage_meta_errors(self, device, dtype):
345        s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype)
346
347        with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
348            s0.cpu()
349
350        with self.assertRaisesRegex(RuntimeError, r'only available on CPU'):
351            s0._share_fd_cpu_()
352
353        with self.assertRaisesRegex(RuntimeError, r'only available on CPU'):
354            s0._share_filename_cpu_()
355
356        if torch.cuda.is_available():
357            with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
358                s0.cuda()
359
360            with self.assertRaisesRegex(RuntimeError, r'only available on CUDA'):
361                s0._share_cuda_()
362
363            with self.assertRaisesRegex(TypeError, r"cannot pin 'torch.storage.UntypedStorage' only CPU memory can be pinned"):
364                s0.pin_memory()
365
366        with self.assertRaisesRegex(RuntimeError, r'only available on CPU'):
367            s0.share_memory_()
368
369        with self.assertRaisesRegex(NotImplementedError, r'Not available'):
370            s0.tolist()
371
372        with tempfile.NamedTemporaryFile() as f:
373            with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
374                s0._write_file(f, True, True, s0.element_size())
375
376        for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']:
377            s1 = torch.TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
378
379            with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
380                s1.copy_(s0)
381
382    @onlyCPU
383    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
384    def test_storage_meta_ok(self, device, dtype):
385        s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype)
386
387        # This is OK, it changes the meta storage size without allocating
388        s0.resize_(10)
389
390    @onlyCUDA
391    def test_module_share_memory(self):
392        # Test fix for issue #80733
393        # See https://github.com/pytorch/pytorch/issues/80733
394        model = torch.nn.Linear(3, 1)
395        model_cuda = model.to('cuda')
396        model.share_memory()
397
398    @dtypes(torch.float32, torch.complex64)
399    def test_deepcopy(self, device, dtype):
400        from copy import deepcopy
401        a = torch.randn(5, 5, dtype=dtype, device=device)
402        b = torch.randn(5, 5, dtype=dtype, device=device)
403        c = a.view(25)
404        q = [a, [a.storage(), b.storage()], b, c]
405        w = deepcopy(q)
406        self.assertEqual(w[0], q[0], atol=0, rtol=0)
407        self.assertEqual(w[1][0], q[1][0], atol=0, rtol=0)
408        self.assertEqual(w[1][1], q[1][1], atol=0, rtol=0)
409        self.assertEqual(w[1], q[1], atol=0, rtol=0)
410        self.assertEqual(w[2], q[2], atol=0, rtol=0)
411
412        # Check that deepcopy preserves sharing
413        w[0].add_(1)
414        for i in range(a.numel()):
415            self.assertEqual(w[1][0][i], q[1][0][i] + 1)
416        self.assertEqual(w[3], c + 1)
417        w[2].sub_(1)
418        for i in range(a.numel()):
419            self.assertEqual(w[1][1][i], q[1][1][i] - 1)
420
421        # Check that deepcopy preserves attributes
422        a.foo = 3
423        self.assertEqual(deepcopy(a).foo, 3)
424
425    @dtypes(torch.float32, torch.complex64)
426    def test_deepcopy_scalar(self, device, dtype):
427        from copy import deepcopy
428        a = torch.tensor(5, dtype=dtype, device=device)
429        self.assertEqual(a.size(), deepcopy(a).size())
430        self.assertEqual(a, deepcopy(a))
431
432    def check_internal_mem_overlap(self, inplace_op, num_inputs,
433                                   dtype, device,
434                                   expected_failure=False):
435        if isinstance(inplace_op, str):
436            inplace_op = getattr(torch.Tensor, inplace_op)
437        input = torch.randn(1, dtype=dtype, device=device).expand(3, 3)
438        inputs = [input] + [torch.randn_like(input)
439                            for i in range(num_inputs - 1)]
440        if not expected_failure:
441            with self.assertRaisesRegex(RuntimeError, 'single memory location'):
442                inplace_op(*inputs)
443        else:
444            with self.assertRaises(AssertionError):
445                with self.assertRaisesRegex(RuntimeError, 'single memory location'):
446                    inplace_op(*inputs)
447
448    def unary_check_input_output_mem_overlap(self, data, sz, op,
449                                             expected_failure=False):
450
451        def _test(op, output, input):
452            output_exp = torch.empty_like(output)
453            op(input, out=output_exp)
454            self.assertEqual(op(input, out=output), output_exp, msg=op.__name__)
455
456        # output is identical to input:
457        _test(op, output=data[0:sz], input=data[0:sz])
458        # output and input are independent:
459        _test(op, output=data[0:sz], input=data[sz:2 * sz])
460        # output partially overlaps with input:
461        if not expected_failure:
462            with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
463                _test(op, data[0:sz], data[1:sz + 1])
464        else:
465            with self.assertRaises(AssertionError):
466                with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
467                    _test(op, data[0:sz], data[1:sz + 1])
468        # output is transpose of input:
469        length = int(math.sqrt(sz))
470        input = data[:length**2].view([length, length])
471        out = input.t()
472        if not expected_failure:
473            with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
474                _test(op, out, input)
475        else:
476            with self.assertRaises(AssertionError):
477                with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
478                    _test(op, out, input)
479
480    def ternary_check_input_output_mem_overlap(self, op, device,
481                                               expected_failure=False):
482        sz = 9
483        data = torch.randn(2 * sz, device=device)
484        other1 = torch.randn(sz, device=device)
485        other2 = torch.randn(sz, device=device)
486
487        self.unary_check_input_output_mem_overlap(
488            data, sz, lambda input, out:
489                op(input, other1.view(input.shape), other2.view(input.shape), out=out),
490            expected_failure=expected_failure)
491
492        self.unary_check_input_output_mem_overlap(
493            data, sz, lambda input, out:
494                op(other1.view(input.shape), input, other2.view(input.shape), out=out),
495            expected_failure=expected_failure)
496
497        self.unary_check_input_output_mem_overlap(
498            data, sz, lambda input, out:
499                op(other1.view(input.shape), other2.view(input.shape), input, out=out),
500            expected_failure=expected_failure)
501
502    def _select_broadcastable_dims(self, dims_full=None):
503        # select full dimensionality
504        if dims_full is None:
505            dims_full = []
506            ndims = random.randint(1, 4)
507            dims_full = [random.randint(1, 8) for _ in range(ndims)]
508        else:
509            ndims = len(dims_full)
510
511        # select actual dimensions for ops:
512        # larger: full ndims, individual sizes may be reduced
513        # smaller: possibly reduced ndims, sizes may be reduced
514        smaller_ndims = random.randint(1, ndims)
515        dims_small = []
516        dims_large = []
517        for i in range(ndims - 1, -1, -1):
518            j = random.randint(1, 3)
519            if j == 1:  # no reduced singleton dimension
520                ds = dims_full[i]
521                dl = dims_full[i]
522            elif j == 2:  # larger may have reduced singleton dimension
523                ds = dims_full[i]
524                dl = 1 if len(dims_small) < smaller_ndims else dims_full[i]
525            elif j == 3:  # smaller may have reduced singleton dimension
526                ds = 1
527                dl = dims_full[i]
528            dims_large = [dl] + dims_large
529            if len(dims_small) < smaller_ndims:
530                dims_small = [ds] + dims_small
531        return (dims_small, dims_large, dims_full)
532
533    # collected tests of ops that used scalar_check in Declarations.cwrap for
534    # correctness
535    def test_scalar_check(self, device):
536        zero_d = torch.randn((), device=device)
537        one_d = torch.randn((1,), device=device)
538
539        # remainder
540        self.assertEqual((), torch.remainder(zero_d, zero_d).shape)
541        self.assertEqual((), torch.remainder(zero_d, 2).shape)
542        self.assertEqual((1,), torch.remainder(zero_d, one_d).shape)
543        self.assertEqual((1,), torch.remainder(one_d, zero_d).shape)
544
545        # fmod
546        self.assertEqual((), torch.fmod(zero_d, zero_d).shape)
547        self.assertEqual((), torch.fmod(zero_d, 2).shape)
548        self.assertEqual((1,), torch.fmod(zero_d, one_d).shape)
549        self.assertEqual((1,), torch.fmod(one_d, zero_d).shape)
550
551        # exp, cos, cosh, tan, atan, tanh, erf, erfc, reciprocal
552        self.assertEqual((), torch.exp(zero_d).shape)
553        self.assertEqual((), torch.cos(zero_d).shape)
554        self.assertEqual((), torch.cosh(zero_d).shape)
555        self.assertEqual((), torch.tan(zero_d).shape)
556        self.assertEqual((), torch.atan(zero_d).shape)
557        self.assertEqual((), torch.acosh(zero_d).shape)
558        self.assertEqual((), torch.asinh(zero_d).shape)
559        self.assertEqual((), torch.atanh(zero_d).shape)
560        self.assertEqual((), torch.tanh(zero_d).shape)
561        self.assertEqual((), torch.erf(zero_d).shape)
562        self.assertEqual((), torch.erfc(zero_d).shape)
563        self.assertEqual((), torch.reciprocal(zero_d).shape)
564        self.assertEqual((1,), torch.exp(one_d).shape)
565        self.assertEqual((1,), torch.cos(one_d).shape)
566        self.assertEqual((1,), torch.cosh(one_d).shape)
567        self.assertEqual((1,), torch.tan(one_d).shape)
568        self.assertEqual((1,), torch.atan(one_d).shape)
569        self.assertEqual((1,), torch.acosh(one_d).shape)
570        self.assertEqual((1,), torch.asinh(one_d).shape)
571        self.assertEqual((1,), torch.atanh(one_d).shape)
572        self.assertEqual((1,), torch.tanh(one_d).shape)
573        self.assertEqual((1,), torch.erf(one_d).shape)
574        self.assertEqual((1,), torch.erfc(one_d).shape)
575        self.assertEqual((1,), torch.reciprocal(one_d).shape)
576
577        # clamp
578        self.assertEqual((), torch.clamp(zero_d, min=0, max=1).shape)
579        self.assertEqual((), torch.clamp(zero_d, min=0).shape)
580        self.assertEqual((), torch.clamp(zero_d, max=1).shape)
581        self.assertEqual((1,), torch.clamp(one_d, min=0, max=1).shape)
582        self.assertEqual((1,), torch.clamp(one_d, min=0).shape)
583        self.assertEqual((1,), torch.clamp(one_d, max=1).shape)
584
585        # cumsum, cumprod, cummax, cummin
586        self.assertEqual((), torch.logcumsumexp(zero_d, 0).shape)
587        self.assertEqual((), torch.cumsum(zero_d, 0).shape)
588        self.assertEqual((), torch.cumprod(zero_d, 0).shape)
589        self.assertEqual((), torch.cummax(zero_d, 0)[0].shape)
590        self.assertEqual((), torch.cummin(zero_d, 0)[0].shape)
591
592        # sort, topk
593        self.assertEqual([(), ()], [x.shape for x in torch.sort(zero_d, 0, False)])
594        self.assertEqual([(), ()], [x.shape for x in torch.sort(zero_d, 0, True)])
595        self.assertEqual([(), ()], [x.shape for x in torch.topk(zero_d, 1, 0, False)])
596        self.assertEqual([(), ()], [x.shape for x in torch.topk(zero_d, 1, 0, True)])
597
598        # max, min
599        self.assertEqual((), torch.max(zero_d, zero_d).shape)
600        self.assertEqual((1,), torch.max(one_d, zero_d).shape)
601        self.assertEqual((1,), torch.max(zero_d, one_d).shape)
602        self.assertEqual((), torch.min(zero_d, zero_d).shape)
603        self.assertEqual((1,), torch.min(one_d, zero_d).shape)
604        self.assertEqual((1,), torch.min(zero_d, one_d).shape)
605
606        zero_d_int = torch.tensor(1, device=device)
607        one_d_int = torch.tensor([1], device=device)
608
609        # lshift, rshift
610        self.assertEqual((), (zero_d_int >> zero_d_int).shape)
611        self.assertEqual((), (zero_d_int >> 1).shape)
612        self.assertEqual((1,), (one_d_int >> zero_d_int).shape)
613        self.assertEqual((1,), (zero_d_int >> one_d_int).shape)
614        self.assertEqual((1,), (one_d_int >> 1).shape)
615
616        self.assertEqual((), (zero_d_int << zero_d_int).shape)
617        self.assertEqual((), (zero_d_int << 1).shape)
618        self.assertEqual((1,), (one_d_int << zero_d_int).shape)
619        self.assertEqual((1,), (zero_d_int << one_d_int).shape)
620        self.assertEqual((1,), (one_d_int << 1).shape)
621
622        # or
623        self.assertEqual((), (zero_d_int | zero_d_int).shape)
624        self.assertEqual((), (zero_d_int | 1).shape)
625        self.assertEqual((1,), (one_d_int | zero_d_int).shape)
626        self.assertEqual((1,), (zero_d_int | one_d_int).shape)
627        self.assertEqual((1,), (one_d_int | 1).shape)
628
629        # and
630        self.assertEqual((), (zero_d_int & zero_d_int).shape)
631        self.assertEqual((), (zero_d_int & 1).shape)
632        self.assertEqual((1,), (one_d_int & zero_d_int).shape)
633        self.assertEqual((1,), (zero_d_int & one_d_int).shape)
634        self.assertEqual((1,), (one_d_int & 1).shape)
635
636        # clone
637        self.assertEqual((), zero_d.clone().shape)
638
639        zero_d_bool = torch.tensor(True, device=device)
640        one_d_bool = torch.tensor([True], device=device)
641
642        # masked_select
643        self.assertEqual((1,), torch.masked_select(zero_d_bool, zero_d_bool).shape)
644        self.assertEqual((1,), torch.masked_select(zero_d_bool, one_d_bool).shape)
645        self.assertEqual((1,), torch.masked_select(one_d_bool, zero_d_bool).shape)
646
647        zero_d_uint8 = torch.tensor(1, dtype=torch.uint8, device=device)
648        one_d_uint8 = torch.tensor([1], dtype=torch.uint8, device=device)
649
650        # mode
651        self.assertEqual([(), ()], [x.shape for x in torch.mode(zero_d, dim=0, keepdim=True)])
652        self.assertEqual([(), ()], [x.shape for x in torch.mode(zero_d, dim=0, keepdim=False)])
653        self.assertEqual([(1,), (1,)], [x.shape for x in torch.mode(one_d, dim=0, keepdim=True)])
654        self.assertEqual([(), ()], [x.shape for x in torch.mode(one_d, dim=0, keepdim=False)])
655
656        # max
657        self.assertEqual([(), ()], [x.shape for x in torch.max(zero_d, dim=0, keepdim=True)])
658        self.assertEqual([(), ()], [x.shape for x in torch.max(zero_d, dim=0, keepdim=False)])
659        self.assertEqual([(1,), (1,)], [x.shape for x in torch.max(one_d, dim=0, keepdim=True)])
660        self.assertEqual([(), ()], [x.shape for x in torch.max(one_d, dim=0, keepdim=False)])
661
662        # amax
663        self.assertEqual((), torch.amax(zero_d, dim=0, keepdim=True).shape)
664        self.assertEqual((), torch.amax(zero_d, dim=0, keepdim=False).shape)
665        self.assertEqual((1,), torch.amax(one_d, dim=0, keepdim=True).shape)
666        self.assertEqual((), torch.amax(one_d, dim=0, keepdim=False).shape)
667
668        # min
669        self.assertEqual([(), ()], [x.shape for x in torch.min(zero_d, dim=0, keepdim=True)])
670        self.assertEqual([(), ()], [x.shape for x in torch.min(zero_d, dim=0, keepdim=False)])
671        self.assertEqual([(1,), (1,)], [x.shape for x in torch.min(one_d, dim=0, keepdim=True)])
672        self.assertEqual([(), ()], [x.shape for x in torch.min(one_d, dim=0, keepdim=False)])
673
674        # amin
675        self.assertEqual((), torch.amin(zero_d, dim=0, keepdim=True).shape)
676        self.assertEqual((), torch.amin(zero_d, dim=0, keepdim=False).shape)
677        self.assertEqual((1,), torch.amin(one_d, dim=0, keepdim=True).shape)
678        self.assertEqual((), torch.amin(one_d, dim=0, keepdim=False).shape)
679
680        # set_
681        zero_d_clone = zero_d.clone()
682        one_d_clone = one_d.clone()
683        self.assertEqual((), zero_d_clone.set_(one_d.storage(), 0, (), ()).shape)
684        self.assertEqual((1,), zero_d_clone.set_(one_d.storage(), 0, (1,), (1,)).shape)
685        self.assertEqual((), one_d_clone.set_(one_d.storage(), 0, (), ()).shape)
686        self.assertEqual((1,), one_d_clone.set_(one_d.storage(), 0, (1,), (1,)).shape)
687
688        self.assertEqual((), zero_d.clone().set_(zero_d).shape)
689        self.assertEqual((), one_d.clone().set_(zero_d).shape)
690        self.assertEqual((1,), zero_d.clone().set_(one_d).shape)
691        self.assertEqual((1,), one_d.clone().set_(one_d).shape)
692
693        # take
694        self.assertEqual((), torch.randn((2, 3), device=device).take(zero_d_int).shape)
695        self.assertEqual((1,), torch.randn((2, 3), device=device).take(one_d_int).shape)
696
697        # gather
698        self.assertEqual((), torch.gather(zero_d, 0, torch.zeros((), dtype=torch.int64, device=device)).shape)
699        self.assertEqual((1,), torch.gather(zero_d, 0, torch.zeros((1,), dtype=torch.int64, device=device)).shape)
700        self.assertEqual((), torch.gather(one_d, 0, torch.zeros((), dtype=torch.int64, device=device)).shape)
701        self.assertEqual((1,), torch.gather(one_d, 0, torch.zeros((1,), dtype=torch.int64, device=device)).shape)
702
703        # normal
704        # std must be >= 0
705        zero_d_ge_0 = torch.rand((), device=device)
706        # documentation says out shape matches shape of mean
707        self.assertEqual((), torch.normal(zero_d, zero_d_ge_0).shape)
708        self.assertEqual((1,), torch.normal(one_d, zero_d_ge_0).shape)
709        self.assertEqual((), torch.normal(1, zero_d_ge_0).shape)
710        self.assertEqual((), torch.normal(zero_d, 1).shape)
711        self.assertEqual((1,), torch.normal(one_d, 1).shape)
712        # TODO: this behavior differs on CPU and GPU, see https://github.com/pytorch/pytorch/issues/30480.
713        # self.assertEqual((), torch.normal(zero_d, one_d).shape)
714        # self.assertEqual((), torch.normal(1, one_d).shape)
715
716        # convolutions.  Yes, we are testing nn.functional here; seems justified
717        # given its similar to the other tests
718        w = torch.randn(2, 1, 3, 3, device=device).div_(2).requires_grad_()
719        self.assertRaises(RuntimeError, lambda: torch.nn.functional.conv2d(zero_d, w, groups=1))
720        self.assertRaises(RuntimeError, lambda: torch.nn.functional.conv2d(zero_d, w, groups=2))
721
722        # nll_loss -- verify input can't be 0-dimensional.
723        self.assertRaises(ValueError, lambda: torch.nn.functional.nll_loss(zero_d, zero_d, reduction='none'))
724        self.assertRaises(ValueError, lambda: torch.nn.functional.nll_loss(zero_d, one_d, reduction='none'))
725        # verify output is 0-dimensional when reduction != 'none'
726        for (input, target) in ((torch.randn(1, 1, device=device), torch.tensor([0], device=device)),
727                                (torch.randn(1, 1, 1, 1, device=device), torch.tensor([[[0]]], device=device))):
728            self.assertEqual((), torch.nn.functional.nll_loss(input, target, reduction='mean').shape)
729            self.assertEqual((), torch.nn.functional.nll_loss(input, target, reduction='sum').shape)
730
731    # Test that `torch._check_tensor_all` raises errors in the correct cases
732    def test_check_tensor_all(self, device):
733        default_message = 'Expected cond to be True'
734        check_fn = torch._check_tensor_all
735        expected_error = RuntimeError
736
737        # cond must be a tensor
738        with self.assertRaisesRegex(TypeError, 'cond must be a tensor'):
739            check_fn(True)
740
741        # cond tensor must be boolean
742        with self.assertRaisesRegex(TypeError, 'cond tensor must have dtype torch.bool'):
743            check_fn(torch.ones(1, device=device))
744
745        test_sizes = [
746            (),
747            (1,),
748            (10,),
749            (1, 1),
750            (1, 10),
751            (10, 1),
752            (10, 10),
753            (1, 1, 1),
754            (10, 1, 1),
755            (1, 10, 1),
756            (10, 10, 10),
757        ]
758        for size in test_sizes:
759            t_all_true = torch.ones(size, dtype=torch.bool, device=device)
760            t_all_false = torch.zeros(size, dtype=torch.bool, device=device)
761
762            # Should not raise error
763            check_fn(t_all_true)
764
765            with self.assertRaisesRegex(expected_error, default_message):
766                check_fn(t_all_false)
767
768            if t_all_true.numel() > 1:
769                t_all_true_but_one = t_all_true.clone()
770                # Choose a random element to set to false
771                idx = (random.choice(range(dim_size)) for dim_size in size)
772                t_all_true_but_one[(..., *idx)] = False
773
774                with self.assertRaisesRegex(expected_error, default_message):
775                    check_fn(t_all_true_but_one)
776
777            # Test a simple failure message
778            message = 'message'
779            with self.assertRaisesRegex(expected_error, message):
780                check_fn(t_all_false, lambda: message)
781
782            # Test message with tensor
783            def message():
784                return torch.arange(4)
785
786            with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
787                check_fn(t_all_false, message)
788
789            # Test format string message
790            def message():
791                return f"{'test'} {[1, 2, 'a', True]} {True} {100} {torch.arange(4)}"
792
793            with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
794                check_fn(t_all_false, message)
795
796    # Test that `TORCH_CHECK_TENSOR_ALL` raises errors that propagate from C++ to Python
797    def test_check_tensor_internal(self, device):
798        test_sizes = [
799            (),
800            (1,),
801            (10,),
802            (1, 1),
803            (1, 10),
804            (10, 1),
805            (10, 10),
806            (1, 1, 1),
807            (10, 1, 1),
808            (1, 10, 1),
809            (10, 10, 10),
810        ]
811        for size in test_sizes:
812            t_all_true = torch.ones(size, dtype=torch.bool, device=device)
813            t_all_false = torch.zeros(size, dtype=torch.bool, device=device)
814
815            # Should not raise error
816            torch._test_check_tensor(t_all_true)
817
818            with self.assertRaisesRegex(RuntimeError, "Test message for TORCH_CHECK_TENSOR_ALL"):
819                torch._test_check_tensor(t_all_false)
820
821            if t_all_true.numel() > 1:
822                t_all_true_but_one = t_all_true.clone()
823                # Choose a random element to set to false
824                idx = (random.choice(range(dim_size)) for dim_size in size)
825                t_all_true_but_one[(..., *idx)] = False
826
827                with self.assertRaisesRegex(RuntimeError, "Test message for TORCH_CHECK_TENSOR_ALL"):
828                    torch._test_check_tensor(t_all_true_but_one)
829
830    # Uses mismatched arange out size to trigger a warning
831    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
832    @unittest.skipIf(TEST_WITH_CROSSREF, "crossref perturbs line numbering")
833    def test_cpp_warnings_have_python_context(self, device):
834        # Creates long string in advance to avoid a too-long Python line
835        s = ".+Triggered internally at.+RangeFactories.+"
836        # nvfuser deprecation warning filter
837        warnings.filterwarnings("ignore", "torch::jit::fuser::cuda", UserWarning)
838
839        def cpp_warn_fn():
840            out = torch.empty((5,))
841            torch.arange(0, 3, out=out)
842            return out
843
844        # Checks eager-mode cpp warning
845        with warnings.catch_warnings(record=True) as w:
846            cpp_warn_fn()
847            frameinfo = inspect.getframeinfo(inspect.currentframe())
848            warning = w[0]
849
850            # Checks for cpp context in the warning message
851            escaped_warning_message = str(warning.message).encode('unicode_escape')
852            self.assertTrue(re.search(s, repr(escaped_warning_message), re.IGNORECASE) is not None)
853
854            # Checks the Python features of the warning
855            # Note: the eager mode warning refers to the line in the function
856            # that throws the warning.
857            self.assertEqual(frameinfo.lineno - 6, warning.lineno)
858            self.assertEqual(len(w), 1)
859
860        # Checks jitted cpp warning
861        with warnings.catch_warnings(record=True) as w:
862            scripted_cpp_warn_fn = torch.jit.script(cpp_warn_fn)
863            scripted_cpp_warn_fn()
864            warning = w[0]
865
866            # Checks for cpp context in the warning message
867            escaped_warning_message = str(warning.message).encode('unicode_escape')
868            self.assertTrue(re.search(s, repr(escaped_warning_message), re.IGNORECASE) is not None)
869
870            # Checks the Python features of the warning
871            # Note: the jitted warning's lineno refers to the call to the jitted
872            # function, which in our test suite has a layer of indirection
873            # that makes checking the Python lineno fragile
874            self.assertEqual(len(w), 1)
875
876        # Checks jitted Python warning
877        def warn_fn():
878            warnings.warn("Warning!")
879
880        # The jit mimics an eager-mode Python warning in this case
881        with warnings.catch_warnings(record=True) as w:
882            scripted_warn_fn = torch.jit.script(warn_fn)
883            scripted_warn_fn()
884            frameinfo = inspect.getframeinfo(inspect.currentframe())
885            warning = w[0]
886
887            self.assertTrue(re.search('Warning!', str(warning.message)) is not None)
888
889            # Checks the Python features of the warning
890            self.assertEqual(frameinfo.lineno - 6, warning.lineno)
891            self.assertEqual(len(w), 1)
892
893    # FIXME: move to test_testing
894    @onlyCPU
895    def test_warn_always_caught(self, device):
896        # Check that we can catch a TORCH_WARN_ONCE warning twice
897        # since assertWarnsOnceRegex uses set_warn_always(True) which changes
898        # TORCH_WARN_ONCE to TORCH_WARN
899        a = np.arange(10)
900        a.flags.writeable = False
901        with self.assertWarnsOnceRegex(UserWarning, '.*non-writable.*'):
902            torch.from_numpy(a)
903
904        # OK, got it once, now try again
905        with self.assertWarnsOnceRegex(UserWarning, '.*non-writable.*'):
906            torch.from_numpy(a)
907
908        # Make sure emitting two warnings will pass the assertWarnsOnceRegex
909        # context manager
910        with self.assertWarnsOnceRegex(UserWarning, '.*non-writable.*'):
911            torch.from_numpy(a)
912            torch.from_numpy(a)
913
914    @onlyNativeDeviceTypes
915    def test_complex_half_experimental_warning(self, device):
916        msg = 'ComplexHalf support is experimental'
917        with self.assertWarnsOnceRegex(UserWarning, msg):
918            t = torch.randn(3, dtype=torch.chalf, device=device)
919
920        with self.assertWarnsOnceRegex(UserWarning, msg):
921            torch.rand(3, dtype=torch.chalf, device=device)
922
923        with self.assertWarnsOnceRegex(UserWarning, msg):
924            torch.empty(3, dtype=torch.chalf, device=device)
925
926        with self.assertWarnsOnceRegex(UserWarning, msg):
927            torch.ones(3, dtype=torch.chalf, device=device)
928
929        with self.assertWarnsOnceRegex(UserWarning, msg):
930            torch.zeros(3, dtype=torch.chalf, device=device)
931
932        with self.assertWarnsOnceRegex(UserWarning, msg):
933            torch.randn_like(t)
934
935        with self.assertWarnsOnceRegex(UserWarning, msg):
936            torch.rand_like(t)
937
938        with self.assertWarnsOnceRegex(UserWarning, msg):
939            torch.empty_like(t)
940
941        with self.assertWarnsOnceRegex(UserWarning, msg):
942            torch.ones_like(t)
943
944        with self.assertWarnsOnceRegex(UserWarning, msg):
945            torch.zeros_like(t)
946
947        with self.assertWarnsOnceRegex(UserWarning, msg):
948            # t + 1 allocates a new tensor for result using empty
949            t + 1
950
951    @onlyCUDA
952    def test_dtypetensor_warnings(self, device):
953        msg = 'The torch.cuda.*DtypeTensor constructors are no longer recommended'
954        with self.assertWarnsOnceRegex(UserWarning, msg):
955            t = torch.cuda.FloatTensor([0])
956
957        with self.assertWarnsOnceRegex(UserWarning, msg):
958            t = torch.cuda.DoubleTensor([0])
959
960    def test_set_default_tensor_type_warnings(self, device):
961        msg = '.*is deprecated as of PyTorch 2.1, please use torch.set_default_dtype().*'
962        default_type = torch.tensor([]).type()
963        try:
964            with self.assertWarnsOnceRegex(UserWarning, msg):
965                torch.set_default_tensor_type(torch.FloatTensor)
966
967            if torch.cuda.is_available():
968                with self.assertWarnsOnceRegex(UserWarning, msg):
969                    torch.set_default_tensor_type(torch.cuda.FloatTensor)
970        finally:
971            torch.set_default_tensor_type(default_type)
972
973    # TODO: this test should be in test_nn.py
974    def test_conv_transposed_backward_agnostic_to_memory_format(self, device):
975        in_channels = 64
976        out_channels = 128
977        scale_factor = 8
978        batch_size = 8
979        length = 16
980
981        conv = torch.nn.ConvTranspose1d(
982            in_channels, out_channels, kernel_size=scale_factor * 2, stride=scale_factor).to(device)
983        layer_norm = torch.nn.LayerNorm(out_channels).to(device)
984
985        input_ = torch.randn(batch_size, in_channels, length).to(device).contiguous()
986        input_ = conv(input_).contiguous()
987        input_ = layer_norm(input_.transpose(1, 2).contiguous()).contiguous()
988        input_.sum().backward()
989
990        # 3d
991        conv = torch.nn.ConvTranspose3d(3, 3, kernel_size=3).to(device)
992        input = torch.randn(batch_size, 3, length, length, length, device=device)
993        out = conv(input)
994        out.backward(torch.ones_like(out).transpose(-2, -1))
995
996    # TODO: this test should be in test_nn.py
997    @onlyCUDA
998    @largeTensorTest('12GB')
999    def test_conv_transposed_large(self, device):
1000        # ConvTranspose3d works for large input tensors (gh-32866)
1001        in_channels = 64
1002        out_channels = 128
1003        kernel_size = 5
1004
1005        conv = torch.nn.ConvTranspose3d(
1006            in_channels, out_channels, kernel_size=kernel_size,
1007            stride=2, padding=2, output_padding=1).to(device)
1008
1009        x = torch.rand([1, 64, 8, 128, 172]).to(device)
1010        y = conv(x)
1011
1012    def test_is_set_to(self, device):
1013        t1 = torch.empty(3, 4, 9, 10, device=device)
1014        t2 = torch.empty(3, 4, 9, 10, device=device)
1015        t3 = torch.tensor([], device=device).set_(t1)
1016        t4 = t3.clone().resize_(12, 90)
1017        self.assertFalse(t1.is_set_to(t2))
1018        self.assertTrue(t1.is_set_to(t3))
1019        self.assertTrue(t3.is_set_to(t1), "is_set_to should be symmetric")
1020        self.assertFalse(t1.is_set_to(t4))
1021        self.assertFalse(torch.tensor([]).is_set_to(torch.tensor([])),
1022                         "Tensors with no storages should not appear to be set "
1023                         "to each other")
1024
1025        t1 = torch.tensor([True, True], dtype=torch.bool, device=device)
1026        t2 = torch.tensor([0], dtype=torch.bool, device=device).set_(t1)
1027        self.assertTrue(t1.is_set_to(t2))
1028
1029        # test that sizes must match
1030        t1 = torch.empty([2, 3, 4], device=device)
1031        t2 = t1.view(4, 3, 2)
1032        self.assertFalse(t1.is_set_to(t2))
1033        self.assertFalse(t2.is_set_to(t1))
1034
1035        # test that legacy empty size behavior used to be respected (i.e. all
1036        # empty tensors were logically collapsed to size [0]).
1037        t1 = torch.empty([2, 5, 0], device=device)
1038        t2 = t1.view([0])
1039        self.assertFalse(t1.is_set_to(t2))
1040        self.assertFalse(t2.is_set_to(t1))
1041
1042    # See https://github.com/pytorch/pytorch/issues/72650
1043    @skipIfMps
1044    @skipMeta
1045    @parametrize(
1046        "fn",
1047        [
1048            "dist", "atan2", "pow", "lerp", "add", "sub", "mul", "div", "fmod", "remainder", "eq", "ge", "gt", "le",
1049            "lt", "max", "min", "ne", "addcdiv", "addcmul", "masked_scatter", "masked_select", "masked_fill", "map",
1050            "map2", "copy",
1051        ],
1052    )
1053    def test_broadcast(self, fn, device):
1054        # functions with three tensor arguments
1055        fns_3_args = {"map2"}
1056        fns_value_kwarg = {"addcdiv", "addcmul"}
1057
1058        (dims_small, dims_large, dims_full) = self._select_broadcastable_dims()
1059        full1d = torch.randn(*dims_full, device=device).flatten().float()
1060        small = torch.randn(*dims_small, device=device).float()
1061        large = torch.randn(*dims_large, device=device).float()
1062        small_expanded = small.expand(*dims_full)
1063        large_expanded = large.expand(*dims_full)
1064        small2 = None
1065        small2_expanded = None
1066        if fn in fns_3_args or fn in fns_value_kwarg:
1067            # create another smaller tensor
1068            (dims_small2, _, _) = self._select_broadcastable_dims(dims_full)
1069            small2 = torch.randn(*dims_small2, device=device).float()
1070            small2_expanded = small2.expand(*dims_full)
1071
1072        if small.is_cuda and fn in ['map', 'map2']:
1073            # map and map2 are not implementd on CUDA tensors
1074            return
1075
1076        if hasattr(large_expanded, fn):
1077            # run through tensor versions of functions
1078            # and verify fully expanded inputs give same results
1079            expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
1080
1081            def tensorfn(myfn, t1, t2):
1082                if fn == "lerp":
1083                    return myfn(t1, 0.5)
1084                elif fn == "masked_select":
1085                    return myfn(t1 < 0)
1086                elif fn == "masked_scatter":
1087                    return myfn(t1 < 0.5, full1d)
1088                elif fn == "masked_fill":
1089                    return myfn(t1 < 0.5, 1.0)
1090                elif fn in fns_3_args:
1091                    return myfn(1, t1, t2)
1092                elif fn in fns_value_kwarg:
1093                    return myfn(t1, t2, value=1)
1094                else:
1095                    return myfn(t1)
1096
1097            # test various orders
1098            for first, second, third in [(large, small, small2), (small, large, small2),
1099                                         (small2, small, large), (small2, large, small)]:
1100                if first is None:
1101                    break  # ignore last iter when small2 is None
1102                method_expanded = getattr(expanded[first], fn)
1103                method = getattr(first, fn)
1104                r1 = tensorfn(method_expanded, expanded[second], expanded[third])
1105                r2 = tensorfn(method, second, third)
1106                self.assertEqual(r1, r2)
1107
1108        # now for torch. versions of functions
1109        if hasattr(torch, fn):
1110            fntorch = getattr(torch, fn)
1111            expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
1112
1113            def torchfn(t1, t2, t3):
1114                if fn == "lerp":
1115                    return fntorch(t1, t2, 0.5)
1116                elif fn == "masked_select":
1117                    return fntorch(t1, t2 < 0)
1118                elif fn == "masked_scatter":
1119                    return fntorch(t1, t2 < 0.5, full1d)
1120                elif fn == "masked_fill":
1121                    return fntorch(t1, t2 < 0.5, 1.0)
1122                elif fn in fns_3_args:
1123                    return fntorch(t1, 1.0, t2, t3)
1124                elif fn in fns_value_kwarg:
1125                    return fntorch(t1, t2, t3, value=1.0)
1126                else:
1127                    return fntorch(t1, t2)
1128
1129            # test various orders
1130            for first, second, third in [(large, small, small2), (small, large, small2),
1131                                         (small2, small, large), (small2, large, small)]:
1132                if first is None:
1133                    break  # ignore last iter when small2 is None
1134                r1 = torchfn(expanded[first], expanded[second], expanded[third])
1135                r2 = torchfn(first, second, third)
1136                self.assertEqual(r1, r2)
1137
1138        # now for in place functions
1139        # in-place tensor is not broadcastable; test only guaranteed
1140        # to work by broadcasting other argument(s)
1141        if not hasattr(large_expanded, fn + "_"):
1142            return
1143
1144        # need to clone largeExpanded so we can reuse, since functions are in-place
1145        large_expanded_clone = large_expanded.clone()
1146
1147        def tensorfn_inplace(t0, t1, t2=None):
1148            t0_fn = getattr(t0, fn + "_")
1149            if fn == "lerp":
1150                return t0_fn(t1, 0.5)
1151            elif fn == "masked_scatter":
1152                return t0_fn(t1 < 0.5, full1d)
1153            elif fn == "masked_fill":
1154                return t0_fn(t1 < 0.5, 1.0)
1155            elif fn == "map":
1156                return t0_fn(t1, lambda x, y: x + y)
1157            elif fn == "map2":
1158                return t0_fn(t1, t2, lambda x, y, z: x + y + z)
1159            elif fn in fns_3_args:
1160                return t0_fn(1.0, t1, t2)
1161            elif fn in fns_value_kwarg:
1162                return t0_fn(t1, t2, value=1.0)
1163            else:
1164                return t0_fn(t1)
1165        # in-place pointwise operations don't actually work if the in-place
1166        # tensor is 0-strided (numpy has the same issue)
1167        if (0 not in large_expanded.stride() and 0 not in large_expanded_clone.stride()):
1168            r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded)
1169            r2 = tensorfn_inplace(large_expanded_clone, small, small2)
1170            self.assertEqual(r1, r2)
1171
1172        def broadcastable(t0, t1, t2=None):
1173            try:
1174                t1.expand_as(t0)
1175                if t2 is not None:
1176                    t2.expand_as(t0)
1177            except RuntimeError:
1178                return False
1179            return True
1180
1181        def _test_in_place_broadcastable(t0, t1, t2=None):
1182            if not broadcastable(t0, t1, t2):
1183                same_size = t0.numel() == t1.numel() and (t0.numel() == t2.numel() if t2 is not None else True)
1184                if not same_size:
1185                    # Functionalization converts the inplace to an out-of-place, which causes us to error.
1186                    # We should fix this, but "error probably on bad inputs" isn't a hi-pri PT2 item.
1187                    if not TEST_WITH_TORCHINDUCTOR:
1188                        self.assertRaises(RuntimeError, lambda: tensorfn_inplace(t0, t1, t2))
1189            else:
1190                tensorfn_inplace(t0, t1, t2)
1191
1192        if fn not in fns_3_args and fn not in fns_value_kwarg:
1193            _test_in_place_broadcastable(small, large_expanded)
1194            _test_in_place_broadcastable(small, large)
1195        else:
1196            _test_in_place_broadcastable(small2, small_expanded, large_expanded)
1197            _test_in_place_broadcastable(small2, small, large)
1198
1199    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
1200    @onlyCUDA
1201    @wrapDeterministicFlagAPITest
1202    def test_cublas_config_nondeterministic_alert(self, device):
1203        test_cases = [
1204            # (function, (tensor sizes))
1205            ('mm', ((2, 2), (2, 2),)),
1206            ('mv', ((2, 2), (2,),)),
1207            ('bmm', ((1, 2, 2), (1, 2, 2),))]
1208
1209        test_configs = [
1210            # (CuBLAS workspace config, is deterministic)
1211            ('garbage', False),
1212            (None, False),
1213            (':4096:8', True),
1214            (':16:8', True)]
1215
1216        cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG'
1217        is_cuda10_2_or_higher = (
1218            (torch.version.cuda is not None)
1219            and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2]))
1220
1221        def test_case_info(fn_name, config):
1222            return f'function "{fn_name}" with config "{"" if config is None else config}"'
1223
1224        # Create processes to test each combination of test cases and config settings
1225        processes = []
1226        for fn_name, arg_sizes in test_cases:
1227            for config, is_config_deterministic in test_configs:
1228                env = os.environ.copy()
1229                if config is None:
1230                    if env.get(cublas_var_name) is not None:
1231                        del env[cublas_var_name]
1232                else:
1233                    env[cublas_var_name] = config
1234                should_throw_error = is_cuda10_2_or_higher and not is_config_deterministic
1235                script = f"""
1236import torch
1237torch.use_deterministic_algorithms(True)
1238fn = torch.{fn_name}
1239arg_sizes = {arg_sizes}
1240device = '{device}'
1241should_throw_error = {should_throw_error}
1242args = []
1243for arg_size in arg_sizes:
1244    args.append(torch.randn(*arg_size, device=device))
1245try:
1246    fn(*args)
1247except RuntimeError as e:
1248    if not should_throw_error:
1249        raise RuntimeError('Did not expect any error to be raised')
1250    elif 'Deterministic behavior was enabled with either' not in str(e):
1251        raise RuntimeError('Expected a CuBLAS nondeterministic error, but got a different error')
1252else:
1253    if should_throw_error:
1254        raise RuntimeError('Expected a CuBLAS nondeterministic error, but it was not raised')
1255
1256"""
1257                try:
1258                    subprocess.check_output(
1259                        [sys.executable, '-c', script],
1260                        stderr=subprocess.STDOUT,
1261                        # On Windows, opening the subprocess with the default CWD makes `import torch`
1262                        # fail, so just set CWD to this script's directory
1263                        cwd=os.path.dirname(os.path.realpath(__file__)),
1264                        env=env)
1265                except subprocess.CalledProcessError as e:
1266                    self.fail(msg=(
1267                        f'Subprocess exception while attempting to run {test_case_info(fn_name, config)}:\n'
1268                        + e.output.decode("utf-8")))
1269
1270    @onlyCPU
1271    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1272    @dtypes(*get_all_qint_dtypes())
1273    def test_nondeterministic_resize_quantized(self, device, dtype):
1274        a = torch.tensor([-1, 0, 1, 2, 3], dtype=torch.float, device=device)
1275        b = torch.quantize_per_tensor(a, 0.1, 10, dtype)
1276        self.check_nondeterministic_alert(
1277            lambda: b.resize_((10,)),
1278            'quantized_resize_cpu_')
1279
1280    @skipXLA
1281    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1282    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64))
1283    def test_deterministic_resize(self, device, dtype):
1284        test_cases = [
1285            # size, stride, resize_size
1286            ((10,), (1,), (5,)),
1287            ((10,), (0,), (10,)),
1288            ((10,), (1,), (20,)),
1289            ((2, 3, 4), None, (2, 3, 4)),
1290            ((2, 3, 4), None, (6, 3, 4)),
1291            ((2, 3, 4), None, (2, 5, 4)),
1292            ((2, 3, 4), None, (2, 3, 6)),
1293            ((2, 3, 4), None, (3, 4, 5)),
1294            ((2, 3, 4), (1, 4, 12), (2, 3, 4)),
1295            ((2, 3, 4), (1, 4, 12), (4, 3, 4)),
1296            ((2, 3, 4), (1, 4, 12), (2, 4, 4)),
1297            ((2, 3, 4), (1, 4, 12), (2, 3, 5)),
1298            ((2, 3, 4), (1, 4, 12), (3, 4, 5)),
1299            ((2, 3, 4), (1, 0, 1), (2, 4, 5)),
1300        ]
1301
1302        for size, stride, resize_size in test_cases:
1303            if stride is None:
1304                a = torch.zeros(size, dtype=dtype, device=device)
1305            else:
1306                a = torch.empty_strided(size, stride, dtype=dtype, device=device).fill_(0)
1307            old_storage = a.untyped_storage().clone()
1308            with DeterministicGuard(True, fill_uninitialized_memory=True):
1309                a.resize_(resize_size)
1310
1311            new_storage = a.untyped_storage()
1312
1313            # If storage size was increased, check that the new section is
1314            # filled with NaN/MAX_INT. Otherwise, check that the storages are
1315            # equal.
1316            old_tensor = torch.tensor(old_storage, dtype=dtype)
1317            old_numel = old_tensor.numel()
1318            new_tensor = torch.tensor(new_storage, dtype=dtype)
1319            new_numel = new_tensor.numel()
1320
1321            if new_numel > old_numel:
1322                self.assertEqual(new_tensor[:old_numel], old_tensor)
1323                fill_section = new_tensor[old_numel:]
1324
1325                if dtype.is_floating_point or dtype.is_complex:
1326                    self.assertTrue(fill_section.isnan().all())
1327                else:
1328                    if dtype == torch.bool:
1329                        max_val = True
1330                    else:
1331                        max_val = torch.iinfo(dtype).max
1332                    self.assertTrue(fill_section.eq(max_val).all())
1333            else:
1334                self.assertEqual(old_tensor, new_tensor)
1335
1336    # When deterministic algorithms are enabled, `torch.empty` should fill floating
1337    # point tensors with NaN and integer tensors with MAX_INT
1338    @skipXLA
1339    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1340    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64))
1341    def test_deterministic_empty(self, device, dtype):
1342        gen_fns = [
1343            lambda: torch.empty(10, 9, device=device, dtype=dtype),
1344            lambda: torch.empty(10, 9, out=torch.zeros(1, device=device, dtype=dtype)),
1345            lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype)),
1346            lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype), memory_format=torch.contiguous_format),
1347            lambda: torch.empty_strided((10, 9), (1, 5), device=device, dtype=dtype),
1348            lambda: torch.empty_permuted((2, 3, 5), (1, 0, 2), device=device, dtype=dtype),
1349        ]
1350
1351        for gen_fn in gen_fns:
1352            with DeterministicGuard(True, fill_uninitialized_memory=True):
1353                res = gen_fn()
1354
1355            if dtype.is_floating_point or dtype.is_complex:
1356                self.assertTrue(res.isnan().all())
1357            else:
1358                if dtype == torch.bool:
1359                    max_val = True
1360                else:
1361                    max_val = torch.iinfo(dtype).max
1362                self.assertTrue(res.eq(max_val).all())
1363
1364    # FIXME: update OpInfos to support "nondeterministic samples" and port these tests
1365    #   to that architecture
1366    @skipIfMps
1367    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1368    def test_nondeterministic_alert_AvgPool3d(self, device):
1369        module = torch.nn.AvgPool3d(3)
1370        input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device)
1371        res = module(input)
1372        grad = torch.ones_like(res)
1373
1374        self.check_nondeterministic_alert(
1375            lambda: res.backward(grad, retain_graph=True),
1376            'avg_pool3d_backward_cuda',
1377            torch.device(device).type == 'cuda')
1378
1379    @skipIfMps
1380    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1381    def test_nondeterministic_alert_AdaptiveAvgPool2d(self, device):
1382        module = torch.nn.AdaptiveAvgPool2d(3)
1383        input = torch.randn(2, 3, 3, requires_grad=True, device=device)
1384        res = module(input)
1385        grad = torch.ones_like(res)
1386
1387        self.check_nondeterministic_alert(
1388            lambda: res.backward(grad, retain_graph=True),
1389            'adaptive_avg_pool2d_backward_cuda',
1390            torch.device(device).type == 'cuda')
1391
1392    @skipIfMps
1393    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1394    def test_nondeterministic_alert_AdaptiveAvgPool3d(self, device):
1395        module = torch.nn.AdaptiveAvgPool3d(3)
1396        input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device)
1397        res = module(input)
1398        grad = torch.ones_like(res)
1399
1400        self.check_nondeterministic_alert(
1401            lambda: res.backward(grad, retain_graph=True),
1402            'adaptive_avg_pool3d_backward_cuda',
1403            torch.device(device).type == 'cuda')
1404
1405    @skipIfMps
1406    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1407    def test_nondeterministic_alert_MaxPool3d(self, device):
1408        module = torch.nn.MaxPool3d(3)
1409        input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device)
1410        res = module(input)
1411        grad = torch.ones_like(res)
1412
1413        self.check_nondeterministic_alert(
1414            lambda: res.backward(grad, retain_graph=True),
1415            'max_pool3d_with_indices_backward_cuda',
1416            torch.device(device).type == 'cuda')
1417
1418    @skipIfMps
1419    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1420    def test_nondeterministic_alert_AdaptiveMaxPool2d(self, device):
1421        module = torch.nn.AdaptiveMaxPool2d(3)
1422        input = torch.randn(2, 3, 3, requires_grad=True, device=device)
1423        res = module(input)
1424        grad = torch.ones_like(res)
1425
1426        self.check_nondeterministic_alert(
1427            lambda: res.backward(grad, retain_graph=True),
1428            'adaptive_max_pool2d_backward_cuda',
1429            torch.device(device).type == 'cuda')
1430
1431    @skipIfMps
1432    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1433    def test_nondeterministic_alert_FractionalMaxPool2d(self, device):
1434        module = torch.nn.FractionalMaxPool2d(2, output_ratio=0.5)
1435        input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device)
1436        res = module(input)
1437        grad = torch.ones_like(res)
1438
1439        self.check_nondeterministic_alert(
1440            lambda: res.backward(grad, retain_graph=True),
1441            'fractional_max_pool2d_backward_cuda',
1442            torch.device(device).type == 'cuda')
1443
1444    @skipIfMps
1445    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1446    def test_nondeterministic_alert_FractionalMaxPool3d(self, device):
1447        module = torch.nn.FractionalMaxPool3d(2, output_ratio=0.5)
1448        input = torch.randn(2, 3, 3, 3, 3, requires_grad=True, device=device)
1449        res = module(input)
1450        grad = torch.ones_like(res)
1451
1452        self.check_nondeterministic_alert(
1453            lambda: res.backward(grad, retain_graph=True),
1454            'fractional_max_pool3d_backward_cuda',
1455            torch.device(device).type == 'cuda')
1456
1457    @dtypes(*floating_types_and(torch.half))
1458    @onlyNativeDeviceTypes
1459    def test_nondeterministic_alert_MaxUnpool1d(self, device, dtype):
1460        if dtype == torch.half and torch.device(device).type == 'cpu':
1461            self.skipTest('float16 not implemented on CPU')
1462
1463        module = torch.nn.MaxUnpool1d(3, 1)
1464        input = torch.randn(1, 1, 7, dtype=dtype, device=device)
1465        indices = torch.zeros_like(input, dtype=torch.long, device=device)
1466
1467        self.check_nondeterministic_alert(
1468            lambda: module(input, indices),
1469            'max_unpooling2d_forward_out')
1470
1471    @dtypes(*floating_types_and(torch.half))
1472    @onlyNativeDeviceTypes
1473    def test_nondeterministic_alert_MaxUnpool2d(self, device, dtype):
1474        if dtype == torch.half and torch.device(device).type == 'cpu':
1475            self.skipTest('float16 not implemented on CPU')
1476
1477        module = torch.nn.MaxUnpool2d(3, 1)
1478        input = torch.randn(1, 1, 7, 7, dtype=dtype, device=device)
1479        indices = torch.zeros_like(input, dtype=torch.long, device=device)
1480
1481        self.check_nondeterministic_alert(
1482            lambda: module(input, indices),
1483            'max_unpooling2d_forward_out')
1484
1485    @dtypes(*floating_types_and(torch.half))
1486    @onlyNativeDeviceTypes
1487    def test_nondeterministic_alert_MaxUnpool3d(self, device, dtype):
1488        if dtype == torch.half and torch.device(device).type == 'cpu':
1489            self.skipTest('float16 not implemented on CPU')
1490
1491        module = torch.nn.MaxUnpool3d(3, 1)
1492        input = torch.randn(1, 1, 7, 7, 7, dtype=dtype, device=device)
1493        indices = torch.zeros_like(input, dtype=torch.long, device=device)
1494
1495        self.check_nondeterministic_alert(
1496            lambda: module(input, indices),
1497            'max_unpooling3d_forward_out')
1498
1499    @skipIfMps
1500    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1501    def test_nondeterministic_alert_interpolate_linear(self, device):
1502        input = torch.randn(1, 2, 4, device=device, requires_grad=True)
1503        res = torch.nn.functional.interpolate(
1504            input,
1505            size=12,
1506            mode='linear',
1507            align_corners=False)
1508        grad = torch.ones_like(res)
1509
1510        self.check_nondeterministic_alert(
1511            lambda: res.backward(grad),
1512            'upsample_linear1d_backward_out_cuda',
1513            torch.device(device).type == 'cuda')
1514
1515    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1516    def test_nondeterministic_alert_interpolate_bilinear(self, device):
1517        input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True)
1518        res = torch.nn.functional.interpolate(
1519            input,
1520            size=12,
1521            mode='bilinear',
1522            align_corners=False)
1523        grad = torch.ones_like(res)
1524
1525        self.check_nondeterministic_alert(
1526            lambda: res.backward(grad),
1527            'upsample_bilinear2d_backward_out_cuda',
1528            torch.device(device).type == 'cuda')
1529
1530    @skipIfTorchInductor("aot-autograd issue")
1531    def test_deterministic_replication_pad2d(self, device):
1532        test_cases = [
1533            # size, padding
1534            [(1, 2, 4, 4), (0, 0, 0, 0)],
1535            [(1, 2, 4, 4), (3, 4, 5, 6)],
1536            [(3, 8, 7), (0, 0, 0, 0)],
1537            [(3, 8, 7), (4, 3, 2, 7)],
1538        ]
1539
1540        if torch.device(device).type != 'xla':
1541            test_cases += [
1542                [(4, 3, 5, 10), (-9, 4, 5, 6)],
1543                [(3, 8, 7), (-4, -2, -2, -3)],
1544            ]
1545
1546        for size, padding in test_cases:
1547            input = torch.randn(*size, device=device, requires_grad=True)
1548            grad = None
1549            with DeterministicGuard(True):
1550                res = torch.nn.functional.pad(
1551                    input,
1552                    padding,
1553                    mode='replicate')
1554                res.backward(torch.ones_like(res))
1555                if grad is None:
1556                    grad = input.grad
1557                else:
1558                    self.assertEqual(grad, input.grad, atol=0, rtol=0)
1559                input.grad = None
1560
1561    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1562    def test_deterministic_interpolate_bilinear(self, device):
1563        input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True)
1564        grad = None
1565        with DeterministicGuard(True):
1566            for _ in range(5):
1567                res = torch.nn.functional.interpolate(
1568                    input,
1569                    size=12,
1570                    mode='bilinear',
1571                    align_corners=False)
1572                res.backward(torch.ones_like(res))
1573                if grad is None:
1574                    grad = input.grad
1575                else:
1576                    self.assertEqual(grad, input.grad, atol=0, rtol=0)
1577                input.grad = None
1578
1579    @skipIfMps
1580    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1581    def test_nondeterministic_alert_interpolate_bicubic(self, device):
1582        input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True)
1583        res = torch.nn.functional.interpolate(
1584            input,
1585            size=12,
1586            mode='bicubic',
1587            align_corners=False)
1588        grad = torch.ones_like(res)
1589
1590        self.check_nondeterministic_alert(
1591            lambda: res.backward(grad),
1592            'upsample_bicubic2d_backward_out_cuda',
1593            torch.device(device).type == 'cuda')
1594
1595    @skipIfMps
1596    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1597    def test_nondeterministic_alert_interpolate_trilinear(self, device):
1598        input = torch.randn(1, 2, 4, 4, 4, device=device, requires_grad=True)
1599        res = torch.nn.functional.interpolate(
1600            input,
1601            size=12,
1602            mode='trilinear',
1603            align_corners=False)
1604        grad = torch.ones_like(res)
1605
1606        self.check_nondeterministic_alert(
1607            lambda: res.backward(grad),
1608            'upsample_trilinear3d_backward_out_cuda',
1609            torch.device(device).type == 'cuda')
1610
1611    @skipIfMps
1612    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1613    def test_nondeterministic_alert_ReflectionPad1d(self, device):
1614        module = torch.nn.ReflectionPad1d((1, 2))
1615        input = torch.randn(2, 3, 8, device=device, requires_grad=True)
1616        res = module(input)
1617        grad = torch.ones_like(res)
1618
1619        self.check_nondeterministic_alert(
1620            lambda: res.backward(grad, retain_graph=True),
1621            'reflection_pad1d_backward_out_cuda',
1622            torch.device(device).type == 'cuda')
1623
1624    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1625    def test_nondeterministic_alert_ReflectionPad2d(self, device):
1626        module = torch.nn.ReflectionPad2d((1, 2, 3, 4))
1627        input = torch.randn(2, 3, 8, 8, device=device, requires_grad=True)
1628        res = module(input)
1629        grad = torch.ones_like(res)
1630
1631        self.check_nondeterministic_alert(
1632            lambda: res.backward(grad, retain_graph=True),
1633            'reflection_pad2d_backward_cuda',
1634            torch.device(device).type == 'cuda')
1635
1636    @skipIfMps
1637    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1638    def test_nondeterministic_alert_ReflectionPad3d(self, device):
1639        module = torch.nn.ReflectionPad3d((1, 2, 3, 4, 5, 6))
1640        input = torch.randn(2, 3, 8, 8, 8, device=device, requires_grad=True)
1641        res = module(input)
1642        grad = torch.ones_like(res)
1643
1644        self.check_nondeterministic_alert(
1645            lambda: res.backward(grad, retain_graph=True),
1646            'reflection_pad3d_backward_out_cuda',
1647            torch.device(device).type == 'cuda')
1648
1649    @skipIfMps
1650    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1651    def test_nondeterministic_alert_ReplicationPad1d(self, device):
1652        module = torch.nn.ReplicationPad1d((1, 2))
1653        input = torch.randn(2, 3, 4, device=device, requires_grad=True)
1654        res = module(input)
1655        grad = torch.ones_like(res)
1656
1657        self.check_nondeterministic_alert(
1658            lambda: res.backward(grad, retain_graph=True),
1659            'replication_pad1d_backward_cuda',
1660            torch.device(device).type == 'cuda')
1661
1662    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1663    def test_nondeterministic_alert_ReplicationPad2d(self, device):
1664        module = torch.nn.ReplicationPad2d((1, 2, 3, 4))
1665        input = torch.randn(2, 3, 4, 4, device=device, requires_grad=True)
1666        res = module(input)
1667        grad = torch.ones_like(res)
1668
1669        # Nondeterministic alert should only be raised if the forward call was
1670        # nondeterministic
1671        self.check_nondeterministic_alert(
1672            lambda: res.backward(grad, retain_graph=True),
1673            'replication_pad2d_backward_cuda',
1674            torch.device(device).type == 'cuda')
1675
1676        with DeterministicGuard(True):
1677            res = module(input)
1678
1679        grad = torch.ones_like(res)
1680
1681        # If the forward call was deterministic, nondeterministic alert should
1682        # not be raised
1683        self.check_nondeterministic_alert(
1684            lambda: res.backward(grad, retain_graph=True),
1685            'replication_pad2d_backward_cuda',
1686            False)
1687
1688    @skipIfMps
1689    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1690    def test_nondeterministic_alert_ReplicationPad3d(self, device):
1691        module = torch.nn.ReplicationPad3d((1, 2, 3, 4, 5, 6))
1692        input = torch.randn(2, 3, 4, 4, 4, device=device, requires_grad=True)
1693        res = module(input)
1694        grad = torch.ones_like(res)
1695
1696        self.check_nondeterministic_alert(
1697            lambda: res.backward(grad, retain_graph=True),
1698            'replication_pad3d_backward_cuda',
1699            torch.device(device).type == 'cuda')
1700
1701    @skipIfTorchDynamo("Warning is not raised.")
1702    def test_nondeterministic_alert_NLLLoss(self, device):
1703        module = torch.nn.NLLLoss()
1704        input = torch.randn(2, 3, 5, 5, device=device)
1705        target = torch.rand(2, 5, 5, device=device).mul(3).floor().long()
1706
1707
1708        self.check_nondeterministic_alert(
1709            lambda: module(input, target),
1710            'nll_loss2d_forward_out_cuda_template',
1711            torch.device(device).type == 'cuda')
1712
1713    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1714    def test_nondeterministic_alert_CTCLoss(self, device):
1715        module = torch.nn.CTCLoss()
1716        input = torch.randn(50, 3, 15, device=device, requires_grad=True)
1717        target = torch.randint(0, 14, (3, 30), device=device)
1718        input_lengths = [50, 50, 50]
1719        target_lengths = [30, 25, 20]
1720        res = module(input, target, input_lengths, target_lengths)
1721        grad = torch.ones_like(res)
1722
1723        self.check_nondeterministic_alert(
1724            lambda: res.backward(grad, retain_graph=True),
1725            'ctc_loss_backward_gpu',
1726            torch.device(device).type == 'cuda')
1727
1728    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1729    def test_nondeterministic_alert_EmbeddingBag_max(self, device):
1730        module = torch.nn.EmbeddingBag(
1731            4, 3, None, 2., False, 'max',
1732            _weight=torch.randn(4, 3, device=device, requires_grad=True))
1733        input = torch.randint(0, 3, (4, 3), device=device)
1734        res = module(input)
1735        grad = torch.ones_like(res)
1736
1737        self.check_nondeterministic_alert(
1738            lambda: res.backward(grad, retain_graph=True),
1739            'embedding_bag_backward_cuda_max',
1740            torch.device(device).type == 'cuda')
1741
1742    @dtypes(*all_types_and_complex_and(torch.bool))
1743    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1744    def test_nondeterministic_alert_cumsum(self, device, dtype):
1745        input = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9)
1746        should_alert = torch.device(device).type == 'cuda' and (dtype.is_floating_point or dtype.is_complex)
1747
1748        for op_call in [torch.Tensor.cumsum, torch.cumsum]:
1749            self.check_nondeterministic_alert(
1750                lambda: op_call(input, 0),
1751                'cumsum_cuda_kernel',
1752                should_alert)
1753
1754    @expectedFailureMeta  # expected a non-determinitic error, but it was not raised
1755    @onlyNativeDeviceTypes
1756    def test_nondeterministic_alert_put(self, device):
1757        a = torch.randn(10, device=device)
1758        indices = torch.tensor([0, 0], device=device)
1759        values = torch.tensor([0., 1.], device=device)
1760
1761        for op_call in [torch.Tensor.put, torch.Tensor.put_]:
1762            self.check_nondeterministic_alert(
1763                lambda: op_call(a, indices, values, accumulate=False),
1764                'put_')
1765
1766    # warn_only=False correctly raises RuntimeError: put_ does not have a deterministic implementation
1767    # warn_only=True logs warning from the FallbackKernel: torch.ops.aten.put_.default, instead of as UserWarning:
1768    # [W Context.cpp:%(lineno)] Warning: put_ does not have a deterministic implementation
1769    @skipIfTorchInductor("warning is logged from the FallbackKernel: torch.ops.aten.put_.default when warn_only=True")
1770    def test_nondeterministic_alert_put_accumulate(self, device):
1771        a = torch.randn(10, device=device)
1772        indices = torch.tensor([0, 0], device=device)
1773        values = torch.tensor([0., 1.], device=device)
1774
1775        for op_call in [torch.Tensor.put, torch.Tensor.put_]:
1776            self.check_nondeterministic_alert(
1777                lambda: op_call(a, indices, values, accumulate=True),
1778                'put_',
1779                torch.device(device).type == 'cuda')
1780
1781    @skipIfMps
1782    def test_nondeterministic_alert_histc(self, device):
1783        a = torch.tensor([], device=device)
1784        for op_call in [torch.histc, torch.Tensor.histc]:
1785            self.check_nondeterministic_alert(
1786                lambda: op_call(a, min=0, max=3),
1787                '_histc_cuda',
1788                torch.device(device).type == 'cuda')
1789
1790    @skipIfMps
1791    def test_nondeterministic_alert_bincount(self, device):
1792        a = torch.tensor([], device=device, dtype=torch.long)
1793        weights = torch.tensor([], device=device)
1794
1795        for op_call in [torch.bincount, torch.Tensor.bincount]:
1796            # Error should only be raised when device is CUDA and weights are
1797            # given
1798            self.check_nondeterministic_alert(
1799                lambda: op_call(a, weights),
1800                '_bincount_cuda',
1801                torch.device(device).type == 'cuda')
1802
1803            self.check_nondeterministic_alert(
1804                lambda: op_call(a),
1805                '_bincount_cuda',
1806                False)
1807
1808    # Ensures that kthvalue throws nondeterministic alerts in the correct cases
1809    @dtypes(torch.double)
1810    def test_nondeterministic_alert_kthvalue(self, device, dtype):
1811        def test_func(call_type):
1812            S = 10
1813            k = 5
1814            a = torch.randn(S, device=device)
1815            if call_type == 'function':
1816                torch.kthvalue(a, k)
1817            elif call_type == 'method':
1818                a.kthvalue(k)
1819            elif call_type == 'out':
1820                values = torch.empty_like(a)
1821                indices = torch.empty((), device=device, dtype=torch.long)
1822                torch.kthvalue(a, k, out=(values, indices))
1823            else:
1824                self.fail(f"'{call_type}' is not a valid call type")
1825
1826        for call_type in ['function', 'method', 'out']:
1827            self.check_nondeterministic_alert(
1828                lambda: test_func('function'),
1829                'kthvalue CUDA',
1830                torch.device(device).type == 'cuda')
1831
1832    @skipIfMps
1833    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1834    def test_nondeterministic_alert_grid_sample_2d(self, device):
1835        input = torch.empty(1, 1, 2, 2, device=device, requires_grad=True)
1836        grid = torch.empty(1, 1, 1, 2, device=device)
1837        res = torch.nn.functional.grid_sample(input, grid, align_corners=False)
1838        grad = torch.ones_like(res)
1839
1840        self.check_nondeterministic_alert(
1841            lambda: res.backward(grad, retain_graph=True),
1842            'grid_sampler_2d_backward_cuda',
1843            torch.device(device).type == 'cuda')
1844
1845    @skipIfMps
1846    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1847    def test_nondeterministic_alert_grid_sample_3d(self, device):
1848        input = torch.empty(1, 1, 2, 2, 2, device=device, requires_grad=True)
1849        grid = torch.empty(1, 1, 1, 2, 3, device=device)
1850        res = torch.nn.functional.grid_sample(input, grid, align_corners=False)
1851        grad = torch.ones_like(res)
1852
1853        self.check_nondeterministic_alert(
1854            lambda: res.backward(grad, retain_graph=True),
1855            'grid_sampler_3d_backward_cuda',
1856            torch.device(device).type == 'cuda')
1857
1858    def test_invalid_shapes_grid_sampler(self, device):
1859        make_arg = partial(
1860            make_tensor, device=device, dtype=torch.float64, requires_grad=True)
1861
1862        inputs = (
1863            # input, grid
1864            ((5, 5, 5, 5, 5,), (1, 1, 1, 4, 4,)),  # 3d
1865            ((5, 5, 5, 5,), (1, 1, 4, 4,)),  # 2d
1866        )
1867
1868        interpolation_mode = 0
1869        padding_mode = 0
1870        align_corners = True
1871
1872        err = "expected grid and input to have same batch size"
1873
1874        for input, grid in inputs:
1875            input = make_arg(input)
1876            grid = make_arg(grid, low=-1, high=1)
1877
1878            # Wrapper for the 2d, 3d, and cuDNN functions listed below.
1879            with self.assertRaisesRegex(RuntimeError, err):
1880                torch.grid_sampler(
1881                    input, grid, interpolation_mode, padding_mode,
1882                    align_corners)
1883
1884            # Expects 2d input.
1885            with self.assertRaisesRegex(RuntimeError, err):
1886                torch.grid_sampler_2d(
1887                    input, grid, interpolation_mode, padding_mode,
1888                    align_corners)
1889
1890            # Expects 3d input.
1891            with self.assertRaisesRegex(RuntimeError, err):
1892                torch.grid_sampler_3d(
1893                    input, grid, interpolation_mode, padding_mode,
1894                    align_corners)
1895
1896            # Expects 2d input.
1897            with self.assertRaisesRegex(RuntimeError, err):
1898                torch._grid_sampler_2d_cpu_fallback(
1899                    input, grid, interpolation_mode, padding_mode,
1900                    align_corners)
1901
1902            # Expects 2d input, on CUDA.
1903            # Doesn't work on CPU and ROCm.
1904            if device != 'cpu' and TEST_CUDNN and not TEST_WITH_ROCM:
1905                with self.assertRaisesRegex(RuntimeError, err):
1906                    torch.cudnn_grid_sampler(input, grid)
1907
1908    def test_dist(self, device):
1909        def run_test(x, y):
1910            for p in [0, 1, 2, 3, 4, inf, -inf]:
1911                dist_xy = torch.dist(x, y, p)
1912                dist_xy_norm = torch.norm(x - y, p)
1913                self.assertEqual(dist_xy, dist_xy_norm)
1914
1915        run_test(torch.randn(5, device=device), torch.randn(5, device=device))
1916
1917        x = torch.zeros(3, device=device)
1918        y = torch.zeros(3, device=device)
1919        y[1] = 1.
1920        run_test(x, y)
1921
1922    # Ensures that median throws nondeterministic alerts in the correct cases
1923    @dtypes(torch.double)
1924    def test_nondeterministic_alert_median(self, device, dtype):
1925        def test_func(call_type):
1926            S = 10
1927            a = torch.randn(S, device=device)
1928            if call_type == 'function':
1929                torch.median(a)
1930            elif call_type == 'function with indices':
1931                torch.median(a, 0)
1932            elif call_type == 'method':
1933                a.median()
1934            elif call_type == 'method with indices':
1935                a.median(0)
1936            elif call_type == 'out with indices':
1937                result = torch.empty_like(a)
1938                indices = torch.empty((), dtype=torch.long, device=device)
1939                torch.median(a, 0, out=(result, indices))
1940            else:
1941                self.fail(f"'{call_type}' is not a valid call type")
1942
1943        def test_func_expect_error(call_type, should_error):
1944            self.check_nondeterministic_alert(
1945                lambda: test_func(call_type),
1946                'median CUDA with indices output',
1947                should_error)
1948
1949        is_cuda = torch.device(device).type == 'cuda'
1950
1951        test_func_expect_error('function', False)
1952        test_func_expect_error('function with indices', is_cuda)
1953        test_func_expect_error('method', False)
1954        test_func_expect_error('method with indices', is_cuda)
1955        test_func_expect_error('out with indices', is_cuda)
1956
1957    # FIXME: move to test_scatter_gather_ops
1958    def _test_gather_backward_one_dim(self, device, deterministic: bool = False) -> None:
1959        with DeterministicGuard(deterministic):
1960            m = random.randint(2000, 3000)
1961            elems = random.randint(10 * m, 20 * m)
1962            dim = 0
1963            src = torch.randn(m, device=device, requires_grad=True)
1964            idx = torch.randint(m, (elems,), device=device)
1965            res = torch.gather(src, dim, idx)
1966            weight = torch.rand_like(res, device=device) * 10 ** 6
1967            res.backward(weight)
1968            assert src.grad is not None
1969            grad = src.grad.detach().clone()
1970
1971            if torch.device(device).type == 'cuda':
1972                for _ in range(2):
1973                    src.grad.data.zero_()
1974                    res = torch.gather(src, dim, idx)
1975                    res.backward(weight)
1976                    self.assertEqual(src.grad, grad, atol=0, rtol=0)
1977            else:
1978                expected = torch.zeros_like(src, device=device)
1979                for i in range(elems):
1980                    expected[idx[i]] += weight[i]
1981                self.assertEqual(grad, expected, atol=0, rtol=0)
1982
1983    # FIXME: move to test_scatter_gather_ops
1984    @onlyNativeDeviceTypes
1985    def test_gather_backward_deterministic_path(self, device) -> None:
1986        self._test_gather_backward_one_dim(device, True)
1987
1988    # FIXME: move to test_scatter_gather_ops
1989    @onlyCPU
1990    def test_gather_backward_one_dim(self, device) -> None:
1991        self._test_gather_backward_one_dim(device, False)
1992
1993    # FIXME: move to test_scatter_gather_ops
1994    @onlyNativeDeviceTypes
1995    def test_scatter_add_one_dim_deterministic(self, device) -> None:
1996        with DeterministicGuard(True):
1997            m = random.randint(20, 30)
1998            elems = random.randint(2000 * m, 3000 * m)
1999            dim = 0
2000            src = torch.randn(elems, device=device)
2001            idx = torch.randint(m, (elems,), device=device)
2002
2003            x = torch.zeros(m, device=device)
2004            res = x.scatter_add(dim, idx, src)
2005
2006            # Checking if scatter_add is deterministic
2007            for i in range(5):
2008                res_next = x.scatter_add(dim, idx, src)
2009                self.assertEqual(res, res_next, atol=0, rtol=0)
2010                res = res_next
2011
2012            expected = torch.zeros(m, device=device)
2013            for i in range(elems):
2014                expected[idx[i]] += src[i]
2015
2016            self.assertEqual(res, expected, atol=1e-4, rtol=1e-5)
2017
2018    # FIXME: move to test_scatter_gather_ops
2019    @onlyNativeDeviceTypes
2020    def test_scatter_zero_size_index(self, device) -> None:
2021        null_index = torch.zeros((0, 4), dtype=torch.int64)
2022        null_arr = torch.zeros((0, 4))
2023        original = torch.arange(4, dtype=torch.float32)
2024        result = original.scatter(0, null_index, null_arr)
2025        self.assertEqual(result, original, atol=0, rtol=0)
2026
2027    @onlyCUDA
2028    @skipIfTorchInductor("FIXME")
2029    def test_sync_warning(self, device):
2030
2031        def _sync_raises_helper(f, level):
2032            with CudaSyncGuard(level):
2033                if level == 1:
2034                    with self.assertWarnsRegex(UserWarning, "called a synchronizing "):
2035                        f()
2036                elif level == 2:
2037                    with self.assertRaisesRegex(RuntimeError, "called a synchronizing "):
2038                        f()
2039
2040        def _no_sync_helper(f, level):
2041            with CudaSyncGuard(level):
2042                f()
2043
2044        def _ind_put_fn(x, ind, val):
2045            x[ind] = val
2046            return x
2047
2048        def _ind_get_fn(x, ind):
2049            return x[ind]
2050
2051        def _cond_fn(x):
2052            if x:  # taking boolean value of a tensor synchronizes
2053                return x
2054            else:
2055                return 2 * x
2056
2057        # prepare inputs for subsequent ops
2058        size = 4
2059        x = torch.rand(size, device=device)
2060        y = torch.rand((), device=device)
2061        ind = torch.randint(size, (3,), device=device)
2062        ind_cpu = ind.cpu()
2063        repeats = torch.full((1,), 2, device=device)
2064        mask = torch.randint(2, (size,), device=device, dtype=bool)
2065        expect_no_sync = (lambda: _ind_put_fn(x, mask, 1.),
2066                          lambda: _ind_put_fn(x, ind, y),
2067                          lambda: _ind_get_fn(x, ind),
2068                          lambda: torch.nn.functional.one_hot(ind, num_classes=size),
2069                          lambda: torch.randperm(20000, device=device),
2070                          lambda: torch.repeat_interleave(x, 2, output_size=2 * size),
2071                          lambda: torch.repeat_interleave(x, repeats, output_size=2 * size),
2072                          lambda: torch.any(y))
2073        expect_sync = (lambda: _ind_put_fn(x, mask, y),
2074                       lambda: _ind_put_fn(x, ind_cpu, y),
2075                       lambda: _ind_get_fn(x, mask),
2076                       lambda: _ind_get_fn(x, ind_cpu),
2077                       lambda: x.nonzero(),
2078                       lambda: _cond_fn(y),
2079                       lambda: torch.nn.functional.one_hot(ind),
2080                       lambda: torch.repeat_interleave(x, repeats))
2081        for f, level in product(expect_no_sync, (1, 2)):
2082            _no_sync_helper(f, level)
2083        for f, level in product(expect_sync, (1, 2)):
2084            _sync_raises_helper(f, level)
2085
2086
2087    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
2088    @skipIfMps
2089    def test_log_normal(self, device, dtype):
2090        a = torch.tensor([10], dtype=dtype, device=device).log_normal_()
2091        self.assertEqual(a.dtype, dtype)
2092        self.assertEqual(a.size(), torch.Size([1]))
2093
2094    @dtypes(*all_types_and(torch.half, torch.bfloat16))
2095    @skipIfMps
2096    def test_geometric(self, device, dtype):
2097        a = torch.tensor([10], dtype=dtype, device=device).geometric_(0.5)
2098        self.assertEqual(a.dtype, dtype)
2099        self.assertEqual(a.size(), torch.Size([1]))
2100
2101    @skipIfMps
2102    def test_repeat_interleave(self, device):
2103        y = torch.tensor([[1, 2], [3, 4]], device=device)
2104        # exercise single argument function signature
2105        temp = y.repeat_interleave(2)
2106        self.assertEqual(torch.Size([8]), temp.size())
2107
2108        for dtype in [torch.int, torch.long]:
2109            lengths = torch.tensor([1, 2], dtype=dtype, device=device)
2110            output_size = torch.sum(lengths)
2111            a = torch.repeat_interleave(
2112                y,
2113                lengths,
2114                dim=0,
2115            )
2116            self.assertEqual(a.dtype, y.dtype)
2117            self.assertEqual(a.size(), torch.Size([3, 2]))
2118
2119            a_with_output = torch.repeat_interleave(
2120                y,
2121                lengths,
2122                dim=0,
2123                output_size=output_size,
2124            )
2125            self.assertEqual(a_with_output.dtype, y.dtype)
2126            self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
2127
2128    @dtypes(*floating_types())
2129    @dtypesIfCPU(*floating_types_and(torch.bfloat16, torch.half))
2130    @dtypesIfCUDA(*floating_types_and(torch.half))
2131    def test_bernoulli_p(self, device, dtype):
2132        for trivial_p in ([0, 1], [1, 0, 1, 1, 0, 1]):
2133            x = torch.tensor(trivial_p, dtype=dtype, device=device)
2134            self.assertEqual(x.bernoulli().tolist(), trivial_p)
2135
2136        def isBinary(t):
2137            return torch.ne(t, 0).mul_(torch.ne(t, 1)).sum().item() == 0
2138
2139        p = torch.rand(5, 5, dtype=dtype, device=device)
2140        self.assertTrue(isBinary(p.bernoulli()))
2141
2142        p = torch.rand(5, dtype=dtype, device=device).expand(5, 5)
2143        self.assertTrue(isBinary(p.bernoulli()))
2144
2145        p = torch.rand(5, 5, dtype=dtype, device=device)
2146        torch.bernoulli(torch.rand_like(p), out=p)
2147        self.assertTrue(isBinary(p))
2148
2149    # RngUniform not implemented for Integral type in XLA test
2150    @dtypes(*floating_types())
2151    @dtypesIfCPU(*all_types_and(torch.bool, torch.half))
2152    @dtypesIfCUDA(*all_types_and(torch.bool, torch.half))
2153    def test_bernoulli_self(self, device, dtype):
2154
2155        def isBinary(t):
2156            return torch.ne(t, 0).mul_(torch.ne(t, 1)).sum().item() == 0
2157
2158        t = torch.empty(10, 10, dtype=dtype, device=device)
2159
2160        t.fill_(2)
2161        t.bernoulli_(0.5)
2162        self.assertTrue(isBinary(t))
2163
2164        for p_dtype in floating_types_and(*[torch.half] if device.startswith('cuda') else []):
2165            p = torch.rand(10, dtype=p_dtype, device=device).expand(10, 10)
2166            t.fill_(2)
2167            t.bernoulli_(p)
2168            self.assertTrue(isBinary(t))
2169
2170            t.fill_(2)
2171            torch.bernoulli(torch.rand_like(t, dtype=p_dtype), out=t)
2172            self.assertTrue(isBinary(t))
2173
2174            t.fill_(2)
2175            t.bernoulli_(torch.rand_like(t, dtype=p_dtype))
2176            self.assertTrue(isBinary(t))
2177
2178    @slowTest
2179    @dtypes(*floating_types_and(torch.half))
2180    @dtypesIfCUDA(*floating_types_and(torch.half))
2181    def test_bernoulli_edge_cases(self, device, dtype):
2182        # Need to draw a lot of samples to cover every random floating point number.
2183        a = torch.zeros(10000, 10000, dtype=dtype, device=device)  # probability of drawing "1" is 0
2184        num_ones = (torch.bernoulli(a) == 1).sum()
2185        self.assertEqual(num_ones, 0)
2186
2187        b = torch.ones(10000, 10000, dtype=dtype, device=device)  # probability of drawing "1" is 1
2188        num_zeros = (torch.bernoulli(b) == 0).sum()
2189        self.assertEqual(num_zeros, 0)
2190
2191    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
2192    @skipIfMps
2193    def test_exponential(self, device, dtype):
2194        a = torch.tensor([10], dtype=dtype, device=device).exponential_(0.5)
2195        self.assertEqual(a.dtype, dtype)
2196        self.assertEqual(a.size(), torch.Size([1]))
2197
2198        # Tests extremal behavior
2199        t = torch.empty((1,), device=device, dtype=dtype).exponential_(float('inf'))
2200        self.assertTrue(t.item() == 0)
2201
2202        # Tests that negative lambda fails
2203        with self.assertRaises(RuntimeError):
2204            torch.empty((1,), device=device, dtype=dtype).exponential_(-0.5)
2205
2206    @onlyCUDA
2207    @dtypes(torch.half, torch.float)
2208    def test_exponential_no_zero(self, device, dtype):
2209        # naively, 0 in exponential can be generated with probability 2^-24
2210        # so we need more samples to check if it's not generated
2211        # instead of doing one
2212        # don't test CPU, that would be a long test
2213        x = torch.empty(50000000, device=device, dtype=dtype).exponential_()
2214        self.assertTrue(x.min() > 0)
2215
2216    def _generate_correlation_tensors(self, device, dtype):
2217        yield make_tensor((0, 0), dtype=dtype, device=device)
2218        yield make_tensor((1, 0), dtype=dtype, device=device)
2219        yield make_tensor((0, 1), dtype=dtype, device=device)
2220        yield make_tensor((2,), dtype=dtype, device=device)
2221        yield make_tensor((2, 1), dtype=dtype, device=device)
2222        yield make_tensor((2, 2), dtype=dtype, device=device)
2223        yield make_tensor((2, 3), dtype=dtype, device=device)
2224        yield make_tensor((5, 10), dtype=dtype, device=device)
2225        yield make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True)
2226        if dtype != torch.int:
2227            yield torch.tensor([0, -2, nan, 10.2, inf], dtype=dtype, device=device)
2228
2229    @onlyNativeDeviceTypes
2230    @dtypes(torch.int, torch.float, torch.cfloat)
2231    def test_corrcoef(self, device, dtype):
2232        for x in self._generate_correlation_tensors(device, dtype):
2233            res = torch.corrcoef(x)
2234            ref = np.corrcoef(x.cpu().numpy())
2235            self.assertEqual(res, ref, exact_dtype=False)
2236
2237    @skipRocmIfTorchInductor
2238    @dtypes(torch.int, torch.float, torch.cfloat)
2239    def test_cov(self, device, dtype):
2240        def check(t, correction=1, fweights=None, aweights=None):
2241            res = torch.cov(t, correction=correction, fweights=fweights, aweights=aweights)
2242            t = t.cpu().numpy()
2243            fweights = fweights.cpu().numpy() if fweights is not None else None
2244            aweights = aweights.cpu().numpy() if aweights is not None else None
2245            ref = np.cov(t, ddof=correction, fweights=fweights, aweights=aweights)
2246            self.assertEqual(res, ref, atol=1e-05, rtol=1e-05, exact_dtype=False)
2247
2248        for x in self._generate_correlation_tensors(device, dtype):
2249            check(x)
2250            num_observations = x.numel() if x.ndim < 2 else x.size(1)
2251            if num_observations > 0:
2252                fweights = torch.randint(1, 10, (num_observations,), device=device)
2253                aweights = make_tensor((num_observations,), dtype=torch.float, device=device, low=1)
2254                for correction, fw, aw in product([0, 1, 2], [None, fweights], [None, aweights]):
2255                    check(x, correction, fweights, aweights)
2256
2257    @skipIfNoSciPy
2258    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
2259    def test_uniform_kstest(self, device, dtype):
2260        from scipy import stats
2261        size = 1000
2262        for from_ in [-42, 0, 4.2]:
2263            for to_ in [-4.2, 0, 42]:
2264                if to_ > from_:
2265                    t = torch.empty(size, dtype=dtype, device=device).uniform_(from_, to_)
2266                    res = stats.kstest(t.cpu().to(torch.double), 'uniform', args=(from_, (to_ - from_)))
2267                    self.assertTrue(res.statistic < 0.1)
2268
2269    @skipIfNoSciPy
2270    @dtypes(*floating_types_and(torch.half))
2271    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
2272    def test_normal_kstest(self, device, dtype):
2273        from scipy import stats
2274        size = 1000
2275        for mean in [-10, 0, 50]:
2276            for std in [1, 5, 10]:
2277                t = torch.empty(size, dtype=dtype, device=device).normal_(mean=mean, std=std)
2278                res = stats.kstest(t.cpu().to(torch.double), 'norm', args=(mean, std))
2279                self.assertTrue(res.statistic < 0.1)
2280
2281    @skipIfMps
2282    @skipIfNoSciPy
2283    @skipRocmIfTorchInductor
2284    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
2285    def test_lognormal_kstest(self, device, dtype):
2286        from scipy import stats
2287        size = 1000
2288        for mean in [-3, 0, 7]:
2289            for std in [1, 5, 7]:
2290                t = torch.empty(size, dtype=dtype, device=device).log_normal_(mean=mean, std=std)
2291                res = stats.kstest(t.cpu().to(torch.double), 'lognorm', args=(std, 0, math.exp(mean)))
2292                if dtype == torch.half:
2293                    self.assertTrue(res.statistic < 0.3)
2294                else:
2295                    self.assertTrue(res.statistic < 0.1)
2296
2297    @skipIfMps
2298    @skipIfNoSciPy
2299    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
2300    def test_exponential_kstest(self, device, dtype):
2301        from scipy import stats
2302        size = 1000
2303        for lambd in [0.5, 1.0, 5.0]:
2304            t = torch.empty(size, dtype=dtype, device=device).exponential_(lambd=lambd)
2305            res = stats.kstest(t.cpu().to(torch.double), 'expon', args=(0, 1 / lambd,))
2306            self.assertTrue(res.statistic < 0.1)
2307
2308    @skipIfMps
2309    @skipIfNoSciPy
2310    @skipRocmIfTorchInductor
2311    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
2312    def test_cauchy_kstest(self, device, dtype):
2313        from scipy import stats
2314        size = 1000
2315        for median in [-10, 0, 50]:
2316            for sigma in [0.5, 1.0, 10.0]:
2317                t = torch.empty(size, dtype=dtype, device=device).cauchy_(median=median, sigma=sigma)
2318                res = stats.kstest(t.cpu().to(torch.double), 'cauchy', args=(median, sigma))
2319                self.assertTrue(res.statistic < 0.1)
2320
2321    @slowTest
2322    @onlyCUDA
2323    @dtypes(torch.bfloat16, torch.float32)
2324    def test_cauchy_no_inf(self, device, dtype):
2325        # torch.float16 will have `inf` because of its smaller range.
2326        for _ in range((2**16) * 2):
2327            x = torch.empty((2**16), dtype=dtype, device=device)
2328            x.cauchy_()
2329            self.assertFalse(x.isinf().sum())
2330
2331    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
2332    def test_cauchy(self, device, dtype):
2333        a = torch.tensor([10], dtype=dtype, device=device).cauchy_(0.0, 0.5)
2334        self.assertEqual(a.dtype, dtype)
2335        self.assertEqual(a.size(), torch.Size([1]))
2336
2337        # Tests extremal behavior
2338        t = torch.empty((1,), device=device, dtype=dtype).cauchy_(float('inf'), 0.5)
2339        self.assertTrue(t.item() == float('inf'))
2340
2341        # Tests non-positive rate fails
2342        with self.assertRaises(RuntimeError):
2343            torch.empty((1,), device=device, dtype=dtype).cauchy_(0.0, 0.0)
2344
2345    @skipIfMps
2346    @skipIfNoSciPy
2347    @skipRocmIfTorchInductor
2348    @dtypes(*all_types_and(torch.half, torch.bfloat16))
2349    def test_geometric_kstest(self, device, dtype):
2350        from scipy import stats
2351        size = 1000
2352        for p in [0.2, 0.5, 0.8]:
2353            t = torch.empty(size, dtype=dtype, device=device).geometric_(p=p)
2354            actual = np.histogram(t.cpu().to(torch.double), np.arange(1, 100))[0]
2355            expected = stats.geom(p).pmf(np.arange(1, 99)) * size
2356            res = stats.chisquare(actual, expected)
2357            self.assertEqual(res.pvalue, 1.0, atol=0.1, rtol=0)
2358
2359    # FIXME: find test suite for pdist and cdist
2360    def test_pairwise_distance_empty(self, device):
2361        shape = (2, 0)
2362        x = torch.randn(shape, device=device)
2363        y = torch.randn(shape, device=device)
2364
2365        self.assertEqual(torch.zeros(2, device=device), torch.pairwise_distance(x, y))
2366        self.assertEqual(torch.zeros((2, 1), device=device), torch.pairwise_distance(x, y, keepdim=True))
2367
2368        shape = (0, 2)
2369        x = torch.randn(shape, device=device)
2370        y = torch.randn(shape, device=device)
2371        self.assertEqual(torch.zeros(0, device=device), torch.pairwise_distance(x, y))
2372        self.assertEqual(torch.zeros((0, 1), device=device), torch.pairwise_distance(x, y, keepdim=True))
2373
2374    def test_pdist_empty(self, device):
2375        shape = (0, 2)
2376        x = torch.randn(shape, device=device)
2377        self.assertEqual(torch.empty(0, device=device), torch.pdist(x))
2378
2379        shape = (1, 2)
2380        x = torch.randn(shape, device=device)
2381        self.assertEqual(torch.empty(0, device=device), torch.pdist(x))
2382
2383        shape = (3, 0)
2384        x = torch.randn(shape, device=device)
2385        self.assertEqual(torch.zeros(3, device=device), torch.pdist(x))
2386
2387    def test_cdist_empty(self, device):
2388        x = torch.randn((0, 5), device=device)
2389        y = torch.randn((4, 5), device=device)
2390        self.assertEqual(torch.empty(0, 4, device=device), torch.cdist(x, y))
2391
2392        x = torch.randn((2, 5), device=device)
2393        y = torch.randn((0, 5), device=device)
2394        self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
2395
2396        x = torch.randn((2, 0), device=device)
2397        y = torch.randn((3, 0), device=device)
2398        self.assertEqual(torch.zeros(2, 3, device=device), torch.cdist(x, y))
2399
2400        x = torch.randn((2, 0), device=device)
2401        y = torch.randn((0, 0), device=device)
2402        self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
2403
2404    def _brute_cdist(self, x, y, p=2):
2405        r1 = x.shape[-2]
2406        r2 = y.shape[-2]
2407        if r1 == 0 or r2 == 0:
2408            return torch.empty(r1, r2, device=x.device)
2409        return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
2410
2411    @skipIfMps
2412    def test_cdist_norm(self, device):
2413        for r1 in [3, 4, 5, 6]:
2414            for m in [2, 3, 4, 10]:
2415                for r2 in [4, 6, 7, 8]:
2416                    for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
2417                        x = torch.randn(r1, m, device=device)
2418                        y = torch.randn(r2, m, device=device)
2419                        if p == 2:
2420                            for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2421                                actual = torch.cdist(x, y, p=2, compute_mode=cm)
2422                                expected = self._brute_cdist(x, y, p=2)
2423                                self.assertEqual(expected, actual, rtol=0, atol=0.02)
2424                        else:
2425                            actual = torch.cdist(x, y, p=p)
2426                            expected = self._brute_cdist(x, y, p=p)
2427                            self.assertEqual(expected, actual)
2428
2429    @skipIfMps
2430    def test_cdist_norm_batch(self, device):
2431        for r1 in [3, 4, 5, 6]:
2432            for m in [2, 3, 4, 10]:
2433                for r2 in [4, 6, 7, 8]:
2434                    for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
2435                        x = torch.randn(2, 3, 6, r1, m, device=device)
2436                        y = torch.randn(2, 3, 6, r2, m, device=device)
2437                        if p == 2:
2438                            for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2439                                actual = torch.cdist(x, y, p=2, compute_mode=cm)
2440                                expected = self._brute_cdist(x, y, p=2)
2441                                self.assertEqual(expected, actual, rtol=0, atol=0.02)
2442                        else:
2443                            actual = torch.cdist(x, y, p=p)
2444                            expected = self._brute_cdist(x, y, p=p)
2445                            self.assertEqual(expected, actual)
2446
2447    @onlyCUDA
2448    def test_cdist_cuda_backward(self, device):
2449        for l1 in [1, 511, 513]:
2450            for l2 in [1, 511, 513]:
2451                for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
2452                    x1 = torch.randn(4, l1, 32, device=device, requires_grad=True)
2453                    x2 = x1.clone().detach_().requires_grad_()
2454                    y1 = torch.randn(4, l2, 32, device=device, requires_grad=True)
2455                    y2 = y1.clone().detach_().requires_grad_()
2456                    if p == 2:
2457                        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2458                            z1 = torch.cdist(x1, y1, p=2, compute_mode=cm).mean()
2459                            z2 = self._brute_cdist(x2, y2, p=2).mean()
2460                            z1.backward()
2461                            z2.backward()
2462                            self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001)
2463                            self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001)
2464                    else:
2465                        z1 = torch.cdist(x1, y1, p=p).mean()
2466                        z2 = self._brute_cdist(x2, y2, p=p).mean()
2467                        self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001)
2468                        self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001)
2469
2470    @tf32_on_and_off(0.005)
2471    @bf32_on_and_off(0.005)
2472    def test_cdist_large(self, device):
2473        for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2474            x = torch.randn(1000, 10, device=device)
2475            y = torch.randn(1000, 10, device=device)
2476            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2477            expected = self._brute_cdist(x, y, p=2)
2478            self.assertEqual(expected, actual)
2479
2480    @slowTest
2481    @tf32_on_and_off(0.01)
2482    @bf32_on_and_off(0.01)
2483    def test_cdist_large_batch(self, device):
2484        for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2485            x = torch.randn(4, 3, 1000, 10, device=device)
2486            y = torch.randn(4, 3, 1000, 10, device=device)
2487            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2488            expected = self._brute_cdist(x, y, p=2)
2489            self.assertEqual(expected, actual)
2490
2491    @tf32_on_and_off(0.005)
2492    @bf32_on_and_off(0.005)
2493    def test_cdist_non_contiguous(self, device):
2494        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2495            x = torch.randn(5, 7, device=device).mT
2496            y = torch.randn(5, 3, device=device).mT
2497            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2498            expected = self._brute_cdist(x, y, p=2)
2499            self.assertFalse(x.is_contiguous())
2500            self.assertFalse(y.is_contiguous())
2501            self.assertEqual(expected, actual)
2502
2503            x = torch.randn(7, 5, device=device)
2504            y = torch.randn(5, 3, device=device).t()
2505            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2506            expected = self._brute_cdist(x, y, p=2)
2507            self.assertTrue(x.is_contiguous())
2508            self.assertFalse(y.is_contiguous())
2509            self.assertEqual(expected, actual)
2510
2511            x = torch.randn(5, 7, device=device).t()
2512            y = torch.randn(3, 5, device=device)
2513            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2514            expected = self._brute_cdist(x, y, p=2)
2515            self.assertFalse(x.is_contiguous())
2516            self.assertTrue(y.is_contiguous())
2517            self.assertEqual(expected, actual)
2518
2519    @tf32_on_and_off(0.005)
2520    @bf32_on_and_off(0.005)
2521    def test_cdist_non_contiguous_batch(self, device):
2522        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2523            x = torch.randn(4, 3, 2, 5, 7, device=device).mT
2524            y = torch.randn(4, 3, 2, 5, 3, device=device).mT
2525            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2526            expected = self._brute_cdist(x, y, p=2)
2527            self.assertFalse(x.is_contiguous())
2528            self.assertFalse(y.is_contiguous())
2529            self.assertEqual(expected, actual)
2530
2531            x = torch.randn(7, 2, 7, 5, device=device)
2532            y = torch.randn(7, 2, 5, 3, device=device).mT
2533            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2534            expected = self._brute_cdist(x, y, p=2)
2535            self.assertTrue(x.is_contiguous())
2536            self.assertFalse(y.is_contiguous())
2537            self.assertEqual(expected, actual)
2538
2539            x = torch.randn(4, 5, 7, device=device).mT
2540            y = torch.randn(4, 3, 5, device=device)
2541            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2542            expected = self._brute_cdist(x, y, p=2)
2543            self.assertFalse(x.is_contiguous())
2544            self.assertTrue(y.is_contiguous())
2545            self.assertEqual(expected, actual)
2546
2547    # Maybe merge into OpInfo?
2548    def test_cdist_euclidean_large(self, device):
2549        def _test_euclidean_large_cdist(sizex, sizey=None):
2550            if sizey is None:
2551                sizey = sizex
2552            x = torch.randn(sizex, device=device, dtype=torch.float)
2553            y = torch.randn(sizey, device=device, dtype=torch.float)
2554            eps = 1e-6
2555            # to avoid extremum
2556            x = x - (((x - y) < eps).float() * 2 * eps)
2557            x.requires_grad = True
2558            y.requires_grad = True
2559            dist = torch.cdist(x, y, p=2)
2560            # Do a backward pass to check that it is valid for large
2561            # matrices
2562            loss = dist.sum()
2563            loss.backward()
2564
2565        _test_euclidean_large_cdist((2000, 5))
2566
2567    # Ensure that cdist backward with p<1 does not produce NaNs
2568    @skipIfMps
2569    def test_cdist_grad_p_lt_1_no_nan(self, device):
2570        for p in [0.99, 0.7, 0.5, 0.1, 0.01]:
2571            x = torch.randn(1, 2, device=device)
2572            y = x.clone().detach() + torch.tensor([[1., 0.]], device=device)
2573            x.requires_grad = True
2574            y.requires_grad = True
2575            result = torch.cdist(x, y, p=p)
2576            result.backward(torch.ones_like(result))
2577            self.assertFalse(torch.isnan(x.grad).any())
2578            self.assertFalse(torch.isnan(y.grad).any())
2579
2580    def test_cdist_same_inputs(self, device):
2581        # Test to detect issues in cdist gradient calculation
2582        # When the distances are 0
2583        sizex = (1, 27, 32)
2584        for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
2585            x = torch.randn(sizex, device=device, dtype=torch.float)
2586            dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
2587            y = x.clone()
2588            eps = 1e-6
2589            x.requires_grad = True
2590            d = torch.cdist(x, y)
2591            d.backward(dist_grad)
2592            # Check that the backward passs does not contain invalid
2593            # values such as nan or inf
2594            assert torch.isfinite(x.grad).all()
2595
2596    @skipIfMps
2597    def test_cumsum(self, device):
2598        x = torch.rand(100, 100, device=device)
2599        res1 = torch.cumsum(x, 1)
2600        res2 = torch.tensor([]).to(device)
2601        torch.cumsum(x, 1, out=res2)
2602        self.assertEqual(res1, res2)
2603        x.cumsum_(1)
2604        self.assertEqual(res1, x)
2605
2606        a = torch.tensor([[True, False, True],
2607                          [False, False, False],
2608                          [True, True, True]], device=device)
2609        b = a.byte()
2610        aRes = torch.cumsum(a, 0)
2611        bRes = torch.cumsum(b, 0)
2612        self.assertEqual(aRes, bRes)
2613        self.assertEqual(aRes, torch.tensor([[1, 0, 1],
2614                                             [1, 0, 1],
2615                                             [2, 1, 2]]))
2616
2617        aRes = torch.cumsum(a, 1)
2618        bRes = torch.cumsum(b, 1)
2619        self.assertEqual(aRes, bRes)
2620        self.assertEqual(aRes, torch.tensor([[1, 1, 2],
2621                                             [0, 0, 0],
2622                                             [1, 2, 3]]))
2623
2624        # Check that cummulative sum over a zero length dimension doesn't crash on backprop.
2625        # Also check that cumsum over other dimensions in a tensor with a zero-length
2626        # dimensiuon also works
2627        # Also include a basic suite of similar tests for other bases cases.
2628        shapes = [[2, 0], [2, 1, 4], [0, 2, 3], [1], [5]]
2629        for shape in shapes:
2630            for dim in range(len(shape)):
2631                raw_tensor = torch.zeros(*shape, requires_grad=True)
2632                integrated = raw_tensor.cumsum(dim=dim)
2633                # Check that backward does not crash
2634                integrated.sum().backward()
2635                # Check that output maintained correct shape
2636                self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2637
2638        # Check a scalar example
2639        raw_tensor = torch.tensor(3., requires_grad=True)
2640        integrated = raw_tensor.cumsum(dim=-1)
2641        self.assertEqual(raw_tensor, integrated)
2642        # Check that backward does not crash
2643        integrated.sum().backward()
2644        # Check that output maintained correct shape
2645        self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2646
2647    @skipIfMps
2648    def test_cumprod(self, device):
2649        x = torch.rand(100, 100, device=device)
2650        res1 = torch.cumprod(x, 1)
2651        res2 = torch.tensor([]).to(device)
2652        if not TEST_WITH_TORCHINDUCTOR:
2653            torch.cumprod(x, 1, out=res2)
2654            self.assertEqual(res1, res2)
2655        x.cumprod_(1)
2656        self.assertEqual(res1, x)
2657
2658        a = torch.tensor([[True, False, True],
2659                          [False, False, False],
2660                          [True, True, True]], dtype=torch.bool, device=device)
2661        b = a.byte()
2662        aRes = torch.cumprod(a, 0)
2663        bRes = torch.cumprod(b, 0)
2664        self.assertEqual(aRes, bRes)
2665        self.assertEqual(aRes, torch.tensor([[1, 0, 1],
2666                                             [0, 0, 0],
2667                                             [0, 0, 0]]))
2668
2669        aRes = torch.cumprod(a, 1)
2670        bRes = torch.cumprod(b, 1)
2671        self.assertEqual(aRes, bRes)
2672        self.assertEqual(aRes, torch.tensor([[1, 0, 0],
2673                                             [0, 0, 0],
2674                                             [1, 1, 1]]))
2675
2676        # Check that cummulative prod over a zero length dimension doesn't crash on backprop.
2677        # Also check that cumprod over other dimensions in a tensor with a zero-length
2678        # dimensiuon also works
2679        # Also include a basic suite of similar tests for other bases cases.
2680        shapes = [[2, 0], [2, 1, 4], [0, 2, 3], [1], [5]]
2681        for shape in shapes:
2682            for dim in range(len(shape)):
2683                raw_tensor = torch.zeros(*shape, requires_grad=True)
2684                integrated = raw_tensor.cumprod(dim=dim)
2685                # Check that backward does not crash
2686                integrated.sum().backward()
2687                # Check that output maintained correct shape
2688                self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2689
2690        # Check a scalar example
2691        raw_tensor = torch.tensor(3., requires_grad=True)
2692        integrated = raw_tensor.cumprod(dim=-1)
2693        self.assertEqual(raw_tensor, integrated)
2694        # Check that backward does not crash
2695        integrated.sum().backward()
2696        # Check that output maintained correct shape
2697        self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2698
2699    @skipIfMps
2700    def test_cummax_cummin(self, device):
2701        def test_ops(op, string_of_function_name, expected_output1, expected_output2):
2702            x = torch.rand(100, 100, device=device)
2703            out1 = op(x, 1)
2704            res2 = torch.empty(0, device=device)
2705            indices2 = torch.empty(0, dtype=torch.int64, device=device)
2706            op(x, 1, out=(res2, indices2))
2707            self.assertEqual(out1[0], res2)
2708            self.assertEqual(out1[1], indices2)
2709
2710            a = torch.tensor([[True, False, True],
2711                              [False, False, False],
2712                              [True, True, True]], dtype=torch.bool, device=device)
2713            b = a.byte()
2714            aRes = op(a, 0)
2715            bRes = op(b, 0)
2716            self.assertEqual(aRes[0], bRes[0].bool())
2717            self.assertEqual(aRes[0], expected_output1.bool())
2718
2719            # test inf and nan input
2720            x = torch.tensor([4, inf, 1.5, -inf, 0, nan, 1])
2721            xRes = op(x, 0)[0]
2722            self.assertEqual(xRes, expected_output2)
2723
2724            # op shouldn't support values, indices with a dtype, device type or layout
2725            # different from that of input tensor
2726            t = torch.randn(10)
2727            values = torch.empty(0, dtype=torch.int16)
2728            indices = torch.empty(0, dtype=torch.int64)
2729            with self.assertRaisesRegex(
2730                    RuntimeError,
2731                    'expected scalar_type Float but found Short'):
2732                op(t, 0, out=(values, indices))
2733
2734            # Check that op over a zero length dimension doesn't crash on backprop.
2735            # Also check that op over other dimensions in a tensor with a zero-length
2736            # dimension also works
2737            # Also include a basic suite of similar tests for other bases cases.
2738            shapes = [[2, 0], [2, 1, 4], [0, 2, 3], [1], [5]]
2739            for shape in shapes:
2740                for dim in range(len(shape)):
2741                    raw_tensor = torch.zeros(*shape, requires_grad=True)
2742                    integrated = getattr(raw_tensor, string_of_function_name)(dim=dim)
2743                    # Check that backward does not crash
2744                    integrated[0].sum().backward()
2745                    # Check that output maintained correct shape
2746                    self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2747
2748            # Check a scalar example
2749            raw_tensor = torch.tensor(3., requires_grad=True)
2750            integrated = getattr(raw_tensor, string_of_function_name)(dim=-1)
2751            # Check that backward does not crash
2752            integrated[0].sum().backward()
2753            # Check that output maintained correct shape
2754            self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2755
2756        expected_out = torch.tensor([4, inf, inf, inf, inf, nan, nan])
2757        test_ops(torch.cummax, "cummax", torch.tensor([[1, 0, 1],
2758                                                       [1, 0, 1],
2759                                                       [1, 1, 1]]), expected_out)
2760
2761        expected_out = torch.tensor([4, 4, 1.5, -inf, -inf, nan, nan])
2762        test_ops(torch.cummin, "cummin", torch.tensor([[1, 0, 1],
2763                                                       [0, 0, 0],
2764                                                       [0, 0, 0]]), expected_out)
2765
2766    @skipIfMps
2767    def test_logcumsumexp(self, device):
2768        def logcumsumexp(a, axis):
2769            return torch.cumsum(a.exp(), axis=axis).log_()
2770
2771        axis = -1
2772        a = torch.randn(100, 100, device=device)
2773
2774        actual = a.logcumsumexp(axis)
2775        expected = logcumsumexp(a, axis)
2776        self.assertEqual(a.dtype, actual.dtype)
2777        self.assertEqual(expected.shape, actual.shape)
2778        self.assertEqual(expected, actual)
2779
2780        # check -inf and nan handling
2781        x = torch.tensor([-float('inf'), -float('inf'), 1.0, 1.0, float('inf'),
2782                         float('inf'), float('nan'), 1.0, 1.0], device=device)
2783        x2d = x.unsqueeze(0).expand(2, -1)
2784
2785        for inp in (x, x2d):
2786            actual = inp.logcumsumexp(axis)
2787            expected = logcumsumexp(inp, axis)
2788            self.assertEqual(expected, actual)
2789
2790        # Check that out is actually inplace
2791        b = torch.randn(5, 2, device=device)
2792        inplace_out = torch.zeros(5, 2, device=device)
2793
2794        expected = logcumsumexp(b, axis)
2795        torch.logcumsumexp(b, axis=axis, out=inplace_out)
2796
2797        self.assertEqual(inplace_out, expected)
2798
2799        # Check input and inplace_output type mismatch
2800        b = torch.randn(5, 2, device=device, dtype=torch.float64)
2801        inplace_out = torch.zeros(5, 2, device=device, dtype=torch.float32)
2802        with self.assertRaisesRegex(
2803                RuntimeError,
2804                'expected scalar_type Double but found Float'):
2805            torch.logcumsumexp(b, axis, out=inplace_out)
2806
2807    def _test_diff_numpy(self, t, dims=None):
2808        # Helper for test_diff to compare with NumPy reference implementation
2809        def to_np(t):
2810            if t.dtype == torch.bfloat16:
2811                return t.to(dtype=torch.float, device="cpu").numpy()
2812            else:
2813                return t.cpu().numpy()
2814
2815        for dim in dims if dims else range(t.dim()):
2816            prepend = t.narrow(dim, 0, 1)
2817            append = t.narrow(dim, 0, 1)
2818            np_t = to_np(t)
2819
2820            # test when no prepend and append
2821            for n in range(t.size(dim)):
2822                actual = torch.diff(t, dim=dim, n=n)
2823                expected = torch.from_numpy(np.diff(np_t, axis=dim, n=n))
2824                self.assertEqual(actual, expected.to(t.dtype))
2825
2826            # test when prepend and append's size along dim is 1
2827            for n in range(1, t.size(dim) + 4):
2828                actual = torch.diff(t, dim=dim, n=n, prepend=prepend, append=append)
2829                expected = torch.from_numpy(np.diff(np_t, axis=dim, n=n, prepend=to_np(prepend), append=to_np(append)))
2830                self.assertEqual(actual, expected.to(t.dtype))
2831
2832            # test when prepend and append's size along dim != 1
2833            for n in range(1, t.size(dim) * 3):
2834                actual = torch.diff(t, dim=dim, n=n, prepend=t, append=t)
2835                expected = torch.from_numpy(np.diff(np_t, axis=dim, n=n, prepend=np_t, append=np_t))
2836                self.assertEqual(actual, expected.to(t.dtype))
2837
2838    # All tensors appear contiguous on XLA
2839    @onlyNativeDeviceTypes
2840    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
2841    def test_diff_noncontig(self, device, dtype):
2842        shapes = (
2843            (1,),
2844            (1, 5),
2845            (3, 5),
2846            (1, 5, 1),
2847            (2, 3, 5))
2848
2849        for shape in shapes:
2850            contig = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
2851
2852            non_contig = torch.empty(shape + (2, 2), device=device, dtype=dtype)[..., 0]
2853            non_contig = non_contig.select(-1, -1)
2854            non_contig.copy_(contig)
2855            self.assertTrue(not non_contig.is_contiguous() or shape == (1,))
2856
2857            self._test_diff_numpy(non_contig)
2858
2859    # RngNormal not implemented for type f16 for XLA
2860    @dtypes(*all_types_and_complex_and(torch.bool))
2861    @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool))
2862    @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool))
2863    def test_diff(self, device, dtype):
2864        shapes = (
2865            (1,),
2866            (1, 5),
2867            (3, 5),
2868            (1, 5, 1),
2869            (2, 3, 5))
2870
2871        for shape in shapes:
2872            contig = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
2873            self._test_diff_numpy(contig)
2874
2875        t = torch.ones(2, 3)
2876
2877        with self.assertRaisesRegex(
2878                RuntimeError, 'diff expects prepend or append to be the same dimension as input'):
2879            invalid_prepend = torch.tensor([1, 2, 3], device=device, dtype=dtype)
2880            t.diff(dim=0, prepend=invalid_prepend)
2881
2882        with self.assertRaisesRegex(
2883                RuntimeError, 'diff expects the shape of tensor to prepend or append to match that of input'):
2884            invalid_prepend = torch.tensor([[0, 1]], device=device, dtype=dtype)
2885            t.diff(dim=0, prepend=invalid_prepend)
2886
2887        with self.assertRaisesRegex(
2888                RuntimeError, 'diff expects input to be at least one-dimensional'):
2889            scalar = torch.tensor(2, device=device, dtype=dtype)
2890            torch.diff(scalar)
2891
2892    # if the given input arg is not a list, it returns a list of single element: [arg]
2893    def _wrap_to_list(self, input_array):
2894        return input_array if isinstance(input_array, list) else [input_array]
2895
2896    # To ensure inf, -inf, and nan values do not cause divergence between Numpy and PyTorch.
2897    # There are two types of possible divergence:
2898    # 1. When we compute a,b both real numbers and has very small absolute values (i.e. very near to 0.0)
2899    # then, result of a/b be inf, -inf and nan, and this cause divergence.
2900    # 2. When we are dividing complex numbers by zero. For example, when a = torch.tensor(3+5j) we have
2901    # a/0 to be equal to nan + nan*j in PyTorch and inf + inf*j in Numpy.
2902    def _inf_nan_preprocess(self, actual, expected):
2903        for i in range(len(expected)):
2904            expected[i] = np.nan_to_num(expected[i], nan=nan, posinf=nan, neginf=nan)
2905            # nan_to_num is not defined for complex tensors in PyTorch.
2906            if actual[i].dtype == torch.complex64 :
2907                actual[i].real = torch.nan_to_num(actual[i].real, nan=nan, posinf=nan, neginf=nan)
2908                actual[i].imag = torch.nan_to_num(actual[i].imag, nan=nan, posinf=nan, neginf=nan)
2909            else:
2910                actual[i] = torch.nan_to_num(actual[i], nan=nan, posinf=nan, neginf=nan)
2911
2912        return actual, expected
2913
2914    @onlyNativeDeviceTypes
2915    @dtypes(torch.long, torch.float32, torch.complex64)
2916    def test_gradient_all(self, device, dtype):
2917        def create_scalar(shape):
2918            return make_tensor((1,), device='cpu', dtype=dtype, low=1.).item()
2919
2920        def create_list(shape):
2921            return make_tensor((len(shape),), device='cpu', dtype=dtype, low=1.).tolist()
2922
2923        def create_coordinate_tensors(shape):
2924            tensor_list = []
2925            for i in range(len(shape)):
2926                tensor_list.append(make_tensor((shape[i],), device=device, dtype=dtype))
2927            return tensor_list
2928
2929        def filter_shape(shape, dim):
2930            filtered_shape = []
2931            for i in range(len(dim)):
2932                filtered_shape.append(shape[dim[i]])
2933            return filtered_shape
2934
2935        # shape, dims format
2936        test_cases = (
2937            ((5,), (0,)),
2938            ((4, 4), (0, 1)),
2939            ((3, 3, 3), (-1, 0)),
2940            ((4, 4, 4), (2,)),
2941            ((4, 4, 4), (0, 1)),
2942            ((4, 4, 4, 3), (0, 2, 3)),
2943            ((4, 5, 3, 4, 3), (1, 2)),
2944            ((4, 3, 6, 5, 3), (2, 4)),
2945            ((4, 3, 3, 5, 3), (0, 1, 2, 3, 4)),
2946            ((1, 3, 3), (1, 2)),
2947            ((1, 5), (1,)),
2948        )
2949
2950        for case, contig, edge_order, space_fn in product(test_cases, [True, False], [1, 2],
2951                                                          (create_scalar, create_list, create_coordinate_tensors)):
2952            shape, dims = case
2953            # filter shape by dims before passing filtered shape to create_* functions
2954            filtered_shape = filter_shape(shape, dims)
2955
2956            spacing = space_fn(filtered_shape)
2957            t = make_tensor(shape, device=device, dtype=dtype, noncontiguous=not contig)
2958            t_np = t.cpu().numpy()
2959
2960            actual = torch.gradient(t, spacing=spacing, dim=dims, edge_order=edge_order)
2961            if space_fn == create_coordinate_tensors and spacing[0].device != 'cpu':
2962                spacing = [space.cpu().detach().numpy() for space in spacing]
2963            expected = np.gradient(t_np, *self._wrap_to_list(spacing), axis=dims, edge_order=edge_order)
2964            actual, expected = self._inf_nan_preprocess(list(actual), self._wrap_to_list(expected))
2965            self.assertEqual(actual, expected, equal_nan=True, atol=1e-4, rtol=0, exact_dtype=False)
2966
2967    @onlyNativeDeviceTypes
2968    @slowTestIf(TEST_WITH_TORCHINDUCTOR)
2969    @dtypes(torch.long, torch.float32, torch.complex64)
2970    def test_gradient_extreme_cases(self, device, dtype):
2971        # Test behaviour for inf and nan values
2972        actual = torch.gradient(torch.tensor([2, -2, inf, inf, -inf, -inf, inf, 3, -inf, 2, nan, nan, 3, inf, nan]))
2973        expected = np.gradient(np.array([2, -2, inf, inf, -inf, -inf, inf, 3, -inf, 2, nan, nan, 3, inf, nan]))
2974        self.assertEqual(actual, self._wrap_to_list(expected), exact_dtype=False)
2975
2976        # Test behaviour in very big tensors
2977        large_size = 100000
2978        t = make_tensor((large_size,), dtype=dtype, device=device)
2979        t_np = t.cpu().numpy()
2980        coordinates_np = np.random.randn(large_size)
2981        coordinates = [torch.tensor(coordinates_np, device=device)]
2982        actual = torch.gradient(t, spacing=coordinates, dim=0, edge_order=1)
2983        expected = [np.gradient(t_np, coordinates_np, axis=0, edge_order=1)]
2984        self.assertEqual(actual, expected, exact_dtype=False)
2985
2986        actual = torch.gradient(t, spacing=coordinates, dim=0, edge_order=2)
2987        expected = [np.gradient(t_np, coordinates_np, axis=0, edge_order=2)]
2988        self.assertEqual(actual, expected, exact_dtype=False)
2989
2990    @onlyNativeDeviceTypes
2991    def test_gradient_type_promotion(self, device):
2992        inputs = (
2993            make_tensor((4, 4), device=device, dtype=torch.float32),
2994            make_tensor((4, 4), device=device, dtype=torch.complex64),
2995            make_tensor((4, 4), device=device, dtype=torch.int64),
2996        )
2997
2998        spacing = (
2999            make_tensor((1,), device='cpu', dtype=torch.float32).item(),
3000            make_tensor((1,), device='cpu', dtype=torch.int64).item(),
3001            make_tensor((1,), device='cpu', dtype=torch.complex64).item(),
3002            make_tensor((2,), device='cpu', dtype=torch.float32, low=0.1).tolist(),
3003            make_tensor((2,), device='cpu', dtype=torch.int64, low=1).tolist(),
3004            make_tensor((2,), device='cpu', dtype=torch.complex64).tolist(),
3005            [make_tensor((4,), device=device, dtype=torch.float32),
3006             make_tensor((4,), device=device, dtype=torch.float32)],
3007            [make_tensor((4,), device=device, dtype=torch.int64),
3008             make_tensor((4,), device=device, dtype=torch.int64)],
3009            [make_tensor((4,), device=device, dtype=torch.complex64),
3010             make_tensor((4,), device=device, dtype=torch.complex64)],
3011        )
3012
3013        for input, spacing_or_coord, edge_order in product(inputs, spacing, [1, 2]):
3014            input_np = input.cpu().numpy()
3015            input_np = input.cpu().numpy()
3016            actual = torch.gradient(input, spacing=spacing_or_coord, dim=(0, 1), edge_order=edge_order)
3017            spacing_or_coord_wrapped = self._wrap_to_list(spacing_or_coord)
3018            spacing_or_coord_np = []
3019            if torch.is_tensor(spacing_or_coord_wrapped[0]) and torch.device(spacing_or_coord_wrapped[0].device).type != 'cpu':
3020                for i in range(len(spacing_or_coord_wrapped)):
3021                    spacing_or_coord_np.append(spacing_or_coord_wrapped[i].detach().clone().cpu().numpy())
3022            else:
3023                spacing_or_coord_np = spacing_or_coord_wrapped
3024            expected = np.gradient(input_np, *spacing_or_coord_np, axis=(0, 1), edge_order=edge_order)
3025            if actual[0].dtype == torch.complex64 and input.dtype != torch.complex64:
3026                for i in range(len(actual)):
3027                    self.assertEqual(actual[i].real, expected[i].real, exact_dtype=False)
3028                    # Type promotion fails on Numpy when spacing is given as complex number and input is given as real.
3029                    # Result is given just as real number and all the imaginary parts to be equal to zero.
3030                    self.assertEqual(expected[i].imag, torch.zeros(actual[i].shape), exact_dtype=False)
3031            else:
3032                actual, expected = self._inf_nan_preprocess(list(actual), expected)
3033                self.assertEqual(actual, expected, equal_nan=True, exact_dtype=False)
3034
3035    @onlyNativeDeviceTypes
3036    @dtypes(torch.long, torch.float32, torch.complex64)
3037    def test_gradient_spacing_list_length_error(self, device, dtype):
3038        t = make_tensor((2, 2), device=device, dtype=dtype)
3039
3040        spacing = (make_tensor((2,), device=device, dtype=dtype),)
3041        with self.assertRaisesRegex(RuntimeError, r'expected spacing to be'):
3042            torch.gradient(t, spacing=spacing)
3043
3044        spacing = (make_tensor((2,), device=device, dtype=dtype),) * 2
3045        torch.gradient(t, spacing=spacing)
3046
3047        spacing = (make_tensor((2,), device=device, dtype=dtype),) * 3
3048        with self.assertRaisesRegex(RuntimeError, r'expected spacing to be'):
3049            torch.gradient(t, spacing=spacing)
3050
3051        spacing = (2,)
3052        with self.assertRaisesRegex(RuntimeError, r'expected spacing to be'):
3053            torch.gradient(t, spacing=spacing)
3054
3055        spacing = (2, 2)
3056        torch.gradient(t, spacing=spacing)
3057
3058        spacing = (2, 2, 2)
3059        with self.assertRaisesRegex(RuntimeError, r'expected spacing to be'):
3060            torch.gradient(t, spacing=spacing)
3061
3062    def _test_large_cum_fn_helper(self, x, fn):
3063        expected = fn(x.cpu().float())
3064        actual = fn(x).cpu().float()
3065        # Avoid self.assertEqual to save memory.
3066        torch.testing.assert_close(expected, actual)
3067
3068    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "sandcastle OOM with current tpx gpu/re configuration")
3069    @unittest.skipIf(IS_JETSON, "psutil issue for largeTensorTest. Too large for Jetson.")
3070    @onlyCUDA
3071    @dtypes(torch.half)  # only small dtype not to get oom
3072    @largeTensorTest('25GB', device='cpu')
3073    @largeTensorTest('4GB', device='cuda')
3074    def test_large_cumsum(self, device, dtype):
3075        # initialization to avoid overflow and half caveats
3076        x = torch.empty(2**30 + 200, device=device, dtype=dtype)
3077        x[::3] = -3
3078        x[1::3] = 2
3079        x[2::3] = 1
3080        self._test_large_cum_fn_helper(x, lambda x: torch.cumsum(x, 0))
3081
3082    @onlyCUDA
3083    @dtypes(torch.half)  # only small dtype not to get oom
3084    @largeTensorTest('25GB', device='cpu')
3085    @largeTensorTest('4GB', device='cuda')
3086    @unittest.skipIf(IS_JETSON, "psutil issue for largeTensorTest. Too large for Jetson.")
3087    def test_large_cumprod(self, device, dtype):
3088        # initialization to avoid overflow and half caveats
3089        x = torch.empty(2**30 + 200, device=device, dtype=dtype)
3090        x[::3] = 8
3091        x[1::3] = .25
3092        x[2::3] = .5
3093        self._test_large_cum_fn_helper(x, lambda x: torch.cumprod(x, 0))
3094
3095    @skipIfTorchDynamo("Torchdynamo fails with unknown reason")
3096    @skipIfMps
3097    def test_discontiguous_out_cumsum(self, device):
3098        x = torch.randn(4, 8, device=device)
3099        y = torch.empty(4, 16, device=device)[:, ::2]
3100        out = torch.cumsum(x, 0)
3101        torch.cumsum(x, 0, out=y)
3102        self.assertFalse(y.is_contiguous())
3103        self.assertEqual(out, y, atol=0., rtol=0.)
3104
3105    def _test_cumminmax_helper(self, x, fn, expected_val, expected_ind):
3106        val, ind = fn(x, -1)
3107        self.assertEqual(val, expected_val, atol=0, rtol=0)
3108        self.assertEqual(ind, expected_ind, atol=0, rtol=0)
3109        out_val = torch.empty_like(val).t().contiguous().t()
3110        out_ind = torch.empty_like(ind).t().contiguous().t()
3111        fn(x, -1, out=(out_val, out_ind))
3112        # TODO: Fix this. It reproduces with aot_eager too, and looks like a functionalization bug.
3113        # (the problematic case seems rare, as we're calling an out= op directly from user code,
3114        # where the passed-in out tensors are non-contiguous).
3115        if not TEST_WITH_TORCHINDUCTOR:
3116            self.assertFalse(out_val.is_contiguous())
3117            self.assertFalse(out_ind.is_contiguous())
3118        self.assertEqual(out_val, expected_val, atol=0, rtol=0)
3119        self.assertEqual(out_ind, expected_ind, atol=0, rtol=0)
3120
3121    @skipIfMps
3122    def test_cummax_discontiguous(self, device):
3123        x = torch.tensor([[0, 1, 2, 3, 2, 1], [4, 5, 6, 5, 6, 7]], device=device, dtype=torch.float).t().contiguous().t()
3124        expected_val = torch.tensor([[0, 1, 2, 3, 3, 3], [4, 5, 6, 6, 6, 7]], device=device, dtype=torch.float)
3125        expected_ind = torch.tensor([[0, 1, 2, 3, 3, 3], [0, 1, 2, 2, 4, 5]], device=device, dtype=torch.long)
3126        self._test_cumminmax_helper(x, torch.cummax, expected_val, expected_ind)
3127
3128    @skipIfMps
3129    def test_cummin_discontiguous(self, device):
3130        x = torch.tensor([[3, 2, 1, 0, 1, 2], [7, 6, 5, 4, 5, 2]], device=device, dtype=torch.float).t().contiguous().t()
3131        expected_val = torch.tensor([[3, 2, 1, 0, 0, 0], [7, 6, 5, 4, 4, 2]], device=device, dtype=torch.float)
3132        expected_ind = torch.tensor([[0, 1, 2, 3, 3, 3], [0, 1, 2, 3, 3, 5]], device=device, dtype=torch.long)
3133        self._test_cumminmax_helper(x, torch.cummin, expected_val, expected_ind)
3134
3135    def test_bool_tensor_value_change(self, device):
3136        x = torch.tensor([True, False], dtype=torch.bool, device=device)
3137        x[0] = False
3138        x[1] = True
3139        self.assertEqual(x, torch.tensor([False, True], dtype=torch.bool, device=device))
3140
3141    # FIXME: move to shape ops test suite
3142    def test_unfold_all_devices_and_dtypes(self, device):
3143        for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
3144
3145            if dt == torch.bool:
3146                x = torch.empty((0, 1, 3, 0), dtype=dt, device=device)
3147                self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
3148            else:
3149                x = torch.empty((0, 1, 3, 0), dtype=dt, device=device)
3150                self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
3151
3152    # FIXME: move to shape ops test suite
3153    def test_unfold_scalars(self, device):
3154        x = torch.tensor(0.5, device=device)
3155        # unfold on a 0-dimensional tensor should always return a 1-d dimensional
3156        # tensor of shape [size] (i.e., the second parameter to unfold)
3157
3158        self.assertEqual(torch.empty(0, device=device), x.unfold(0, 0, 1))
3159        self.assertEqual(torch.empty(0, device=device), x.unfold(0, 0, 2))
3160        self.assertEqual(torch.tensor([0.5], device=device), x.unfold(0, 1, 1))
3161
3162    # FIXME: move to data movement test suite
3163    def test_copy_all_dtypes_and_devices(self, device):
3164        from copy import copy
3165        for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
3166            x = torch.tensor([1, 2, 3, 4], dtype=dt, device=device)
3167            x_clone = x.clone()
3168            y = copy(x)
3169            y.fill_(1)
3170            # copy is a shallow copy, only copies the tensor view,
3171            # not the data
3172            self.assertEqual(x, y)
3173
3174    @onlyCPU
3175    def test_bfloat16_neg_abs(self, device):
3176        src = torch.randn(256)
3177        src[0] = torch.nan
3178        src[1] = -torch.nan
3179        src[2] = torch.inf
3180        src[3] = -torch.inf
3181        src_bf16 = src.bfloat16()
3182        self.assertEqual(src.neg().bfloat16(), src_bf16.neg())
3183        self.assertEqual(src.abs().bfloat16(), src_bf16.abs())
3184
3185    @onlyCPU
3186    @dtypes(torch.bfloat16, torch.half)
3187    def test_reduced_type_float_copy(self, device, dtype):
3188        for shape in [(20, 7), (249, 137), (1029, 917), (1, 7, 19, 17), (3, 77, 1091)]:
3189            input = torch.randn(shape, dtype=torch.float, device=device)
3190            out1 = input.to(dtype=dtype)
3191            self.assertEqual(input, out1, atol=None, rtol=None, exact_dtype=False)
3192            out2 = out1.to(torch.float)
3193            self.assertEqual(out2, out1, atol=0, rtol=0, exact_dtype=False)
3194
3195            input_s = input[..., ::2, :]
3196            out1 = input_s.to(dtype=dtype)
3197            self.assertEqual(input_s, out1, atol=None, rtol=None, exact_dtype=False)
3198            out2 = out1.to(torch.float)
3199            self.assertEqual(out2, out1, atol=0, rtol=0, exact_dtype=False)
3200
3201    # FIXME: move to data movement test suite
3202    @onlyNativeDeviceTypes
3203    def test_copy_math_view(self, device):
3204        for dst_dtype, src_dtype in [
3205                (torch.float32, torch.float32),
3206                (torch.float64, torch.float32),
3207                (torch.int64, torch.int32),
3208                (torch.complex128, torch.complex64),
3209        ]:
3210            src = make_tensor((100,), dtype=src_dtype, device=device)
3211            dst = torch.empty(100, dtype=dst_dtype, device=device)
3212
3213            dst.copy_(src)
3214            self.assertEqual(dst, src, exact_dtype=False)
3215
3216            dst.copy_(src._neg_view())
3217            self.assertEqual(dst, src.neg(), exact_dtype=False)
3218
3219            dst._neg_view().copy_(torch._neg_view(src))
3220            self.assertEqual(dst, src, exact_dtype=False)
3221
3222            dst._neg_view().copy_(src)
3223            self.assertEqual(dst, src.neg(), exact_dtype=False)
3224
3225            # issue: https://github.com/pytorch/pytorch/issues/106051
3226            dst._neg_view().copy_(dst)
3227            self.assertEqual(dst, src, exact_dtype=False)
3228
3229        for dst_dtype, src_dtype in [
3230                (torch.complex64, torch.complex64),
3231                (torch.complex128, torch.complex64),
3232        ]:
3233            src = make_tensor((100,), dtype=src_dtype, device=device)
3234            dst = torch.empty(100, dtype=dst_dtype, device=device)
3235
3236            dst.conj().copy_(src)
3237            self.assertEqual(dst, src.conj_physical(), exact_dtype=False)
3238
3239            dst.conj().copy_(src._neg_view())
3240            self.assertEqual(dst, src.neg().conj_physical(), exact_dtype=False)
3241
3242    # FIXME: move to data movement test suite
3243    @onlyNativeDeviceTypes
3244    @dtypes(torch.int64, torch.float32, torch.complex64)
3245    def test_copy_transpose_math_view(self, device, dtype):
3246        src = make_tensor((100, 100), dtype=dtype, device=device).transpose(0, 1)
3247        dst = torch.empty((100, 100), dtype=dtype, device=device)
3248
3249        dst._neg_view().copy_(src)
3250        self.assertEqual(dst, -src)
3251        dst._neg_view().copy_(src._neg_view())
3252        self.assertEqual(dst, src)
3253        dst.copy_(src._neg_view())
3254        self.assertEqual(dst, -src)
3255
3256        if dtype.is_complex:
3257            dst.conj().copy_(src)
3258            self.assertEqual(dst, src.conj_physical())
3259            dst.conj().copy_(src.conj())
3260            self.assertEqual(dst, src)
3261            dst.copy_(src.conj())
3262            self.assertEqual(dst, src.conj_physical())
3263
3264    def test_clone_all_dtypes_and_devices(self, device):
3265        for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
3266            x = torch.tensor((1, 1), dtype=dt, device=device)
3267            y = x.clone()
3268            self.assertEqual(x, y)
3269
3270    def test_clone_zero_stride_dim(self, device):
3271        # stride zero, size 1 axis, not contiguous
3272        x = torch.randn(10)
3273        y = x.as_strided([2, 1, 5], [1, 0, 2])
3274        self.assertEqual(y, y.clone())
3275
3276    def test_clone_not_memory_dense(self):
3277        # github issue: https://github.com/pytorch/pytorch/issues/64176
3278        x = torch.randn(10, 8).t()[::2, ::2]
3279        y = x.clone()
3280        # should retain permutation after densification
3281        self.assertTrue(y.stride() == (1, 4))
3282
3283    # FIXME: move to elementwise ternary test suite
3284    @dtypesIfCUDA(*set(get_all_math_dtypes('cuda')))
3285    @dtypes(*set(get_all_math_dtypes('cpu')))
3286    def test_addcmul(self, device, dtype):
3287        # Returns floating or integral scalar corresponding to dtype
3288        def _number(floating, integer, dtype):
3289            if dtype in [torch.half, torch.float, torch.double, torch.bfloat16]:
3290                return floating
3291            elif dtype in [torch.cfloat, torch.cdouble]:
3292                return floating * (1 + 1j)
3293            else:
3294                return integer
3295
3296        def rand_tensor(size, dtype, device):
3297            if dtype.is_floating_point or dtype.is_complex:
3298                return torch.rand(size=size, dtype=dtype, device=device)
3299            if dtype == torch.uint8:
3300                return torch.randint(1, 5, size=size, dtype=dtype, device=device)
3301            else:
3302                return torch.randint(-5, 5, size=size, dtype=dtype, device=device)
3303
3304        a = rand_tensor((2, 2), dtype=dtype, device=device)
3305        b = rand_tensor((2, 2), dtype=dtype, device=device)
3306        c = rand_tensor((2, 2), dtype=dtype, device=device)
3307
3308        alpha = _number(0.5, 3, dtype)
3309
3310        actual = torch.addcmul(a, b, c, value=alpha)
3311        expected = a + alpha * b * c
3312
3313        self.assertEqual(expected, actual)
3314
3315        with self.assertWarnsOnceRegex(
3316                UserWarning, "This overload of addcmul is deprecated"):
3317            self.assertEqual(actual, torch.addcmul(a, alpha, b, c))
3318
3319        if self.device_type == 'cuda' and dtype == torch.half:
3320            a = torch.tensor([60000.0], device=device, dtype=dtype)
3321            b = torch.tensor([60000.0], device=device, dtype=dtype)
3322            c = torch.tensor([2.0], device=device, dtype=dtype)
3323            out = torch.addcmul(a, b, c, value=-1)
3324            self.assertTrue(not (out.isnan() or out.isinf()))
3325
3326    # FIXME: move to shape ops test suite
3327    def test_narrow_empty(self, device):
3328        x = torch.randn(2, 3, 4, device=device)
3329        for d in range(x.dim()):
3330            y = x.narrow(d, x.size(d), 0)
3331            sz = list(x.size())
3332            sz[d] = 0
3333            self.assertEqual(sz, y.size())
3334
3335    def test_narrow_copy_non_contiguous(self, device):
3336        # see https://github.com/pytorch/pytorch/issues/91690.
3337        inp = torch.randn(10, 2, device=device).movedim(-1, 0)
3338        expected = torch.narrow_copy(inp.contiguous(), 1, 0, 10)
3339        actual = torch.narrow_copy(inp, 1, 0, 10)
3340        self.assertEqual(expected, actual)
3341
3342    # FIXME: move to indexing test suite
3343    @parametrize("reduce", ['prod', 'amin', 'amax', 'mean'])
3344    @dtypes(*all_types_and(torch.half, torch.bfloat16))
3345    def test_index_reduce(self, device, dtype, reduce):
3346        size = (3, 4, 5)
3347        index_dtypes = [torch.int, torch.long]
3348        include_selfs = [True, False]
3349        amin_init = float('inf') if dtype.is_floating_point else torch.iinfo(dtype).max
3350        amax_init = -float('inf') if dtype.is_floating_point else torch.iinfo(dtype).min
3351        reduction_init = {'prod': 1, 'mean': 0, 'amin': amin_init, 'amax': amax_init}
3352
3353        for dest_noncontig, src_noncontig, index_noncontig in product([True, False], repeat=3):
3354            for idx_dtype, include_self in product(index_dtypes, include_selfs):
3355                for dim in range(len(size)):
3356                    num_src = np.random.randint(10)
3357                    num_dest = size[dim]
3358                    dest = make_tensor(size, device=device, dtype=dtype, noncontiguous=dest_noncontig)
3359                    src_size = size[:dim] + (num_src,) + size[dim + 1:]
3360                    src = make_tensor(src_size, device=device, dtype=dtype, noncontiguous=src_noncontig)
3361                    idx = torch.testing.make_tensor(
3362                        num_src, low=0, high=num_dest, dtype=idx_dtype, device=device, noncontiguous=index_noncontig
3363                    )
3364                    expected = dest.clone()
3365                    dest.index_reduce_(dim, idx, src, reduce, include_self=include_self)
3366                    # fill rows in idx with reduction inits if include_self=False
3367                    if (not include_self):
3368                        expected.index_fill_(dim, idx.long(), reduction_init[reduce])
3369                    expected = expected.transpose(0, dim)
3370                    src = src.transpose(0, dim)
3371                    for i in range(num_src):
3372                        if reduce == 'prod':
3373                            expected[idx[i]] *= src[i]
3374                        elif reduce == 'amin':
3375                            torch.minimum(expected[idx[i]], src[i], out=expected[idx[i]])
3376                        elif reduce == 'amax':
3377                            torch.maximum(expected[idx[i]], src[i], out=expected[idx[i]])
3378                        else:
3379                            expected[idx[i]] += src[i]
3380                    if reduce == 'mean':
3381                        counts = torch.ones_like(expected) if include_self else torch.zeros_like(expected)
3382                        counts.index_add_(0, idx, torch.ones_like(src))
3383                        counts.masked_fill_(counts == 0, 1)
3384                        if (dtype.is_floating_point):
3385                            expected.div_(counts)
3386                        else:
3387                            expected.div_(counts, rounding_mode="floor")
3388                    expected = expected.transpose(0, dim)
3389
3390                    self.assertEqual(dest, expected)
3391
3392    # FIXME: move to test indexing
3393    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3394    def test_index_copy(self, device, dtype):
3395        # We just test for num_copy <= num_dest, as otherwise there are repeated indices
3396        # and the behavior is undefined
3397        num_copy, num_dest = 3, 5
3398
3399        def make_arg(batch_sizes, n, dim, contig):
3400            size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:]
3401            return make_tensor(size_arg, dtype=dtype, device=device, low=None, high=None, noncontiguous=not contig)
3402
3403        def ref_index_copy(tgt, dim, idx, src):
3404            for i in range(idx.size(0)):
3405                idx_dest = dim * (slice(None),) + (idx[i],)
3406                idx_src = dim * (slice(None),) + (i,)
3407                tgt[idx_dest] = src[idx_src]
3408
3409        # More thorough testing as in index_add
3410        for dest_contig, src_contig, index_contig in product([True, False], repeat=3):
3411            for other_sizes in ((), (4, 5)):
3412                for dim in range(len(other_sizes)):
3413                    dest = make_arg(other_sizes, num_dest, dim, dest_contig)
3414                    src = make_arg(other_sizes, num_copy, dim, src_contig)
3415                    idx = torch.randperm(num_dest, dtype=torch.int64, device=device)[:num_copy]
3416                    if not index_contig:
3417                        idx = torch.repeat_interleave(idx, 2, dim=-1)
3418                        idx = idx[..., ::2]
3419                    dest2 = dest.clone()
3420                    dest.index_copy_(dim, idx, src)
3421                    ref_index_copy(dest2, dim, idx, src)
3422                    self.assertEqual(dest, dest2)
3423
3424    # FIXME: move to test indexing
3425    # onlyNativeDeviceTypes due to an XLA error:
3426    # https://github.com/pytorch/pytorch/issues/53256
3427    @onlyNativeDeviceTypes
3428    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3429    def test_index_copy_scalars(self, device, dtype):
3430        # Create the 8 possible combinations of scalar sizes for target / index / source
3431        scalars = ((make_tensor(size_t, dtype=dtype, device=device, low=None, high=None),
3432                    make_tensor(size_i, dtype=torch.int64, device=device, low=0, high=1),
3433                    make_tensor(size_s, dtype=dtype, device=device, low=None, high=None))
3434                   for size_t, size_i, size_s in product([(), (1,)], repeat=3))
3435        for target, idx, source in scalars:
3436            target.index_copy_(0, idx, source)
3437            self.assertEqual(target.item(), source.item())
3438
3439    # FIXME: move to test indexing
3440    @onlyCPU
3441    def test_errors_index_copy(self, device):
3442        # We do not test the GPU as the CUDA_ASSERT would break the CUDA context
3443        idx_dim = 8
3444        tgt_dim = 5
3445        batch_dim = 3
3446
3447        # Too large of an index
3448        a = torch.randn(batch_dim, tgt_dim, device=device)
3449        idx = torch.full((idx_dim,), tgt_dim, device=device)
3450        c = torch.zeros(batch_dim, idx_dim, device=device)
3451        with self.assertRaises(IndexError):
3452            a.index_copy_(1, idx, c)
3453
3454        # Too small (negative indices)
3455        idx = torch.full((idx_dim,), -1, device=device)
3456        with self.assertRaises(IndexError):
3457            a.index_copy_(1, idx, c)
3458
3459        # Too small (very negative indices) - they should be unsupported even
3460        # when support for negative indices is implemented for index_copy_
3461        idx = torch.full((idx_dim,), -tgt_dim - 1, device=device)
3462        with self.assertRaises(IndexError):
3463            a.index_copy_(1, idx, c)
3464
3465    def _prepare_data_for_index_copy_and_add_deterministic(
3466        self, dim: int, device: torch.device
3467    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
3468        assert (dim >= 0 and dim < 3)
3469        a = [5, 4, 3]
3470        a[dim] = 2000
3471        x = torch.zeros(a, device=device)
3472        b = a.copy()
3473        elems = a[dim] * 20
3474        b[dim] = elems
3475        src = torch.rand(b, device=device)
3476        index = torch.randint(a[dim], (elems,), device=device)
3477        return (x, index, src)
3478
3479    # FIXME: move to test indexing
3480    @onlyNativeDeviceTypes
3481    def test_index_copy_deterministic(self, device: torch.device) -> None:
3482        for dim in range(3):
3483            x, index, src = self._prepare_data_for_index_copy_and_add_deterministic(dim, device)
3484            with DeterministicGuard(True):
3485                y0 = torch.index_copy(x, dim, index, src)
3486
3487            x0 = x.clone().detach()
3488            index_list = index.tolist()
3489            for i in range(len(index_list)):
3490                if dim == 0:
3491                    x0[index_list[i], :, :] = src[i, :, :]
3492                elif dim == 1:
3493                    x0[:, index_list[i], :] = src[:, i, :]
3494                elif dim == 2:
3495                    x0[:, :, index_list[i]] = src[:, :, i]
3496
3497            self.assertEqual(x0, y0, atol=0, rtol=0)
3498
3499    # FIXME: move to test indexing
3500    @onlyNativeDeviceTypes
3501    def test_index_add_deterministic(self, device: torch.device) -> None:
3502        for dim in range(3):
3503            x, index, src = self._prepare_data_for_index_copy_and_add_deterministic(dim, device)
3504            alpha = random.random() + 1
3505            # on CPU it should be deterministic regardless of the deterministic mode
3506            with DeterministicGuard(True):
3507                y0 = torch.index_add(x, dim, index, src, alpha=alpha)
3508                for _ in range(3):
3509                    y = torch.index_add(x, dim, index, src, alpha=alpha)
3510                    self.assertEqual(y, y0, atol=0, rtol=0)
3511
3512            with DeterministicGuard(False):
3513                for _ in range(3):
3514                    y_nd = torch.index_add(x, dim, index, src, alpha=alpha)
3515                    self.assertEqual(y_nd, y0, atol=1e-3, rtol=1e-5)
3516
3517    # FIXME: find a test suite for the put operator
3518    @onlyNativeDeviceTypes
3519    def test_index_put_non_accumulate_deterministic(self, device) -> None:
3520        with DeterministicGuard(True):
3521            for i in range(3):
3522                m = random.randint(10, 20)
3523                elems = random.randint(20000, 30000)
3524                values = torch.rand(elems, device=device)
3525                indices = torch.randint(m, (elems,), device=device)
3526                input = torch.rand(m, device=device)
3527                output = input.index_put((indices,), values, accumulate=False)
3528
3529                input_list = input.tolist()
3530                indices_list = indices.tolist()
3531                values_list = values.tolist()
3532                for i, v in zip(indices_list, values_list):
3533                    input_list[i] = v
3534
3535                self.assertEqual(output, input_list)
3536
3537    # FIXME: move to test indexing
3538    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3539    @skipIfMps
3540    def test_index_fill(self, device, dtype):
3541        x = torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device)
3542        index = torch.tensor([0], device=device)
3543        x.index_fill_(1, index, 0)
3544        self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dtype, device=device))
3545        if not x.is_complex() and not device == "meta":
3546            with self.assertRaisesRegex(RuntimeError, r"Scalar"):
3547                x.index_fill_(1, index, 1 + 1j)
3548        # Make sure that the result stays 0-dim while applied to
3549        # a 0-dim input
3550        x = torch.tensor(1, dtype=dtype, device=device)
3551        self.assertEqual(0, x.index_fill(0, index, -1).dim())
3552        self.assertEqual(0, x.index_fill_(0, index, -1).dim())
3553
3554    # FIXME: move to test indexing
3555    # The test fails for zero-dimensional tensors on XLA
3556    @onlyNativeDeviceTypes
3557    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3558    def test_index_select(self, device, dtype):
3559        num_src, num_out = 3, 5
3560
3561        def make_arg(batch_sizes, n, dim, contig):
3562            size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:]
3563            return make_tensor(size_arg, dtype=dtype, device=device, low=None, high=None, noncontiguous=not contig)
3564
3565        def ref_index_select(src, dim, idx):
3566            # bfloat16 is just used on GPU, so it's not supported on numpy
3567            if dtype == torch.bfloat16:
3568                src = src.float()
3569            out = torch.from_numpy(np.take(src.cpu().numpy(), idx.cpu().numpy(), axis=dim))
3570            if dtype == torch.bfloat16:
3571                out = out.to(device=device, dtype=dtype)
3572            return out
3573
3574        for src_contig, idx_contig in product([True, False], repeat=2):
3575            for other_sizes in ((), (4, 5)):
3576                for dim in range(len(other_sizes)):
3577                    src = make_arg(other_sizes, num_src, dim, src_contig)
3578                    idx = make_tensor(
3579                        (num_out,), dtype=torch.int64, device=device, low=0, high=num_src, noncontiguous=not idx_contig
3580                    )
3581                    out = torch.index_select(src, dim, idx)
3582                    out2 = ref_index_select(src, dim, idx)
3583                    self.assertEqual(out, out2)
3584
3585        for idx_type in (torch.int32, torch.int64):
3586            other_sizes = (3, 2)
3587            dim = 1
3588            src = make_arg(other_sizes, num_src, dim, True)
3589            idx = make_tensor((num_out,), dtype=idx_type, device=device, low=0, high=num_src, noncontiguous=False)
3590            out = torch.index_select(src, dim, idx)
3591            out2 = ref_index_select(src, dim, idx)
3592            self.assertEqual(out, out2)
3593
3594        # Create the 4 possible combinations of scalar sizes for index / source
3595        scalars = ((make_tensor(size_s, dtype=dtype, device=device),
3596                    torch.zeros(size_i, dtype=torch.int64, device=device))
3597                   for size_s, size_i in product([(), (1,)], repeat=2))
3598        for source, idx in scalars:
3599            out = source.index_select(0, idx)
3600            self.assertEqual(out.item(), source.item())
3601
3602    # FIXME: find a test suite for the take operator
3603    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3604    def test_take(self, device, dtype):
3605        idx_size = (4,)
3606
3607        make_arg = partial(make_tensor, device=device, dtype=dtype)
3608        make_idx = partial(make_tensor, low=0, device=device, dtype=torch.int64)
3609
3610        def ref_take(src, idx):
3611            if dtype == torch.bfloat16:
3612                src = src.half()
3613            src = src.cpu().numpy()
3614            idx = idx.cpu().numpy()
3615            out = torch.from_numpy(np.take(src, idx)).to(device=device, dtype=dtype)
3616            return out
3617
3618        for src_contig, idx_contig, idx_reshape in product([True, False], repeat=3):
3619            for src_size in ((5,), (4, 5)):
3620                src = make_arg(src_size, noncontiguous=not src_contig)
3621                idx = make_idx(idx_size, high=src.numel(), noncontiguous=not idx_contig)
3622                if idx_reshape:
3623                    idx = idx.reshape(2, 2)
3624                out = torch.take(src, idx)
3625                out2 = ref_take(src, idx)
3626                self.assertEqual(out, out2)
3627
3628        # Create the 4 possible combinations of scalar sizes for source / index
3629        for size_s, size_i in product([(), (1,)], repeat=2):
3630            source = make_arg(size_s)
3631            idx = make_idx(size_i, high=1)
3632            out = source.take(idx)
3633            self.assertEqual(out.item(), source.item())
3634
3635    # FIXME: find a test suite for the put operator
3636    # The bool instance does not work on GPU. See
3637    # https://github.com/pytorch/pytorch/issues/54317
3638    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
3639    def test_put(self, device, dtype):
3640        src_size = (4,)
3641
3642        make_arg = partial(make_tensor, device=device, dtype=dtype)
3643        make_idx = partial(make_tensor, low=0, device=device, dtype=torch.int64)
3644
3645        def ref_put(dst, idx, src, accumulate):
3646            new_dst = dst.clone(memory_format=torch.contiguous_format).view(-1)
3647            new_idx = idx.contiguous().view(-1)
3648            new_src = src.contiguous().view(-1)
3649            method = new_dst.index_add_ if accumulate else new_dst.index_copy_
3650            return method(0, new_idx, new_src).view_as(dst)
3651
3652        for dst_contig, src_contig, idx_contig, idx_reshape, accumulate in product([True, False], repeat=5):
3653            for dst_size in ((5,), (4, 5)):
3654                dst = make_arg(dst_size, noncontiguous=not dst_contig)
3655                src = make_arg(src_size, noncontiguous=not src_contig)
3656
3657                # If accumulate=True, `put_` should be deterministic regardless of the inputs on CPU
3658                # On CUDA it may not be, but the test has enough tolerance to account for this
3659                if accumulate:
3660                    idx = make_idx(src_size, high=dst.numel())
3661                else:
3662                    idx = torch.randperm(dst.numel(), dtype=torch.int64, device=device)[:src_size[0]]
3663                if not idx_contig:
3664                    idx = torch.repeat_interleave(idx, 2, dim=-1)[..., ::2]
3665                if idx_reshape:
3666                    idx = idx.reshape(2, 2)
3667                out = torch.put(dst, idx, src, accumulate)
3668                # out-place
3669                reference = ref_put(dst, idx, src, accumulate)
3670                self.assertEqual(out, reference)
3671
3672                # in-place
3673                dst.put_(idx, src, accumulate)
3674                self.assertEqual(dst, reference)
3675
3676
3677        # Create the 8 possible combinations of scalar sizes for target / index / source
3678        scalars = ((make_arg(size_t),
3679                    make_idx(size_i, high=1),
3680                    make_arg(size_s))
3681                   for size_t, size_i, size_s in product([(), (1,)], repeat=3))
3682        for (dest, idx, source), accumulate in product(scalars, [True, False]):
3683            dest_init = dest.clone()
3684            # out-place
3685            out = torch.put(dest, idx, source, accumulate=accumulate)
3686            # in-place
3687            dest1 = dest.clone()
3688            dest1.put_(idx, source, accumulate=accumulate)
3689            for d in [out, dest1]:
3690                if accumulate:
3691                    self.assertEqual(d.item(), (dest_init + source).item())
3692                else:
3693                    self.assertEqual(d.item(), source.item())
3694
3695        # Empty case
3696        dest = make_arg((3, 2))
3697        reference = dest.clone()
3698        idx = make_idx((0,), high=1)
3699        source = make_arg((0,))
3700        for accumulate in [True, False]:
3701            out = torch.put(dest, idx, source, accumulate=accumulate)
3702            self.assertEqual(out, reference)
3703            dest.put_(idx, source, accumulate=accumulate)
3704            self.assertEqual(dest, reference)
3705
3706    # FIXME: find a test suite for the put operator
3707    # The bool instance does not work on GPU. See
3708    # https://github.com/pytorch/pytorch/issues/54317
3709    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
3710    def test_put_accumulate(self, device, dtype):
3711        # Test for parallel adds with accumulate == True
3712        low_precision = dtype == torch.half or dtype == torch.bfloat16
3713        # Less numbers to avoid overflow with low_precision
3714        # Grainsize is 3000 for the for_loop to be parallized on CPU
3715        sizes = ((100,)) if low_precision else ((200,), (3002,))
3716        # Bfloat16 has a particularly bad performance here
3717        # This operation is nondeterministic on GPU, so we are generous with the rtol
3718        rtol, atol = (1e-1, 1e-2) if low_precision else (1e-3, 1e-4)
3719
3720        make_arg = partial(make_tensor, low=-2, high=3, device=device, dtype=dtype)
3721        # Dump everything into the 0-th position
3722        make_idx = partial(torch.zeros, device=device, dtype=torch.int64)
3723        args = ((make_idx(size), make_arg(size)) for size in sizes)
3724
3725        for idx, source in args:
3726            orig = make_arg((1,))
3727            out = orig.put(idx, source, accumulate=True)
3728            self.assertEqual(out, orig + source.sum(), rtol=rtol, atol=atol)
3729
3730    # FIXME: find a test suite for the take operator
3731    @skipIfMps
3732    def test_take_empty(self, device):
3733        for input_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]:
3734            for indices_shape in [(0,), (0, 1, 2, 0)]:
3735                input = torch.empty(input_shape, device=device)
3736                indices = torch.empty(indices_shape, dtype=torch.int64, device=device)
3737                self.assertEqual(indices, torch.take(input, indices), exact_dtype=False)
3738
3739    # FIXME: find a test suite for the put operator
3740    def test_put_empty(self, device):
3741        for dst_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]:
3742            for indices_shape in [(0,), (0, 1, 2, 0)]:
3743                for accumulate in [False, True]:
3744                    dst = torch.randn(dst_shape, device=device)
3745                    indices = torch.empty(indices_shape, dtype=torch.int64, device=device)
3746                    src = torch.randn(indices_shape, device=device)
3747                    self.assertEqual(dst, dst.put_(indices, src, accumulate=accumulate))
3748
3749    # FIXME: port to test_scatter_gather_ops.py
3750    def scatter_allow_reduce(self, device, dtype, reduceop):
3751        device_type = torch.device(device).type
3752        return device_type != 'cuda' or (reduceop == 'multiply' and dtype.is_floating_point)
3753
3754    @dtypes(*floating_and_complex_types())
3755    @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3756    @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3757    def test_scatter_reduce_operations_to_large_input(self, device, dtype):
3758        index = torch.tensor([[1], [2]], device=device, dtype=torch.long)
3759        test_data = [
3760            (torch.zeros(4, 4, device=device, dtype=dtype),
3761             torch.ones(2, 2, device=device, dtype=dtype),
3762             torch.tensor([[0, 0, 0, 0],
3763                           [1, 0, 0, 0],
3764                           [1, 0, 0, 0],
3765                           [0, 0, 0, 0]],
3766                          device=device, dtype=dtype), "add"),
3767            (torch.tensor([2], device=device, dtype=dtype).repeat(4, 4),
3768             torch.tensor([6], device=device, dtype=dtype).repeat(2, 2),
3769             torch.tensor([[2, 2, 2, 2],
3770                           [12, 2, 2, 2],
3771                           [12, 2, 2, 2],
3772                           [2, 2, 2, 2]], device=device, dtype=dtype), "multiply"),
3773        ]
3774
3775        for input, src, result, operation in test_data:
3776            if not self.scatter_allow_reduce(device, dtype, operation):
3777                continue
3778            input.scatter_(0, index, src, reduce=operation)
3779            self.assertEqual(input, result)
3780
3781    @dtypes(*floating_and_complex_types())
3782    @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3783    @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3784    def test_scatter_reduce_scalar(self, device, dtype):
3785        index = torch.tensor([[1], [2]], device=device, dtype=torch.long)
3786        test_data = [
3787            (torch.zeros(4, 4, device=device, dtype=dtype), 1,
3788             torch.tensor([[0, 0, 0, 0],
3789                           [1, 0, 0, 0],
3790                           [1, 0, 0, 0],
3791                           [0, 0, 0, 0]],
3792                          device=device, dtype=dtype), "add"),
3793            (torch.tensor([2], device=device, dtype=dtype).repeat(4, 4), 2,
3794             torch.tensor([[2, 2, 2, 2],
3795                           [4, 2, 2, 2],
3796                           [4, 2, 2, 2],
3797                           [2, 2, 2, 2]], device=device, dtype=dtype), "multiply"),
3798        ]
3799
3800        for input, src, result, operation in test_data:
3801            if not self.scatter_allow_reduce(device, dtype, operation):
3802                continue
3803            input.scatter_(0, index, src, reduce=operation)
3804            self.assertEqual(input, result)
3805
3806    # FIXME: port to test_scatter_gather_ops.py
3807    # TODO: remove this after scatter_add_ is deprecated.
3808    def test_scatter_add_non_unique_index(self, device):
3809        height = 2
3810        width = 65536
3811        input = torch.ones(height, width, device=device)
3812        index = torch.zeros(height, width, dtype=torch.long, device=device)
3813        src = torch.ones(height, width, device=device)
3814        input.scatter_add_(0, index, src)
3815
3816        self.assertEqual(input,
3817                         torch.tensor([[3], [1]], device=device,
3818                                      dtype=torch.float32).repeat(1, width))
3819
3820    @dtypes(*floating_and_complex_types())
3821    @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3822    @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3823    def test_scatter_reduce_non_unique_index(self, device, dtype):
3824        height = 2
3825        width = 2
3826        index = torch.zeros(height, width, dtype=torch.long, device=device)
3827        test_data = [
3828            (torch.ones(height, width, device=device, dtype=dtype),
3829             torch.ones(height, width, device=device, dtype=dtype),
3830             torch.tensor([[3], [1]], device=device, dtype=dtype).repeat(1, width), "add"),
3831            (torch.tensor([2], device=device, dtype=dtype).repeat(height, width),
3832             torch.tensor([2], device=device, dtype=dtype).repeat(height, width),
3833             torch.tensor([[8], [2]], device=device,
3834                          dtype=dtype).repeat(1, width), "multiply"),
3835        ]
3836
3837        for input, src, result, operation in test_data:
3838            if not self.scatter_allow_reduce(device, dtype, operation):
3839                continue
3840            input.scatter_(0, index, src, reduce=operation)
3841            self.assertEqual(input, result, msg=f"result: {result} input: {input} method: {str(operation)}")
3842
3843    @onlyCUDA
3844    @dtypes(*complex_types())
3845    def test_scatter_reduce_multiply_unsupported_dtypes(self, device, dtype):
3846        height = 2
3847        width = 2
3848        index = torch.zeros(height, width, dtype=torch.long, device=device)
3849        input = torch.ones(height, width, device=device, dtype=dtype)
3850        src = torch.ones(height, width, device=device, dtype=dtype)
3851        with self.assertRaises(RuntimeError):
3852            input.scatter_(0, index, src, reduce="multiply")
3853
3854    # FIXME: port to test_scatter_gather_ops.py
3855    def test_scatter_to_large_input(self, device):
3856        input = torch.zeros(4, 4, device=device)
3857        src = torch.ones(2, 2, device=device)
3858        index = torch.tensor([[1], [2]], device=device, dtype=torch.long)
3859        input.scatter_(0, index, src)
3860        self.assertEqual(input, torch.tensor([[0, 0, 0, 0],
3861                                              [1, 0, 0, 0],
3862                                              [1, 0, 0, 0],
3863                                              [0, 0, 0, 0]], device=device, dtype=torch.float32))
3864
3865    # FIXME: port to test_scatter_gather_ops.py
3866    def test_scatter_add_to_large_input(self, device):
3867        input = torch.zeros(4, 4, device=device)
3868        src = torch.ones(2, 2, device=device)
3869        index = torch.tensor([[1], [2]], device=device, dtype=torch.long)
3870        input.scatter_add_(0, index, src)
3871        self.assertEqual(input, torch.tensor([[0, 0, 0, 0],
3872                                              [1, 0, 0, 0],
3873                                              [1, 0, 0, 0],
3874                                              [0, 0, 0, 0]], device=device, dtype=torch.float32))
3875
3876    # FIXME: port to test_scatter_gather_ops.py
3877    def test_scatter_bool(self, device):
3878        x = torch.tensor([[True, True, True], [True, True, True]], device=device)
3879        res = torch.zeros(3, 3, dtype=torch.bool, device=device)
3880        res = res.scatter_(0, torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), x)
3881        self.assertEqual(res, torch.tensor([[True, False, False],
3882                                            [False, True, False],
3883                                            [False, False, True]], device=device))
3884
3885    # FIXME: port to test_scatter_gather_ops.py
3886    def test_scatter_add_bool(self, device):
3887        x = torch.tensor([[True, True, True, True, True], [True, True, True, True, True]], device=device)
3888        res = torch.zeros(3, 5, dtype=torch.bool, device=device)
3889        res = res.scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]], device=device), x)
3890        self.assertEqual(res, torch.tensor([[True, True, True, True, True],
3891                                            [False, True, False, True, False],
3892                                            [True, False, True, False, True]], device=device))
3893
3894    # FIXME: find a test suite for the masked scatter operator
3895    @onlyNativeDeviceTypes
3896    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
3897    def test_masked_scatter(self, device, dtype):
3898        dt = dtype
3899        num_copy, num_dest = 3, 10
3900        dest = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dt, device=device)
3901        dest2 = dest.clone()
3902        dest_ones = dest.clone()
3903        dest_ones_expected = dest.clone()
3904        src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt, device=device)
3905        src_ones = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=dt, device=device)
3906        mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=torch.bool, device=device)
3907
3908        dest.masked_scatter_(mask, src)
3909        j = 0
3910        for i in range(num_dest):
3911            if mask[i]:
3912                dest2[i] = src[j]
3913                dest_ones_expected[i] = src_ones[j]
3914                j += 1
3915        self.assertEqual(dest, dest2, atol=0, rtol=0)
3916
3917        dest_ones.masked_scatter_(mask, src_ones)
3918        self.assertEqual(dest_ones, dest_ones_expected, atol=0, rtol=0)
3919
3920        # Bound checking in CUDA is done inside a kernel
3921        # in order to avoid synchronization, but this means
3922        # we can not clear the failures. So there is no way
3923        # to test it then recover.
3924        if self.device_type != 'cuda':
3925            # make src smaller. this should fail
3926            src = torch.zeros(num_copy - 1, dtype=dt, device=device)
3927            with self.assertRaises(RuntimeError):
3928                dest.masked_scatter_(mask, src)
3929
3930        # empty tensor
3931        dest = torch.empty((5, 0, 5), dtype=dt, device=device)
3932        mask = torch.ones_like(dest, dtype=torch.bool, device=device)
3933        src = torch.empty((0,), dtype=dt, device=device)
3934        dest.masked_scatter_(mask, src)
3935
3936        dest = torch.empty((5, 0, 5), dtype=dt, device=device)
3937        mask = torch.ones((5, 1, 5), dtype=torch.bool, device=device)
3938        src = torch.empty((0,), dtype=dt, device=device)
3939        dest.masked_scatter_(mask, src)
3940
3941    # FIXME: find a test suite for the masked scatter operator
3942    @skipIfMps
3943    def test_masked_scatter_bool_tensor(self, device):
3944        src = torch.tensor([True, True, True], device=device)
3945        dst = torch.tensor([False, False, False], device=device)
3946        mask = torch.tensor([False, True, False], device=device)
3947
3948        dst.masked_scatter_(mask, src)
3949        self.assertEqual(dst, torch.tensor([False, True, False], device=device))
3950
3951        mask = torch.tensor([True, False, True], device=device)
3952        dst = dst.masked_scatter(mask, src)
3953        self.assertEqual(dst, torch.tensor([True, True, True], device=device))
3954
3955    # FIXME: find a test suite for the masked scatter operator
3956    #   test_scatter_gather_ops or test_masked_ops?
3957    @onlyCUDA
3958    @largeTensorTest('30GB')
3959    def test_masked_scatter_large_tensor(self, device):
3960        t_cpu = torch.empty(2**31 + 1, dtype=torch.bool).random_()
3961        t = t_cpu.to(device)
3962        result_cpu = t_cpu.masked_scatter(t_cpu, t_cpu)
3963        result = t.masked_scatter(t, t)
3964        self.assertEqual(result, result_cpu)
3965
3966    # FIXME: find a test suite for the masked select operator
3967    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3968    def test_masked_select(self, device, dtype):
3969        if device == 'cpu':
3970            warn = 'masked_select received a mask with dtype torch.uint8,'
3971        else:
3972            warn = 'indexing with dtype torch.uint8 is now deprecated, pl'
3973        for maskType in integral_types_and(torch.bool):
3974            num_src = 10
3975            src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dtype, device=device)
3976            mask = torch.randint(2, (num_src,), device=device, dtype=maskType)
3977
3978            if maskType is not torch.bool:
3979                with self.assertRaisesRegex(RuntimeError, r'expected BoolTensor for mask'):
3980                    dst = src.masked_select(mask)
3981                continue
3982            else:
3983                dst = src.masked_select(mask)
3984            dst2 = []
3985            for i in range(num_src):
3986                if mask[i]:
3987                    dst2 += [src[i]]
3988            self.assertEqual(dst, torch.tensor(dst2), atol=0, rtol=0)
3989
3990            dst3 = torch.empty(0, device=device, dtype=dtype)
3991            torch.masked_select(src, mask, out=dst3)
3992            self.assertEqual(dst3, torch.tensor(dst2, dtype=dst3.dtype), atol=0, rtol=0)
3993
3994        # Since half on CPU is not supported, need to skip the remaining test cases
3995        if dtype == torch.half and torch.device(device).type == 'cpu':
3996            return
3997
3998        # Ensure that masks are expanded to match tensor properly
3999        a = torch.rand(100, 100, device=device).mul(100).to(dtype)
4000        mask_first_el_each_row = torch.zeros(100, device=device, dtype=torch.bool)
4001        mask_first_el_each_row[0] = True
4002        a_masked = a.masked_select(mask_first_el_each_row)
4003        self.assertEqual(a_masked, a[:, 0])
4004
4005        mask_first_row = torch.zeros(100, 1, device=device, dtype=torch.bool)
4006        mask_first_row[0][0] = True
4007        a_masked = a.masked_select(mask_first_row)
4008        self.assertEqual(a_masked, a[0, :])
4009
4010        # Ensure that tensor is expanded to match mask properly
4011        a = torch.rand(100, device=device).mul(100).to(dtype)
4012        mask_copy_3_times = torch.tensor([[True], [True], [False], [True]], device=device)
4013        a_masked = a.masked_select(mask_copy_3_times)
4014        self.assertEqual(a_masked, a.unsqueeze(0).expand(3, 100).flatten())
4015
4016    # FIXME: find a test suite for the masked select operator
4017    def test_masked_select_discontiguous(self, device):
4018        for size in (10, 200):
4019            vals = torch.rand(size, size, device=device)
4020            mask = torch.full((size, size), False, dtype=torch.bool, device=device)
4021            mask[:, ::2] = True
4022            vals_list = (vals, vals.t())
4023            mask_list = (mask, mask.t())
4024            out_dc = torch.empty(size * size, device=device)[::2]
4025            for v, m in product(vals_list, mask_list):
4026                if m.is_contiguous():
4027                    expected = v[:, ::2].clone().reshape((-1, ))
4028                else:
4029                    expected = v[::2].clone().reshape((-1, ))
4030                out = torch.masked_select(v, m)
4031                self.assertEqual(out, expected, atol=0, rtol=0)
4032                torch.masked_select(v, m, out=out_dc)
4033                self.assertEqual(out_dc, expected, atol=0, rtol=0)
4034
4035    # FIXME: find a test suite for the masked fill operator
4036    @dtypes(*product(all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16), (torch.uint8, torch.bool)))
4037    def test_masked_fill(self, device, dtypes):
4038        dtype = dtypes[0]
4039        mask_dtype = dtypes[1]
4040
4041        num_dest = 10
4042        dst = torch.zeros(num_dest, dtype=dtype)
4043        mask = torch.randint(2, (num_dest,), dtype=mask_dtype)
4044        val = random.random()
4045        dst2 = dst.clone()
4046
4047        if mask_dtype is not torch.bool:
4048            with self.assertRaisesRegex(RuntimeError, 'only supports boolean masks'):
4049                dst.masked_fill_(mask, val)
4050            return
4051
4052        dst.masked_fill_(mask, val)
4053        for i in range(num_dest):
4054            if mask[i]:
4055                dst2[i] = val
4056        self.assertEqual(dst, dst2, atol=0, rtol=0)
4057
4058        # test non-contiguous case
4059        dst = ((torch.randn(num_dest, num_dest, num_dest) * 10).to(dtype)).permute((2, 0, 1))
4060        dst2 = dst.contiguous()
4061        if dtype.is_complex:
4062            mask = dst.abs() > 0
4063        else:
4064            mask = dst > 0
4065        self.assertTrue(not dst.is_contiguous())
4066        self.assertTrue(dst2.is_contiguous())
4067        dst.masked_fill_(mask.to(mask_dtype), val)
4068        dst2.masked_fill_(mask.to(mask_dtype), val)
4069        self.assertEqual(dst, dst2, atol=0, rtol=0)
4070
4071    # FIXME: find a test suite for the masked fill operator
4072    def test_masked_fill_bool_tensor(self, device):
4073        dst = torch.tensor([True, False, True], device=device)
4074        mask = torch.tensor([False, True, False], device=device)
4075
4076        dst.masked_fill_(mask, True)
4077        self.assertEqual(dst, torch.tensor([True, True, True], device=device))
4078
4079        dst = dst.masked_fill(mask, False)
4080        self.assertEqual(dst, torch.tensor([True, False, True], device=device))
4081
4082    def test_tensor_shape_empty(self, device):
4083        x = torch.randn((0, 1, 3, 0), device=device)
4084        # flatten
4085        self.assertEqual((0,), torch.flatten(x, 0, 3).shape)
4086        self.assertEqual((0, 0), torch.flatten(x, 0, 2).shape)
4087        self.assertEqual((0, 3, 0), torch.flatten(x, 1, 2).shape)
4088
4089        # squeeze, unsqueeze
4090        self.assertEqual((0, 1, 1, 3, 0), torch.unsqueeze(x, 1).shape)
4091        self.assertEqual((0, 3, 0), torch.squeeze(x, 1).shape)
4092        self.assertEqual((0, 3, 0), torch.squeeze(x).shape)
4093
4094        # transpose, t
4095        self.assertEqual((0, 0, 3, 1), torch.transpose(x, 1, 3).shape)
4096        y = torch.randn((5, 0), device=device)
4097        self.assertEqual((0, 5), y.t().shape)
4098
4099        # select
4100        self.assertEqual((0, 1, 0), torch.select(x, 2, 2).shape)
4101
4102        # repeat, permute
4103        self.assertEqual((9, 0, 5, 6, 0), x.repeat(9, 7, 5, 2, 3).shape)
4104        self.assertEqual((3, 0, 0, 1), x.permute(2, 3, 0, 1).shape)
4105
4106        # diagonal, diagflat
4107        self.assertEqual((0,), torch.diagonal(torch.randn((5, 0), device=device)).shape)
4108        self.assertEqual((0,), torch.diagonal(torch.randn((0, 5), device=device)).shape)
4109        # off the end offsets are valid
4110        self.assertEqual((0,), torch.diagonal(torch.randn((5, 0), device=device), offset=1).shape)
4111        self.assertEqual((0,), torch.diagonal(torch.randn((0, 5), device=device), offset=1).shape)
4112        # check non-zero sized offsets off the end
4113        self.assertEqual((5, 6, 0), torch.diagonal(torch.randn((3, 4, 5, 6), device=device), offset=45252).shape)
4114        self.assertEqual((5, 6, 0), torch.diagonal(torch.randn((3, 4, 5, 6), device=device), offset=-45252).shape)
4115
4116        self.assertEqual((0, 0), torch.diagflat(torch.tensor([], device=device)).shape)
4117        self.assertEqual(torch.zeros(1, 1), torch.diagflat(torch.tensor([], device=device), offset=1))
4118        self.assertEqual((0, 0), torch.diagflat(torch.tensor([[]], device=device)).shape)
4119        self.assertEqual(torch.zeros(1, 1), torch.diagflat(torch.tensor([[]], device=device), offset=1))
4120
4121        # stack, split, chunk
4122        self.assertEqual((4, 0, 1, 3, 0), torch.stack((x, x, x, x)).shape)
4123        self.assertEqual([(0, 1, 3, 0)],
4124                         [z.shape for z in torch.chunk(x, 1, dim=0)])
4125
4126        self.assertEqual([(0, 1, 3, 0), ] * 3, [z.shape for z in torch.chunk(x, 3, dim=0)])
4127        self.assertEqual([(0, 1, 1, 0), ] * 3, [z.shape for z in torch.chunk(x, 3, dim=2)])
4128
4129        # NOTE: split_with_sizes behaves differently than NumPy in that it
4130        # takes sizes rather than offsets
4131        self.assertEqual([(0, 1, 0, 0), (0, 1, 1, 0), (0, 1, 2, 0)],
4132                         [z.shape for z in torch.split(x, (0, 1, 2), dim=2)])
4133
4134        self.assertRaises(RuntimeError, lambda: torch.split(x, 0, dim=1))
4135        # This is strange because the split size is larger than the dim size, but consistent with
4136        # how split handles that case generally (when no 0s are involved).
4137        self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 1, dim=0)])
4138        self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 0, dim=0)])
4139
4140    # functions that operate over a dimension but don't reduce.
4141    def test_dim_function_empty(self, device):
4142        shape = (0, 1, 2, 0)
4143        x = torch.randn(shape, device=device)
4144
4145        # size stride
4146        self.assertEqual(0, x.size(3))
4147        self.assertEqual(2, x.size(2))
4148        self.assertEqual(2, x.stride(0))
4149        self.assertEqual(1, x.stride(2))
4150
4151        self.assertEqual(x, torch.nn.functional.glu(x, 0))
4152        self.assertEqual((0, 1, 1, 0), torch.nn.functional.glu(x, 2).shape)
4153
4154        # softmax, logsoftmax
4155        self.assertEqual(x, torch.nn.functional.softmax(x, 0))
4156        self.assertEqual(x, torch.nn.functional.softmax(x, 2))
4157        self.assertEqual(x, torch.nn.functional.softmax(x, 3))
4158
4159        self.assertEqual(x, torch.nn.functional.log_softmax(x, 0))
4160        self.assertEqual(x, torch.nn.functional.log_softmax(x, 2))
4161        self.assertEqual(x, torch.nn.functional.log_softmax(x, 3))
4162
4163        # cumsum, cumprod, cummax, cummin
4164        self.assertEqual(shape, torch.cumsum(x, 0).shape)
4165        self.assertEqual(shape, torch.cumsum(x, 2).shape)
4166        self.assertEqual(shape, torch.cumprod(x, 0).shape)
4167        self.assertEqual(shape, torch.cumprod(x, 2).shape)
4168        self.assertEqual(shape, torch.cummax(x, 0)[0].shape)
4169        self.assertEqual(shape, torch.cummax(x, 2)[0].shape)
4170        self.assertEqual(shape, torch.cummin(x, 0)[0].shape)
4171        self.assertEqual(shape, torch.cummin(x, 2)[0].shape)
4172        self.assertEqual(shape, torch.logcumsumexp(x, 0).shape)
4173        self.assertEqual(shape, torch.logcumsumexp(x, 2).shape)
4174
4175        # flip
4176        self.assertEqual(x, x.flip(0))
4177        self.assertEqual(x, x.flip(2))
4178
4179        # roll
4180        self.assertEqual(x, x.roll(0, 1).roll(0, -1))
4181        self.assertEqual(x, x.roll(1, x.size(1)))
4182        self.assertEqual(x, x.roll(1))
4183        self.assertEqual(x, x.roll((1, 1), (3, 1)))
4184
4185        # unbind
4186        self.assertEqual((), x.unbind(0))
4187        self.assertEqual((torch.empty((0, 1, 0), device=device), torch.empty((0, 1, 0), device=device)),
4188                         x.unbind(2))
4189
4190        # cross
4191        y = torch.randn((0, 1, 3, 0), device=device)
4192        self.assertEqual(y.shape, torch.cross(y, y).shape)
4193
4194        # renorm
4195        self.assertEqual(shape, torch.renorm(x, 1, 0, 5).shape)
4196        self.assertEqual(shape, torch.renorm(x, 1, 2, 5).shape)
4197
4198        # sort
4199        self.assertEqual([shape, shape], [z.shape for z in torch.sort(x, dim=0)])
4200        self.assertEqual([shape, shape], [z.shape for z in torch.sort(x, dim=2)])
4201
4202        # topk
4203        self.assertEqual([shape, shape], [z.shape for z in torch.topk(x, 0, dim=0)])
4204        self.assertEqual([(0, 1, 1, 0), (0, 1, 1, 0)], [z.shape for z in torch.topk(x, 1, dim=2)])
4205
4206        y = torch.randn((2, 3, 4), device=device)
4207        self.assertEqual([(2, 3, 0), (2, 3, 0)], [z.shape for z in torch.topk(y, 0)])
4208
4209        # gather
4210        self.assertEqual(shape, torch.gather(x, 0, torch.empty(shape, dtype=torch.int64, device=device)).shape)
4211        self.assertEqual(shape, torch.gather(x, 2, torch.empty(shape, dtype=torch.int64, device=device)).shape)
4212        larger_shape = torch.empty((0, 1, 3, 0), dtype=torch.int64, device=device)
4213        self.assertEqual(larger_shape.shape, torch.gather(x, 2, larger_shape).shape)
4214        smaller_shape = torch.empty((0, 1, 0, 0), dtype=torch.int64, device=device)
4215        self.assertEqual(smaller_shape.shape, torch.gather(x, 2, smaller_shape).shape)
4216        y = torch.randn((2, 3, 4), device=device)
4217        self.assertEqual((0, 3, 4),
4218                         torch.gather(y, 0, torch.empty((0, 3, 4), dtype=torch.int64, device=device)).shape)
4219
4220        # scatter, scatter_add
4221        for dim in [0, 2]:
4222            y = torch.randn(shape, device=device)
4223            y_src = torch.randn(shape, device=device)
4224            ind = torch.empty(shape, dtype=torch.int64, device=device)
4225            self.assertEqual(shape, y.scatter_(dim, ind, y_src).shape)
4226            self.assertEqual(shape, y.scatter_add_(dim, ind, y_src).shape)
4227
4228        z = torch.randn((2, 3, 4), device=device)
4229        z_src = torch.randn((2, 3, 4), device=device)
4230        self.assertEqual(z, z.scatter_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src))
4231        self.assertEqual(z, z.scatter_add_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src))
4232
4233        # index_fill, index_copy, index_add
4234        c = x.clone()
4235        c_clone = c.clone()
4236        ind_empty = torch.tensor([], dtype=torch.int64, device=device)
4237        ind_01 = torch.tensor([0, 1], dtype=torch.int64, device=device)
4238        self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1))
4239        self.assertEqual(c_clone, c.index_fill_(2, ind_empty, -1))
4240        self.assertEqual(c_clone, c.index_fill_(2, ind_01, -1))
4241        self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device)))
4242        self.assertEqual(c_clone, c.index_copy_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device)))
4243        self.assertEqual(c_clone, c.index_copy_(2, ind_01, torch.empty((0, 1, 2, 0), device=device)))
4244        self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device)))
4245        self.assertEqual(c_clone, c.index_add_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device)))
4246        self.assertEqual(c_clone, c.index_add_(2, ind_01, torch.empty((0, 1, 2, 0), device=device)))
4247
4248        c = torch.randn((0, 1, 2), device=device)
4249        c_clone = c.clone()
4250        self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1))
4251        self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
4252        self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
4253        self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1))
4254        self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
4255        self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
4256
4257        # index fill/copy/add non-empty
4258        z = torch.randn((2, 3, 4), device=device)
4259        self.assertEqual(z, z.index_fill_(0, ind_empty, -1))
4260        z = torch.randn((2, 3, 4), device=device)
4261        self.assertEqual(z, z.index_copy_(0, ind_empty, torch.empty((0, 3, 4), device=device)))
4262        z = torch.randn((2, 3, 4), device=device)
4263        self.assertEqual(z, z.index_add_(0, ind_empty, torch.empty((0, 3, 4), device=device)))
4264
4265        # index_select
4266        self.assertEqual(x, x.index_select(0, ind_empty))
4267        self.assertEqual((0, 1, 0, 0), x.index_select(2, ind_empty).shape)
4268        self.assertEqual(x, x.index_select(2, ind_01))
4269        z = torch.randn((2, 3, 4), device=device)  # non-empty
4270        self.assertEqual((0, 3, 4), z.index_select(0, ind_empty).shape)
4271        c = torch.randn((0, 1, 2), device=device)
4272        self.assertEqual(c, c.index_select(0, ind_empty))
4273        c = torch.randn((0, 1, 2), device=device)
4274        self.assertEqual(c, c.index_select(0, ind_empty))
4275        w = torch.randn((0, 3), device=device)
4276        self.assertEqual((0, 2), w.index_select(1, ind_01).shape)
4277        w = torch.randn((3, 0), device=device)
4278        self.assertEqual((2, 0), w.index_select(0, ind_01).shape)
4279        ind_01_int32 = torch.tensor([0, 1], dtype=torch.int32, device=device)
4280        self.assertEqual((2, 0), w.index_select(0, ind_01_int32).shape)
4281        s = torch.randn([], device=device)
4282        ind_0 = torch.tensor([0], dtype=torch.int32, device=device)
4283        self.assertEqual([], s.index_select(0, ind_0).shape)
4284        if device == 'cpu':
4285            w = torch.randn((0, 3), device=device)
4286            with self.assertRaisesRegex(RuntimeError, "self indexing axis dim should be positive"):
4287                torch.index_select(w, 0, ind_01)
4288            ind_05 = torch.tensor([0, 5], dtype=torch.int64, device=device)
4289            with self.assertRaisesRegex(RuntimeError, "INDICES element is out of DATA bounds"):
4290                torch.index_select(w, 1, ind_05)
4291            with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
4292                torch.index_select(s, 0, ind_empty)
4293        with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
4294            torch.ones([]).index_select(0, torch.Tensor([0, 0]).int())
4295
4296    # FIXME: find a test suite for the pdist operator
4297    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "sandcastle OOM with current tpx gpu/re configuration")
4298    @skipIfRocm
4299    @onlyCUDA
4300    @largeTensorTest('32GB', device='cpu')
4301    @largeTensorTest('5GB', device='cuda')
4302    def test_pdist_norm_large(self, device):
4303        # use dim0>=46342 for forward, see:
4304        # https://github.com/pytorch/pytorch/issues/30583
4305        # Compare output using GPU with the CPU implementation
4306        x = torch.randn(50000, 1, dtype=torch.float32)      # 50k * 4 bytes = 200 KB
4307        # Will require 1249975000 float32s
4308        expected_cpu = torch.pdist(x, p=2)                  # ~1250M * 4 bytes = 5 GB on CPU
4309        actual_cpu = torch.pdist(x.to(device), p=2).cpu()         # 5 GB on GPU + 5GB on CPU
4310        # Workaround for large memory overhead of self.assertTrue (see #84944)
4311        self.assertTrue(torch.allclose(expected_cpu, actual_cpu))  # ~20GB in allclose
4312
4313    # FIXME: move to elementwise ternary test suite
4314    @onlyNativeDeviceTypes
4315    @dtypesIfCUDA(*set(get_all_math_dtypes('cuda')))
4316    @dtypes(*set(get_all_math_dtypes('cpu')))
4317    def test_addcdiv(self, device, dtype):
4318        # Returns floating or integral scalar corresponding to dtype
4319        def _number(floating, integer, dtype):
4320            if dtype in [torch.half, torch.float, torch.double, torch.bfloat16]:
4321                return floating
4322            elif dtype in [torch.cfloat, torch.cdouble]:
4323                return floating * (1 + 1j)
4324            else:
4325                return integer
4326
4327        def non_zero_rand(size, dtype, device):
4328            if dtype.is_floating_point or dtype.is_complex:
4329                a = torch.rand(size=size, dtype=dtype, device=device)
4330            elif dtype == torch.uint8:
4331                a = torch.randint(1, 5, size=size, dtype=dtype, device=device)
4332            else:
4333                a = torch.randint(-5, 5, size=size, dtype=dtype, device=device)
4334            return a + (a == 0).to(dtype)
4335
4336        def _test_addcdiv():
4337            a = non_zero_rand((2, 2), dtype=dtype, device=device)
4338            b = non_zero_rand((2, 2), dtype=dtype, device=device)
4339            c = non_zero_rand((2, 2), dtype=dtype, device=device)
4340            alpha = _number(0.5, 3, dtype)
4341
4342            expected = a + (alpha * b) / c
4343            actual = torch.addcdiv(a, b, c, value=alpha)
4344            self.assertEqual(expected, actual)
4345
4346            with self.assertWarnsOnceRegex(
4347                    UserWarning, "This overload of addcdiv is deprecated"):
4348                self.assertEqual(actual, torch.addcdiv(a, alpha, b, c))
4349
4350        if not (dtype.is_floating_point or dtype.is_complex):
4351            # Integer division with addcdiv is prohibited
4352            with self.assertRaises(RuntimeError):
4353                _test_addcdiv()
4354        else:
4355            _test_addcdiv()
4356
4357        if self.device_type == 'cuda' and dtype == torch.half:
4358            a = torch.tensor([60000.0], device=device, dtype=dtype)
4359            b = torch.tensor([60000.0], device=device, dtype=dtype)
4360            c = torch.tensor([1.0], device=device, dtype=dtype)
4361            out = torch.addcmul(a, b, c, value=-2)
4362            self.assertTrue(not (out.isnan() or out.isinf()))
4363
4364    def test_nullary_op_mem_overlap(self, device):
4365        ops = (
4366            ("random_", ()),
4367            ("uniform_", ()),
4368            ("cauchy_", ()),
4369            ("log_normal_", ()),
4370            ("exponential_", ()),
4371            ("geometric_", (0.5,)),
4372            ("normal_", ()),
4373        )
4374
4375        x = torch.rand((1, 3)).expand((3, 3))
4376        for op, args in ops:
4377            with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4378                getattr(x, op)(*args)
4379
4380    # FIXME: move to an elementwise ternary test suite and make this an OpInfo test
4381    # https://github.com/pytorch/pytorch/issues/126474
4382    @xfailIfTorchDynamo
4383    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/126474")
4384    @dtypes(torch.double)
4385    def test_ternary_op_mem_overlap(self, device, dtype):
4386        if device == "cpu" and TEST_WITH_TORCHINDUCTOR:
4387            self.skipTest("Failing on cpu")
4388
4389        ops = [
4390            ("addcmul", True, True, 'cpu'),
4391            ("addcmul", True, True, 'cuda'),
4392            ("addcdiv", True, True, 'cpu'),
4393            ("addcdiv", True, True, 'cuda'),
4394            ("lerp", True, True, 'cpu'),
4395            ("lerp", True, True, 'cuda')
4396        ]
4397
4398        for (fn, has_input_output_mem_overlap_check,
4399             has_internal_mem_overlap_check, dev) in ops:
4400            if dev != device:
4401                continue
4402            out_op = getattr(torch, fn)
4403            inplace_op = getattr(torch.Tensor, fn + '_')
4404            self.check_internal_mem_overlap(
4405                inplace_op, 3, dtype, device,
4406                expected_failure=not has_internal_mem_overlap_check)
4407            self.ternary_check_input_output_mem_overlap(out_op, dev,
4408                                                        expected_failure=not has_input_output_mem_overlap_check)
4409
4410    @expectedFailureMeta  # RuntimeError not raised
4411    @dtypes(torch.double)
4412    @onlyNativeDeviceTypes
4413    def test_copy_mem_overlap(self, device, dtype):
4414        self.check_internal_mem_overlap(
4415            torch.Tensor.copy_, num_inputs=2, dtype=dtype, device=device)
4416        sz = 9
4417        doubles = torch.randn(2 * sz, dtype=dtype, device=device)
4418        self.unary_check_input_output_mem_overlap(
4419            doubles, sz, lambda input, out: out.copy_(input))
4420
4421    # FIXME: convert to ErrorInputs
4422    # (but have to extend ErrorInputs to handle inplace-only errors!)
4423    @onlyNativeDeviceTypes
4424    def test_index_add_mem_overlap(self, device):
4425        x = torch.rand((1,), device=device).expand((6,))
4426        y = torch.rand((6,), device=device)
4427        ind = torch.tensor([2, 1, 0], device=device)
4428        value = torch.rand((3,), device=device)
4429        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4430            x.index_add_(0, ind, value)
4431        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4432            y.index_add_(0, ind, y[:3])
4433        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4434            ind.index_add_(0, ind, ind.clone())
4435        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4436            ind.index_add_(0, ind.clone(), ind)
4437
4438    # FIXME: convert to ErrorInputs
4439    # (but have to extend ErrorInputs to handle inplace-only errors!)
4440    @onlyNativeDeviceTypes
4441    def test_index_copy_mem_overlap(self, device):
4442        x = torch.rand((1,), device=device).expand((6,))
4443        y = torch.rand((6,), device=device)
4444        ind = torch.tensor([2, 1, 0], device=device)
4445        value = torch.rand((3,), device=device)
4446        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4447            x.index_copy_(0, ind, value)
4448        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4449            y.index_copy_(0, ind, y[:3])
4450        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4451            ind.index_copy_(0, ind, ind.clone())
4452        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4453            ind.index_copy_(0, ind.clone(), ind)
4454
4455    # FIXME: convert to ErrorInputs
4456    # (but have to extend ErrorInputs to handle inplace-only errors!)
4457    @expectedFailureMeta  # Warning not triggered
4458    @onlyNativeDeviceTypes
4459    def test_index_fill_mem_overlap(self, device):
4460        x = torch.rand((1,), device=device).expand((6,))
4461        y = torch.rand((6,), device=device)
4462        ind = torch.tensor([2, 1, 0], device=device)
4463        value = torch.rand((3,), device=device)
4464
4465        with self.assertWarnsRegex(UserWarning, "index_fill_ on expanded tensors"):
4466            x.index_fill_(0, ind, 1.0)
4467        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4468            ind.index_fill_(0, ind, 0)
4469
4470    # FIXME: convert to ErrorInputs
4471    @expectedFailureMeta  # RuntimeError not raised
4472    @onlyNativeDeviceTypes
4473    def test_shift_mem_overlap(self, device):
4474        x = torch.rand(3, device=device)
4475        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4476            x[:-1] <<= x[1:]
4477        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4478            x[:-1] >>= x[1:]
4479
4480    # FIXME: convert to ErrorInputs
4481    # (but have to extend ErrorInputs to handle inplace-only errors)
4482    @expectedFailureMeta  # RuntimeError not raised
4483    @onlyNativeDeviceTypes
4484    def test_bernoulli_mem_overlap(self, device):
4485        x = torch.rand((1,), device=device).expand((6,))
4486
4487        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4488            x.bernoulli_()
4489        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4490            x.bernoulli_(p=0.1)
4491        p = torch.rand(6, device=device)
4492        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4493            x.bernoulli_(p=p)
4494
4495    # FIXME: convert to ErrorInputs
4496    # (but have to extend ErrorInputs to handle inplace-only errors!)
4497    @expectedFailureMeta  # RuntimeError not raised
4498    @onlyNativeDeviceTypes
4499    def test_put_mem_overlap(self, device):
4500        x = torch.rand((1,), device=device).expand((6,))
4501        y = torch.rand((6,), device=device)
4502        ind = torch.tensor([2, 1, 0], device=device)
4503        value = torch.rand((3,), device=device)
4504        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4505            x.put_(ind, value)
4506        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4507            y.put_(ind[0], y[0])
4508        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4509            ind.put_(ind, ind)
4510        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4511            y.put_(ind, y[:3])
4512        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4513            ind.put_(ind, ind.clone())
4514        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4515            ind.put_(ind.clone(), ind)
4516
4517    # FIXME: convert to ErrorInputs
4518    # (but have to extend ErrorInputs to handle inplace-only errors!)
4519    @expectedFailureMeta  # UserWarning not triggered
4520    @onlyNativeDeviceTypes
4521    def test_index_put_mem_overlap(self, device):
4522        x = torch.rand((1,), device=device).expand((6,))
4523        y = torch.rand((6,), device=device)
4524        ind = torch.tensor([2, 1, 0], device=device)
4525        value = torch.rand((3,), device=device)
4526        with self.assertWarnsRegex(UserWarning, 'expanded tensors'):
4527            x.index_put_((ind,), value)
4528        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4529            y.index_put_((ind,), y[0])
4530        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4531            ind.index_put_((ind,), ind)
4532        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4533            y.index_put_((ind,), y[:3])
4534        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4535            ind.index_put_((ind,), ind.clone())
4536        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4537            ind.index_put_((ind.clone(),), ind)
4538
4539    # FIXME: convert to ErrorInputs
4540    # (but have to extend ErrorInputs to handle inplace-only errors!)
4541    @expectedFailureMeta  # UserWarning not triggered
4542    @onlyNativeDeviceTypes
4543    def test_masked_fill_mem_overlap(self, device):
4544        x = torch.rand((1,), device=device).expand((6,))
4545        mask = torch.tensor([True, False, True, True, False, False], device=device)
4546        with self.assertWarnsRegex(UserWarning, 'expanded tensors'):
4547            x.masked_fill_(mask, 0.)
4548
4549        fill_val = torch.tensor(0., device=device)
4550        with self.assertWarnsRegex(UserWarning, 'expanded tensors'):
4551            x.masked_fill_(mask, fill_val)
4552
4553        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4554            mask[1:].masked_fill_(mask[:-1], False)
4555
4556    # FIXME: convert to ErrorInputs
4557    # (but have to extend ErrorInputs to handle inplace-only errors!)
4558    @expectedFailureMeta  # RuntimeError not raised
4559    @onlyNativeDeviceTypes
4560    def test_masked_scatter_mem_overlap(self, device):
4561        x = torch.rand((1,), device=device).expand((6,))
4562        src = torch.rand((3,), device=device)
4563        mask = torch.tensor([True, False, True, True, False, False], device=device)
4564
4565        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4566            x.masked_scatter_(mask, src)
4567
4568    # FIXME: convert to ErrorInputs
4569    # (but have to extend ErrorInputs to handle inplace-only errors!)
4570    @onlyNativeDeviceTypes
4571    def test_scatter_mem_overlap(self, device):
4572        x = torch.rand((1,), device=device).expand((6,))
4573        src = torch.rand((3,), device=device)
4574        ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64)
4575
4576        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4577            x.scatter_(0, ind, src)
4578        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4579            src.scatter_(0, ind, src)
4580        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4581            ind.scatter_(0, ind, ind.clone())
4582
4583    # FIXME: move to test distributions
4584    @onlyCUDA
4585    def test_multinomial_device_constrain(self, device):
4586        x = torch.empty(3, device="cpu")
4587        y = torch.empty(3, device=device)
4588        self.assertRaisesRegex(
4589            RuntimeError, "Expected all tensors to be on the same device",
4590            lambda: torch.multinomial(x, 2, out=y))
4591
4592    # FIXME: move to test distributions
4593    @deviceCountAtLeast(2)
4594    @onlyCUDA
4595    @skipIfTorchInductor("FIXME: error not thrown")
4596    def test_multinomial_gpu_device_constrain(self, devices):
4597        x = torch.empty(3, device=devices[0])
4598        y = torch.empty(3, device=devices[1], dtype=torch.long)
4599        self.assertRaisesRegex(
4600            RuntimeError, "Expected all tensors to be on the same device",
4601            lambda: torch.multinomial(x, 2, out=y))
4602
4603    # FIXME: convert this to an automated OpInfo test
4604    @deviceCountAtLeast(2)
4605    @onlyCUDA
4606    def test_device_guard(self, devices):
4607        # verify that all operators with `device_guard: False` behave properly with multiple devices.
4608        # TODO: if we had operator introspection we could figure out this set of operators automatically...
4609        x = torch.randn((1, 2, 3), device=devices[1])
4610        y = torch.zeros((1, 3, 2), device=devices[1])
4611        scalar = torch.tensor(5, device=devices[1])
4612
4613        # property ops
4614        torch.cudnn_is_acceptable(x)
4615        x.is_distributed()
4616        x.is_floating_point()
4617        x.is_complex()
4618        x.is_same_size(y)
4619        x.is_signed()
4620        x.size(0)
4621        x.stride(0)
4622        x.numel()
4623        x.is_set_to(y)
4624        x.data_ptr()
4625        scalar.is_nonzero()
4626
4627        # sparse property ops
4628        y[0][1] = 5
4629        y_sparse = y.to_sparse()
4630        y_sparse.sparse_dim()
4631        y_sparse._dimI()
4632        y_sparse.dense_dim()
4633        y_sparse._dimV()
4634        y_sparse._nnz()
4635        y_sparse.is_coalesced()
4636        y_sparse._indices()
4637        y_sparse._values()
4638        y_sparse.indices()
4639        y_sparse.values()
4640
4641        # in-place ops
4642        def inplace():
4643            return torch.randn((1, 2, 3), device=devices[1])
4644        inplace().as_strided_(y.size(), y.stride())
4645        inplace().resize_(y.size())
4646        inplace().squeeze_()
4647        inplace().squeeze_(0)
4648        inplace().unsqueeze_(2)
4649        inplace().transpose_(1, 2)
4650        inplace().squeeze_().t_()
4651        inplace().set_(x.storage())
4652        inplace().set_(x.storage(), x.storage_offset(), x.size(), x.stride())
4653        inplace().set_(x)
4654        inplace().set_()
4655        y_sparse._coalesced_(True)
4656
4657        # shape modification
4658        x.as_strided(y.size(), y.stride())
4659        x.expand((5, 2, 3))
4660        x.expand_as(x)
4661        x.sum_to_size((1,))
4662        torch.broadcast_tensors(x , x)
4663        x.reshape((1, 3, 2))
4664        x.reshape_as(y)
4665        x.squeeze()
4666        x.squeeze(0)
4667        x.squeeze().t()
4668        x.transpose(1, 2)
4669        x.unsqueeze(2)
4670        x.view((1, 3, 2))
4671        x.view_as(y)
4672
4673        # chunk, split, etc.
4674        x.chunk(2, dim=1)
4675        x.split(1, dim=2)
4676        x.split_with_sizes([1, 2], dim=2)
4677        x.unfold(dimension=2, size=1, step=1)
4678
4679        x.narrow(1, 1, 1)
4680        x.select(1, 1)
4681        torch.isnan(x)
4682
4683        torch.empty((1, 3, 2), out=y)
4684        torch.empty_like(x)
4685        torch.empty_like(x, dtype=torch.int64)
4686
4687        # to
4688        x.to(x)
4689        x.to(y)
4690        x.to(x, copy=True)
4691
4692    def test_is_signed(self, device):
4693        self.assertEqual(torch.IntTensor(5).to(device).is_signed(), True)
4694        self.assertEqual(torch.ByteTensor(5).to(device).is_signed(), False)
4695        self.assertEqual(torch.CharTensor(5).to(device).is_signed(), True)
4696        self.assertEqual(torch.FloatTensor(5).to(device).is_signed(), True)
4697        self.assertEqual(torch.HalfTensor(10).to(device).is_signed(), True)
4698
4699    def test_tensor_type(self):
4700        for t in torch._tensor_classes:
4701            if 'cuda' in t.__module__:
4702                self.assertEqual(t.is_cuda, True)
4703            else:
4704                self.assertEqual(t.is_cuda, False)
4705            if 'xpu' in t.__module__:
4706                self.assertEqual(t.is_xpu, True)
4707            else:
4708                self.assertEqual(t.is_xpu, False)
4709
4710    # Note - reports a leak of 512 bytes on CUDA device 1
4711    @deviceCountAtLeast(2)
4712    @skipCUDAMemoryLeakCheckIf(True)
4713    @onlyCUDA
4714    def test_tensor_set_errors_multigpu(self, devices):
4715        f_cuda0 = torch.randn((2, 3), dtype=torch.float32, device=devices[0])
4716        f_cuda1 = torch.randn((2, 3), dtype=torch.float32, device=devices[1])
4717
4718        self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1.storage()))
4719        self.assertRaises(RuntimeError,
4720                          lambda: f_cuda0.set_(f_cuda1.storage(), 0, f_cuda1.size(), f_cuda1.stride()))
4721        self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1))
4722
4723    # FIXME: move to test_serialization
4724    @onlyCUDA
4725    @deviceCountAtLeast(1)  # Note: Tests works with one but prefers more devices
4726    def test_serialization(self, devices):
4727        def _test_serialization(filecontext_lambda):
4728            t0 = torch.cuda.FloatTensor(5).fill_(1)
4729            with torch.cuda.device(devices[-1]):
4730                tn = torch.cuda.FloatTensor(3).fill_(2)
4731            torch.cuda.set_device(devices[0])
4732            b = (t0, tn)
4733            with filecontext_lambda() as f:
4734                torch.save(b, f)
4735                f.seek(0)
4736                c = torch.load(f)
4737                self.assertEqual(b, c, atol=0, rtol=0)
4738                u0, un = c
4739                self.assertEqual(str(u0.device), devices[0])
4740                self.assertEqual(str(un.device), devices[-1])
4741
4742        _test_serialization(tempfile.NamedTemporaryFile)
4743        _test_serialization(BytesIOContext)
4744
4745    # FIXME: move memory format tests to their own test class/suite
4746    def test_memory_format_preserved_after_permute(self, device):
4747        x = torch.randn(4, 3, 8, 8, device=device)
4748        nhwc = x.contiguous(memory_format=torch.channels_last)
4749        y = nhwc.permute(0, 1, 3, 2).permute(0, 1, 3, 2)
4750        self.assertTrue(y.is_contiguous(memory_format=torch.channels_last))
4751
4752        x = torch.randn(4, 3, 8, 8, 8, device=device)
4753        ndhwc = x.contiguous(memory_format=torch.channels_last_3d)
4754        y = ndhwc.permute(0, 1, 4, 3, 2).permute(0, 1, 4, 3, 2)
4755        self.assertTrue(y.is_contiguous(memory_format=torch.channels_last_3d))
4756
4757    def test_memory_format_propagation_rules(self, device):
4758
4759        contiguous = torch.rand(10, 3, 5, 5, device=device)
4760        cl = torch.rand(10, 3, 5, 5, device=device).contiguous(memory_format=torch.channels_last)
4761        ambiguous = torch.rand(10, 3, 1, 1, device=device).contiguous(memory_format=torch.channels_last)
4762        self.assertTrue(ambiguous.is_contiguous(memory_format=torch.channels_last))
4763        self.assertTrue(ambiguous.is_contiguous(memory_format=torch.contiguous_format))
4764        bias = torch.rand(1, 1, 1, 1, device=device).contiguous(memory_format=torch.channels_last)
4765
4766        def _test_propagation_rules(self, contiguous, cl, ambiguous, bias):
4767            options = ((ambiguous, contiguous, torch.contiguous_format),
4768                       (ambiguous, cl, torch.channels_last),
4769                       (contiguous, ambiguous, torch.contiguous_format),
4770                       (contiguous, cl, torch.contiguous_format),
4771                       (cl, ambiguous, torch.channels_last),
4772                       (cl, contiguous, torch.channels_last),
4773                       (bias, cl, torch.channels_last),
4774                       (cl, bias, torch.channels_last),)
4775
4776            for a, b, mf in options:
4777                result = a + b
4778                self.assertTrue(result.is_contiguous(memory_format=mf))
4779
4780        _test_propagation_rules(self, contiguous, cl, ambiguous, bias)
4781
4782        cl = cl.to(memory_format=torch.channels_last)
4783        ambiguous = ambiguous.to(memory_format=torch.channels_last)
4784        bias = bias.to(memory_format=torch.channels_last)
4785
4786        _test_propagation_rules(self, contiguous, cl, ambiguous, bias)
4787
4788        # test cases when strides matter in ambiguous tensors
4789        for mf in (torch.channels_last, torch.contiguous_format):
4790            ambiguous = torch.rand(10, 3, 1, 1, device=device).to(memory_format=mf)
4791            bias = torch.rand(3, 1, 1, device=device)
4792            result = ambiguous + bias
4793            self.assertEqual(ambiguous.stride(), result.stride())
4794            result = bias + ambiguous
4795            self.assertEqual(ambiguous.stride(), result.stride())
4796            result = ambiguous * 5
4797            self.assertEqual(ambiguous.stride(), result.stride())
4798
4799    @skipIfMps
4800    def test_memory_format_empty_like(self, device):
4801        def test_helper(x, memory_format):
4802            xc = x.contiguous(memory_format=memory_format)
4803
4804            like = torch.empty_like(xc, memory_format=torch.preserve_format)
4805            self.assertFalse(like.is_contiguous())
4806            self.assertTrue(like.is_contiguous(memory_format=memory_format))
4807
4808            like_x = torch.empty_like(x, memory_format=torch.preserve_format)
4809            self.assertTrue(like_x.is_contiguous())
4810            self.assertFalse(like_x.is_contiguous(memory_format=memory_format))
4811
4812            like = torch.empty_like(x, memory_format=memory_format)
4813            self.assertFalse(like.is_contiguous())
4814            self.assertTrue(like.is_contiguous(memory_format=memory_format))
4815
4816            like = torch.empty_like(xc, memory_format=torch.contiguous_format)
4817            self.assertTrue(like.is_contiguous())
4818            self.assertFalse(like.is_contiguous(memory_format=memory_format))
4819
4820            like = torch.empty_like(xc)
4821            self.assertFalse(like.is_contiguous())
4822            self.assertTrue(like.is_contiguous(memory_format=memory_format))
4823
4824            sparse = x.to_sparse()
4825            with self.assertRaises(RuntimeError):
4826                z = torch.empty_like(sparse, memory_format=torch.preserve_format)
4827
4828        test_helper(torch.randn(4, 3, 8, 8, device=device), torch.channels_last)
4829        test_helper(torch.randn(4, 3, 8, 8, 8, device=device), torch.channels_last_3d)
4830
4831    def test_memory_format_consistency(self, device):
4832        x = torch.randn(10, 3, 1, 1, device=device)
4833        x_rep = x.as_strided(x.size(), x.stride())
4834        self.assertEqual(x.size(), x_rep.size())
4835        self.assertEqual(x.stride(), x_rep.stride())
4836        self.assertEqual(x.is_contiguous(), x_rep.is_contiguous())
4837        self.assertEqual(x.is_contiguous(memory_format=torch.channels_last), x_rep.is_contiguous(memory_format=torch.channels_last))
4838        self.assertEqual(
4839            x.is_contiguous(memory_format=torch.channels_last_3d), x_rep.is_contiguous(memory_format=torch.channels_last_3d))
4840
4841    # FIXME: make this a elementwise unary and elementwise binary OpInfo test
4842    def test_memory_format_operators(self, device):
4843        def _chunk_op(x, y):
4844            x1, x2 = x.chunk(2, dim=1)
4845            return x1 + x2
4846
4847        def _unsqueeze_op_add(x, y):
4848            return x[0].unsqueeze(0) + 3
4849
4850        def _unsqueeze_op_clone(x, y):
4851            return x[0].unsqueeze(0).clone()
4852
4853        def _test_helper(x, y, bias, memory_format):
4854            return_contig_fns = [
4855                lambda x, y: y + x,
4856                lambda x, y: y * x,
4857                lambda x, y: y.addcdiv(x, y, value=2),
4858                lambda x, y: y.addcmul(x, y, value=2),
4859            ]
4860            bias_fns = [
4861                lambda x, b: x + b,
4862                lambda x, b: b + x,
4863            ]
4864            fns = [
4865                lambda x, y: x.clone(),
4866                lambda x, y: x + 3,
4867                lambda x, y: 3 * x,
4868                lambda x, y: x + y,
4869                lambda x, y: x * y,
4870                lambda x, y: abs(x),
4871                lambda x, y: x.abs(),
4872                lambda x, y: x.abs_(),
4873                lambda x, y: x.acos(),
4874                lambda x, y: x.acos_(),
4875                lambda x, y: x.add(y, alpha=3),
4876                lambda x, y: x.add_(y, alpha=3),
4877                lambda x, y: x.addcdiv(y, y, value=2),
4878                lambda x, y: x.addcdiv_(y, y, value=2),
4879                lambda x, y: x.addcmul(y, y, value=2),
4880                lambda x, y: x.addcmul_(y, y, value=2),
4881                lambda x, y: x.acosh(),
4882                lambda x, y: x.acosh_(),
4883                lambda x, y: x.asinh(),
4884                lambda x, y: x.asinh_(),
4885                lambda x, y: x.atanh(),
4886                lambda x, y: x.atanh_(),
4887                lambda x, y: x.asin(),
4888                lambda x, y: x.asin_(),
4889                lambda x, y: x.atan(),
4890                lambda x, y: x.atan2(y),
4891                lambda x, y: x.atan2_(y),
4892                lambda x, y: x.ceil(),
4893                lambda x, y: x.ceil_(),
4894                lambda x, y: x.clamp(-1, 1),
4895                lambda x, y: x.cos(),
4896                lambda x, y: x.cosh(),
4897                lambda x, y: x.div(0.5),
4898                lambda x, y: x.div_(0.5),
4899                lambda x, y: x.div(y),
4900                lambda x, y: x.div_(y),
4901                lambda x, y: x.digamma(),
4902                lambda x, y: x.digamma_(),
4903                lambda x, y: x.erf(),
4904                lambda x, y: x.erfc(),
4905                lambda x, y: x.erfinv(),
4906                lambda x, y: x.erfinv_(),
4907                lambda x, y: x.exp(),
4908                lambda x, y: x.expm1(),
4909                lambda x, y: x.expm1_(),
4910                lambda x, y: x.floor(),
4911                lambda x, y: x.floor_(),
4912                lambda x, y: x.fmod(2),
4913                lambda x, y: x.frac(),
4914                lambda x, y: x.hypot(y),
4915                lambda x, y: x.hypot_(y),
4916                lambda x, y: x.i0(),
4917                lambda x, y: x.i0_(),
4918                lambda x, y: x.lerp(y, 0.5),
4919                lambda x, y: x.log(),
4920                lambda x, y: x.log_(),
4921                lambda x, y: x.log10(),
4922                lambda x, y: x.log10_(),
4923                lambda x, y: x.log1p(),
4924                lambda x, y: x.log1p_(),
4925                lambda x, y: x.log2(),
4926                lambda x, y: x.log2_(),
4927                lambda x, y: x.mul(3),
4928                lambda x, y: x.mul_(3),
4929                lambda x, y: x.neg(),
4930                lambda x, y: x.neg_(),
4931                lambda x, y: x.pow(3),
4932                lambda x, y: x.pow_(3),
4933                lambda x, y: x.pow(0.0),
4934                lambda x, y: x.pow(1.0),
4935                lambda x, y: x.reciprocal(),
4936                lambda x, y: x.remainder(2),
4937                lambda x, y: x.round(),
4938                lambda x, y: x.round_(),
4939                lambda x, y: x.rsqrt(),
4940                lambda x, y: x.rsqrt_(),
4941                lambda x, y: x.sigmoid(),
4942                lambda x, y: x.sigmoid_(),
4943                lambda x, y: x.logit(),
4944                lambda x, y: x.logit_(),
4945                lambda x, y: x.logit(1e-6),
4946                lambda x, y: x.logit_(1e-6),
4947                lambda x, y: x.sign(),
4948                lambda x, y: x.sign_(),
4949                lambda x, y: x.sgn(),
4950                lambda x, y: x.sgn_(),
4951                lambda x, y: x.sin(),
4952                lambda x, y: x.sin_(),
4953                lambda x, y: x.sinh(),
4954                lambda x, y: x.sinh_(),
4955                lambda x, y: x.sqrt(),
4956                lambda x, y: x.sqrt_(),
4957                lambda x, y: x.tan(),
4958                lambda x, y: x.tanh(),
4959                lambda x, y: x.trunc(),
4960                lambda x, y: x.trunc_(),
4961                _chunk_op,
4962                _unsqueeze_op_add,
4963                _unsqueeze_op_clone,
4964            ]
4965            x_c = x.contiguous()
4966            y_c = y.contiguous()
4967            b_c = bias.contiguous()
4968            for fn in fns:
4969                is_inplace = '_(' in inspect.getsource(fn)
4970                x_clone = x.clone() if is_inplace else x
4971                x_c_clone = x_c.clone() if is_inplace else x_c
4972                result_c = fn(x_c_clone, y_c)
4973                result = fn(x_clone, y)
4974                self.assertEqual(result, result_c, f"Failed for '{inspect.getsource(fn).strip()}'")
4975                self.assertTrue(
4976                    result.is_contiguous(memory_format=memory_format),
4977                    f"result of the '{inspect.getsource(fn).strip()}' is not in '{memory_format}' format")
4978
4979            for fn in bias_fns:
4980                result_c = fn(x_c, b_c)
4981                result = fn(x, bias)
4982                self.assertEqual(result, result_c, f"Failed for '{inspect.getsource(fn).strip()}'")
4983                self.assertTrue(
4984                    result.is_contiguous(memory_format=memory_format),
4985                    f"result of the '{inspect.getsource(fn).strip()}' is not in '{memory_format}' format")
4986
4987            for fn in return_contig_fns:
4988                result_c = fn(x_c, y_c)
4989                result = fn(x, y)
4990                self.assertEqual(result, result_c, f"Failed for '{inspect.getsource(fn).strip()}'")
4991                self.assertTrue(
4992                    result.is_contiguous(memory_format=torch.contiguous_format),
4993                    f"result of the '{inspect.getsource(fn).strip()}' is not in '{torch.contiguous_format}' format")
4994
4995        _test_helper(
4996            torch.randn((4, 3, 8, 8), device=device).contiguous(memory_format=torch.channels_last),
4997            abs(torch.randn((4, 3, 8, 8), device=device)) + 1,
4998            torch.randn((1, 3, 1, 1), device=device).contiguous(memory_format=torch.channels_last),
4999            torch.channels_last)
5000        _test_helper(
5001            torch.randn((4, 3, 8, 8, 8), device=device).contiguous(memory_format=torch.channels_last_3d),
5002            abs(torch.randn((4, 3, 8, 8, 8), device=device)) + 1,
5003            torch.randn((1, 3, 1, 1, 1), device=device).contiguous(memory_format=torch.channels_last_3d),
5004            torch.channels_last_3d)
5005
5006    # FIXME: make this a elementwise unary and elementwise binary OpInfo test
5007    def test_strides_propagation(self, device):
5008        def _test_helper(x, op, unary=False):
5009            def compare_strides(s1, s2, div):
5010                sdiv = [s // div for s in s1]
5011                self.assertEqual(sdiv, s2)
5012
5013            dim = x.dim()
5014            # we produce memory dense outputs, so when input is strided on the last dimension
5015            # we need to divide by that dimension stride to compare input and result strides
5016            div = x.stride(-1)
5017            for p in permutations(range(dim)):
5018                xp = x.permute(p)
5019                if not unary:
5020                    y = torch.randn(xp.size(-1), device=x.device, dtype=x.dtype)
5021                    for inputs in ((xp, xp), (xp, y), (y, xp)):
5022                        res = op(*inputs)
5023                        compare_strides(xp.stride(), res.stride(), div)
5024                        self.assertEqual(xp.size(), res.size())
5025                        out = torch.empty(0, device=xp.device, dtype=res.dtype)
5026                        res = op(*inputs, out=out)
5027                        compare_strides(xp.stride(), res.stride(), div)
5028                        self.assertEqual(xp.size(), res.size())
5029                else:
5030                    res = op(xp)
5031                    compare_strides(xp.stride(), res.stride(), div)
5032                    self.assertEqual(xp.size(), res.size())
5033                    out = torch.empty(0, device=xp.device, dtype=res.dtype)
5034                    res = op(xp, out=out)
5035                    compare_strides(xp.stride(), res.stride(), div)
5036                    self.assertEqual(xp.size(), res.size())
5037
5038        # torch.eq by default calls TensorIterator with defined output, torch.add with undefined
5039        binary_ops = (torch.eq, torch.add)
5040        unary_ops = (torch.exp,)
5041        # memory dense, sliced and ambiguous sliced (ambiguous dense loses permutation information)
5042        xs = (torch.randn(2, 3, 4, device=device), torch.randn(2, 3, 8, device=device)[:, :, ::2],
5043              torch.randn(1, 1, 4, 12, device=device)[:, :, :, ::2])
5044        for op in binary_ops:
5045            for x in xs:
5046                _test_helper(x, op)
5047        for op in unary_ops:
5048            for x in xs:
5049                _test_helper(x, op, unary=True)
5050
5051    @onlyCUDA
5052    @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
5053    @skipIfTorchDynamo("NotImplementedError: PrimTorch does not support pinned memory")
5054    def test_pin_memory_from_constructor(self, device):
5055        def _get_like(t, **kwargs):
5056            return [
5057                torch.rand_like(t, **kwargs),
5058                torch.randn_like(t, **kwargs),
5059                torch.empty_like(t, **kwargs),
5060                torch.full_like(t, 4, **kwargs),
5061                torch.zeros_like(t, **kwargs),
5062                torch.ones_like(t, **kwargs),
5063            ]
5064
5065        def _get_tensors(**kwargs):
5066            return [
5067                torch.tensor([10, 11], **kwargs),
5068                torch.randn(3, 5, **kwargs),
5069                torch.rand(3, **kwargs),
5070                # torch.randint(3, 5, **kwargs), // unsupported
5071                torch.zeros(3, **kwargs),
5072                torch.randperm(3, **kwargs),
5073                torch.empty(6, **kwargs),
5074                torch.ones(6, **kwargs),
5075                torch.eye(6, **kwargs),
5076                torch.arange(3, 5, **kwargs)]
5077
5078        pinned_tensors = _get_tensors(pin_memory=True) + _get_like(torch.empty(5, dtype=torch.float64), pin_memory=True)
5079        for x in pinned_tensors:
5080            self.assertTrue(x.is_pinned())
5081
5082        tensors = _get_tensors() + _get_like(torch.empty(5, dtype=torch.float64, pin_memory=True))
5083        for x in tensors:
5084            self.assertFalse(x.is_pinned())
5085
5086    @deviceCountAtLeast(1)
5087    @onlyCUDA
5088    def test_storage_all_devices(self, devices):
5089        for device in devices:
5090            t = torch.tensor((), device=device)
5091            self.assertEqual(t.dtype, t.storage().dtype)
5092
5093    # Note [lazy_clone_ tests with inductor enabled]
5094    # These `lazy_clone_` tests are written in a way that makes them pass in
5095    # both eager mode and compiled mode (`PYTORCH_TEST_WITH_INDUCTOR=1`). There
5096    # are cases where COW tensors can materialize at different times and in
5097    # different ways in compiled mode versus eager mode, and those cases need to
5098    # be avoided. There are two main wrinkles the be aware of.
5099    #
5100    # The first wrinkle is that these tests have to check the internal
5101    # properties of tensors to make sure they materialize in the expected way,
5102    # and those checks cause dynamo graph breaks. Depending on the situation, a
5103    # graph break in-between two compiled graphs that operate on the same COW
5104    # tensor can make the tensor materialize when it would not materialize in
5105    # eager mode, causing the checks to fail. The strategy for avoiding this is
5106    # to make all the operations on COW tensors get compiled into the same
5107    # graph, by not doing any checks between the operations, and just do all the
5108    # checks at the end of the test. If we really do want to perform checks
5109    # between two operations, `op1` and `op2`, the solution is to create two
5110    # different tests. One test performs just `op1` and then checks. The other
5111    # test performs `op1` followed immediately by `op2` and then checks.
5112    #
5113    # The second wrinkle is that in eager mode, if we perform writes on two COW
5114    # tensors where one is a lazy clone of the other, the first tensor to be
5115    # written will be materialized with a new data pointer, and the second
5116    # tensor will just reuse the original data pointer when it is materialized.
5117    # But in compiled mode, if these writes happen in the same graph, the order
5118    # in which the tensors materialize can be different than in eager mode. So
5119    # in this case the strategy is to purposefully cause a graph break to happen
5120    # in-between the two write operations, by adding checks between them, so
5121    # that they have to materialize in the expected order.
5122    @skipXLA
5123    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
5124    def test_lazy_clone(self, device, dtype):
5125        t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
5126        t_orig_storage_addr = torch._C._storage_address(t)
5127        orig_data_ptr = torch._C._data_address(t)
5128        clone = t._lazy_clone()
5129
5130        # Lazy cloning a tensor should cause both it and its clone to become COW
5131        # tensors. They should have different storages, but the same data
5132        # pointer.
5133
5134        self.assertTrue(torch._C._is_cow_tensor(clone))
5135        self.assertTrue(torch._C._is_cow_tensor(t))
5136
5137        self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
5138        self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
5139
5140        self.assertTrue(torch._C._data_address(t) == orig_data_ptr)
5141        self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
5142
5143    # See Note [lazy_clone_ tests with inductor enabled]
5144    @skipXLA
5145    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
5146    def test_lazy_clone_view(self, device, dtype):
5147        t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
5148        t_orig_storage_addr = torch._C._storage_address(t)
5149        orig_data_ptr = torch._C._data_address(t)
5150        clone = t._lazy_clone()
5151        view = t.view([4])
5152
5153        # Viewing `t` should not cause a copy (materialize) to happen. All the
5154        # tensors should still be COW and have the same data pointer. `view` and
5155        # `t` should have the same storage, and `clone` should have a different
5156        # storage.
5157
5158        self.assertTrue(torch._C._is_cow_tensor(t))
5159        self.assertTrue(torch._C._is_cow_tensor(view))
5160        self.assertTrue(torch._C._is_cow_tensor(clone))
5161
5162        self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
5163        self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr)
5164        self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
5165
5166        self.assertTrue(torch._C._data_address(t) == orig_data_ptr)
5167        self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
5168        self.assertTrue(torch._C._data_address(view) == orig_data_ptr)
5169
5170    # See Note [lazy_clone_ tests with inductor enabled]
5171    @skipXLA
5172    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
5173    def test_lazy_clone_view_materialize(self, device, dtype):
5174        t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
5175        t_orig_storage_addr = torch._C._storage_address(t)
5176        orig_data_ptr = torch._C._data_address(t)
5177        clone = t._lazy_clone()
5178        view = t.view([4])
5179        view += torch.ones(1, device=device, dtype=dtype)
5180
5181        # Writing to `t` should cause the storage under `t` and `view` to be
5182        # copied (materialized), but should not affect `clone`.
5183
5184        self.assertFalse(torch._C._is_cow_tensor(t))
5185        self.assertFalse(torch._C._is_cow_tensor(view))
5186        self.assertTrue(torch._C._is_cow_tensor(clone))
5187
5188        self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
5189        self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr)
5190        self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
5191
5192        t_new_data_addr = torch._C._data_address(t)
5193        self.assertTrue(t_new_data_addr != orig_data_ptr)
5194        self.assertTrue(torch._C._data_address(view) == t_new_data_addr)
5195        self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
5196
5197        clone += torch.ones(1, device=device, dtype=dtype)
5198
5199        # Writing to `clone` should materialize it, so it should no longer
5200        # be COW. However, since `clone`'s storage is the only COW storage
5201        # left that holds a reference to the original data pointer, this
5202        # materialization should not actually cause a copy--it should
5203        # just reuse the original data pointer.
5204
5205        self.assertFalse(torch._C._is_cow_tensor(t))
5206        self.assertFalse(torch._C._is_cow_tensor(view))
5207        self.assertFalse(torch._C._is_cow_tensor(clone))
5208
5209        self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
5210        self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr)
5211        self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
5212
5213        self.assertTrue(torch._C._data_address(t) == t_new_data_addr)
5214        self.assertTrue(torch._C._data_address(view) == t_new_data_addr)
5215        self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
5216
5217    @skipXLA
5218    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
5219    def test_lazy_clone_binary_op_no_materialize(self, device, dtype):
5220        t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
5221        clone = t._lazy_clone()
5222        res = t + clone
5223        self.assertTrue(torch._C._is_cow_tensor(t))
5224        self.assertTrue(torch._C._is_cow_tensor(clone))
5225
5226    # This tests that if a COW materialization is attempted inside an
5227    # `at::parallel_for` loop function, then an error is raised. This test is
5228    # implemented in Python rather than C++ because the C++ tests are built
5229    # without multithreading support in `at::parallel_for`.
5230    @skipXLA
5231    @skipIfTorchDynamo("Torchdynamo fails and we do not need to test it here anyway")
5232    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
5233    def test_parallel_cow_materialize_error(self, device, dtype):
5234
5235        def run(num_threads, num_parallel, skip_first, should_error):
5236            orig_num_threads = torch.get_num_threads()
5237
5238            try:
5239                torch.set_num_threads(num_threads)
5240
5241                a = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)._lazy_clone()
5242
5243                if should_error:
5244                    with self.assertRaisesRegex(RuntimeError, r'Materializing a storage'):
5245                        torch._test_parallel_materialize(
5246                            a, num_parallel, skip_first)
5247                else:
5248                    torch._test_parallel_materialize(a, num_parallel, skip_first)
5249
5250                # Error should not raise in any case if the tensor is not COW
5251                b = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
5252                torch._test_parallel_materialize(b, num_parallel, skip_first)
5253
5254            finally:
5255                torch.set_num_threads(orig_num_threads)
5256
5257        run(1, 1, False, True)
5258        run(1, 1, True, False)
5259        run(1, 10, False, True)
5260        run(1, 10, True, True)
5261        run(10, 1, False, True)
5262        run(10, 1, True, False)
5263        run(10, 10, False, True)
5264        run(10, 10, True, True)
5265        run(10, 2, False, True)
5266        run(10, 2, True, True)
5267
5268    # FIXME: move to test distributions
5269    @skipIfMps
5270    @dtypesIfCUDA(torch.float, torch.double, torch.half)
5271    @dtypes(torch.float, torch.double, torch.half)
5272    def test_multinomial(self, device, dtype):
5273        def make_prob_dist(shape, is_contiguous):
5274            if is_contiguous:
5275                if dtype == torch.half:
5276                    return torch.zeros(shape, device=device).uniform_().to(dtype=torch.half)
5277                return torch.zeros(shape, device=device, dtype=dtype).uniform_()
5278            elif len(shape) == 1:
5279                if dtype == torch.half:
5280                    return torch.zeros((shape + [5]), device=device).uniform_().to(dtype=torch.half)[:, 2]
5281                return torch.zeros((shape + [5]), device=device, dtype=dtype).uniform_()[:, 2]
5282            else:
5283                # num dim = 2
5284                new_shape = [2, shape[1], 7, 1, shape[0], 1, 10]
5285                if dtype == torch.half:
5286                    prob_dist = torch.zeros(new_shape, device=device).uniform_().to(dtype=torch.half)
5287                else:
5288                    prob_dist = torch.zeros(new_shape, device=device, dtype=dtype).uniform_()
5289                prob_dist = prob_dist.transpose(1, 4)
5290                prob_dist = prob_dist[1, :, 5, 0, :, 0, 4]
5291                assert not prob_dist.is_contiguous()  # sanity check
5292                return prob_dist
5293
5294        for is_contiguous in (True, False):
5295            # with replacement
5296            n_row = 3
5297            for n_col in range(4, 5 + 1):
5298                prob_dist = make_prob_dist([n_row, n_col], is_contiguous)
5299                # indices that shouldn't be sampled (<0 means none)
5300                zero_prob_indices = torch.LongTensor(n_row).random_(-2, n_col).tolist()
5301                for i, j in enumerate(zero_prob_indices):
5302                    if j >= 0:
5303                        prob_dist[i, j] = 0
5304                n_sample = n_col * 3
5305                sample_indices = torch.multinomial(prob_dist, n_sample, True)
5306                self.assertEqual(prob_dist.dim(), 2)
5307                self.assertEqual(sample_indices.size(1), n_sample)
5308                for i in range(n_row):
5309                    zero_prob_idx = zero_prob_indices[i]
5310                    if zero_prob_idx < 0:
5311                        continue
5312                    for j in range(n_sample):
5313                        self.assertNotEqual(sample_indices[i, j], zero_prob_idx,
5314                                            msg="sampled an index with zero probability")
5315
5316            # without replacement
5317            n_row = 3
5318            for n_col in range(2, 10 + 1, 2):
5319                prob_dist = make_prob_dist([n_row, n_col], is_contiguous)
5320                # indices that shouldn't be sampled (<0 means none)
5321                zero_prob_indices = torch.LongTensor(n_row).random_(-1, n_col).tolist()
5322                for i, j in enumerate(zero_prob_indices):
5323                    if j >= 0:
5324                        prob_dist[i, j] = 0
5325                n_sample = max(1, n_col - 2)
5326                sample_indices = torch.multinomial(prob_dist, n_sample, False)
5327                self.assertEqual(prob_dist.dim(), 2)
5328                self.assertEqual(sample_indices.size(1), n_sample)
5329                for i in range(n_row):
5330                    row_samples = {}
5331                    zero_prob_idx = zero_prob_indices[i]
5332                    for j in range(n_sample):
5333                        sample_idx = sample_indices[i, j]
5334                        if zero_prob_idx >= 0:
5335                            self.assertNotEqual(sample_idx, zero_prob_idx,
5336                                                msg="sampled an index with zero probability")
5337                        self.assertNotIn(sample_idx, row_samples, "sampled an index twice")
5338                        row_samples[sample_idx] = True
5339
5340            # vector
5341            n_col = 4
5342            prob_dist = make_prob_dist([n_col], is_contiguous).fill_(1)
5343            zero_prob_idx = 1  # index that shouldn't be sampled
5344            prob_dist[zero_prob_idx] = 0
5345            n_sample = 20
5346            sample_indices = torch.multinomial(prob_dist, n_sample, True)
5347            for sample_index in sample_indices:
5348                self.assertNotEqual(sample_index, zero_prob_idx, msg="sampled an index with zero probability")
5349            s_dim = sample_indices.dim()
5350            self.assertEqual(sample_indices.dim(), 1, msg="wrong number of dimensions")
5351            self.assertEqual(prob_dist.dim(), 1, msg="wrong number of prob_dist dimensions")
5352            self.assertEqual(sample_indices.size(0), n_sample, msg="wrong number of samples")
5353
5354        # CUDA misalignment issue (#46702)
5355        n_row, n_col = 2, 3
5356        prob_dist = make_prob_dist([n_row, n_col], True)
5357        n_sample = 1
5358        sample_indices = torch.multinomial(prob_dist, n_sample, True)
5359        self.assertEqual(sample_indices.dim(), 2, msg="wrong number of dimensions")
5360        self.assertEqual(sample_indices.size(1), n_sample, msg="wrong number of samples")
5361
5362    # FIXME: move to test distributions
5363    @onlyCUDA
5364    @dtypes(torch.float, torch.double, torch.half)
5365    def test_multinomial_deterministic(self, device, dtype):
5366        gen = torch.Generator(device=device)
5367
5368        trials = 5
5369        seed = 0
5370        prob_dist = torch.rand(10000, 1000, device=device, dtype=dtype)
5371        n_sample = 1
5372
5373        for i in range(trials):
5374            gen.manual_seed(seed)
5375            samples_1 = torch.multinomial(prob_dist, n_sample, True, generator=gen)
5376
5377            gen.manual_seed(seed)
5378            samples_2 = torch.multinomial(prob_dist, n_sample, True, generator=gen)
5379
5380            self.assertEqual(samples_1, samples_2)
5381            self.assertEqual(samples_1.dim(), 2, msg="wrong number of dimensions")
5382            self.assertEqual(samples_1.size(1), n_sample, msg="wrong number of samples")
5383
5384    # FIXME: move to test distributions
5385    @slowTest
5386    @dtypes(torch.float)
5387    def test_multinomial_rng_state_advance(self, device, dtype):
5388        corpus_size = 100000
5389        freqs = torch.ones(corpus_size, dtype=torch.float, device=device)
5390        n_sample = 100
5391        samples1 = torch.multinomial(freqs, n_sample, replacement=True)
5392        samples2 = torch.multinomial(freqs, n_sample, replacement=True)
5393        samples = torch.cat([samples1, samples2])
5394        # expect no more than 1 repeating elements generated in 2 attempts
5395        # the probability of at least element being repeated is surprisingly large, 18%
5396        self.assertLessEqual(2 * n_sample - samples.unique().size(0), 2)
5397        samples1 = torch.multinomial(freqs, n_sample, replacement=False)
5398        samples2 = torch.multinomial(freqs, n_sample, replacement=False)
5399        samples = torch.cat([samples1, samples2])
5400        # expect no more than 1 repeating elements generated in 2 attempts
5401        self.assertLessEqual(2 * n_sample - samples.unique().size(0), 1)
5402
5403    def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
5404                                            memory_format, compare_data=True, default_is_preserve=False):
5405
5406        assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d
5407
5408        # xc is a channels last tensor
5409        xc = input_generator_fn(device)
5410        # xc is not memory dense, but looks like channels last
5411        # We don't preserve non-dense striding
5412        if not TEST_WITH_TORCHINDUCTOR:
5413            if memory_format == torch.channels_last:
5414                xc = xc[..., ::2, ::2]
5415            else:
5416                xc = xc[..., ::2, ::2, ::2]
5417
5418        clone = transformation_fn(xc, memory_format=torch.preserve_format)
5419
5420
5421        self.assertFalse(clone.is_contiguous())
5422        self.assertTrue(clone.is_contiguous(memory_format=memory_format))
5423        if not TEST_WITH_TORCHINDUCTOR:
5424            self.assertFalse(xc.is_contiguous())
5425            self.assertFalse(xc.is_contiguous(memory_format=memory_format))
5426        if compare_data:
5427            self.assertEqual(xc, clone.to(xc))
5428
5429        xc = input_generator_fn(device)
5430        clone = transformation_fn(xc, memory_format=torch.contiguous_format)
5431        self.assertTrue(clone.is_contiguous())
5432        self.assertFalse(clone.is_contiguous(memory_format=memory_format))
5433        if compare_data:
5434            self.assertEqual(xc, clone.to(xc))
5435
5436        xc = input_generator_fn(device)
5437        clone = transformation_fn(xc)
5438
5439        if default_is_preserve:
5440            self.assertFalse(clone.is_contiguous())
5441            self.assertTrue(clone.is_contiguous(memory_format=memory_format))
5442        else:
5443            self.assertTrue(clone.is_contiguous())
5444            self.assertFalse(clone.is_contiguous(memory_format=memory_format))
5445        if compare_data:
5446            self.assertEqual(xc, clone.to(xc))
5447
5448        # TODO copy _like constructors to stride permutation instead of just layout
5449        if not TEST_WITH_TORCHINDUCTOR:
5450            x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device)
5451            for i in range(10):
5452                permutation = list(range(len(x.shape)))
5453                random.shuffle(permutation)
5454                x = x.permute(permutation)
5455                self.assertEqual(x.stride(), transformation_fn(x, memory_format=torch.preserve_format).stride())
5456
5457    def test_memory_format_to(self, device):
5458        def get_generator(memory_format, shape):
5459            def input_generator_fn(device):
5460                return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
5461            return input_generator_fn
5462
5463        def transformation_fn(tensor, **kwargs):
5464            return tensor.to(dtype=torch.float64, **kwargs)
5465
5466        formats_shapes = (
5467            (torch.channels_last, (4, 3, 8, 8)),
5468            (torch.channels_last_3d, (4, 3, 8, 8, 8)))
5469
5470        for mf, shape in formats_shapes:
5471            self._test_memory_format_transformations(
5472                device, get_generator(mf, shape), transformation_fn, mf, default_is_preserve=True)
5473
5474    def test_memory_format_type(self, device):
5475        def get_generator(memory_format, shape):
5476            def input_generator_fn(device):
5477                return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
5478            return input_generator_fn
5479
5480        def transformation_fn(tensor, **kwargs):
5481            return tensor.to(torch.float64, **kwargs)
5482
5483        formats_shapes = (
5484            (torch.channels_last, (4, 3, 8, 8)),
5485            (torch.channels_last_3d, (4, 3, 8, 8, 8)))
5486
5487        for mf, shape in formats_shapes:
5488            self._test_memory_format_transformations(
5489                device, get_generator(mf, shape), transformation_fn, mf, default_is_preserve=True)
5490
5491    def test_memory_format_clone(self, device):
5492        def get_generator(memory_format, shape):
5493            def input_generator_fn(device):
5494                return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
5495            return input_generator_fn
5496
5497        def transformation_fn(tensor, **kwargs):
5498            return tensor.clone(**kwargs)
5499
5500        formats_shapes = (
5501            (torch.channels_last, (4, 3, 8, 8)),
5502            (torch.channels_last_3d, (4, 3, 8, 8, 8)))
5503
5504        for mf, shape in formats_shapes:
5505            self._test_memory_format_transformations(
5506                device, get_generator(mf, shape), transformation_fn, mf, True, default_is_preserve=True)
5507
5508    def test_memory_format_factory_like_functions_preserve(self, device):
5509        def get_generator(memory_format, shape):
5510            def input_generator_fn(device):
5511                return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
5512            return input_generator_fn
5513
5514        transformation_fns = [
5515            lambda t, **kwargs: torch.zeros_like(t, **kwargs),
5516            lambda t, **kwargs: torch.ones_like(t, **kwargs),
5517            lambda t, **kwargs: torch.randint_like(t, 10, 100, **kwargs),
5518            lambda t, **kwargs: torch.randint_like(t, 100, **kwargs),
5519            lambda t, **kwargs: torch.randn_like(t, **kwargs),
5520            lambda t, **kwargs: torch.rand_like(t, **kwargs),
5521            lambda t, **kwargs: torch.full_like(t, 7, **kwargs),
5522            lambda t, **kwargs: torch.empty_like(t, **kwargs)]
5523
5524        formats_shapes = (
5525            (torch.channels_last, (4, 3, 8, 8)),
5526            (torch.channels_last_3d, (4, 3, 8, 8, 8)))
5527
5528        for mf, shape, in formats_shapes:
5529            for transformation_fn in transformation_fns:
5530                self._test_memory_format_transformations(
5531                    device, get_generator(mf, shape), transformation_fn, mf, compare_data=False, default_is_preserve=True)
5532
5533    def test_memory_format_type_shortcuts(self, device):
5534        def get_generator(memory_format, shape, dtype):
5535            def input_generator_fn(device):
5536                return torch.randn(shape, device=device, dtype=dtype).clamp(0, 1) \
5537                    .round().contiguous(memory_format=memory_format)
5538            return input_generator_fn
5539
5540
5541        def get_fn(fn_name):
5542            def transformation_fn(tensor, **kwargs):
5543                fn = getattr(tensor, fn_name)
5544                return fn(**kwargs)
5545            return transformation_fn
5546
5547        shortcuts = ['byte', 'char', 'double', 'bool', 'half', 'int', 'long', 'short']
5548        if device == 'cpu':
5549            shortcuts += ['bfloat16']
5550
5551        formats_shapes = (
5552            (torch.channels_last, (4, 3, 8, 8)),
5553            (torch.channels_last_3d, (4, 3, 8, 8, 8)))
5554
5555        for mf, shape in formats_shapes:
5556            for fn_name in shortcuts:
5557                self._test_memory_format_transformations(
5558                    device, get_generator(mf, shape, torch.float32), get_fn(fn_name), mf, default_is_preserve=True)
5559
5560        # Test 'float' separately to avoid float->float no-op.
5561        for mf, shape in formats_shapes:
5562            self._test_memory_format_transformations(
5563                device, get_generator(mf, shape, torch.float64), get_fn('float'), mf, default_is_preserve=True)
5564
5565    @onlyCUDA
5566    def test_memory_format_cpu_and_cuda_ops(self, device):
5567        def get_generator(memory_format, shape):
5568            def input_generator_fn(device):
5569                return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
5570            return input_generator_fn
5571
5572        def transformation_cpu_fn(tensor, **kwargs):
5573            return tensor.cpu(**kwargs)
5574
5575        def transformation_cuda_fn(tensor, **kwargs):
5576            return tensor.cuda(**kwargs)
5577
5578        formats_shapes = (
5579            (torch.channels_last, (4, 3, 8, 8)),
5580            (torch.channels_last_3d, (4, 3, 8, 8, 8)))
5581
5582        for mf, shape in formats_shapes:
5583            self._test_memory_format_transformations(
5584                'cuda', get_generator(mf, shape), transformation_cpu_fn, mf, default_is_preserve=True)
5585            self._test_memory_format_transformations(
5586                'cpu', get_generator(mf, shape), transformation_cuda_fn, mf, default_is_preserve=True)
5587
5588    # FIXME: move to test_serialization
5589    @onlyNativeDeviceTypes
5590    def test_pickle_gradscaler(self, device):
5591        # This test should pass in 3 cases for cuda:
5592        #  1. cuda is not available.
5593        #  2. cuda is available but device is not cuda.
5594        #  3. cuda is available and device is cuda.
5595        # In case 1, a and b disable themselves on construction and shouldn't try to pickle workhorse attributes.
5596        # In case 2, a and b are enabled.  Workhorse attributes participate in pickling, but none are lazy-inited
5597        # to cuda Tensors, because I don't want to do cuda things if device is not cuda.
5598        # In case 3, a and b are enabled and we may also try lazy-initing _scale to a cuda tensor.
5599        device = torch.device(device)
5600        try_lazy_inits = (True, False)
5601        GradScaler = partial(torch.GradScaler, device=device.type)
5602        for lazy_init_scale in try_lazy_inits:
5603            a = GradScaler(init_scale=3., growth_factor=4., backoff_factor=.5, growth_interval=2)
5604            if device.type == "cuda":
5605                self.assertTrue(not a.is_enabled() if torch.cuda.amp.common.amp_definitely_not_available() else a.is_enabled())
5606            else:
5607                self.assertTrue(a.is_enabled())
5608            if lazy_init_scale:
5609                # Dummy a.scale() call lazy-inits a._scale Tensor.
5610                a.scale(torch.tensor([4.0], dtype=torch.float32, device=device))
5611                self.assertTrue(a._scale.device.type == device.type)
5612            # The following three lines should work whether or not cuda is available.
5613            serialized = pickle.dumps(a)
5614            b = pickle.loads(serialized)
5615            self.assertEqual(b.is_enabled(), a.is_enabled())
5616            if a.is_enabled():
5617                self.assertEqual(b.get_scale(), 3.)
5618                self.assertEqual(b.get_growth_factor(), 4.)
5619                self.assertEqual(b.get_backoff_factor(), .5)
5620                self.assertEqual(b.get_growth_interval(), 2)
5621                self.assertEqual(b._init_growth_tracker, 0)
5622                # supplies a dummy key to test the defaultdict's default_factory
5623                self.assertEqual(b._per_optimizer_states["fdsa"],
5624                                 torch.amp.grad_scaler._refresh_per_optimizer_state())
5625                if lazy_init_scale:
5626                    self.assertEqual(b.scale(torch.tensor([4.0], dtype=torch.float32, device=device)), 12.0)
5627
5628    # FIXME: move to test distributions
5629    def _test_multinomial_empty(self, device, replacement, num_samples):
5630        probs = torch.ones(0, 3, device=device)
5631        expected = torch.empty(0, num_samples, dtype=torch.int64)
5632        out = torch.multinomial(probs, num_samples=num_samples, replacement=replacement)
5633        self.assertEqual(out, expected)
5634
5635    # FIXME: move to test distributions
5636    def test_multinomial_empty_w_replacement(self, device):
5637        self._test_multinomial_empty(device, True, 1)
5638        self._test_multinomial_empty(device, True, 2)
5639
5640    # FIXME: move to test distributions
5641    def test_multinomial_empty_wo_replacement(self, device):
5642        self._test_multinomial_empty(device, False, 1)
5643        self._test_multinomial_empty(device, False, 2)
5644
5645    @onlyNativeDeviceTypes
5646    @dtypes(torch.float, torch.double)
5647    def test_grad_scaling_unscale(self, device, dtype):
5648        device = torch.device(device)
5649        device0 = "cuda:0" if device.type == "cuda" else "cpu"
5650        inv_scale = torch.full((1,), 0.25, dtype=torch.float, device=device0)
5651        found_inf = torch.full((1,), 0.0, dtype=torch.float, device=device0)
5652
5653        size = 20
5654        g = torch.full((size, size), 4.0, dtype=dtype, device=device0)
5655        ginf = g.clone()
5656        ginf[2, 2] = float('inf')
5657        gnan = g.clone()
5658        gnan[2, 2] = float('nan')
5659
5660        # Tries selected combinations of
5661        #  - contiguous grads
5662        #  - g.clone().t() which is not contiguous but still non overlapping and dense
5663        #  - variants of g.clone()[:, :5] which are not non overlapping and dense
5664        # Non overlapping and dense grads route into a multi tensor apply kernel,
5665        # others use a fallback per-tensor kernel, so we should try both.
5666        cases = (
5667            ([g.clone(), g.clone()], False),
5668            ([g.clone(), g.clone().t()], False),
5669            ([g.clone(), g.clone()[:, :5]], False),
5670            ([g.clone()[:, :5], g.clone()[:, :5]], False),
5671            ([g.clone(), ginf.clone()], True),
5672            ([g.clone(), gnan.clone()], True),
5673            ([g.clone(), ginf.clone()[:, :5]], True),
5674            ([g.clone(), gnan.clone()[:, :5]], True),
5675            ([ginf.clone(), g.clone()[:, :5]], True),
5676            ([ginf.clone()[:, :5], g.clone()[:, :5]], True),
5677        )
5678
5679        for grads, has_inf in cases:
5680            found_inf.zero_()
5681            torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale)
5682            if has_inf:
5683                self.assertEqual(found_inf, 1.0)
5684            else:
5685                self.assertEqual(found_inf, 0.0)
5686                for grad in grads:
5687                    self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7)
5688
5689        # When passing lists with mismatched dtypes to a raw
5690        # _amp_foreach_non_finite_check_and_unscale_ call on CUDA,
5691        # it's expected to fall back to single-tensor TensorIterator kernel.
5692        grads = [g.clone(), g.to(dtype=torch.float16)]
5693        torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale)
5694        for grad in grads:
5695            self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7)
5696
5697        # Passing lists with mismatched devices to a raw
5698        # _amp_foreach_non_finite_check_and_unscale_ call should raise errors.
5699        if device.type == "cuda" and TEST_MULTIGPU:
5700            with self.assertRaisesRegex(RuntimeError, r"Expected all tensors to be on the same device"):
5701                torch._amp_foreach_non_finite_check_and_unscale_([g.clone(), g.to(device="cuda:1")],
5702                                                                 found_inf,
5703                                                                 inv_scale)
5704
5705        # Creates a list of grads with mismatched dtypes and devices, to ensure
5706        # scaler._unscale_grads_ organizes grads by dtype and device before calling
5707        # _amp_foreach_non_finite_check_and_unscale_ on each set.
5708        # If inject_inf >= 0, writes an inf into one grad for _unscale_grads_ to find.
5709        def perfect_storm_grads(inject_inf):
5710            grads = [g.clone(), g.clone()[:, :5], g.to(dtype=torch.float16), g.to(dtype=torch.float16)]
5711            if device.type == "cuda" and TEST_MULTIGPU:
5712                grads += [g.to(device="cuda:1"),
5713                          g.to(device="cuda:1")[:, :5],
5714                          g.to(device="cuda:1", dtype=torch.float16),
5715                          g.to(device="cuda:1", dtype=torch.float16)]
5716            if inject_inf >= 0:
5717                grads[inject_inf][2, 2] = float('inf')
5718            return grads
5719
5720        GradScaler = partial(torch.GradScaler, device=device.type)
5721        scaler = GradScaler()
5722        dummy_params = [torch.empty_like(g) for g in perfect_storm_grads(-1)]
5723        dummy_opt = torch.optim.SGD(dummy_params, lr=1.)
5724
5725        # Ensures the inf/nan checking can find an inf injected onto any grad in the perfect storm.
5726        for inject_inf in range(-1, len(dummy_params)):
5727            found_inf = torch.full((1,), 0.0, dtype=torch.float, device=device0)
5728            grads = perfect_storm_grads(inject_inf)
5729            for i, p in enumerate(dummy_params):
5730                p.grad = grads[i]
5731            found_inf_per_device = scaler._unscale_grads_(dummy_opt, inv_scale, found_inf, True)
5732            if inject_inf < 0:
5733                # No inf was injected, ensures unscaling worked normally.
5734                self.assertTrue(sum(v.item() for v in found_inf_per_device.values()) == 0)
5735                for grad in grads:
5736                    self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7)
5737            else:
5738                # inf was injected, ensures inf was found.
5739                self.assertTrue(sum(v.item() for v in found_inf_per_device.values()) == 1)
5740
5741    @onlyNativeDeviceTypes
5742    @dtypes(torch.float)
5743    def test_grad_scaling_update_scale(self, device, dtype):
5744        growth = 2.0
5745        backoff = 0.25
5746        growth_interval = 2
5747        scale = torch.full((1,), 4.0, dtype=dtype, device=device)
5748        growth_tracker = torch.full((1,), 0.0, dtype=torch.int32, device=device)
5749        found_inf = torch.full((1,), 0.0, dtype=torch.float, device=device)
5750
5751        # Simulates 2 consecutive unskipped iterations
5752        torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
5753        self.assertEqual(growth_tracker, 1)
5754        self.assertEqual(scale, 4.0)
5755        torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
5756        self.assertEqual(growth_tracker, 0)
5757        self.assertEqual(scale, 8.0)
5758
5759        # Simulates a skipped iteration
5760        found_inf.fill_(1.0)
5761        torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
5762        self.assertEqual(growth_tracker, 0)
5763        self.assertEqual(scale, 2.0)
5764
5765    @skipIfTorchDynamo("Failed running call_function for sparse_coo_tensor. See https://github.com/pytorch/pytorch/issues/118856")
5766    @onlyNativeDeviceTypes
5767    @dtypes(torch.float)
5768    def test_grad_scaling_unscale_sparse(self, device, dtype):
5769        device = torch.device(device)
5770        scaler = torch.GradScaler(device=device.type)
5771
5772        inv_scale = torch.full((1,), 0.25, dtype=dtype, device=device)
5773        found_inf = torch.empty((1,), dtype=dtype, device=device)
5774        cur = found_inf.device
5775
5776        i = torch.tensor([[0, 1, 1],
5777                          [2, 0, 2]], device=device, dtype=torch.int64)
5778        v = torch.tensor([16., 32., 64.], device=device, dtype=torch.float)
5779        s = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device=device, dtype=dtype)
5780
5781        p = s.clone()
5782        assert p.is_sparse
5783        opt = torch.optim.SGD([p], lr=1.)
5784
5785        p.grad = s.clone()
5786        found_inf.zero_()
5787        found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, False)[cur]
5788        self.assertEqual(found_inf, 0.0)
5789        self.assertEqual(p.grad.to_dense(), (s / 4).to_dense())
5790
5791        v = torch.FloatTensor([16., 32., float('inf')])
5792        p.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device=device, dtype=dtype)
5793        found_inf.zero_()
5794        found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, False)[cur]
5795        self.assertEqual(found_inf, 1.0)
5796
5797        v = torch.FloatTensor([16., 32., float('nan')])
5798        p.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device=device, dtype=dtype)
5799        found_inf.zero_()
5800        found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, False)[cur]
5801        self.assertEqual(found_inf, 1.0)
5802
5803        p = s.clone().half()
5804        assert p.is_sparse
5805        opt = torch.optim.SGD([p], lr=1.)
5806
5807        p.grad = s.clone().half()
5808        found_inf.zero_()
5809        found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, True)[cur]
5810        self.assertEqual(found_inf, 0.0)
5811        self.assertEqual(p.grad.to_dense(), (s.half() / 4).to_dense())
5812
5813        # Creates fp16 sparse tensor with duplicated indices (uncoalesced).  The uncoalesced representation
5814        # does not overflow in fp16, but the coalesced representation would, because 64000 + 64000 > fp16 max.
5815        # _amp_non_finite_check_and_unscale_ should report an overflow here.
5816        i = torch.LongTensor([[0, 1, 0],
5817                              [2, 0, 2]])
5818        v = torch.FloatTensor([64000., 32., 64000.])
5819        p.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device=device, dtype=torch.float16)
5820        found_inf.zero_()
5821        found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, True)[cur]
5822        self.assertEqual(found_inf, 1.0)
5823
5824    @onlyNativeDeviceTypes
5825    def test_grad_scaling_state_dict(self, device):
5826        device = torch.device(device)
5827        GradScaler = partial(torch.GradScaler, device=device.type)
5828        for lazy_init_scale in True, False:
5829            s0 = GradScaler(init_scale=3., growth_factor=4., backoff_factor=.5, growth_interval=2)
5830            s1 = GradScaler(init_scale=6., growth_factor=7., backoff_factor=.8, growth_interval=1)
5831
5832            # sets a random value for load_state_dict to overwrite
5833            s1._init_growth_tracker = 7
5834
5835            if lazy_init_scale:
5836                # Dummy scale() call to ensure the scale tensor is lazily initialized.
5837                s1.scale(torch.full((1,), 4.0, dtype=torch.float32, device=device))
5838                if "cuda" == device.type:
5839                    self.assertTrue(isinstance(s1._scale, torch.cuda.FloatTensor))
5840                else:
5841                    self.assertTrue(isinstance(s1._scale, torch.FloatTensor))
5842
5843            s1.load_state_dict(s0.state_dict())
5844
5845            self.assertEqual(s1.get_scale(), 3.)
5846            self.assertEqual(s1.get_growth_factor(), 4.)
5847            self.assertEqual(s1.get_backoff_factor(), .5)
5848            self.assertEqual(s1.get_growth_interval(), 2)
5849            self.assertEqual(s1._init_growth_tracker, 0)
5850
5851    # _run_scaling_case generalizes some single-optimizer test logic to avoid too much copy-pasting below.
5852    def _run_scaling_case(self, device, run, unskipped, skipped, atol=1e-7, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
5853        # Ensure scaling can be disabled without changing user control flow.
5854        for enabled in True, False:
5855            (
5856                mod_control, mod_scaling, opt_control, opt_scaling, data, loss_fn, skip_iter,
5857            ) = _create_scaling_case(device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs)
5858
5859            # For functionality, test with a modest initial scale, and an unrealistically-large growth factor
5860            # so any potential errors with the growth factor handling will be magnified.
5861            GradScaler = partial(torch.GradScaler, device=device)
5862            scaler = GradScaler(init_scale=128., growth_factor=2.0, enabled=enabled, growth_interval=1)
5863
5864            _ = run(device, data, mod_control, opt_control, scaler, loss_fn, skip_iter, False)
5865            ret = run(device, data, mod_scaling, opt_scaling, scaler, loss_fn, skip_iter, True)
5866
5867            # Allows run() to optionally return a different scaler instance.
5868            scaler = ret if ret else scaler
5869
5870            # If scaling was enabled, the scale factor should have been multiplied by the growth factor
5871            # len(data) - skipped times and the backoff factor "skipped" times.
5872            if enabled:
5873                net_growth = scaler.get_growth_factor()**unskipped if unskipped > 0 else 1.0
5874                net_backoff = scaler.get_backoff_factor()**skipped if skipped > 0 else 1.0
5875                self.assertTrue(scaler.get_scale() == (128. * net_growth * net_backoff))
5876            else:
5877                self.assertTrue(scaler.get_scale() == 1.0)
5878
5879            for c, s in zip(mod_control.parameters(), mod_scaling.parameters()):
5880                self.assertEqual(c.grad, s.grad, atol=atol, rtol=1e-05)
5881
5882                c_state, s_state = opt_control.state[c], opt_scaling.state[s]
5883                for k in c_state:
5884                    self.assertEqual(c_state[k], s_state[k], atol=atol, rtol=1e-05, msg=k)
5885
5886                self.assertEqual(c, s, atol=atol, rtol=1e-05)
5887
5888    @onlyNativeDeviceTypes
5889    @parametrize("foreach, fused", [(None, None), (True, None), (None, True)])
5890    @optims(
5891        [optim for optim in optim_db if optim.optim_cls in [torch.optim.AdamW, torch.optim.Adam, torch.optim.SGD]],
5892        dtypes=[torch.float32]
5893    )
5894    def test_grad_scaling_autocast(self, device, dtype, optim_info, foreach, fused):
5895        try_pickle = False
5896
5897        def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
5898            for i, (input, target) in enumerate(data):
5899                optimizer.zero_grad()
5900                with torch.autocast(device_type=device, dtype=torch.half, enabled=try_scaling_api):
5901                    output = model(input)
5902                    loss = loss_fn(output, target)
5903                if try_scaling_api:
5904                    scaler.scale(loss).backward()
5905                    if i == skip_iter and scaler.is_enabled():
5906                        with torch.no_grad():
5907                            model[1].weight.grad.fill_(float('inf'))
5908                    scaler.step(optimizer)
5909                    scaler.update()
5910                    if try_pickle:
5911                        scaler = pickle.loads(pickle.dumps(scaler))
5912                else:
5913                    loss.backward()
5914                    if (not scaler.is_enabled()) or (i != skip_iter):
5915                        optimizer.step()
5916            return scaler
5917
5918        optimizer_ctor = optim_info.optim_cls
5919
5920        # Compares no scaling + no autocasting against scaling + autocasting.
5921        # NOTE(mkozuki): With current way of testing, `torch.optim.Adam` is failing in spite of `foreach` and `fused`.
5922        #   Giving some flexibility to this test might help.
5923        context = contextlib.nullcontext
5924        if optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW):
5925            from functools import partial
5926            context = partial(self.assertRaises, AssertionError)
5927        with context():
5928            # sets atol=1e-3 because we're comparing pure fp32 arithmetic vs a mixture of fp16 and fp32
5929            self._run_scaling_case(
5930                device, run, unskipped=3, skipped=1, atol=1e-3,
5931                optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": foreach, "fused": fused},
5932            )
5933            # this will be picked up by try_pickle within run():
5934            try_pickle = True
5935            self._run_scaling_case(
5936                device, run, unskipped=3, skipped=1, atol=1e-3,
5937                optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": foreach, "fused": fused},
5938            )
5939
5940    # Make sure that the parameters become nonsense when scaled gradients are finite
5941    # but they get invalidated before `optimizer.step`, after `GradScaler.unscale_`
5942
5943    @onlyNativeDeviceTypes
5944    @optims(
5945        [optim for optim in optim_db if optim.optim_cls in [torch.optim.AdamW, torch.optim.Adam, torch.optim.SGD]],
5946        dtypes=[torch.float32]
5947    )
5948    def test_params_invalidated_with_grads_invalidated_between_unscale_and_step(self, device, dtype, optim_info):
5949        optimizer_ctor = optim_info.optim_cls
5950        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
5951            device, dtype, optim_info, skip=("differentiable",))
5952
5953        for optim_input in all_optim_inputs:
5954            model, _, optimizer, _, data, loss_fn, _ = _create_scaling_case(
5955                device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optim_input.kwargs,
5956            )
5957            scaler = torch.GradScaler(device=device, init_scale=128.0)
5958
5959            for input, target in data:
5960                optimizer.zero_grad()
5961                with torch.autocast(device_type=device, dtype=torch.half):
5962                    output = model(input)
5963                    loss = loss_fn(output, target)
5964                scaler.scale(loss).backward()
5965                scaler.unscale_(optimizer)
5966
5967                # deliberately break grads
5968                for j, param in enumerate(model.parameters()):
5969                    param.grad.copy_(torch.inf if j % 2 else torch.nan)
5970
5971                scaler.step(optimizer)
5972                scaler.update()
5973
5974            self.assertTrue(all((p.isnan().any() or p.isinf().any()) for p in model.parameters()))
5975
5976    @onlyNativeDeviceTypes
5977    def test_grad_scale_will_not_overflow(self, device):
5978        device = torch.device(device)
5979        model = torch.nn.Linear(5, 1).to(device)
5980        optimizer = torch.optim.Adam(model.parameters())
5981        scaler = torch.GradScaler(device=device.type, growth_interval=1, growth_factor=2**4, init_scale=1e38)
5982        optimizer.zero_grad()
5983        x = torch.randn(1, 5).to(device)
5984        y = 1e-30 * torch.randn(1, 1).to(device)
5985        l = ((model(x) - y) ** 2).mean()
5986        scaler.scale(l).backward()
5987        scaler.step(optimizer)
5988        scaler.update()
5989        assert scaler._scale != float("inf") and scaler._scale != float("nan")
5990
5991    @onlyNativeDeviceTypes
5992    def test_grad_scaling_clipping(self, device):
5993        device = torch.device(device)
5994
5995        def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
5996            max_norm = 0.2  # A reasonable value that actually has an effect, based on printouts of grads
5997            for i, (input, target) in enumerate(data):
5998                optimizer.zero_grad()
5999                output = model(input)
6000                loss = loss_fn(output, target)
6001                if try_scaling_api:
6002                    scaler.scale(loss).backward()
6003                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm * scaler.get_scale())
6004                    if i == skip_iter and scaler.is_enabled():
6005                        model[1].weight.grad.data.fill_(float('inf'))
6006                    scaler.step(optimizer)
6007                    scaler.update()
6008                else:
6009                    loss.backward()
6010                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
6011                    if (not scaler.is_enabled()) or (i != skip_iter):
6012                        optimizer.step()
6013
6014        self._run_scaling_case(device.type, run, unskipped=3, skipped=1, atol=1e-5)
6015
6016    @onlyNativeDeviceTypes
6017    def test_grad_scaling_clipping_separate_unscale(self, device):
6018        device = torch.device(device)
6019
6020        def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
6021            max_norm = 0.2  # A reasonable value that actually has an effect, based on printouts of grads
6022            for i, (input, target) in enumerate(data):
6023                optimizer.zero_grad()
6024                output = model(input)
6025                loss = loss_fn(output, target)
6026                if try_scaling_api:
6027                    scaler.scale(loss).backward()
6028                    if i == skip_iter and scaler.is_enabled():
6029                        model[1].weight.grad.data.fill_(float('inf'))
6030                    scaler.unscale_(optimizer)
6031                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, error_if_nonfinite=False)
6032                    scaler.step(optimizer)
6033                    scaler.update()
6034                else:
6035                    loss.backward()
6036                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
6037                    if (not scaler.is_enabled()) or (i != skip_iter):
6038                        optimizer.step()
6039
6040        self._run_scaling_case(device.type, run, unskipped=3, skipped=1)
6041
6042    @onlyNativeDeviceTypes
6043    @unittest.skipIf(IS_WINDOWS, 'FIXME: fix this test for Windows')
6044    def test_grad_scaling_penalty(self, device):
6045        device = torch.device(device)
6046
6047        def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
6048            for i, (input, target) in enumerate(data):
6049                optimizer.zero_grad()
6050                output = model(input)
6051                loss = loss_fn(output, target)
6052
6053                if try_scaling_api:
6054                    grad_params = torch.autograd.grad(scaler.scale(loss),
6055                                                      model.parameters(), create_graph=True)
6056                    inv_scale = 1. / scaler.get_scale()
6057                    grad_params = [p * inv_scale for p in grad_params]
6058                else:
6059                    grad_params = torch.autograd.grad(loss, model.parameters(), create_graph=True)
6060
6061                grad_norm = 0
6062                for grad in grad_params:
6063                    grad_norm += grad.pow(2).sum()
6064                grad_norm = grad_norm.sqrt()
6065                loss = loss + grad_norm
6066
6067                if try_scaling_api:
6068                    scaler.scale(loss).backward()
6069                    if i == skip_iter and scaler.is_enabled():
6070                        model[1].weight.grad.data.fill_(float('inf'))
6071                    scaler.step(optimizer)
6072                    scaler.update()
6073                else:
6074                    loss.backward()
6075                    if (not scaler.is_enabled()) or (i != skip_iter):
6076                        optimizer.step()
6077
6078        self._run_scaling_case(device.type, run, unskipped=3, skipped=1)
6079
6080    @onlyNativeDeviceTypes
6081    def test_grad_scaling_accumulation(self, device):
6082        device = torch.device(device)
6083
6084        def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
6085            iters_to_accumulate = 2
6086            for i, (input, target) in enumerate(data):
6087                output = model(input)
6088                loss = loss_fn(output, target)
6089                loss = loss / iters_to_accumulate
6090                if try_scaling_api:
6091                    scaler.scale(loss).backward()
6092                else:
6093                    loss.backward()
6094                if (i + 1) % iters_to_accumulate == 0:
6095                    if try_scaling_api:
6096                        scaler.step(optimizer)
6097                        scaler.update()
6098                        optimizer.zero_grad()
6099                    else:
6100                        optimizer.step()
6101                        optimizer.zero_grad()
6102
6103        self._run_scaling_case(device.type, run, unskipped=2, skipped=0)
6104
6105    @onlyNativeDeviceTypes
6106    def test_grad_scaling_multiple(self, device):
6107        device = torch.device(device)
6108        # Tests gradient scaling with 2 models and 2 optimizers that both receive gradients from 2 losses.
6109        # Some of the logic here cannot reuse the generic helper functions created for the 1-optimizer cases.
6110        for enabled in True, False:
6111            mod_control0, mod_scaling0, opt_control0, opt_scaling0, data, loss_fn, skip_iter = \
6112                _create_scaling_case(device.type)
6113            mod_control1, mod_scaling1, opt_control1, opt_scaling1 = \
6114                _create_scaling_models_optimizers(device.type)
6115
6116            GradScaler = partial(torch.GradScaler, device=device.type)
6117            scaler = GradScaler(init_scale=128., growth_factor=2.0, enabled=enabled, growth_interval=1)
6118
6119            def run(model0, model1, optimizer0, optimizer1, try_scaling_api):
6120                for i, (input, target) in enumerate(data):
6121                    optimizer0.zero_grad()
6122                    optimizer1.zero_grad()
6123                    output0 = model0(input)
6124                    output1 = model1(input)
6125                    loss0 = loss_fn(0.3 * output0 + 0.7 * output1, target)
6126                    loss1 = loss_fn(0.6 * output0 - 0.4 * output1, target)
6127
6128                    if try_scaling_api:
6129                        scaler.scale(loss0).backward(retain_graph=True)
6130                        scaler.scale(loss1).backward()
6131                        if i == skip_iter and scaler.is_enabled():
6132                            model1[1].weight.grad.data.fill_(float('inf'))
6133
6134                        # As an additional stress test, separately unscale for one of the optimizers.
6135                        scaler.unscale_(optimizer0)
6136
6137                        scaler.step(optimizer0)
6138                        scaler.step(optimizer1)
6139                        scaler.update()
6140                    else:
6141                        loss0.backward(retain_graph=True)
6142                        loss1.backward()
6143                        optimizer0.step()
6144                        if (not scaler.is_enabled()) or (i != skip_iter):
6145                            optimizer1.step()
6146
6147            run(mod_control0, mod_control1, opt_control0, opt_control1, False)
6148            run(mod_scaling0, mod_scaling1, opt_scaling0, opt_scaling1, True)
6149
6150            # The loss scale should have been multiplied by the growth factor 3 times and the backoff factor once.
6151            self.assertTrue(scaler.get_scale() == (128. * scaler.get_growth_factor()**3 *
6152                                                   scaler.get_backoff_factor()**1) if enabled else 1.0)
6153
6154            for c, s in zip(chain(mod_control0.parameters(), mod_control1.parameters()),
6155                            chain(mod_scaling0.parameters(), mod_scaling1.parameters())):
6156                self.assertEqual(c, s, rtol=1e-5, atol=1e-7)
6157
6158    @onlyNativeDeviceTypes
6159    def test_grad_scaler_pass_itself(self, device):
6160        device = torch.device(device)
6161        GradScaler = partial(torch.amp.GradScaler, device=device.type)
6162
6163        class _PlaceHolderOptimizer(torch.optim.Optimizer):
6164            tester = self
6165
6166            def __init__(self, params, defaults=None):
6167                if defaults is None:
6168                    defaults = {}
6169                super().__init__(params, defaults)
6170                self._step_supports_amp_scaling = True
6171
6172        class Optimizer1(_PlaceHolderOptimizer):
6173            def step(self, closure=None, *, grad_scaler=None):
6174                self.tester.assertTrue(isinstance(grad_scaler, torch.amp.GradScaler))
6175                self.tester.assertFalse(hasattr(self, "grad_scale"))
6176                self.tester.assertFalse(hasattr(self, "found_inf"))
6177
6178        class Optimizer2(_PlaceHolderOptimizer):
6179            def step(self, closure=None):
6180                self.tester.assertTrue(isinstance(self.grad_scale, torch.Tensor))
6181                self.tester.assertTrue(isinstance(self.found_inf, torch.Tensor))
6182
6183        x = torch.randn(4, 4).to(device)
6184        m = torch.nn.Linear(4, 1).to(device)
6185        o1 = Optimizer1(m.parameters())
6186        o2 = Optimizer2(m.parameters())
6187        scaler = GradScaler(init_scale=2.0)
6188
6189        with torch.autocast(device_type=device.type, dtype=torch.half):
6190            y = m(x)
6191            loss = y.mean()
6192        scaler.scale(loss).backward()
6193        with self.assertWarns(FutureWarning):
6194            scaler.step(o1)
6195        scaler.step(o2)
6196        scaler.update()
6197
6198    @onlyNativeDeviceTypes
6199    def test_grad_scaler_deprecated_warning(self, device):
6200        device = torch.device(device)
6201        GradScaler = torch.cuda.amp.GradScaler if "cuda" == device.type else torch.cpu.amp.GradScaler
6202
6203        with self.assertWarnsRegex(
6204            FutureWarning,
6205            rf"`torch.{device.type}.amp.GradScaler\(args...\)` is deprecated.",
6206        ):
6207            _ = GradScaler(init_scale=2.0)
6208
6209    @dtypesIfCUDA(torch.float, torch.double, torch.half)
6210    @dtypesIfCPU(torch.float, torch.double, torch.bfloat16, torch.half)
6211    @dtypes(torch.float, torch.double)
6212    def test_multinomial_cpu(self, device, dtype):
6213        def make_prob_dist(shape, is_contiguous):
6214            if is_contiguous:
6215                if dtype == torch.half or dtype == torch.bfloat16:
6216                    return torch.zeros(shape, device=device).uniform_().to(dtype=dtype)
6217                return torch.zeros(shape, device=device, dtype=dtype).uniform_()
6218            elif len(shape) == 1:
6219                if dtype == torch.half or dtype == torch.bfloat16:
6220                    return torch.zeros((shape + [5]), device=device).uniform_().to(dtype=dtype)[:, 2]
6221                return torch.zeros((shape + [5]), device=device, dtype=dtype).uniform_()[:, 2]
6222            else:
6223                # num dim = 2
6224                new_shape = [2, shape[1], 7, 1, shape[0], 1, 10]
6225                if dtype == torch.half or dtype == torch.bfloat16:
6226                    prob_dist = torch.zeros(new_shape, device=device).uniform_().to(dtype=dtype)
6227                else:
6228                    prob_dist = torch.zeros(new_shape, device=device, dtype=dtype).uniform_()
6229                prob_dist = prob_dist.transpose(1, 4)
6230                prob_dist = prob_dist[1, :, 5, 0, :, 0, 4]
6231                assert not prob_dist.is_contiguous()  # sanity check
6232                return prob_dist
6233
6234    # FIXME: move to elementwise ternary test suite
6235    # As the test fails with Runtime Error not raised on XLA
6236    @onlyNativeDeviceTypes
6237    def test_where_scalar_handcrafted_values(self, device):
6238        # Tests ScalarxScalar, ScalarxTensor and TensorxScalar
6239        # variant of `where` against NumPy version with
6240        # handcrafted values.
6241        condition_shape = (5, 5)
6242        dtypes = (
6243            torch.bool, torch.uint8, torch.int8, torch.int16, torch.int64,
6244            torch.float16, torch.float32, torch.float64,
6245            torch.complex64, torch.complex128,
6246        )
6247        shapes = ((), (5,), (1, 5),)
6248
6249        with torch.no_grad():
6250            tensors = (torch.empty(shape, dtype=dtype, device=device).fill_(17)
6251                       for shape, dtype in product(shapes, dtypes))
6252
6253        # Use different values for `x` and `y`
6254        # as they are the output values which are compared.
6255        x_vals = (True, 3, 7.0, 1 + 0.5j)
6256        y_vals = itertools.chain((False, 4, 8.0, 2 + 0.5j), tensors)
6257        for x in x_vals:
6258            for y in y_vals:
6259                condition = torch.empty(*condition_shape, dtype=torch.bool, device=device).bernoulli_()
6260                common_dtype = torch.result_type(x, y)
6261
6262                def check_equal(condition, x, y):
6263                    condition_np = condition.cpu().numpy()
6264                    x_np = x.cpu().numpy() if isinstance(x, torch.Tensor) else x
6265                    y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y
6266
6267                    # NumPy aggressively promotes to double, hence cast to output to correct dtype
6268                    expected = torch.from_numpy(np.where(condition_np, x_np, y_np)).to(common_dtype)
6269                    result = torch.where(condition, x, y)
6270                    self.assertEqual(expected, result)
6271
6272                check_equal(condition, x, y)
6273                check_equal(condition, y, x)
6274                if self.device_type == "cuda":
6275                    check_equal(condition, torch.tensor(x), y)
6276                    check_equal(condition, y, torch.tensor(x))
6277                    if not isinstance(y, torch.Tensor):
6278                        check_equal(condition, torch.tensor(y), torch.tensor(x))
6279                    if isinstance(y, torch.Tensor) and y.ndim > 0:
6280                        check_equal(torch.tensor(True), x, y)
6281                        check_equal(torch.tensor(True), y, x)
6282
6283
6284    @skipIfTorchInductor("FIXME")
6285    def test_hook_remove(self, device):
6286        # Reference: https://github.com/pytorch/pytorch/issues/58354
6287        def _test_helper(remove_hook):
6288            def install_hook(tensor):
6289                handle = None
6290
6291                def hook(tensor):
6292                    if remove_hook:
6293                        handle.remove()
6294                    return torch.zeros_like(tensor)
6295                handle = tensor.register_hook(hook)
6296
6297            t = torch.ones((1, 5), device=device, requires_grad=True)
6298            install_hook(t)
6299
6300            # First call to backward
6301            t.mean().backward()
6302            self.assertEqual(t.grad, torch.zeros_like(t))
6303
6304            # Second call to backward
6305            t.mean().backward()
6306            if remove_hook:
6307                # After removing the hook, make sure the usual gradient is returned
6308                self.assertEqual(t.grad, 0.2 * torch.ones_like(t))
6309            else:
6310                self.assertEqual(t.grad, torch.zeros_like(t))
6311
6312        _test_helper(remove_hook=True)
6313        _test_helper(remove_hook=False)
6314
6315    # FIXME: get PyTorch/XLA to run test_testing
6316    # This test should ideally be in test_testing.py,
6317    # but since pytorch/xla runs tests from test_torch.py, we have it here.
6318    @skipXLA
6319    def test_skip_xla(self, device):
6320        if self.device_type == 'xla':
6321            # Should not reach here!
6322            self.assertTrue(False)
6323
6324    # FIXME: get PyTorch/XLA to run test_testing
6325    # This test should ideally be in test_testing.py,
6326    # but since pytorch/xla runs tests from test_torch.py, we have it here.
6327    @expectedFailureXLA
6328    def test_expected_failure_xla(self, device):
6329        if self.device_type == 'xla':
6330            self.assertTrue(False)
6331
6332    # FIXME: get PyTorch/XLA to run test_testing
6333    # This test should ideally be in test_testing.py,
6334    # but since pytorch/xla runs tests from test_torch.py, we have it here.
6335    def test_assertRaisesRegex_ignore_msg_non_native_device(self, device):
6336        # Verify that self.assertRaisesRegex only checks the Error and ignores
6337        # message for non-native devices.
6338        x = torch.randn((10, 3), device=device)
6339        t = torch.empty(10, dtype=torch.int64, device=device).random_(0, 3)
6340        invalid_weight = torch.randn(4, device=device)
6341        msg = "weight tensor should be defined either for all 3 classes or no classes"
6342
6343        # XLA raises RuntimeError with a different message.
6344        with self.assertRaisesRegex(RuntimeError, msg):
6345            torch.nn.functional.nll_loss(x, t, weight=invalid_weight)
6346
6347    @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.complex32))
6348    def test_copy_(self, device, dtype):
6349        def can_cast(src_dtype, dst_dtype):
6350            # torch.can_cast(torch.int16, torch.uint8) returns True
6351            # which isn't actually safe-cast.
6352            # This function returns False in this case.
6353            def is_unsigned_int(dtype):
6354                return dtype is torch.uint8
6355
6356            if is_unsigned_int(dst_dtype):
6357                return is_unsigned_int(src_dtype)
6358            return torch.can_cast(src_dtype, dst_dtype)
6359
6360        def make_tensor_wrapper(shape, dtype):
6361            if dtype is not torch.complex32:
6362                # Make tensor does not support generating
6363                # complex32 tensor
6364                return make_tensor(shape, device=device, dtype=dtype)
6365            return torch.randn(shape, device=device, dtype=dtype)
6366
6367        t = make_tensor_wrapper((50,), dtype)
6368        src_dtypes = all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.complex32)
6369        for src_dtype in src_dtypes:
6370            src = make_tensor_wrapper((50,), dtype=src_dtype)
6371            t.copy_(src)
6372            dst = make_tensor_wrapper((50, ), dtype=src_dtype)
6373            if can_cast(src_dtype, dtype):
6374                rtol = None
6375                atol = None
6376                if dtype in (torch.half, torch.complex32):
6377                    rtol = 1e-3
6378                    atol = 1e-3
6379                if dtype in (torch.bfloat16,):
6380                    rtol = 1e-2
6381                    atol = 1e-2
6382                self.assertEqual(src, dst.copy_(t), rtol=rtol, atol=atol)
6383
6384    @dtypes(*all_types_and_complex_and(
6385        torch.bool, torch.half, torch.bfloat16, torch.complex32,
6386        torch.uint16, torch.uint32, torch.uint64))
6387    def test_item(self, device, dtype):
6388        if torch.device(device).type == 'xla' and dtype in [torch.uint16, torch.uint32, torch.uint64]:
6389            self.skipTest('uint16,32,64 not implemented on XLA')
6390        t = torch.ones((), device=device, dtype=dtype)
6391        self.assertEqual(1, t.item())
6392
6393    @onlyNativeDeviceTypes
6394    def test_masked_scatter_inplace_noncontiguous(self, device):
6395        t = torch.zeros(5, 2, dtype=torch.long, device=device)
6396        t_non_contig = t.transpose(0, 1)
6397        t_contig = t_non_contig.contiguous()
6398
6399        assert t_contig.is_contiguous()
6400        assert not t_non_contig.is_contiguous()
6401
6402        mask = torch.tensor([[False, True], [False, True], [False, False], [True, True], [True, True]], device=device)
6403        mask_non_contig = mask.transpose(0, 1)
6404        mask_contig = mask_non_contig.contiguous()
6405
6406        assert mask_contig.is_contiguous()
6407        assert not mask_non_contig.is_contiguous()
6408
6409        # source is always converted to contiguous by the op.
6410        source = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 9]], device=device)
6411
6412        # t: contig, mask: contig
6413        expected = t_contig.masked_scatter_(mask_contig, source)
6414
6415        # t: non-contig, mask: non-contig
6416        actual = t_non_contig.masked_scatter_(mask_non_contig, source)
6417        self.assertEqual(actual, expected)
6418
6419        # t: contig, mask: non-contig
6420        actual = t_contig.masked_scatter_(mask_non_contig, source)
6421        self.assertEqual(actual, expected)
6422
6423        # t: non-contig, mask: contig
6424        actual = t_non_contig.masked_scatter_(mask_contig, source)
6425        self.assertEqual(actual, expected)
6426
6427
6428# Tests that compare a device's computation with the (gold-standard) CPU's.
6429class TestDevicePrecision(TestCase):
6430    exact_dtype = True
6431
6432    # FIXME: move to indexing test suite
6433    @onlyCUDA
6434    def test_index_add_bfloat16(self, device):
6435        inp_tensor = torch.randn(5, 3, device='cpu').bfloat16()
6436        t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.bfloat16, device='cpu')
6437        index = torch.tensor([0, 4, 2], device='cpu')
6438        out_cpu = inp_tensor.index_add(0, index, t)
6439
6440        inp_tensor = inp_tensor.to(device=device)
6441        t = t.to(device=device)
6442        index = index.to(device=device)
6443        out_gpu = inp_tensor.index_add(0, index, t)
6444
6445        self.assertEqual(out_cpu, out_gpu, atol=1e-2, rtol=0)
6446
6447    # FIXME: move to serialization test suite
6448    def test_device_serialization(self, device):
6449        x = torch.randn(4, 4, device=device)
6450
6451        with tempfile.NamedTemporaryFile() as f:
6452            torch.save(x, f)
6453            f.seek(0)
6454            x_copy = torch.load(f)
6455
6456        self.assertEqual(x_copy, x)
6457        self.assertIs(type(x_copy), type(x))
6458        self.assertEqual(x_copy.device, x.device)
6459
6460    # FIXME: move to serialization test suite
6461    @deviceCountAtLeast(2)
6462    def test_multidevice_serialization(self, devices):
6463        x = [torch.randn(4, 4, device=devices[0]),
6464             torch.randn(4, 4, device=devices[1])]
6465
6466        with tempfile.NamedTemporaryFile() as f:
6467            torch.save(x, f)
6468            f.seek(0)
6469            x_copy = torch.load(f)
6470
6471        for original, cp in zip(x, x_copy):
6472            self.assertEqual(cp, original)
6473            self.assertIs(type(cp), type(original))
6474            self.assertEqual(cp.device, original.device)
6475
6476    # FIXME: move to data movement test suite
6477    @deviceCountAtLeast(1)
6478    def test_copy_noncontig(self, devices):
6479        def do_test(d0, d1):
6480            x = torch.tensor([1.5, 2.5, 3.5, 4.5, 5.5, 6.5], device=d0)
6481            y = torch.tensor([0, 0, 0, 0, 0, 0], device=d1)
6482            self.assertNotEqual(x.dtype, y.dtype)
6483
6484            y[::2].copy_(x[::2])
6485            self.assertEqual(y, [1, 0, 3, 0, 5, 0])
6486
6487        do_test('cpu', devices[0])
6488        do_test(devices[0], 'cpu')
6489
6490        if len(devices) > 1:
6491            do_test(devices[0], devices[1])
6492
6493    @deviceCountAtLeast(2)
6494    def test_type_conversions_same_device(self, devices):
6495        x = torch.randn(5, 5, device=devices[1])
6496        self.assertEqual(x.int().device, torch.device(devices[1]))
6497        self.assertEqual(x.type(torch.int).device, torch.device(devices[1]))
6498        self.assertEqual(x.to(torch.int).device, torch.device(devices[1]))
6499
6500    @dtypesIfCUDA(torch.half, torch.float, torch.double,
6501                  torch.int8, torch.short, torch.int, torch.long,
6502                  torch.uint8)
6503    @dtypes(torch.float, torch.double,
6504            torch.int8, torch.short, torch.int, torch.long,
6505            torch.uint8)
6506    def test_from_sequence(self, device, dtype):
6507        seq = [list(range(i * 4, i * 4 + 4)) for i in range(5)]
6508        reference = torch.arange(0, 20).resize_(5, 4)
6509        self.assertEqual(torch.tensor(seq, dtype=dtype, device=device), reference, exact_dtype=False)
6510
6511    # FIXME: moved to indexing test suite
6512    @deviceCountAtLeast(1)
6513    def test_advancedindex_mixed_cpu_devices(self, devices) -> None:
6514        def test(x: torch.Tensor, ia: torch.Tensor, ib: torch.Tensor) -> None:
6515            # test getitem
6516            self.assertEqual(x[:, ia, None, ib, 0].cpu(),
6517                             x.cpu()[:, ia.cpu(), None, ib.cpu(), 0])
6518            self.assertEqual(x[ia], x.cpu()[ia.cpu()])
6519            # test setitem
6520            x_clone1 = x.clone()
6521            x_clone2 = x.clone()
6522            first_shape = x[:, ia, None, ib, 0].shape
6523            second_shape = x[ia].shape
6524            x_clone1[:, ia, None, ib, 0] = torch.randn(first_shape).to(x_clone1)
6525            x_clone2[ia] = torch.randn(second_shape).to(x_clone2)
6526
6527        cpu = torch.device('cpu')
6528        for device in devices:
6529            x = torch.randn(3, 4, 4, 4, 3)
6530            ia = torch.tensor([0, 2, 1])
6531            ib = torch.tensor([0, 2, 1])
6532
6533            # Index device tensor with cpu tensor
6534            x = x.to(device)
6535            ia = ia.to(cpu)
6536            ib = ib.to(cpu)
6537            test(x, ia, ib)
6538
6539            # Index device tensor with mixed cpu, device tensors
6540            x = x.to(device)
6541            ia = ia.to(cpu)
6542            ib = ib.to(device)
6543            test(x, ia, ib)
6544
6545    @deviceCountAtLeast(1)
6546    def test_advancedindex_mixed_devices_error(self, devices) -> None:
6547        def test(x: torch.Tensor, ia: torch.Tensor, ib: torch.Tensor) -> None:
6548            # test getitem
6549            with self.assertRaisesRegex(RuntimeError, fr"indices should be either .* \({x.device}\)"):
6550                value = x[:, ia, None, ib, 0]
6551            with self.assertRaisesRegex(RuntimeError, fr"indices should be either .* \({x.device}\)"):
6552                value = x[ib]
6553
6554        cpu = torch.device('cpu')
6555        for device in devices:
6556            # Index cpu tensor with device tensor
6557            x = torch.randn(3, 4, 4, 4, 3)
6558            ia = torch.tensor([0, 2, 1]).to(device)
6559            ib = torch.tensor([0, 2, 1]).to(device)
6560            test(x, ia, ib)
6561
6562            # Index cpu tensor with mixed cpu, device tensors
6563            x = x.to(cpu)
6564            ia = ia.to(cpu)
6565            ib = ib.to(device)
6566            test(x, ia, ib)
6567
6568            if len(devices) > 1:
6569                other_device = devices[0] if device == devices[1] else devices[1]
6570
6571                # Index device tensor with mixed cpu, device tensors on different devices
6572                x = x.to(device)
6573                ia = ia.to(cpu)
6574                ib = ib.to(other_device)
6575                test(x, ia, ib)
6576
6577    # FIXME: move to data movement test suite
6578    def test_copy_broadcast(self, device) -> None:
6579        x = torch.randn(10, 5)
6580        y = torch.randn(5, device=device)
6581        x.copy_(y)
6582        self.assertEqual(x[3], y)
6583
6584        x = torch.randn(10, 5, device=device)
6585        y = torch.randn(5)
6586        x.copy_(y)
6587        self.assertEqual(x[3], y)
6588
6589    # FIXME: move to an elementwise ternary test suite
6590    @dtypes(torch.int64, torch.float32, torch.float64)
6591    def test_clamp(self, device, dtype):
6592        test_args = [
6593            *product(
6594                [(100, 50), (10, 64), (97,)],  # shape
6595                (True, False),  # non-contiguous
6596            )
6597        ]
6598
6599        for shape, noncontig in test_args:
6600            x = make_tensor(shape, device=device, dtype=dtype,
6601                            noncontiguous=noncontig)
6602            ub = make_tensor(shape, device=device, dtype=dtype,
6603                             noncontiguous=noncontig)
6604            lb = make_tensor(shape, device=device, dtype=dtype,
6605                             noncontiguous=noncontig)
6606
6607            expect = x.max(lb).min(ub)
6608            actual = x.clamp(lb, ub)
6609            self.assertEqual(expect, actual)
6610
6611            expect = np.clip(x.cpu().numpy(), lb.cpu().numpy(), ub.cpu().numpy())
6612            self.assertEqual(expect, actual)
6613
6614            expect = x.max(lb)
6615            actual = x.clamp(min=lb)
6616            self.assertEqual(expect, actual)
6617
6618            expect = x.min(ub)
6619            actual = x.clamp(max=ub)
6620            self.assertEqual(expect, actual)
6621
6622            # Test broadcasting min & max
6623            expect = x.max(lb[0]).min(ub[..., :1])
6624            actual = x.clamp(lb[0], ub[..., :1])
6625            self.assertEqual(expect, actual)
6626
6627            # Test broadcasting x
6628            expect = x[..., :1].max(lb).min(ub)
6629            actual = x[..., :1].clamp(lb, ub)
6630            self.assertEqual(expect, actual)
6631
6632    def test_cuda_device_idx(self, device):
6633        x = torch.zeros(3, device=device)
6634        y = torch._efficientzerotensor(3, device=device)
6635        self.assertEqual(x.device, y.device)
6636
6637# we implemented custom deallocation for subclasses, so it behooves
6638# us to make sure all of these bits work.  We'll use __del__ to
6639# track if objects die or not
6640class Tracker:
6641    def __init__(self, marker):
6642        self.marker = marker
6643
6644    @staticmethod
6645    def make():
6646        marker = [False]
6647        return marker, Tracker(marker)
6648
6649    def __del__(self):
6650        self.marker[0] = True
6651
6652@contextlib.contextmanager
6653def disable_gc():
6654    if gc.isenabled():
6655        try:
6656            gc.disable()
6657            yield
6658        finally:
6659            gc.enable()
6660    else:
6661        yield
6662
6663class TestTorch(TestCase):
6664    exact_dtype = True
6665
6666    def test_dir(self):
6667        dir(torch)
6668
6669    def test_wildcard_import(self):
6670        exec('from torch import *')
6671
6672    def test_newaxis_numpy_comparison(self):
6673        def run_test(tensor, *idx):
6674            npt = tensor.numpy()
6675            self.assertEqual(tensor[idx], npt[idx])
6676
6677        # 1D Tensor Tests
6678        x = torch.arange(0, 10)
6679        cases = [
6680            [None],
6681            [None, None],
6682            [Ellipsis, None],
6683            [None, Ellipsis],
6684            [2, None],
6685            [None, 2],
6686            [Ellipsis, None, 2],
6687            [Ellipsis, 2, None],
6688            [2, Ellipsis, None],
6689            [2, None, Ellipsis],
6690            [None, 2, Ellipsis],
6691            [None, Ellipsis, 2],
6692        ]
6693
6694        for case in cases:
6695            run_test(x, *case)
6696
6697        # 2D Tensor Tests
6698        x = torch.arange(0, 12).view(3, 4)
6699        cases = [
6700            [None],
6701            [None, None],
6702            [None, None, None],
6703            [Ellipsis, None],
6704            [Ellipsis, None, None],
6705            [None, Ellipsis],
6706            [None, Ellipsis, None],
6707            [None, None, Ellipsis],
6708            [2, None],
6709            [2, None, Ellipsis],
6710            [2, Ellipsis, None],
6711            [None, 2, Ellipsis],
6712            [Ellipsis, 2, None],
6713            [Ellipsis, None, 2],
6714            [None, Ellipsis, 2],
6715            [1, 2, None],
6716            [1, 2, Ellipsis, None],
6717            [1, Ellipsis, 2, None],
6718            [Ellipsis, 1, None, 2],
6719            [Ellipsis, 1, 2, None],
6720            [1, None, 2, Ellipsis],
6721            [None, 1, Ellipsis, 2],
6722            [None, 1, 2, Ellipsis],
6723        ]
6724
6725        for case in cases:
6726            run_test(x, *case)
6727
6728    def _consecutive(self, size, start=1):
6729        sequence = torch.ones(torch.tensor(size).prod(0)).cumsum(0)
6730        sequence.add_(start - 1)
6731        return sequence.resize_(*size)
6732
6733    def test_newindex(self):
6734        reference = self._consecutive((3, 3, 3))
6735        # This relies on __index__() being correct - but we have separate tests for that
6736
6737        def checkPartialAssign(index):
6738            reference = torch.zeros(3, 3, 3)
6739            reference[index] = self._consecutive((3, 3, 3))[index]
6740            self.assertEqual(reference[index], self._consecutive((3, 3, 3))[index], atol=0, rtol=0)
6741            reference[index] = 0
6742            self.assertEqual(reference, torch.zeros(3, 3, 3), atol=0, rtol=0)
6743
6744        checkPartialAssign(0)
6745        checkPartialAssign(1)
6746        checkPartialAssign(2)
6747        checkPartialAssign((0, 1))
6748        checkPartialAssign((1, 2))
6749        checkPartialAssign((0, 2))
6750        checkPartialAssign(torch.LongTensor((0, 2)))
6751
6752        with self.assertRaises(IndexError):
6753            reference[1, 1, 1, 1] = 1
6754        with self.assertRaises(IndexError):
6755            reference[1, 1, 1, (1, 1)] = 1
6756        with self.assertRaises(IndexError):
6757            reference[3, 3, 3, 3, 3, 3, 3, 3] = 1
6758        with self.assertRaises(IndexError):
6759            reference[0.0] = 1
6760        with self.assertRaises(TypeError):
6761            reference[0.0:2.0] = 1
6762        with self.assertRaises(IndexError):
6763            reference[0.0, 0.0:2.0] = 1
6764        with self.assertRaises(IndexError):
6765            reference[0.0, :, 0.0:2.0] = 1
6766        with self.assertRaises(IndexError):
6767            reference[0.0, ..., 0.0:2.0] = 1
6768        with self.assertRaises(IndexError):
6769            reference[0.0, :, 0.0] = 1
6770
6771    # Test `torch._check*` functions
6772    def test_check(self):
6773        test_cases = [
6774            # check function, expected error
6775            (torch._check, RuntimeError),
6776            (torch._check_index, IndexError),
6777            (torch._check_value, ValueError),
6778            (torch._check_type, TypeError),
6779            (torch._check_not_implemented, NotImplementedError),
6780        ]
6781
6782        for check_fn, expected_error in test_cases:
6783            # cond=True should not raise an error
6784            check_fn(True)
6785
6786            # Test default failure message for cond=False
6787            default_message = 'Expected cond to be True'
6788            with self.assertRaisesRegex(expected_error, default_message):
6789                check_fn(False)
6790
6791            # Test a simple failure message
6792            message = 'message'
6793            with self.assertRaisesRegex(expected_error, message):
6794                check_fn(False, lambda: message)
6795
6796            # Test message with tensor
6797            def message():
6798                return torch.arange(4)
6799
6800            with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
6801                check_fn(False, message)
6802
6803            # Test format string message
6804            def message():
6805                return f"{'test'} {[1, 2, 'a', True]} {True} {100} {torch.arange(4)}"
6806
6807            with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
6808                check_fn(False, message)
6809
6810            # Test incorrect `cond` arg type
6811            with self.assertRaisesRegex(TypeError, 'cond must be a bool'):
6812                check_fn('wrong type')
6813
6814            with self.assertRaisesRegex(TypeError, 'cond must be a bool'):
6815                check_fn(torch.tensor(True))
6816
6817    # FIXME: move to indexing test suite
6818    def test_index_add(self):
6819        for device in get_all_device_types():
6820            for dest_contig, src_contig, index_contig in product([True, False], repeat=3):
6821                for other_sizes in ((), (4, 5)):
6822                    for dtype in [torch.int, torch.long]:
6823                        num_copy, num_dest = 3, 3
6824                        dest = torch.randn(num_dest, *other_sizes, device=device)
6825                        if not dest_contig:
6826                            dest = make_tensor(dest.shape, device=device, dtype=dest.dtype, noncontiguous=True)
6827                        src = torch.randn(num_copy, *other_sizes, device=device)
6828                        if not src_contig:
6829                            src = noncontiguous_like(src)
6830                        idx = torch.randperm(num_dest, dtype=dtype, device=device).narrow(0, 0, num_copy)
6831                        if not index_contig:
6832                            idx = noncontiguous_like(idx)
6833                        # index_add_ without alpha argument
6834                        dest2 = dest.clone()
6835                        dest.index_add_(0, idx, src)
6836                        for i in range(idx.size(0)):
6837                            dest2[idx[i]] += src[i]
6838                        self.assertEqual(dest, dest2)
6839                        # index_add_ with alpha argument
6840                        dest2 = dest.clone()
6841                        dest.index_add_(0, idx, src, alpha=2)
6842                        for i in range(idx.size(0)):
6843                            dest2[idx[i]] += src[i] * 2
6844                        self.assertEqual(dest, dest2)
6845
6846    # FIXME: resolve comment below and move this to indexing test suite
6847    # add coverage for issue with atomic add that appeared only for
6848    # specific dtypes on cuda:
6849    # https://github.com/pytorch/pytorch/issues/29153
6850    def test_index_add_all_dtypes(self):
6851        for device in get_all_device_types():
6852            for dtype in get_all_math_dtypes(device):
6853                for idx_dtype in [torch.int, torch.long]:
6854                    size = [5, 5]
6855                    if dtype.is_floating_point or dtype.is_complex:
6856                        tensor = torch.rand(size, dtype=dtype, device=device)
6857                    elif dtype.is_signed:
6858                        tensor = torch.randint(-5, 15, size, dtype=dtype, device=device)
6859                    else:
6860                        tensor = torch.randint(0, 10, size, dtype=dtype, device=device)
6861
6862                    # index_add calls atomicAdd on cuda.
6863                    zeros = torch.zeros(size, dtype=dtype, device=device)
6864
6865                    added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor)
6866                    self.assertEqual(added, tensor)
6867
6868                    added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor, alpha=-1)
6869                    self.assertEqual(added, -tensor)
6870
6871    @unittest.mock.patch.object(torch._dynamo.config, "suppress_errors", False)
6872    @set_default_dtype(torch.double)
6873    def test_index_add_correctness(self):
6874        # Check whether index_add can get correct result when
6875        # alpha is 1, and dtype of index is torch.long,
6876        # i.e., using scatter_add
6877        def helper(dim, dtype, device, size_result, size_source):
6878            tensor = torch.zeros(size_result, dtype=dtype, device=device)
6879            index = torch.randint(0, size_result[dim], (size_source[dim],),
6880                                  dtype=torch.long, device=device)
6881            if dtype.is_floating_point or dtype.is_complex:
6882                source = torch.rand(size_source, dtype=dtype, device=device)
6883            elif dtype.is_signed:
6884                source = torch.randint(-2, 5, size_source, dtype=dtype, device=device)
6885            else:
6886                source = torch.randint(0, 5, size_source, dtype=dtype, device=device)
6887
6888            ref_out = tensor.index_add(dim, index, source, alpha=2.) / 2.
6889            ref_out = ref_out.to(dtype=dtype)
6890            out = tensor.index_add(dim, index, source)
6891            if device == 'cuda':
6892                self.assertEqual(out, ref_out, atol=1e-2, rtol=1e-2)
6893            else:
6894                # scatter_add uses fp32 as accumulate type, while index_add doesn't.
6895                self.assertEqual(out, ref_out.to(dtype=dtype), atol=1e-2, rtol=1e-2)
6896
6897        for dim in [-1, -2, -3]:
6898            for dtype in all_types_and_complex_and(torch.half, torch.bfloat16):
6899                for device in get_all_device_types():
6900                    for size in [(2, 512, 256), (5, 256, 256)]:
6901                        helper(dim, dtype, device, size, size)
6902
6903                # Check bound
6904                result = torch.zeros(1, 512, 256, dtype=dtype)
6905                source = torch.ones(1, 512, 256, dtype=dtype)
6906                index = torch.ones(257).to(dtype=torch.long)
6907                self.assertRaises(RuntimeError, lambda: result.index_add_(dim, index, source))
6908                index = (torch.ones(256) * 257).to(dtype=torch.long)
6909                self.assertRaises(RuntimeError, lambda: result.index_add_(dim, index, source))
6910
6911    def test_index_add_cornercase(self):
6912        for device in get_all_device_types():
6913            dest = torch.randn((), device=device)
6914            index = torch.tensor([0], device=device)
6915            source = torch.randn(1, 1, 1, device=device)
6916            with self.assertRaisesRegex(
6917                RuntimeError,
6918                r"source tensor shape must match self tensor shape, excluding the specified dimension",
6919            ):
6920                dest.index_add(0, index, source)
6921
6922    def test_linspace_logspace(self):
6923        # Ensure the output does not require grad regardless of inputs requiring gard or not.
6924        # The output of factory functions should not be part of any computational graph.
6925        start = 0.0
6926        end = 3.0
6927
6928        for step in [0, 1, 2]:
6929            self.assertFalse(
6930                torch.linspace(
6931                    torch.tensor(start, requires_grad=True),
6932                    torch.tensor(end, requires_grad=True), step
6933                ).requires_grad
6934            )
6935            self.assertFalse(torch.linspace(torch.tensor(start, requires_grad=True), end, step).requires_grad)
6936            self.assertFalse(torch.linspace(start, torch.tensor(end, requires_grad=True), step).requires_grad)
6937            self.assertFalse(
6938                torch.logspace(
6939                    torch.tensor(start, requires_grad=True),
6940                    torch.tensor(end, requires_grad=True), step
6941                ).requires_grad
6942            )
6943            self.assertFalse(torch.logspace(torch.tensor(start, requires_grad=True), end, step).requires_grad)
6944            self.assertFalse(torch.logspace(start, torch.tensor(end, requires_grad=True), step).requires_grad)
6945
6946    # FIXME: move to shape ops test suite
6947    def test_unflatten(self):
6948        # test args: tensor, int, sizes
6949        self.assertEqual(torch.tensor([]).unflatten(0, (0, 1)), torch.empty(0, 1))
6950        self.assertEqual(torch.tensor([1]).unflatten(0, (1, 1)), torch.tensor([[1]]))
6951        self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, (2, 2)), torch.tensor([[1, 2], [3, 4]]))
6952        self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, [2, 2]), torch.tensor([[1, 2], [3, 4]]))
6953        self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, torch.Size([2, 2])), torch.tensor([[1, 2], [3, 4]]))
6954        self.assertEqual(torch.ones(2, 10).unflatten(1, (5, 2)), torch.ones(2, 5, 2))
6955        self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, (-1, 2)),
6956                         torch.tensor([[1, 2], [3, 4]]))
6957        self.assertEqual(torch.ones(2, 10).unflatten(1, (5, -1)),
6958                         torch.ones(2, 5, 2))
6959        self.assertEqual(torch.ones(2, 10).unflatten(1, (-1,)),
6960                         torch.ones(2, 10))
6961        self.assertEqual(torch.ones(2, 3 * 4 * 5 * 6).unflatten(1, (3, 4, -1, 6)),
6962                         torch.ones(2, 3, 4, 5, 6))
6963        self.assertEqual(torch.ones(2, 0, 2).unflatten(1, (3, -1, 4, 5)),
6964                         torch.ones(2, 3, 0, 4, 5, 2))
6965
6966        # test invalid args: tensor, str, sizes
6967        with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"):
6968            torch.tensor([1]).unflatten('A', (1, 1))
6969
6970        # test invalid args: tensor, str, namedshape
6971        with self.assertRaisesRegex(RuntimeError, r"Name 'A' not found in Tensor\[None\]."):
6972            torch.ones(4).unflatten('A', (('A', 2), ('B', 2)))
6973
6974        # test other invalid arguments
6975        with self.assertRaisesRegex(RuntimeError, r"sizes must be non-empty"):
6976            torch.tensor([1]).unflatten(0, [])
6977        with self.assertRaisesRegex(RuntimeError, r"Provided sizes \[2, 2\] don't multiply up to the size of dim 0 \(1\)"):
6978            torch.tensor([1]).unflatten(0, [2, 2])
6979        with self.assertRaisesRegex(IndexError, r"Dimension specified as 0 but tensor has no dimensions"):
6980            torch.tensor(1).unflatten(0, [0])
6981        with self.assertRaisesRegex(RuntimeError, r"only one dimension can be inferred"):
6982            torch.randn(5, 10).unflatten(1, (-1, -1))
6983        with self.assertRaisesRegex(RuntimeError,
6984                                    r"Provided sizes \[-1, 4\] don't multiply up to the size of dim 1 \(10\)"):
6985            torch.randn(5, 10).unflatten(1, (-1, 4))
6986        with self.assertRaisesRegex(RuntimeError,
6987                                    r"the unspecified dimension size -1 can be any value and is ambiguous"):
6988            torch.randn(2, 0).unflatten(1, (2, -1, 0))
6989
6990    # Test that warnings generated from C++ are translated to the correct type
6991    def test_warn_types(self):
6992        test_cases = [
6993            # function, warning type, message
6994            (torch._C._warn, UserWarning, r"Test message for TORCH_WARN"),
6995            (torch._C._warn_deprecation, DeprecationWarning, r"Test message for TORCH_WARN_DEPRECATION"),
6996        ]
6997
6998        for fn, warning_type, message in test_cases:
6999            with warnings.catch_warnings(record=True) as w:
7000                warnings.resetwarnings()
7001                warnings.filterwarnings('always', category=warning_type)
7002                fn()
7003
7004                self.assertEqual(len(w), 1, msg=f'{warning_type} not raised')
7005                warning = w[0].message
7006                self.assertTrue(isinstance(warning, warning_type), msg=f'{warning_type} not raised')
7007                self.assertTrue(re.search(
7008                    message,
7009                    str(warning)))
7010
7011    def test_structseq_repr(self):
7012        a = torch.arange(250).reshape(5, 5, 10)
7013        expected = """
7014        torch.return_types.max(
7015        values=tensor([[ 40,  41,  42,  43,  44,  45,  46,  47,  48,  49],
7016                [ 90,  91,  92,  93,  94,  95,  96,  97,  98,  99],
7017                [140, 141, 142, 143, 144, 145, 146, 147, 148, 149],
7018                [190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
7019                [240, 241, 242, 243, 244, 245, 246, 247, 248, 249]]),
7020        indices=tensor([[4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
7021                [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
7022                [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
7023                [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
7024                [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]))"""
7025        self.assertEqual(repr(a.max(1)), textwrap.dedent(expected).strip())
7026
7027    def test_is_same_size(self):
7028        t1 = torch.empty(3, 4, 9, 10)
7029        t2 = torch.empty(3, 4)
7030        t3 = torch.empty(1, 9, 3, 3)
7031        t4 = torch.empty(3, 4, 9, 10)
7032
7033        self.assertFalse(t1.is_same_size(t2))
7034        self.assertFalse(t1.is_same_size(t3))
7035        self.assertTrue(t1.is_same_size(t4))
7036
7037        nt1 = torch.nested.nested_tensor([torch.ones(2, 4), torch.ones(3, 4), torch.ones(5, 4)])
7038        nt2 = torch.nested.nested_tensor([torch.ones(2, 4), torch.ones(2, 4), torch.ones(2, 4)])
7039        nt3 = torch.nested.nested_tensor([torch.ones(2, 4, 5), torch.ones(2, 6, 5)])
7040        nt4 = torch.nested.nested_tensor([torch.ones(2, 4), torch.ones(3, 4), torch.ones(5, 4)])
7041
7042        self.assertFalse(nt1.is_same_size(nt2))
7043        self.assertFalse(nt1.is_same_size(nt3))
7044        self.assertTrue(nt1.is_same_size(nt4))
7045        with self.assertRaisesRegex(RuntimeError, "Expected both self and other to be nested tensors."):
7046            t1.is_same_size(nt1)
7047
7048        with self.assertRaisesRegex(RuntimeError, "Expected both self and other to be nested tensors."):
7049            nt1.is_same_size(t1)
7050
7051    def test_tensor_set(self):
7052        t1 = torch.tensor([])
7053        t2 = torch.empty(3, 4, 9, 10).uniform_()
7054        t1.set_(t2)
7055        self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
7056        size = torch.Size([9, 3, 4, 10])
7057        t1.set_(t2.storage(), 0, size)
7058        self.assertEqual(t1.size(), size)
7059        t1.set_(t2.storage(), 0, tuple(size))
7060        self.assertEqual(t1.size(), size)
7061        self.assertEqual(t1.stride(), (120, 40, 10, 1))
7062        stride = (10, 360, 90, 1)
7063        t1.set_(t2.storage(), 0, size, stride)
7064        self.assertEqual(t1.stride(), stride)
7065        t1.set_(t2.storage(), 0, size=size, stride=stride)
7066        self.assertEqual(t1.size(), size)
7067        self.assertEqual(t1.stride(), stride)
7068
7069        # test argument names
7070        t1 = torch.tensor([])
7071        # 1. case when source is tensor
7072        t1.set_(source=t2)
7073        self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
7074        # 2. case when source is storage
7075        t1.set_(source=t2.storage())
7076        self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
7077        # 3. case when source is storage, and other args also specified
7078        t1.set_(source=t2.storage(), storage_offset=0, size=size, stride=stride)
7079        self.assertEqual(t1.size(), size)
7080        self.assertEqual(t1.stride(), stride)
7081
7082        t1 = torch.tensor([True, True], dtype=torch.bool)
7083        t2 = torch.tensor([False, False], dtype=torch.bool)
7084        t1.set_(t2)
7085        self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
7086
7087    def test_tensor_set_errors(self):
7088        f_cpu = torch.randn((2, 3), dtype=torch.float32)
7089        d_cpu = torch.randn((2, 3), dtype=torch.float64)
7090
7091        # change dtype
7092        self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu.storage()))
7093        self.assertRaises(RuntimeError,
7094                          lambda: f_cpu.set_(d_cpu.storage(), 0, d_cpu.size(), d_cpu.stride()))
7095        self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu))
7096
7097        # change device
7098        if torch.cuda.is_available():
7099            f_cuda = torch.randn((2, 3), dtype=torch.float32, device='cuda')
7100
7101            # cpu -> cuda
7102            self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda.storage()))
7103            self.assertRaises(RuntimeError,
7104                              lambda: f_cpu.set_(f_cuda.storage(), 0, f_cuda.size(), f_cuda.stride()))
7105            self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda))
7106
7107            # cuda -> cpu
7108            self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu.storage()))
7109            self.assertRaises(RuntimeError,
7110                              lambda: f_cuda.set_(f_cpu.storage(), 0, f_cpu.size(), f_cpu.stride()))
7111            self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu))
7112
7113    # FIXME: move this test test_testing.py (along with allclose testing)
7114    # NOTE: test_equal will be deprecated in favor of torch.testing.assert_close
7115    #   once torch.testing is out of beta
7116    def test_equal(self):
7117        devices = [torch.cpu, torch.cuda]
7118        for device in ["cpu", "cuda"]:
7119            if device == "cuda" and not torch.cuda.is_available():
7120                continue
7121
7122            # Contiguous, 1D
7123            t1 = torch.tensor((3., 4., 9., 10.), device=device)
7124            t2 = t1.contiguous()
7125            t3 = torch.tensor((1., 9., 3., 10.), device=device)
7126            t4 = torch.tensor((3., 4., 9.), device=device)
7127            t5 = torch.tensor([], device=device)
7128            self.assertTrue(t1.equal(t2))
7129            self.assertFalse(t1.equal(t3))
7130            self.assertFalse(t1.equal(t4))
7131            self.assertFalse(t1.equal(t5))
7132            self.assertTrue(torch.equal(t1, t2))
7133            self.assertFalse(torch.equal(t1, t3))
7134            self.assertFalse(torch.equal(t1, t4))
7135            self.assertFalse(torch.equal(t1, t5))
7136
7137            # Non contiguous, 2D
7138            s = torch.tensor(((1, 2, 3, 4), (5, 6, 7, 8)), device=device)
7139            s1 = s[:, 1:3]
7140            s2 = s1.clone()
7141            s3 = torch.tensor(((2, 3), (6, 7)), device=device)
7142            s4 = torch.tensor(((0, 0), (0, 0)), device=device)
7143
7144            self.assertFalse(s1.is_contiguous())
7145            self.assertTrue(s1.equal(s2))
7146            self.assertTrue(s1.equal(s3))
7147            self.assertFalse(s1.equal(s4))
7148            self.assertTrue(torch.equal(s1, s2))
7149            self.assertTrue(torch.equal(s1, s3))
7150            self.assertFalse(torch.equal(s1, s4))
7151
7152            # Different dtypes
7153            x = torch.tensor((1, 2, 3), dtype=torch.float, device=device)
7154            y = torch.tensor((1, 2, 3), dtype=torch.int, device=device)
7155            z = torch.tensor((1, -1), dtype=torch.int, device=device)
7156            self.assertTrue(torch.equal(x, y))
7157            self.assertFalse(torch.equal(z, x))
7158
7159            # Fast path test: tensor flags, like neg and conj
7160            neg_0 = torch.tensor((1, 2, 3), dtype=torch.float, device=device)
7161            neg_1 = neg_0._neg_view()
7162            self.assertTrue(neg_1.is_neg())
7163            self.assertEqual(neg_0.data_ptr(), neg_1.data_ptr())
7164            self.assertEqual(neg_0.storage_offset(), neg_1.storage_offset())
7165            self.assertEqual(neg_0.stride(), neg_1.stride())
7166            self.assertEqual(neg_0.size(), neg_1.size())
7167            self.assertFalse(torch.equal(neg_0, neg_1))
7168            # FIXME: Disable the following check due to the inductor failure
7169            # See https://github.com/pytorch/pytorch/issues/100340 and
7170            # https://github.com/pytorch/pytorch/issues/98175
7171            if not TEST_WITH_TORCHINDUCTOR:
7172                self.assertTrue(torch.equal(neg_0, neg_1._neg_view()))
7173
7174            conj_0 = torch.tensor([1.0 + 2.0j, 2.0 + 1.0j], device=device)
7175            conj_1 = conj_0.conj()
7176            self.assertTrue(conj_1.is_conj())
7177            self.assertEqual(conj_0.data_ptr(), conj_1.data_ptr())
7178            self.assertEqual(conj_0.storage_offset(), conj_1.storage_offset())
7179            self.assertEqual(conj_0.stride(), conj_1.stride())
7180            self.assertEqual(conj_0.size(), conj_1.size())
7181            self.assertFalse(torch.equal(conj_0, conj_1))
7182            # FIXME: Disable the following check due to the inductor failure
7183            # See https://github.com/pytorch/pytorch/issues/100340 and
7184            # https://github.com/pytorch/pytorch/issues/98175
7185            if not TEST_WITH_TORCHINDUCTOR:
7186                self.assertTrue(torch.equal(conj_0, conj_1.conj()))
7187
7188            # Fast path test: two tensors share the same storage, but different dtype
7189            s_0 = torch.rand((2, 3), dtype=torch.float, device=device)
7190            s_1 = s_0.view(dtype=torch.int32)
7191            self.assertEqual(s_0.data_ptr(), s_1.data_ptr())
7192            self.assertEqual(s_0.storage_offset(), s_1.storage_offset())
7193            self.assertEqual(s_0.stride(), s_1.stride())
7194            self.assertEqual(s_0.size(), s_1.size())
7195            self.assertFalse(torch.equal(s_0, s_1))
7196
7197            # Fast path test: two tensors share the same storage, but different strides
7198            t_0 = torch.rand((2, 3), dtype=torch.float, device=device)
7199            t_1 = t_0.t()
7200            self.assertEqual(t_0.data_ptr(), t_1.data_ptr())
7201            self.assertEqual(t_0.storage_offset(), t_1.storage_offset())
7202            self.assertNotEqual(t_0.stride(), t_1.stride())
7203            self.assertNotEqual(t_0.size(), t_1.size())
7204            self.assertFalse(torch.equal(t_0, t_1))
7205
7206            # Fast path: tensor containing `nan` is not equal to self
7207            for dtype in floating_and_complex_types():
7208                t = torch.tensor([1., float('nan')], dtype=dtype)
7209                self.assertFalse(torch.equal(t, t))
7210
7211    def test_element_size(self):
7212        byte = torch.ByteStorage().element_size()
7213        char = torch.CharStorage().element_size()
7214        short = torch.ShortStorage().element_size()
7215        int = torch.IntStorage().element_size()
7216        long = torch.LongStorage().element_size()
7217        float = torch.FloatStorage().element_size()
7218        double = torch.DoubleStorage().element_size()
7219        bool = torch.BoolStorage().element_size()
7220        bfloat16 = torch.BFloat16Storage().element_size()
7221        complexfloat = torch.ComplexFloatStorage().element_size()
7222        complexdouble = torch.ComplexDoubleStorage().element_size()
7223
7224        self.assertEqual(byte, torch.ByteTensor().element_size())
7225        self.assertEqual(byte, torch.ByteTensor().itemsize)
7226        self.assertEqual(char, torch.CharTensor().element_size())
7227        self.assertEqual(char, torch.CharTensor().itemsize)
7228        self.assertEqual(short, torch.ShortTensor().element_size())
7229        self.assertEqual(short, torch.ShortTensor().itemsize)
7230        self.assertEqual(int, torch.IntTensor().element_size())
7231        self.assertEqual(int, torch.IntTensor().itemsize)
7232        self.assertEqual(long, torch.LongTensor().element_size())
7233        self.assertEqual(long, torch.LongTensor().itemsize)
7234        self.assertEqual(float, torch.FloatTensor().element_size())
7235        self.assertEqual(float, torch.FloatTensor().itemsize)
7236        self.assertEqual(double, torch.DoubleTensor().element_size())
7237        self.assertEqual(double, torch.DoubleTensor().itemsize)
7238        self.assertEqual(bool, torch.BoolTensor().element_size())
7239        self.assertEqual(bool, torch.BoolTensor().itemsize)
7240        self.assertEqual(bfloat16, torch.tensor([], dtype=torch.bfloat16).element_size())
7241        self.assertEqual(bfloat16, torch.tensor([], dtype=torch.bfloat16).itemsize)
7242        self.assertEqual(complexfloat, torch.tensor([], dtype=torch.complex64).element_size())
7243        self.assertEqual(complexfloat, torch.tensor([], dtype=torch.complex64).itemsize)
7244        self.assertEqual(complexdouble, torch.tensor([], dtype=torch.complex128).element_size())
7245        self.assertEqual(complexdouble, torch.tensor([], dtype=torch.complex128).itemsize)
7246
7247        self.assertGreater(byte, 0)
7248        self.assertGreater(char, 0)
7249        self.assertGreater(short, 0)
7250        self.assertGreater(int, 0)
7251        self.assertGreater(long, 0)
7252        self.assertGreater(float, 0)
7253        self.assertGreater(double, 0)
7254        self.assertGreater(bool, 0)
7255        self.assertGreater(bfloat16, 0)
7256        self.assertGreater(complexfloat, 0)
7257        self.assertGreater(complexdouble, 0)
7258
7259        # These tests are portable, not necessarily strict for your system.
7260        self.assertEqual(byte, 1)
7261        self.assertEqual(char, 1)
7262        self.assertEqual(bool, 1)
7263        self.assertGreaterEqual(short, 2)
7264        self.assertGreaterEqual(int, 2)
7265        self.assertGreaterEqual(int, short)
7266        self.assertGreaterEqual(long, 4)
7267        self.assertGreaterEqual(long, int)
7268        self.assertGreaterEqual(double, float)
7269
7270    def test_permute(self):
7271        orig = [1, 2, 3, 4, 5, 6, 7]
7272        perm = torch.randperm(7).tolist()
7273        x = torch.empty(*orig).fill_(0)
7274        new = [i - 1 for i in x.permute(*perm).size()]
7275        self.assertEqual(perm, new)
7276        self.assertEqual(x.size(), orig)
7277
7278    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
7279    def test_reversed(self):
7280        val = torch.arange(0, 10)
7281        self.assertEqual(reversed(val), torch.arange(9, -1, -1))
7282
7283        val = torch.arange(1, 10).view(3, 3)
7284        self.assertEqual(reversed(val), torch.tensor([[7, 8, 9], [4, 5, 6], [1, 2, 3]]))
7285
7286        val = torch.tensor(42)
7287        self.assertEqual(reversed(val), torch.tensor(42))
7288
7289    def test_contains(self):
7290        x = torch.arange(0, 10)
7291        self.assertEqual(4 in x, True)
7292        self.assertEqual(12 in x, False)
7293
7294        x = torch.arange(1, 10).view(3, 3)
7295        val = torch.arange(1, 4)
7296        self.assertEqual(val in x, True)
7297        val += 10
7298        self.assertEqual(val in x, False)
7299
7300        self.assertRaisesRegex(
7301            RuntimeError,
7302            f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {str}.",
7303            lambda: "foo" in x)
7304        self.assertRaisesRegex(
7305            RuntimeError,
7306            f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {type([1, 2])}.",
7307            lambda: [1, 2] in x)
7308
7309    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
7310    def test_deepcopy_parameter(self):
7311        from copy import deepcopy
7312        l = torch.nn.Linear(10, 1)
7313        s = l.state_dict(keep_vars=True)
7314        self.assertEqual(torch.nn.Parameter, type(s['weight']))
7315        self.assertEqual(torch.nn.Parameter, type(s['bias']))
7316
7317        s2 = deepcopy(s)
7318        self.assertEqual(torch.nn.Parameter, type(s2['weight']))
7319        self.assertEqual(torch.nn.Parameter, type(s2['bias']))
7320
7321    def test_pickle(self):
7322        import pickle
7323        a = torch.randn(5, 5)
7324        serialized = pickle.dumps(a)
7325        b = pickle.loads(serialized)
7326        self.assertEqual(a, b)
7327
7328    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
7329    def test_pickle_parameter(self):
7330        import pickle
7331        a = torch.nn.Parameter(torch.randn(5, 5))
7332        serialized = pickle.dumps(a)
7333        b = pickle.loads(serialized)
7334        self.assertTrue(isinstance(b, torch.nn.Parameter))
7335        self.assertEqual(a.requires_grad, b.requires_grad)
7336        self.assertEqual(a, b)
7337
7338    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
7339    def test_pickle_parameter_no_requires_grad(self):
7340        import pickle
7341        a = torch.nn.Parameter(torch.randn(5, 5), requires_grad=False)
7342        serialized = pickle.dumps(a)
7343        b = pickle.loads(serialized)
7344        self.assertTrue(isinstance(b, torch.nn.Parameter))
7345        self.assertEqual(a.requires_grad, b.requires_grad)
7346        self.assertEqual(a, b)
7347
7348    def test_pickle_dtype(self):
7349        t = torch.float32
7350        serialized = pickle.dumps(t)
7351        b = pickle.loads(serialized)
7352        self.assertTrue(isinstance(b, torch.dtype))
7353        self.assertEqual(id(b), id(t))
7354
7355    def test_pickle_size(self):
7356        a = torch.rand(10).size()
7357        serialized = pickle.dumps(a)
7358        b = pickle.loads(serialized)
7359        self.assertTrue(isinstance(b, torch.Size))
7360        self.assertEqual(a, b)
7361
7362    def test_pickle_function(self):
7363        # https://github.com/pytorch/pytorch/issues/37703
7364        a = torch.tanh
7365        serialized = pickle.dumps(a)
7366        b = pickle.loads(serialized)
7367        self.assertEqual(a, b)
7368
7369    def test_generator_cpu(self):
7370        # test default generators are equal
7371        self.assertEqual(torch.default_generator, torch.default_generator)
7372
7373        # tests Generator API
7374        # manual_seed, seed, initial_seed, get_state, set_state
7375        g1 = torch.Generator()
7376        g2 = torch.Generator()
7377        g1.manual_seed(12345)
7378        g2.manual_seed(12345)
7379        self.assertEqual(g1.initial_seed(), g2.initial_seed())
7380
7381        g1.seed()
7382        g2.seed()
7383        self.assertNotEqual(g1.initial_seed(), g2.initial_seed())
7384
7385        g1 = torch.Generator()
7386        g2_state = g2.get_state()
7387        g2_randn = torch.randn(1, generator=g2)
7388        g1.set_state(g2_state)
7389        g1_randn = torch.randn(1, generator=g1)
7390        self.assertEqual(g1_randn, g2_randn)
7391
7392        default_state = torch.default_generator.get_state()
7393        q = torch.empty(100)
7394        g1_normal = q.normal_()
7395        g2 = torch.Generator()
7396        g2.set_state(default_state)
7397        g2_normal = q.normal_(generator=g2)
7398        self.assertEqual(g1_normal, g2_normal)
7399
7400    def test_invalid_generator_raises(self):
7401        self.assertRaises(RuntimeError, lambda: torch.Generator('opengl'))
7402
7403    def test_pickle_generator(self) -> None:
7404        devices = ['cpu']
7405        if torch.cuda.is_available():
7406            devices += ['cuda']
7407
7408        for device in devices:
7409            with self.subTest(device=device):
7410                generator = torch.Generator(device=device).manual_seed(12345)
7411                if device != "cpu":
7412                    generator.set_offset(100)
7413                torch.randn((100, 100), generator=generator, device=device)  # progress the RNG state
7414
7415                reserialized: torch.Generator = pickle.loads(pickle.dumps(generator))
7416
7417                self.assertEqual(generator.device, reserialized.device)
7418                self.assertEqual(generator.initial_seed(), reserialized.initial_seed())
7419                if device != "cpu":
7420                    self.assertEqual(generator.get_offset(), reserialized.get_offset())
7421                torch.testing.assert_close(generator.get_state(), reserialized.get_state())
7422
7423    def _sobol_reference_samples(self, scramble: bool) -> torch.Tensor:
7424        if not scramble:
7425            # theoretical values from Joe Kuo 2010
7426            return torch.tensor(
7427                [
7428                    [0., 0.],
7429                    [0.5, 0.5],
7430                    [0.75, 0.25],
7431                    [0.25, 0.75],
7432                    [0.375, 0.375],
7433                    [0.875, 0.875],
7434                    [0.625, 0.125],
7435                    [0.125, 0.625],
7436                ],
7437            )
7438        else:
7439            # theoretical values unknown: convergence properties checked
7440            return torch.tensor(
7441                [
7442                    [0.50860737, 0.29320504],
7443                    [0.07116939, 0.89594537],
7444                    [0.49354145, 0.11524881],
7445                    [0.93097717, 0.70244044],
7446                    [0.87266153, 0.23887917],
7447                    [0.31021884, 0.57600391],
7448                    [0.13687253, 0.42054182],
7449                    [0.69931293, 0.77336788],
7450                ],
7451            )
7452
7453    def test_sobolengine_bounds(self, scramble: bool = False):
7454        engine = torch.quasirandom.SobolEngine(100, scramble=scramble, seed=123456)
7455        sample = engine.draw(512)
7456        self.assertTrue(torch.all(sample >= 0))
7457        self.assertTrue(torch.all(sample <= 1))
7458
7459    def test_sobolengine_bounds_scrambled(self):
7460        self.test_sobolengine_bounds(scramble=True)
7461
7462    def test_sobolengine_draw(self, scramble: bool = False):
7463        ref_sample = self._sobol_reference_samples(scramble=scramble)
7464        engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
7465        sample = engine.draw(n=len(ref_sample))
7466        self.assertEqual(sample, ref_sample)
7467        self.assertEqual(engine.num_generated, len(ref_sample))
7468
7469    def test_sobolengine_draw_scrambled(self):
7470        self.test_sobolengine_draw(scramble=True)
7471
7472    def test_sobolengine_first_point(self):
7473        for dtype in (torch.float, torch.double):
7474            engine = torch.quasirandom.SobolEngine(2, scramble=False)
7475            sample = engine.draw(1, dtype=dtype)
7476            self.assertTrue(torch.all(sample == 0))
7477            self.assertEqual(sample.dtype, dtype)
7478        for dtype in (torch.float, torch.double):
7479            engine = torch.quasirandom.SobolEngine(2, scramble=True, seed=123456)
7480            sample = engine.draw(1, dtype=dtype)
7481            self.assertTrue(torch.all(sample != 0))
7482            self.assertEqual(sample.dtype, dtype)
7483
7484    def test_sobolengine_continuing(self, scramble: bool = False):
7485        ref_sample = self._sobol_reference_samples(scramble=scramble)
7486        engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
7487        n_half = len(ref_sample) // 2
7488        _ = engine.draw(n=n_half)
7489        sample = engine.draw(n=n_half)
7490        torch.testing.assert_close(sample, ref_sample[n_half:])
7491
7492    def test_sobolengine_continuing_scrambled(self):
7493        self.test_sobolengine_continuing(scramble=True)
7494
7495    def test_sobolengine_reset(self, scramble: bool = False):
7496        ref_sample = self._sobol_reference_samples(scramble=scramble)
7497        engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
7498        _ = engine.draw(n=len(ref_sample) // 2)
7499        engine.reset()
7500        self.assertEqual(engine.num_generated, 0)
7501        sample = engine.draw(n=len(ref_sample))
7502        torch.testing.assert_close(sample, ref_sample)
7503
7504    def test_sobolengine_reset_scrambled(self):
7505        self.test_sobolengine_reset(scramble=True)
7506
7507    def test_sobolengine_fast_forward(self, scramble: bool = False):
7508        ref_sample = self._sobol_reference_samples(scramble=scramble)
7509        engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
7510        engine.fast_forward(4)
7511        sample = engine.draw(n=4)
7512        torch.testing.assert_close(sample, ref_sample[4:])
7513        # alternate fast forwarding with sampling
7514        engine.reset()
7515        even_draws = []
7516        for i in range(8):
7517            if i % 2 == 0:
7518                even_draws.append(engine.draw())
7519            else:
7520                engine.fast_forward(1)
7521        torch.testing.assert_close(
7522            ref_sample[[i for i in range(8) if i % 2 == 0]],
7523            torch.from_numpy(np.concatenate(even_draws)),
7524        )
7525
7526    def test_sobolengine_fast_forward_scrambled(self):
7527        self.test_sobolengine_fast_forward(scramble=True)
7528
7529    def test_sobolengine_default_dtype(self):
7530        engine = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123456)
7531        # Check that default dtype is correctly handled
7532        self.assertEqual(engine.draw(n=5).dtype, torch.float32)
7533        with set_default_dtype(torch.float64):
7534            engine = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123456)
7535            # Check that default dtype is correctly handled (when set to float64)
7536            self.assertEqual(engine.draw(n=5).dtype, torch.float64)
7537            # Check that explicitly passed dtype is adhered to
7538            self.assertEqual(engine.draw(n=5, dtype=torch.float32).dtype, torch.float32)
7539            # Reinitialize the engine and check that first draw dtype is correctly handled
7540            engine = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123456)
7541            self.assertEqual(engine.draw(n=5, dtype=torch.float32).dtype, torch.float32)
7542
7543    @skipIfTorchDynamo("np.float64 restored as float32 after graph break.")
7544    def test_sobolengine_distribution(self, scramble=False):
7545        d = 50
7546        engine = torch.quasirandom.SobolEngine(d, scramble=scramble, seed=123456)
7547        sample = engine.draw(1024)
7548        torch.testing.assert_close(
7549            torch.mean(sample, dim=0), torch.full((d,), 0.5), atol=2, rtol=2
7550        )
7551        torch.testing.assert_close(
7552            np.percentile(sample, 25, axis=0), np.repeat(0.25, d), atol=2, rtol=2
7553        )
7554        torch.testing.assert_close(
7555            np.percentile(sample, 75, axis=0), np.repeat(0.75, d), atol=2, rtol=2
7556        )
7557
7558    @skipIfTorchDynamo("np.float64 restored as float32 after graph break.")
7559    def test_sobolengine_distribution_scrambled(self):
7560        self.test_sobolengine_distribution(scramble=True)
7561
7562    def test_sobolengine_draw_base2(self, scramble=False):
7563        ref_sample = self._sobol_reference_samples(scramble=scramble)
7564        engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
7565        sample = engine.draw_base2(2)
7566        self.assertEqual(ref_sample[:4], sample)
7567        # resampling still having N=2**n
7568        sample = engine.draw_base2(2)
7569        self.assertEqual(ref_sample[4:8], sample)
7570
7571    def test_sobolengine_draw_base2_scrambled(self):
7572        self.test_sobolengine_draw_base2(scramble=True)
7573
7574    def test_sobolengine_raise(self):
7575        maxdim = torch.quasirandom.SobolEngine.MAXDIM
7576        with self.assertRaises(ValueError):
7577            torch.quasirandom.SobolEngine(maxdim + 1)
7578
7579    def test_sobolengine_high_dim(self):
7580        engine = torch.quasirandom.SobolEngine(1111, scramble=False, seed=123456)
7581        samples1 = engine.draw()
7582        vals1, counts1 = torch.unique(samples1, return_counts=True)
7583        samples2 = engine.draw()
7584        vals2, counts2 = torch.unique(samples2, return_counts=True)
7585        self.assertEqual(vals1.item(), 0.0)
7586        self.assertEqual(counts1.item(), 1111)
7587        self.assertEqual(vals2.item(), 0.5)
7588        self.assertEqual(counts1.item(), 1111)
7589
7590    def test_parsing_int64(self):
7591        # accepts integer arguments
7592        x = torch.cumsum(torch.ones(5, 5), 0)
7593        self.assertEqual(x, torch.cumsum(torch.ones(5, 5), torch.tensor(0)))
7594        # doesn't accept floating point variables
7595        self.assertRaises(TypeError, lambda: torch.cumsum(torch.ones(5, 5), torch.tensor(0.)))
7596
7597    def test_parsing_double(self):
7598        # accepts floating point and integer arguments
7599        x = torch.randn(2, 3)
7600        torch.isclose(x, x, 1, 1)
7601        self.assertTrue(torch.isclose(x, x, 1, 1).all())
7602        self.assertTrue(torch.isclose(x, x, 1.5, 1.).all())
7603        # accepts floating point and integer tensors
7604        self.assertTrue(torch.isclose(x, x, torch.tensor(1), torch.tensor(1)).all())
7605        self.assertTrue(torch.isclose(x, x, torch.tensor(1.5), torch.tensor(1.)).all())
7606        # doesn't accept variables with requires_grad
7607        self.assertRaises(TypeError,
7608                          lambda: torch.isclose(x, x, torch.tensor(1.5), torch.tensor(1., requires_grad=True)).all())
7609
7610    def test_parsing_intlist(self):
7611        #  parse with integer variables
7612        self.assertEqual(torch.Size([3, 4]), torch.ones((torch.tensor(3), torch.tensor(4))).shape)
7613        self.assertEqual(torch.Size([3, 4]), torch.ones(torch.tensor(3), torch.tensor(4)).shape)
7614        # parse with numpy integers
7615        self.assertEqual(torch.Size([3, 4]), torch.ones((np.array(3), np.int64(4))).shape)
7616        self.assertEqual(torch.Size([3, 4]), torch.ones(np.array(3), np.int64(4)).shape)
7617        self.assertEqual(torch.Size([3, 4]), torch.ones((np.int64(3), np.array(4))).shape)
7618        self.assertEqual(torch.Size([3, 4]), torch.ones(np.int64(3), np.array(4)).shape)
7619
7620        # fail parse with float variables
7621        self.assertRaises(TypeError, lambda: torch.ones((torch.tensor(3.), torch.tensor(4))))
7622        # fail parse with numpy floats
7623        self.assertRaises(TypeError, lambda: torch.ones((3., torch.tensor(4))))
7624        self.assertRaises(TypeError, lambda: torch.ones((np.array(3.), torch.tensor(4))))
7625
7626        # fail parse with > 1 element variables
7627        self.assertRaises(TypeError, lambda: torch.ones(torch.tensor(3, 3)))
7628        self.assertRaises(TypeError, lambda: torch.ones(torch.tensor(3, 3)))
7629        self.assertRaises(TypeError, lambda: torch.ones(np.array(3, 3)))
7630        self.assertRaises(TypeError, lambda: torch.ones(np.array(3, 3)))
7631
7632        # fail parse with additional positional args after intlist arg
7633        self.assertRaisesRegex(TypeError,
7634                               "received an invalid combination of arguments",
7635                               lambda: torch.LongTensor((6, 0), 1, 1, 0))
7636        self.assertRaisesRegex(TypeError,
7637                               "missing 1 required positional arguments",
7638                               lambda: torch.tensor().new_zeros((5, 5), 0))
7639
7640    def test_from_buffer(self):
7641        a = bytearray([1, 2, 3, 4])
7642        self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4])
7643        shorts = torch.ShortStorage.from_buffer(a, 'big')
7644        self.assertEqual(shorts.size(), 2)
7645        self.assertEqual(shorts.tolist(), [258, 772])
7646        ints = torch.IntStorage.from_buffer(a, 'little')
7647        self.assertEqual(ints.size(), 1)
7648        self.assertEqual(ints[0], 67305985)
7649        f = bytearray([0x40, 0x10, 0x00, 0x00])
7650        floats = torch.FloatStorage.from_buffer(f, 'big')
7651        self.assertEqual(floats.size(), 1)
7652        self.assertEqual(floats[0], 2.25)
7653
7654        f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40])
7655        bools = torch.BoolStorage.from_buffer(f, 'big')
7656        self.assertEqual(bools.size(), 8)
7657        self.assertEqual(bools.tolist(), [False, True, True, True, True, True, True, True])
7658        self.assertEqual(bools.type(), 'torch.BoolStorage')
7659        self.assertTrue(isinstance(bools, torch.BoolStorage))
7660
7661        f = bytearray(b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9')
7662        bools = torch.BoolStorage.from_buffer(f, 'big')
7663        self.assertEqual(bools.size(), 19)
7664
7665        f = bytearray(b'\0x4A')
7666        bools = torch.BoolStorage.from_buffer(f, 'big')
7667        self.assertEqual(bools.size(), 4)
7668        self.assertEqual(bools.tolist(), [False, True, True, True])
7669        bytes = torch.ByteStorage.from_buffer(a)
7670        self.assertEqual(bytes.nbytes(), 4)
7671        self.assertEqual(bytes.tolist(), [1, 2, 3, 4])
7672        self.assertTrue(isinstance(bytes, torch.ByteStorage))
7673
7674    def test_storage_error(self):
7675        quantized_storages = [
7676            torch.QInt32Storage,
7677            torch.QInt8Storage,
7678            torch.QUInt2x4Storage,
7679            torch.QUInt4x2Storage,
7680            torch.QUInt8Storage,
7681        ]
7682
7683        with self.assertRaisesRegex(RuntimeError, r"Only child classes of _LegacyStorage can be instantiated"):
7684            torch.storage._LegacyStorage()
7685
7686        for storage_class in torch._storage_classes:
7687            if storage_class in [torch.UntypedStorage, torch.TypedStorage]:
7688                continue
7689
7690            device = 'cuda' if storage_class.__module__ == 'torch.cuda' else 'cpu'
7691            dtype = storage_class.dtype
7692
7693            if device == 'cuda' and not torch.cuda.is_available():
7694                continue
7695
7696            # Legacy <type>Storage constructor errors
7697            with self.assertRaisesRegex(RuntimeError, r"'device' cannot be specified"):
7698                storage_class(device='cpu')
7699
7700            with self.assertRaisesRegex(RuntimeError, r"'dtype' cannot be specified"):
7701                storage_class(dtype=torch.float)
7702
7703            with self.assertRaisesRegex(TypeError, r"got an unexpected keyword"):
7704                storage_class(sdlkjf=torch.float)
7705
7706            with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"):
7707                storage_class(0, 0)
7708
7709            with self.assertRaisesRegex(TypeError, r"invalid data type"):
7710                storage_class('string')
7711
7712            with self.assertRaisesRegex(TypeError, r"Argument type not recognized"):
7713                storage_class(torch.tensor([]))
7714
7715            s = storage_class()
7716
7717            with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
7718                storage_class(0, wrap_storage=s.untyped())
7719
7720            with self.assertRaisesRegex(TypeError, r"must be UntypedStorage"):
7721                storage_class(wrap_storage=s)
7722
7723            if torch.cuda.is_available():
7724                if storage_class in quantized_storages:
7725                    with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"):
7726                        s.cuda()
7727
7728                else:
7729
7730                    if s.is_cuda:
7731                        s_other_device = s.cpu()
7732                    else:
7733                        s_other_device = s.cuda()
7734
7735                    with self.assertRaisesRegex(RuntimeError, r"Device of 'wrap_storage' must be"):
7736                        storage_class(wrap_storage=s_other_device.untyped())
7737
7738            # TypedStorage constructor errors
7739            with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
7740                torch.TypedStorage(0, wrap_storage=s.untyped(), dtype=dtype)
7741
7742            with self.assertRaisesRegex(RuntimeError, r"Argument 'dtype' must be specified"):
7743                torch.TypedStorage(wrap_storage=s.untyped())
7744
7745            with self.assertRaisesRegex(TypeError, r"Argument 'dtype' must be torch.dtype"):
7746                torch.TypedStorage(wrap_storage=s.untyped(), dtype=0)
7747
7748            with self.assertRaisesRegex(RuntimeError, r"Argument 'device' should not be specified"):
7749                torch.TypedStorage(wrap_storage=s.untyped(), dtype=dtype, device=device)
7750
7751            with self.assertRaisesRegex(TypeError, r"Argument 'wrap_storage' must be UntypedStorage"):
7752                torch.TypedStorage(wrap_storage=s, dtype=dtype)
7753
7754            with self.assertRaisesRegex(RuntimeError, r"Storage device not recognized"):
7755                torch.TypedStorage(dtype=dtype, device='xla')
7756
7757            if torch.cuda.is_available():
7758                if storage_class in quantized_storages:
7759                    with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"):
7760                        torch.TypedStorage(dtype=dtype, device='cuda')
7761
7762            with self.assertRaisesRegex(TypeError, r"Argument type not recognized"):
7763                torch.TypedStorage(torch.tensor([]), dtype=dtype, device=device)
7764
7765            with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"):
7766                torch.TypedStorage(0, 0, dtype=dtype, device=device)
7767
7768            if isinstance(s, torch.TypedStorage):
7769                s_other = torch.TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
7770
7771                with self.assertRaisesRegex(RuntimeError, r'cannot set item'):
7772                    s.fill_(s_other)
7773
7774    def test_storage_error_no_attribute(self):
7775        storage_classes = [
7776            torch.cuda.ByteStorage,
7777            torch.cuda.FloatStorage,
7778        ]
7779        for storage_class in storage_classes:
7780            with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
7781                storage_class.from_buffer()
7782
7783            with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
7784                storage_class._new_with_weak_ptr()
7785
7786            with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
7787                storage_class._new_shared_filename(0, 0, 0)
7788
7789    def test_storage_casts(self):
7790        storage = torch.IntStorage([-1, 0, 1, 2, 3, 4])
7791        self.assertEqual(storage.size(), 6)
7792        self.assertEqual(storage.tolist(), [-1, 0, 1, 2, 3, 4])
7793        self.assertEqual(storage.type(), 'torch.IntStorage')
7794        self.assertIs(storage.dtype, torch.int32)
7795
7796        floatStorage = storage.float()
7797        self.assertEqual(floatStorage.size(), 6)
7798        self.assertEqual(floatStorage.tolist(), [-1, 0, 1, 2, 3, 4])
7799        self.assertEqual(floatStorage.type(), 'torch.FloatStorage')
7800        self.assertEqual(floatStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
7801        self.assertIs(floatStorage.dtype, torch.float32)
7802
7803        halfStorage = storage.half()
7804        self.assertEqual(halfStorage.size(), 6)
7805        self.assertEqual(halfStorage.tolist(), [-1, 0, 1, 2, 3, 4])
7806        self.assertEqual(halfStorage.type(), 'torch.HalfStorage')
7807        self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
7808        self.assertIs(halfStorage.dtype, torch.float16)
7809
7810        bfloat16Storage = storage.bfloat16()
7811        self.assertEqual(bfloat16Storage.size(), 6)
7812        self.assertEqual(bfloat16Storage.tolist(), [-1, 0, 1, 2, 3, 4])
7813        self.assertEqual(bfloat16Storage.type(), 'torch.BFloat16Storage')
7814        self.assertEqual(bfloat16Storage.int().tolist(), [-1, 0, 1, 2, 3, 4])
7815        self.assertIs(bfloat16Storage.dtype, torch.bfloat16)
7816
7817        longStorage = storage.long()
7818        self.assertEqual(longStorage.size(), 6)
7819        self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4])
7820        self.assertEqual(longStorage.type(), 'torch.LongStorage')
7821        self.assertEqual(longStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
7822        self.assertIs(longStorage.dtype, torch.int64)
7823
7824        shortStorage = storage.short()
7825        self.assertEqual(shortStorage.size(), 6)
7826        self.assertEqual(shortStorage.tolist(), [-1, 0, 1, 2, 3, 4])
7827        self.assertEqual(shortStorage.type(), 'torch.ShortStorage')
7828        self.assertEqual(shortStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
7829        self.assertIs(shortStorage.dtype, torch.int16)
7830
7831        doubleStorage = storage.double()
7832        self.assertEqual(doubleStorage.size(), 6)
7833        self.assertEqual(doubleStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
7834        self.assertEqual(doubleStorage.type(), 'torch.DoubleStorage')
7835        self.assertEqual(doubleStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
7836        self.assertIs(doubleStorage.dtype, torch.float64)
7837
7838        charStorage = storage.char()
7839        self.assertEqual(charStorage.size(), 6)
7840        self.assertEqual(charStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
7841        self.assertEqual(charStorage.type(), 'torch.CharStorage')
7842        self.assertEqual(charStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
7843        self.assertIs(charStorage.dtype, torch.int8)
7844
7845        byteStorage = storage.byte()
7846        self.assertEqual(byteStorage.size(), 6)
7847        self.assertEqual(byteStorage.tolist(), [255, 0, 1, 2, 3, 4])
7848        self.assertEqual(byteStorage.type(), 'torch.ByteStorage')
7849        self.assertEqual(byteStorage.int().tolist(), [255, 0, 1, 2, 3, 4])
7850        self.assertIs(byteStorage.dtype, torch.uint8)
7851
7852        boolStorage = storage.bool()
7853        self.assertEqual(boolStorage.size(), 6)
7854        self.assertEqual(boolStorage.tolist(), [True, False, True, True, True, True])
7855        self.assertEqual(boolStorage.type(), 'torch.BoolStorage')
7856        self.assertEqual(boolStorage.int().tolist(), [1, 0, 1, 1, 1, 1])
7857        self.assertIs(boolStorage.dtype, torch.bool)
7858
7859        complexfloat_storage = torch.ComplexFloatStorage([-1, 0, 1 + 2j, 2.5j, 3.5, 4 - 2j])
7860        self.assertEqual(complexfloat_storage.size(), 6)
7861        self.assertEqual(complexfloat_storage.tolist(), [-1, 0, 1 + 2j, 2.5j, 3.5, 4 - 2j])
7862        self.assertEqual(complexfloat_storage.type(), 'torch.ComplexFloatStorage')
7863        self.assertIs(complexfloat_storage.dtype, torch.complex64)
7864
7865        complexdouble_storage = complexfloat_storage.complex_double()
7866        self.assertEqual(complexdouble_storage.size(), 6)
7867        self.assertEqual(complexdouble_storage.tolist(), [-1, 0, 1 + 2j, 2.5j, 3.5, 4 - 2j])
7868        self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage')
7869        self.assertIs(complexdouble_storage.dtype, torch.complex128)
7870
7871    def test_storage_byteswap(self):
7872        input = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
7873        swapped_8bytes = [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8]
7874        swapped_4bytes = [3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12]
7875        swapped_2bytes = [1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14]
7876        swapped_1byte = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
7877
7878        storage = torch.storage.TypedStorage(input, dtype=torch.uint8)._untyped_storage
7879
7880        storage_f64 = storage.__copy__()
7881        storage_f64.byteswap(torch.float64)
7882        self.assertEqual(storage_f64.tolist(), swapped_8bytes)
7883
7884        storage_f32 = storage.__copy__()
7885        storage_f32.byteswap(torch.float32)
7886        self.assertEqual(storage_f32.tolist(), swapped_4bytes)
7887
7888        storage_f16 = storage.__copy__()
7889        storage_f16.byteswap(torch.float16)
7890        self.assertEqual(storage_f16.tolist(), swapped_2bytes)
7891
7892        storage_bf16 = storage.__copy__()
7893        storage_bf16.byteswap(torch.bfloat16)
7894        self.assertEqual(storage_bf16.tolist(), swapped_2bytes)
7895
7896        storage_i64 = storage.__copy__()
7897        storage_i64.byteswap(torch.int64)
7898        self.assertEqual(storage_i64.tolist(), swapped_8bytes)
7899
7900        storage_i32 = storage.__copy__()
7901        storage_i32.byteswap(torch.int32)
7902        self.assertEqual(storage_i32.tolist(), swapped_4bytes)
7903
7904        storage_i16 = storage.__copy__()
7905        storage_i16.byteswap(torch.int16)
7906        self.assertEqual(storage_i16.tolist(), swapped_2bytes)
7907
7908        storage_i8 = storage.__copy__()
7909        storage_i8.byteswap(torch.int8)
7910        self.assertEqual(storage_i8.tolist(), swapped_1byte)
7911
7912        storage_ui8 = storage.__copy__()
7913        storage_ui8.byteswap(torch.uint8)
7914        self.assertEqual(storage_ui8.tolist(), swapped_1byte)
7915
7916        storage_bool = storage.__copy__()
7917        storage_bool.byteswap(torch.bool)
7918        self.assertEqual(storage_bool.tolist(), swapped_1byte)
7919
7920        storage_c128 = storage.__copy__()
7921        storage_c128.byteswap(torch.complex128)
7922        self.assertEqual(storage_c128.tolist(), swapped_8bytes)
7923
7924        storage_c64 = storage.__copy__()
7925        storage_c64.byteswap(torch.complex64)
7926        self.assertEqual(storage_c64.tolist(), swapped_4bytes)
7927
7928    # Test that internal versions of functions related to TypedStorage do not
7929    # produce a deprecation warning
7930    def test_typed_storage_internal_no_warning(self):
7931        s0 = torch.FloatStorage(10)
7932        s0_untyped = s0.untyped()
7933        t0 = torch.randn(10)
7934
7935        funcs = [
7936            lambda: torch.FloatStorage(_internal=True),
7937            lambda: torch.TypedStorage(
7938                dtype=torch.float,
7939                device='cpu',
7940                _internal=True),
7941            lambda: torch.TypedStorage(
7942                wrap_storage=s0_untyped,
7943                dtype=s0.dtype,
7944                _internal=True),
7945            lambda: torch.FloatStorage._dtype,
7946            lambda: s0._resize_(20),
7947            lambda: s0._size(),
7948            lambda: s0._untyped_storage,
7949            lambda: s0._is_shared(),
7950            lambda: s0._share_memory_(),
7951            lambda: s0._pickle_storage_type(),
7952            lambda: s0._setitem(slice(0, s0._size()), 1),
7953            lambda: s0._element_size(),
7954            lambda: s0._deepcopy({}),
7955            lambda: s0._data_ptr(),
7956            lambda: s0._nbytes(),
7957            lambda: t0._typed_storage(),
7958        ]
7959
7960        if torch.cuda.is_available():
7961            s1 = torch.cuda.FloatStorage(10)
7962            s1_untyped = s1.untyped()
7963            t1 = torch.randn(10, device='cuda')
7964
7965            funcs += [
7966                lambda: torch.cuda.FloatStorage(_internal=True),
7967                lambda: torch.TypedStorage(
7968                    dtype=torch.float,
7969                    device='cuda',
7970                    _internal=True),
7971                lambda: torch.TypedStorage(
7972                    wrap_storage=s1_untyped,
7973                    dtype=s1.dtype,
7974                    _internal=True),
7975                lambda: torch.cuda.FloatStorage._dtype,
7976                lambda: s1._resize_(20),
7977                lambda: s1._size(),
7978                lambda: s1._untyped_storage,
7979                lambda: s1._is_shared(),
7980                lambda: s1._share_memory_(),
7981                lambda: s1._pickle_storage_type(),
7982                lambda: s1._setitem(slice(0, s1._size()), 1),
7983                lambda: s1._element_size(),
7984                lambda: s1._deepcopy({}),
7985                lambda: s1._data_ptr(),
7986                lambda: s1._nbytes(),
7987                lambda: t1._typed_storage(),
7988            ]
7989
7990        # Check that each of the TypedStorage internal function calls do not
7991        # produce a deprecation warning
7992        for f in funcs:
7993            with warnings.catch_warnings():
7994                warnings.filterwarnings('error', "TypedStorage is deprecated")
7995                f()
7996
7997    # Test that public functions related to TypedStorage produce a deprecation
7998    # warning
7999    @skipIfTorchInductor("FIXME")
8000    def test_typed_storage_deprecation_warning(self):
8001        s0 = torch.FloatStorage(10)
8002        funcs = [
8003            lambda: torch.FloatStorage(),
8004            lambda: torch.FloatStorage.dtype,
8005            lambda: s0.fill_(0),
8006            lambda: s0.is_cuda,
8007            lambda: s0.untyped(),
8008            lambda: len(s0),
8009            lambda: s0[0],
8010        ]
8011
8012        if torch.cuda.is_available():
8013            s1 = torch.cuda.FloatStorage(10)
8014            funcs += [
8015                lambda: torch.cuda.FloatStorage(),
8016                lambda: torch.cuda.FloatStorage.dtype,
8017                lambda: s1.fill_(0),
8018                lambda: s1.is_cuda,
8019                lambda: s1.untyped(),
8020                lambda: len(s1),
8021                lambda: s1[0],
8022            ]
8023
8024        # Check that each of the TypedStorage function calls produce a warning
8025        # if warnings are reset between each
8026        for f in funcs:
8027            with AlwaysWarnTypedStorageRemoval(True):
8028                with warnings.catch_warnings(record=True) as w:
8029                    warnings.resetwarnings()
8030                    f()
8031                    self.assertEqual(len(w), 1, msg=str([str(a) for a in w]))
8032                    warning = w[0].message
8033                    self.assertTrue(warning, DeprecationWarning)
8034                    self.assertTrue(re.search(
8035                        '^TypedStorage is deprecated',
8036                        str(warning)))
8037
8038        # Test that only the first warning is raised by default
8039        torch.storage._reset_warn_typed_storage_removal()
8040        with warnings.catch_warnings(record=True) as w:
8041            warnings.resetwarnings()
8042            torch.FloatStorage()
8043            torch.randn(10).storage()
8044            self.assertEqual(len(w), 1, msg=str([str(a) for a in w]))
8045            warning = w[0].message
8046            self.assertTrue(re.search(
8047                '^TypedStorage is deprecated',
8048                str(warning)))
8049            # Check the line of code from the warning's stack
8050            with open(w[0].filename, encoding="utf-8") as f:
8051                code_line = f.readlines()[w[0].lineno - 1]
8052            self.assertTrue(re.search(re.escape('torch.FloatStorage()'), code_line))
8053
8054        # Check that warnings are not emitted if it happened in the past
8055        with warnings.catch_warnings(record=True) as w:
8056            warnings.resetwarnings()
8057            torch.FloatStorage()
8058            torch.randn(10).storage()
8059            self.assertEqual(len(w), 0, msg=str([str(a) for a in w]))
8060
8061    def test_from_file(self):
8062        def assert_with_filename(filename):
8063            size = 10000
8064            s1 = torch.FloatStorage.from_file(filename, True, size)
8065            t1 = torch.FloatTensor(s1).copy_(torch.randn(size))
8066            self.assertEqual(s1.data_ptr(), torch.FloatTensor(s1).data_ptr())
8067
8068            # check mapping
8069            s2 = torch.FloatStorage.from_file(filename, True, size)
8070            t2 = torch.FloatTensor(s2)
8071            self.assertEqual(t1, t2, atol=0, rtol=0)
8072
8073            # check changes to t1 from t2
8074            rnum = random.uniform(-1, 1)
8075            t1.fill_(rnum)
8076            self.assertEqual(t1, t2, atol=0, rtol=0)
8077
8078            # check changes to t2 from t1
8079            rnum = random.uniform(-1, 1)
8080            t2.fill_(rnum)
8081            self.assertEqual(t1, t2, atol=0, rtol=0)
8082
8083            # release the tensors
8084            del s1, t1, s2, t2
8085
8086        with TemporaryFileName() as fname:
8087            assert_with_filename(fname)
8088
8089        if IS_FILESYSTEM_UTF8_ENCODING:
8090            with TemporaryDirectoryName(suffix='\u4e2d\u6587') as dname, TemporaryFileName(dir=dname) as fname:
8091                assert_with_filename(fname)
8092
8093    def test_torch_from_file(self):
8094        def assert_with_filename(filename):
8095            size = 10000
8096            s1 = torch.from_file(filename, True, size, dtype=torch.float)
8097            t1 = torch.FloatTensor(s1).copy_(torch.randn(size))
8098
8099            # check mapping
8100            s2 = torch.from_file(filename, True, size, dtype=torch.float)
8101            t2 = torch.FloatTensor(s2)
8102            self.assertEqual(t1, t2, atol=0, rtol=0)
8103
8104            # check changes to t1 from t2
8105            rnum = random.uniform(-1, 1)
8106            t1.fill_(rnum)
8107            self.assertEqual(t1, t2, atol=0, rtol=0)
8108
8109            # check changes to t2 from t1
8110            rnum = random.uniform(-1, 1)
8111            t2.fill_(rnum)
8112            self.assertEqual(t1, t2, atol=0, rtol=0)
8113
8114            # release the tensors
8115            del s1, t1, s2, t2
8116
8117        with TemporaryFileName() as fname:
8118            assert_with_filename(fname)
8119
8120        if IS_FILESYSTEM_UTF8_ENCODING:
8121            with TemporaryDirectoryName(suffix='\u4e2d\u6587') as dname, TemporaryFileName(dir=dname) as fname:
8122                assert_with_filename(fname)
8123
8124    def test_print(self):
8125        default_type = torch.tensor([]).type()
8126        for t in torch._tensor_classes:
8127            if t == torch.HalfTensor:
8128                continue  # HalfTensor does not support fill
8129            if t.is_sparse:
8130                continue
8131            if t.is_cuda and not torch.cuda.is_available():
8132                continue
8133            obj = t(100, 100).fill_(1)
8134            obj.__repr__()
8135            str(obj)
8136        # test half tensor
8137        obj = torch.rand(100, 100, device='cpu').half()
8138        obj.__repr__()
8139        str(obj)
8140        for t in torch._storage_classes:
8141            if t == torch.BFloat16Storage:
8142                continue  # Fix once fill is enabled for bfloat16
8143            if t.is_cuda and not torch.cuda.is_available():
8144                continue
8145            if t == torch.BoolStorage or t == torch.cuda.BoolStorage:
8146                obj = t(100).fill_(True)
8147            else:
8148                obj = t(100).fill_(1)
8149            obj.__repr__()
8150            str(obj)
8151
8152        # test complex tensor
8153        # complex tensor print uses two formatters, one for real values
8154        # and the other for imag values. this is consistent with numpy
8155        x = torch.tensor([2.3 + 4j, 7 + 6j])
8156        self.assertEqual(x.__repr__(), str(x))
8157        self.assertExpectedInline(str(x), '''tensor([2.3000+4.j, 7.0000+6.j])''')
8158
8159        # test complex half tensor
8160        x = torch.tensor([1.25 + 4j, -7. + 6j], dtype=torch.chalf)
8161        self.assertEqual(x.__repr__(), str(x))
8162        self.assertExpectedInline(str(x), '''tensor([ 1.2500+4.j, -7.0000+6.j], dtype=torch.complex32)''')
8163
8164        # test scientific notation for complex tensors
8165        x = torch.tensor([1e28 + 2j , -1e-28j])
8166        self.assertEqual(x.__repr__(), str(x))
8167        self.assertExpectedInline(str(x), '''tensor([1.0000e+28+2.0000e+00j, -0.0000e+00-1.0000e-28j])''')
8168
8169        # test big integer
8170        x = torch.tensor(2341234123412341)
8171        self.assertEqual(x.__repr__(), str(x))
8172        self.assertExpectedInline(str(x), '''tensor(2341234123412341)''')
8173
8174        # test scientific notation
8175        x = torch.tensor([1e28, 1e-28])
8176        self.assertEqual(x.__repr__(), str(x))
8177        self.assertExpectedInline(str(x), '''tensor([1.0000e+28, 1.0000e-28])''')
8178
8179        # test scientific notation using set_printoptions
8180        x = torch.tensor([1e2, 1e-2])
8181        torch.set_printoptions(sci_mode=True)
8182        self.assertEqual(x.__repr__(), str(x))
8183        self.assertExpectedInline(str(x), '''tensor([1.0000e+02, 1.0000e-02])''')
8184        torch.set_printoptions(sci_mode=False)
8185        self.assertEqual(x.__repr__(), str(x))
8186        self.assertExpectedInline(str(x), '''tensor([  100.0000,     0.0100])''')
8187        torch.set_printoptions(sci_mode=None)  # reset to the default value
8188
8189        # test no leading space if all elements positive
8190        x = torch.tensor([1, 2])
8191        self.assertEqual(x.__repr__(), str(x))
8192        self.assertExpectedInline(str(x), '''tensor([1, 2])''')
8193
8194        # test for leading space if there are negative elements
8195        x = torch.tensor([1, -2])
8196        self.assertEqual(x.__repr__(), str(x))
8197        self.assertExpectedInline(str(x), '''tensor([ 1, -2])''')
8198
8199        # test inf and nan
8200        x = torch.tensor([4, inf, 1.5, -inf, 0, nan, 1])
8201        self.assertEqual(x.__repr__(), str(x))
8202        self.assertExpectedInline(str(x), '''tensor([4.0000,    inf, 1.5000,   -inf, 0.0000,    nan, 1.0000])''')
8203
8204        y = torch.tensor([4, inf, complex(1.5, inf), complex(-inf, 4), 0, complex(nan, inf), complex(3, nan)])
8205        self.assertEqual(y.__repr__(), str(y))
8206        expected_str = '''\
8207tensor([4.0000+0.j,    inf+0.j, 1.5000+infj,   -inf+4.j, 0.0000+0.j,    nan+infj,
8208        3.0000+nanj])'''
8209        self.assertExpectedInline(str(y), expected_str)
8210
8211        # test dtype
8212        with set_default_dtype(torch.float):
8213            x = torch.tensor([1e-324, 1e-323, 1e-322, 1e307, 1e308, 1e309], dtype=torch.float64)
8214            self.assertEqual(x.__repr__(), str(x))
8215            expected_str = '''\
8216tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
8217                inf], dtype=torch.float64)'''
8218            self.assertExpectedInline(str(x), expected_str)
8219
8220        # test changing default dtype
8221        with set_default_dtype(torch.float64):
8222            self.assertEqual(x.__repr__(), str(x))
8223            expected_str = '''\
8224tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
8225                inf])'''
8226            self.assertExpectedInline(str(x), expected_str)
8227
8228        # test summary
8229        x = torch.zeros(10000)
8230        self.assertEqual(x.__repr__(), str(x))
8231        self.assertExpectedInline(str(x), '''tensor([0., 0., 0.,  ..., 0., 0., 0.])''')
8232
8233        # test internal summary function
8234        x = torch.rand(1, 20, 5, 30)
8235        summary = torch._tensor_str.get_summarized_data(x)
8236        self.assertEqual(summary.shape, (1, 6, 5, 6))
8237        first_and_last = [0, 1, 2, -3, -2, -1]
8238        self.assertEqual(summary, x[:, first_and_last][..., first_and_last])
8239
8240        # test device
8241        if torch.cuda.is_available():
8242            x = torch.tensor([123], device='cuda:0')
8243            self.assertEqual(x.__repr__(), str(x))
8244            self.assertExpectedInline(str(x), '''tensor([123], device='cuda:0')''')
8245
8246            # test changing default to cuda
8247            torch.set_default_tensor_type(torch.cuda.FloatTensor)
8248            self.assertEqual(x.__repr__(), str(x))
8249            self.assertExpectedInline(str(x), '''tensor([123])''')
8250
8251            # test printing a tensor on a different gpu than current one.
8252            if torch.cuda.device_count() >= 2:
8253                with torch.cuda.device(1):
8254                    self.assertEqual(x.__repr__(), str(x))
8255                    self.assertExpectedInline(str(x), '''tensor([123], device='cuda:0')''')
8256
8257            # test printing cpu tensor when default device is cuda
8258            y = torch.tensor([123], device='cpu')
8259            self.assertEqual(y.__repr__(), str(y))
8260            self.assertExpectedInline(str(y), '''tensor([123], device='cpu')''')
8261        torch.set_default_tensor_type(default_type)
8262
8263
8264        # test integral floats and requires_grad
8265        x = torch.tensor([123.], requires_grad=True)
8266        self.assertEqual(x.__repr__(), str(x))
8267        self.assertExpectedInline(str(x), '''tensor([123.], requires_grad=True)''')
8268
8269        # test non-contiguous print
8270        # sliced tensor should have > PRINT_OPTS.threshold elements
8271        x = torch.ones(100, 2, 2, 10)
8272        y = x.as_strided(size=(100, 2, 10), stride=(2 * 2 * 10, 2 * 10, 1))
8273        self.assertEqual(str(y), y.__repr__())
8274        expected_str = '''\
8275tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
8276         [1., 1., 1.,  ..., 1., 1., 1.]],
8277
8278        [[1., 1., 1.,  ..., 1., 1., 1.],
8279         [1., 1., 1.,  ..., 1., 1., 1.]],
8280
8281        [[1., 1., 1.,  ..., 1., 1., 1.],
8282         [1., 1., 1.,  ..., 1., 1., 1.]],
8283
8284        ...,
8285
8286        [[1., 1., 1.,  ..., 1., 1., 1.],
8287         [1., 1., 1.,  ..., 1., 1., 1.]],
8288
8289        [[1., 1., 1.,  ..., 1., 1., 1.],
8290         [1., 1., 1.,  ..., 1., 1., 1.]],
8291
8292        [[1., 1., 1.,  ..., 1., 1., 1.],
8293         [1., 1., 1.,  ..., 1., 1., 1.]]])\
8294'''
8295
8296        self.assertExpectedInline(str(y), expected_str)
8297
8298        x = torch.ones(100, 2, 2, 10) * (1 + 1j)
8299        y = x.as_strided(size=(100, 2, 10), stride=(2 * 2 * 10, 2 * 10, 1))
8300        self.assertEqual(str(y), y.__repr__())
8301        expected_str = '''\
8302tensor([[[1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j],
8303         [1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j]],
8304
8305        [[1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j],
8306         [1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j]],
8307
8308        [[1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j],
8309         [1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j]],
8310
8311        ...,
8312
8313        [[1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j],
8314         [1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j]],
8315
8316        [[1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j],
8317         [1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j]],
8318
8319        [[1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j],
8320         [1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j]]])\
8321'''
8322        self.assertExpectedInline(str(y), expected_str)
8323
8324        # test print 0-dim tensor: there's no 0-dim in Numpy, we match arrayprint style
8325        x = torch.tensor(0.00002)
8326        self.assertEqual(x.__repr__(), str(x))
8327        self.assertExpectedInline(str(x), '''tensor(2.0000e-05)''')
8328
8329        # test print boolean tensor
8330        x = torch.tensor([True])
8331        self.assertEqual(x.__repr__(), str(x))
8332        self.assertExpectedInline(str(x), '''tensor([True])''')
8333
8334        x = torch.tensor(True)
8335        self.assertEqual(x.__repr__(), str(x))
8336        self.assertExpectedInline(str(x), '''tensor(True)''')
8337
8338        # [Numpy] test print float in sci_mode when min < 0.0001.
8339        x = torch.tensor([0.00002])
8340        self.assertEqual(x.__repr__(), str(x))
8341        self.assertExpectedInline(str(x), '''tensor([2.0000e-05])''')
8342
8343        # [Numpy] test print complex in sci_mode when real_min < 0.0001 and (or) imag_min < 0.0001.
8344        x = torch.tensor([0.00002]) * (1 + 1j)
8345        self.assertEqual(x.__repr__(), str(x))
8346        self.assertExpectedInline(str(x), '''tensor([2.0000e-05+2.0000e-05j])''')
8347
8348        # [Numpy] test print float in sci_mode when max > 1e8.
8349        # TODO: Pytorch uses fixed precision to print, while Numpy uses dragon4_scientific
8350        # to do automatic trimming and padding.
8351        x = torch.tensor([123456789.])
8352        self.assertEqual(x.__repr__(), str(x))
8353        self.assertExpectedInline(str(x), '''tensor([1.2346e+08])''')
8354
8355        # [Numpy] test print float in sci_mode when max / min > 1000.
8356        x = torch.tensor([0.01, 11])
8357        self.assertEqual(x.__repr__(), str(x))
8358        self.assertExpectedInline(str(x), '''tensor([1.0000e-02, 1.1000e+01])''')
8359
8360        # [Numpy] test print int max / min > 1000, no sci_mode
8361        x = torch.tensor([1, 1010])
8362        self.assertEqual(x.__repr__(), str(x))
8363        self.assertExpectedInline(str(x), '''tensor([   1, 1010])''')
8364
8365        # [Numpy] test print int > 1e8, no sci_mode
8366        x = torch.tensor([1000000000])  # 1e9
8367        self.assertEqual(x.__repr__(), str(x))
8368        self.assertExpectedInline(str(x), '''tensor([1000000000])''')
8369
8370        # [Numpy] test printing float in int_mode
8371        x = torch.tensor([1., 1000.])
8372        self.assertEqual(x.__repr__(), str(x))
8373        self.assertExpectedInline(str(x), '''tensor([   1., 1000.])''')
8374
8375        # [Numpy] test printing float in int_mode in sci format when max / min > 1000.
8376        x = torch.tensor([1., 1010.])
8377        self.assertEqual(x.__repr__(), str(x))
8378        self.assertExpectedInline(str(x), '''tensor([1.0000e+00, 1.0100e+03])''')
8379
8380    def test_sizeof(self) -> None:
8381        sizeof_empty = torch.randn(0).storage().__sizeof__()
8382        sizeof_10 = torch.randn(10).storage().__sizeof__()
8383        sizeof_100 = torch.randn(100).storage().__sizeof__()
8384        self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10)
8385        self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0)
8386
8387        sizeof_empty = torch.randn(0).to(torch.uint8).storage().__sizeof__()
8388        sizeof_10 = torch.randn(10).to(torch.uint8).storage().__sizeof__()
8389        sizeof_100 = torch.randn(100).to(torch.uint8).storage().__sizeof__()
8390        self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10)
8391        self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0)
8392
8393    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
8394    def test_resizable(self) -> None:
8395        x = torch.randn(5)
8396        self.assertTrue(x.storage().resizable())
8397        x.numpy()
8398        self.assertFalse(x.storage().resizable())
8399
8400    def test_iter(self) -> None:
8401        x = torch.randn(5, 5)
8402        for i, sub in enumerate(x):
8403            self.assertEqual(sub, x[i])  # noqa: PLR1736
8404
8405        x = torch.tensor([])
8406        self.assertEqual(list(x), [])
8407
8408    def test_new(self) -> None:
8409        x = torch.autograd.Variable(torch.tensor([]))
8410        y = torch.autograd.Variable(torch.randn(4, 4))
8411        z = torch.autograd.Variable(torch.IntTensor([1, 2, 3]))
8412        self.assertEqual(x.new().shape, [0])
8413        self.assertEqual(x.new(), x)
8414        self.assertEqual(x.new(1, 2).shape, [1, 2])
8415        self.assertEqual(x.new(torch.Size([3, 4])).shape, [3, 4])
8416        self.assertEqual(x.new([3, 4]).shape, [2])
8417        self.assertEqual(x.new([3, 4]).tolist(), [3, 4])
8418        self.assertEqual(x.new((3, 4)).tolist(), [3, 4])
8419        self.assertEqual(x.new([np.int32(3), np.float64(4)]).tolist(), [3, 4])
8420        self.assertEqual(x.new(np.array((3, 4))).tolist(), [3, 4])
8421        self.assertEqual(x.new([z[2], z[0] + 3]).tolist(), [3, 4])
8422        self.assertEqual(x.new(size=(3, 4)).shape, [3, 4])
8423        self.assertEqual(x.new(()).shape, [0])
8424        self.assertEqual(x.new(y.storage()).data_ptr(), y.data_ptr())
8425        self.assertEqual(x.new(y).data_ptr(), y.data_ptr())
8426        self.assertIsNot(x.new(y), y)
8427
8428        self.assertRaises(TypeError, lambda: x.new(z))
8429        # TypeError would be better
8430        self.assertRaises(RuntimeError, lambda: x.new(z.storage()))
8431
8432    @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
8433    def test_pin_memory(self):
8434        x = torch.randn(3, 5)
8435        self.assertFalse(x.is_pinned())
8436        if torch.cuda.is_available():
8437            pinned = x.pin_memory()
8438            self.assertTrue(pinned.is_pinned())
8439            self.assertEqual(pinned, x)
8440            self.assertNotEqual(pinned.data_ptr(), x.data_ptr())
8441            # test that pin_memory on already pinned tensor has no effect
8442            self.assertIs(pinned, pinned.pin_memory())
8443            self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
8444
8445    def test_error_msg_type_translation(self):
8446        with self.assertRaisesRegex(
8447                RuntimeError,
8448                # message includes both Double and Long
8449                '(?=.*Double)(?=.*Long)'):
8450
8451            # Calls model with a LongTensor input but DoubleTensor weights
8452            input = torch.zeros(1, 1, 1, 6, dtype=torch.long)
8453            weight = torch.nn.Parameter(torch.zeros(1, 1, 1, 3, dtype=torch.double))
8454            model = torch.nn.Conv2d(1, 1, (1, 3), stride=1, padding=0, bias=False)
8455            model.weight = weight
8456            out = model(input)
8457
8458    def test_apply(self):
8459        x = torch.arange(1, 6)
8460        res = x.clone().apply_(lambda k: k + k)
8461        self.assertEqual(res, x * 2)
8462        self.assertRaises(TypeError, lambda: x.apply_(lambda k: "str"))
8463
8464    def test_map(self):
8465        x = torch.autograd.Variable(torch.randn(3, 3))
8466        y = torch.autograd.Variable(torch.randn(3))
8467        res = x.clone()
8468        res.map_(y, lambda a, b: a + b)
8469        self.assertEqual(res, x + y)
8470        self.assertRaisesRegex(TypeError, "not callable", lambda: res.map_(y, "str"))
8471
8472    def test_map2(self):
8473        x = torch.autograd.Variable(torch.randn(3, 3))
8474        y = torch.autograd.Variable(torch.randn(3))
8475        z = torch.autograd.Variable(torch.randn(1, 3))
8476        res = x.clone()
8477        res.map2_(y, z, lambda a, b, c: a + b * c)
8478        self.assertEqual(res, x + y * z)
8479        z.requires_grad = True
8480        self.assertRaisesRegex(
8481            RuntimeError, "requires grad",
8482            lambda: res.map2_(y, z, lambda a, b, c: a + b * c))
8483
8484    def test_Size(self):
8485        x = torch.Size([1, 2, 3])
8486        self.assertIsInstance(x, tuple)
8487        self.assertEqual(x[0], 1)
8488        self.assertEqual(x[1], 2)
8489        self.assertEqual(x[2], 3)
8490        self.assertEqual(len(x), 3)
8491        self.assertRaises(TypeError, lambda: torch.Size(torch.ones(3)))
8492
8493        self.assertIsInstance(x * 2, torch.Size)
8494        self.assertIsInstance(x[:-1], torch.Size)
8495        self.assertIsInstance(x + x, torch.Size)
8496
8497    def test_Size_scalar(self):
8498        three = torch.tensor(3)
8499        two = torch.tensor(2)
8500        x = torch.Size([0, 1, two, three, 4])
8501        for i in range(1, 5):
8502            self.assertEqual(x[i], i)
8503
8504    def test_Size_iter(self):
8505        for sizes in [iter([1, 2, 3, 4, 5]), range(1, 6)]:
8506            x = torch.Size(sizes)
8507            for i in range(0, 5):
8508                self.assertEqual(x[i], i + 1)
8509
8510    def test_t_not_2d_error(self):
8511        self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t())
8512        self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t_())
8513
8514    # skip this test for now as it affects all tests
8515    @unittest.skipIf(True, "flush_denormal not supported")
8516    def test_set_flush_denormal(self):
8517        tiny_float = 1e-42
8518        tiny_double = 1e-320
8519        float_tensor = torch.FloatTensor([1.0, tiny_float])
8520        double_tensor = torch.DoubleTensor([1.0, tiny_float, tiny_double])
8521
8522        self.assertEqual(float_tensor[0], 1.0, atol=0.0, rtol=0)
8523        self.assertEqual(float_tensor[1], tiny_float, atol=tiny_float / 16, rtol=0)
8524        self.assertEqual(double_tensor[0], 1.0, atol=0.0, rtol=0)
8525        self.assertEqual(double_tensor[1], tiny_float, atol=0.0, rtol=0)
8526        self.assertEqual(double_tensor[2], tiny_double, atol=0.0, rtol=0)
8527
8528        torch.set_flush_denormal(True)
8529        self.assertEqual(float_tensor[0], 1.0, atol=0.0, rtol=0)
8530        self.assertEqual(float_tensor[1], 0.0, atol=0.0, rtol=0)  # tiny_float to zero
8531        self.assertEqual(double_tensor[0], 1.0, atol=0.0, rtol=0)
8532        # tiny_float is not converted to zero in double type
8533        self.assertEqual(double_tensor[1], tiny_float, atol=0.0, rtol=0)
8534        self.assertEqual(double_tensor[2], 0.0, atol=0.0, rtol=0)  # tiny_double to zero
8535        torch.set_flush_denormal(False)
8536
8537    def test_show_config(self):
8538        # We can't usefully test the output; just make sure this doesn't crash
8539        torch.__config__.show()
8540
8541    @unittest.skipIf(IS_FBCODE, "CXX_FLAGS is only for OSS build.")
8542    def test_cxx_flags(self):
8543        torch.__config__._cxx_flags()
8544
8545    def test_parallel_info(self):
8546        torch.__config__.parallel_info()
8547
8548    def test_get_cpu_capability(self):
8549        # This method is primarily exposed for torchvision's resize
8550        torch.backends.cpu.get_cpu_capability()
8551
8552        # We have to ensure that method is torchscriptable as torchvision's resize
8553        # should be torchscriptable
8554        torch.jit.script(torch.backends.cpu.get_cpu_capability)
8555
8556    @slowTest
8557    def test_slow_test(self):
8558        # Just a smoketest to make sure our slowTest decorator works.
8559        pass
8560
8561    def test_is_nonzero(self):
8562        with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"):
8563            torch.tensor([]).is_nonzero()
8564        with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with more than one value is ambiguous"):
8565            torch.tensor([0, 0]).is_nonzero()
8566        self.assertFalse(torch.tensor(0).is_nonzero())
8567        self.assertTrue(torch.tensor(1).is_nonzero())
8568        self.assertFalse(torch.tensor([0]).is_nonzero())
8569        self.assertTrue(torch.tensor([1]).is_nonzero())
8570        self.assertFalse(torch.tensor([[0]]).is_nonzero())
8571        self.assertTrue(torch.tensor([[1]]).is_nonzero())
8572        self.assertTrue(torch.tensor(0.1).is_nonzero())
8573        self.assertTrue(torch.tensor(-0.1).is_nonzero())
8574        self.assertFalse(torch.tensor(0.0).is_nonzero())
8575        self.assertTrue(torch.tensor(True).is_nonzero())
8576        self.assertFalse(torch.tensor(False).is_nonzero())
8577        self.assertFalse(torch.tensor(0 + 0j).is_nonzero())
8578        self.assertTrue(torch.tensor(0 + 0.1j).is_nonzero())
8579
8580    def test_assert_async(self):
8581        with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"):
8582            torch._assert_async(torch.tensor([]))
8583        with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with more than one value is ambiguous"):
8584            torch._assert_async(torch.tensor([0, 0]))
8585        with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
8586            torch._assert_async(torch.tensor(0))
8587        torch._assert_async(torch.tensor(1))
8588        torch._assert_async(torch.tensor(0.1))
8589        torch._assert_async(torch.tensor(-0.1))
8590        with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
8591            torch._assert_async(torch.tensor(0.0))
8592        torch._assert_async(torch.tensor(True))
8593        with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
8594            torch._assert_async(torch.tensor(False))
8595        torch._assert_async(torch.tensor(0 + 0.1j))
8596        with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
8597            torch._assert_async(torch.tensor(0 + 0j))
8598
8599    # NB: we must not be built with CUDA; if we are built with CUDA but no CUDA
8600    # is available, we get a different error.
8601    @unittest.skipIf(torch.backends.cuda.is_built() or IS_SANDCASTLE, "CUDA is built, can't test CUDA not built error")
8602    def test_cuda_not_built(self):
8603        msg = "Torch not compiled with CUDA enabled"
8604        self.assertRaisesRegex(AssertionError, msg, lambda: torch.cuda.current_device())
8605        self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1], device="cuda"))
8606        self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1]).cuda())
8607        self.assertRaisesRegex(TypeError, msg, lambda: torch.cuda.FloatTensor())
8608        self.assertRaisesRegex(TypeError, msg, lambda: torch.set_default_tensor_type(torch.cuda.FloatTensor))
8609        self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1]).to(device="cuda"))
8610
8611    def test_has_internal_overlap(self):
8612        OVERLAP_NO = 0
8613        OVERLAP_YES = 1
8614        OVERLAP_TOO_HARD = 2
8615
8616        # Check for contiguous tensors
8617        a = torch.randn(3, 3)
8618        self.assertEqual(torch._debug_has_internal_overlap(a), OVERLAP_NO)
8619
8620        # Checks for zero strides
8621        b = torch.randn(1, 3)
8622        b_expanded = b.expand(4, 3)
8623        self.assertEqual(torch._debug_has_internal_overlap(b_expanded), OVERLAP_YES)
8624
8625        # Check for zero strided, size 1 axis, in non-contiguous storage (gh-33812)
8626        c = torch.randn(10).as_strided([2, 1, 5], [1, 0, 2])
8627        self.assertEqual(torch._debug_has_internal_overlap(c), OVERLAP_NO)
8628        c = torch.randn(2, 1, 10)[::2].as_strided((2, 1, 5), (10, 0, 2))
8629        self.assertEqual(torch._debug_has_internal_overlap(c), OVERLAP_TOO_HARD)
8630
8631    def test_allow_tensor_metadata_change(self):
8632        a = torch.ones(2, 3)
8633        # Metadata changes are allowed on view tensors that are created from detach().
8634
8635    def test_memory_format(self):
8636        def test_helper(x, memory_format):
8637            y = x.contiguous(memory_format=memory_format)
8638            self.assertFalse(y.is_contiguous())
8639            self.assertTrue(y.is_contiguous(memory_format=memory_format))
8640            self.assertEqual(y, x)
8641
8642        test_helper(torch.randn(4, 3, 8, 8), torch.channels_last)
8643        test_helper(torch.randn(4, 3, 8, 8, 8), torch.channels_last_3d)
8644
8645    def test_memory_format_contiguous_returns_same_tensor_if_already_satisfies(self):
8646        def test_helper(x, memory_format):
8647            alias = x.contiguous(memory_format=memory_format)
8648            alias.fill_(7)
8649            self.assertEqual(x, alias)
8650
8651        test_helper(torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2), torch.channels_last)
8652        test_helper(torch.randn(4, 8, 8, 8, 3).permute(0, 4, 1, 2, 3), torch.channels_last_3d)
8653
8654    def test_memory_format_empty(self):
8655        def test_helper(dim1, dim2, memory_format):
8656            with self.assertRaises(RuntimeError):
8657                x = torch.empty(dim1, memory_format=memory_format)
8658            x = torch.empty(dim2, memory_format=memory_format)
8659            self.assertTrue(x.is_contiguous(memory_format=memory_format))
8660
8661        test_helper((3, 3), (3, 3, 3, 3), torch.channels_last)
8662        test_helper((3, 3, 3), (3, 3, 3, 3, 3), torch.channels_last_3d)
8663
8664    def test_dim_order(self):
8665        shape = (2, 3, 5, 7)
8666
8667        t = torch.empty(shape)
8668        self.assertSequenceEqual(t.dim_order(), (0, 1, 2, 3), seq_type=tuple)
8669        # transpose doesn't really change the underlying physical memory
8670        # so expecting dim_order change to reflect that (like strides)
8671        self.assertSequenceEqual(t.transpose(0, 1).dim_order(), (1, 0, 2, 3))
8672
8673        t = torch.empty(shape, memory_format=torch.channels_last)
8674        self.assertSequenceEqual(t.dim_order(), (0, 2, 3, 1))
8675
8676        t = torch.empty((2, 3, 5, 7, 8), memory_format=torch.channels_last_3d)
8677        self.assertSequenceEqual(t.dim_order(), (0, 2, 3, 4, 1))
8678
8679        for dim_order in itertools.permutations(range(4)):
8680            self.assertSequenceEqual(
8681                dim_order, torch.empty_permuted(shape, dim_order).dim_order()
8682            )
8683
8684        for shape in [(2, 2, 2, 2), (2, 1, 2, 2), (2, 2, 1, 2), (2, 2, 2, 1), (2, 2, 1, 1), (2, 1, 1, 2)]:
8685            for memory_format in (torch.contiguous_format, torch.channels_last):
8686                t = torch.empty(shape).to(memory_format=memory_format)
8687                if memory_format == torch.contiguous_format:
8688                    dim_order_target = list(range(len(shape)))
8689                elif memory_format == torch.channels_last:
8690                    dim_order_target = [0, *list(range(2, len(shape))), 1]
8691
8692                self.assertSequenceEqual(dim_order_target, t.dim_order())
8693
8694    def test_subclass_tensors(self):
8695        # raise an error when trying to subclass FloatTensor
8696        with self.assertRaisesRegex(TypeError, "type 'torch.FloatTensor' is not an acceptable base type"):
8697            class Foo1(torch.FloatTensor):
8698                pass
8699
8700        # but allow subclassing Tensor:
8701        class Foo2(torch.Tensor):
8702            def foo(self):
8703                return 5
8704        f = Foo2()
8705        self.assertEqual(f.foo(), 5)
8706
8707    def test_ndim(self):
8708        a = torch.randn(1, 2, 3)
8709        self.assertEqual(3, a.ndim)
8710        b = torch.randn(())
8711        self.assertEqual(0, b.ndim)
8712        c = torch.randn(1, 0)
8713        self.assertEqual(2, c.ndim)
8714
8715    def test_nbytes(self):
8716        a = torch.randn(1, 2, 3, dtype=torch.float64)
8717        self.assertEqual(a.numel() * a.element_size(), a.nbytes)
8718        b = torch.randn(())
8719        self.assertEqual(b.numel() * b.element_size(), b.nbytes)
8720        c = torch.randn(1, 0)
8721        self.assertEqual(c.numel() * c.element_size(), c.nbytes)
8722
8723    def test_fill_diagonal(self):
8724        a1 = torch.randn(7, 3)
8725        a2 = a1.clone()
8726        v = 1
8727        for i in range(3):
8728            a2[i][i] = v
8729        a1.fill_diagonal_(v)
8730        self.assertEqual(a1, a2)
8731
8732        b1 = torch.randn(7, 3)
8733        b2 = b1.clone()
8734        for i in range(3):
8735            b2[i][i] = v
8736            b2[i + 4][i] = v
8737        b1.fill_diagonal_(v, wrap=True)
8738        self.assertEqual(b1, b2)
8739
8740        c1 = torch.rand(3, 3, 3)
8741        c2 = c1.clone()
8742        for i in range(3):
8743            c2[i][i][i] = v
8744        c1.fill_diagonal_(v)
8745        self.assertEqual(c1, c2)
8746
8747        # non-contiguous tensor
8748        d1 = torch.rand(3, 3, 3)[:, 1, ...]
8749        d2 = d1.clone()
8750        for i in range(3):
8751            d2[i][i] = v
8752        d1.fill_diagonal_(v)
8753        self.assertEqual(d1, d2)
8754
8755        e1 = torch.rand(7, 3, 3)[:, 1, ...]
8756        e2 = e1.clone()
8757        for i in range(3):
8758            e2[i][i] = v
8759            e2[i + 4][i] = v
8760        e1.fill_diagonal_(v, wrap=True)
8761        self.assertEqual(e1, e2)
8762
8763    def test_setting_real_imag_to_a_number(self):
8764        x = torch.randn(4, dtype=torch.cfloat)
8765        x.real = 0
8766        x.imag = 0
8767        zeros = torch.zeros(4)
8768        self.assertEqual(x.real, zeros)
8769        self.assertEqual(x.imag, zeros)
8770
8771    def test_batch_norm_cpu_inference(self):
8772        # input nchw in (2,1,1,1), (2,2,2,2)
8773        inputs = [
8774            torch.tensor([[[[-0.5000]]], [[[0.5000]]]]),
8775            torch.tensor([
8776                [
8777                    [[-0.5000, 0.5000], [-1.0000, 1.0000]],
8778                    [[-0.2500, -0.5000], [0.2500, 0.5000]]
8779                ],
8780                [
8781                    [[0.1000, 1.0000], [1.0000, 0.1000]],
8782                    [[1.0000, 0.5000], [1.5000, -1.5000]]
8783                ]])]
8784        # output nchw in (2,1,1,1), (2,2,2,2)
8785        outputs = [
8786            torch.tensor([
8787                [[[-0.499997496604919433593750000]]],
8788                [[[0.499997496604919433593750000]]]]),
8789            torch.tensor([
8790                [[[-0.499997496604919433593750000, 0.499997496604919433593750000],
8791                  [-0.999994993209838867187500000, 0.999994993209838867187500000]],
8792                 [[-0.249998748302459716796875000, -0.499997496604919433593750000],
8793                  [0.249998748302459716796875000, 0.499997496604919433593750000]]],
8794                [[[0.099999502301216125488281250, 0.999994993209838867187500000],
8795                  [0.999994993209838867187500000, 0.099999502301216125488281250]],
8796                 [[0.999994993209838867187500000, 0.499997496604919433593750000],
8797                  [1.499992489814758300781250000, -1.499992489814758300781250000]]]])]
8798
8799
8800        for i in range(len(inputs)):
8801            for affine in [False, True]:
8802                m = torch.nn.BatchNorm2d(inputs[i].size()[1], 1e-05, 0.1, affine=affine)
8803                m.eval()
8804                # contiguous case
8805                input1 = inputs[i].contiguous()
8806                output1 = m(input1)
8807                # non-contiguous case
8808                input2 = input1.permute(0, 1, 3, 2)
8809                output2 = m(input2).permute(0, 1, 3, 2)
8810                # channels last case
8811                input3 = input1.contiguous(memory_format=torch.channels_last)
8812                output3 = m(input3)
8813                self.assertEqual(output3, outputs[i])
8814                self.assertEqual(output3, output1)
8815                self.assertEqual(output3, output2)
8816
8817    # FIXME: move these meta tests to their own test suite/class or
8818    #   distribute them among the appropriate test suites for their ops
8819    @skipIfTorchDynamo("Fails after Triton update, see https://github.com/pytorch/pytorch/issues/94687")
8820    def test_empty_meta(self):
8821        x = torch.empty(2 ** 20, 2 ** 20, device='meta')
8822        y = torch.empty(2 ** 20, device='meta')
8823        z = x + y
8824        self.assertEqual(z.size(), (2 ** 20, 2 ** 20))
8825        self.assertRaises(RuntimeError, lambda: z[0][0].item())
8826
8827    @skipIfTorchDynamo("Fails after Triton update, see https://github.com/pytorch/pytorch/issues/94687")
8828    def test_format_scalar_meta(self):
8829        x = torch.empty((), device='meta')
8830        self.assertEqual(format(x), repr(x))
8831
8832    def test_upsample_nearest1d_meta(self):
8833        # TODO: this test should be triggered by test_nn.py but right
8834        # now meta is not enabled (and even if it was, we are probably
8835        # missing too many meta functions to get through the test unmolested)
8836
8837        # NB: Can't make the exponent too big, or it will overflow
8838        # signed 64-bit integer
8839        x = torch.empty(2 * 10 ** 8, 3, 2 * 10 ** 8, device='meta')
8840        z = torch.nn.functional.interpolate(x, scale_factor=2)
8841        self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
8842        self.assertRaises(RuntimeError, lambda: z[0][0][0].item())
8843
8844        # TODO: the out tests cannot be triggered by test_nn.py because
8845        # we don't actually do out= arguments for nn functions, so there
8846        # is no public API by which to get the out version
8847
8848        # interpolate doesn't seem to support out=
8849        # (not sure why passing None here doesn't work? How strange...)
8850        z = torch.empty(0, device='meta')
8851        torch._C._nn.upsample_nearest1d(x, (4 * 10 ** 8,), 2, out=z)
8852        self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
8853        self.assertRaises(RuntimeError, lambda: z[0][0][0].item())
8854
8855    def test_upsample_nearest2d_meta(self):
8856        # TODO: the out tests cannot be triggered by test_nn.py because
8857        # we don't actually do out= arguments for nn functions, so there
8858        # is no public API by which to get the out version
8859
8860        # Make sure we don't clobber strides of out tensor.  NB: this
8861        # test must be done on 2d/3d, because 1d doesn't have any meaningful
8862        # layout support
8863        x = torch.empty(4, 3, 8, 8, device='meta')
8864        out = torch.empty(4, 3, 16, 16, device='meta', memory_format=torch.channels_last)
8865        torch._C._nn.upsample_nearest2d(x, (16, 16), out=out)
8866        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
8867
8868        x = torch.empty(4, 3, 8, 8, device='meta', memory_format=torch.channels_last)
8869        out = torch.empty(4, 3, 16, 16, device='meta')
8870        torch._C._nn.upsample_nearest2d(x, (16, 16), out=out)
8871        self.assertTrue(out.is_contiguous())
8872
8873        # But if resize occurs, do clobber
8874        x = torch.empty(4, 3, 8, 8, device='meta', memory_format=torch.channels_last)
8875        out = torch.empty(0, device='meta')
8876        torch._C._nn.upsample_nearest2d(x, (16, 16), out=out)
8877        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
8878
8879        # Complain if out dtype mismatch
8880        x = torch.empty(4, 3, 8, 8, device='meta', dtype=torch.float)
8881        out = torch.empty(4, 3, 16, 16, device='meta', dtype=torch.double)
8882        self.assertExpectedRaisesInline(
8883            RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out),
8884            """Expected out tensor to have dtype torch.float32 but got torch.float64 instead"""
8885        )
8886
8887        # Complain if out device mismatch
8888        x = torch.empty(0, 3, 8, 8, device='meta')
8889        out = torch.empty(0, 3, 16, 16, device='cpu')
8890        # FIXME: compiling should properly error with a device mismatch.
8891        if not TEST_WITH_TORCHINDUCTOR:
8892            self.assertExpectedRaisesInline(
8893                RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out),
8894                """Attempting to copy from device meta to device cpu, but cross-device copies are not allowed!"""
8895            )
8896
8897    def test_add_meta_scalar(self):
8898        # From https://github.com/pytorch/pytorch/issues/53815
8899        x = torch.empty(2, device='meta')
8900        y = x + 2
8901        self.assertEqual(y.size(), x.size())
8902
8903    def test_normal_shape(self):
8904        warned = False
8905        for device in get_all_device_types():
8906            tensor1 = torch.rand(1, device=device)
8907            tensor4 = torch.rand(4, device=device)
8908            tensor120 = torch.rand(120, device=device)
8909            tensor2145 = torch.rand(2, 1, 4, 5, device=device)
8910            tensor2345 = torch.rand(2, 3, 4, 5, device=device)
8911            tensor2345_non_contiguous = torch.rand(2, 4, 3, 5, device=device).permute(0, 2, 1, 3)
8912            tensor2345_channels_last = tensor2345.contiguous(memory_format=torch.channels_last)
8913            output2345 = torch.zeros(2, 3, 4, 5, device=device)
8914            output345 = torch.zeros(3, 4, 5, device=device)
8915
8916            # inputs have same size
8917            self.assertEqual(torch.normal(tensor2345, tensor2345).size(), (2, 3, 4, 5))
8918            self.assertEqual(torch.normal(tensor2345_non_contiguous, tensor2345).size(), (2, 3, 4, 5))
8919            self.assertEqual(torch.normal(tensor2345, tensor2345_channels_last).size(), (2, 3, 4, 5))
8920            self.assertEqual(torch.normal(tensor2345_non_contiguous, tensor2345_channels_last).size(), (2, 3, 4, 5))
8921
8922            # scalar case
8923            self.assertEqual(torch.normal(tensor2345, 2).size(), (2, 3, 4, 5))
8924            self.assertEqual(torch.normal(2, tensor2345).size(), (2, 3, 4, 5))
8925
8926            # inputs are expandable tensors
8927            self.assertEqual(torch.normal(tensor2345, tensor1).size(), (2, 3, 4, 5))
8928            self.assertEqual(torch.normal(tensor2145, tensor2345).size(), (2, 3, 4, 5))
8929
8930            # inputs are non-expandable tensors, but they have same number of elements
8931            with self.assertRaisesRegex(
8932                    RuntimeError,
8933                    r"The size of tensor a \(120\) must match the size of "
8934                    r"tensor b \(5\) at non-singleton dimension 3"):
8935                self.assertEqual(torch.normal(tensor120, tensor2345).size(), (120,))
8936            with self.assertRaisesRegex(
8937                    RuntimeError,
8938                    r"The size of tensor a \(5\) must match the size of "
8939                    r"tensor b \(120\) at non-singleton dimension 3"):
8940                self.assertEqual(torch.normal(tensor2345, tensor120).size(), (2, 3, 4, 5))
8941
8942            # inputs are non-expandable tensors and they don't have same number of elements
8943            with self.assertRaisesRegex(
8944                    RuntimeError,
8945                    r"The size of tensor a \(5\) must match the size of "
8946                    r"tensor b \(4\) at non-singleton dimension 3"):
8947                torch.normal(tensor2345, tensor4)
8948
8949            # output and inputs are size compatible
8950            self.assertEqual(torch.normal(tensor2345, tensor2345, out=output2345).size(), (2, 3, 4, 5))
8951
8952            # output and inputs are not size compatible
8953            with self.assertWarnsRegex(
8954                    UserWarning,
8955                    "This behavior is deprecated, and in a future PyTorch "
8956                    "release outputs will not be resized unless they have "
8957                    "zero elements"):
8958                self.assertEqual(torch.normal(tensor2345, tensor2145, out=output345).size(), (2, 3, 4, 5))
8959            with self.assertRaisesRegex(
8960                    RuntimeError,
8961                    r"The size of tensor a \(5\) must match the size of "
8962                    r"tensor b \(120\) at non-singleton dimension 3"):
8963                # inputs are not expandable, output size is not the same as mean
8964                torch.normal(tensor2345, tensor120, out=output345)
8965
8966    def test_tensoriterator_output_setup(self):
8967        # Test whether the output's memory layout is correct
8968        def test_memory_layout(x, y, scale, zero_point, out):
8969            self.assertEqual(x.dim(), 4)
8970            self.assertEqual(x.size(), y.size())
8971            self.assertEqual(y.size(), out.size())
8972
8973            shape = x.size()
8974            for n in range(shape[0]):
8975                for c in range(shape[1]):
8976                    for h in range(shape[2]):
8977                        for w in range(shape[3]):
8978                            if scale is not None and zero_point is not None:
8979                                self.assertEqual(
8980                                    out[n][c][h][w],
8981                                    torch.ops.quantized.add(x[n][c][h][w], y[n][c][h][w], scale, zero_point))
8982                            else:
8983                                self.assertEqual(out[n][c][h][w], x[n][c][h][w] + y[n][c][h][w])
8984
8985        xraw = torch.rand(2, 3, 4, 4)
8986        yraw = torch.rand(2, 3, 4, 4)
8987        qxraw = torch.quantize_per_tensor(xraw, 0.1, 5, torch.quint8)
8988        qyraw = torch.quantize_per_tensor(yraw, 0.1, 5, torch.quint8)
8989
8990        # contiguous case fast setup
8991        test_memory_layout(xraw, yraw, None, None, xraw + yraw)
8992        test_memory_layout(qxraw, qyraw, 0.1, 5, torch.ops.quantized.add(qxraw, qyraw, 0.1, 5))
8993
8994        # channels last case fast setup
8995        x = xraw.contiguous(memory_format=torch.channels_last)
8996        y = yraw.contiguous(memory_format=torch.channels_last)
8997        test_memory_layout(x, y, None, None, x + y)
8998        qx = qxraw.contiguous(memory_format=torch.channels_last)
8999        qy = qyraw.contiguous(memory_format=torch.channels_last)
9000        test_memory_layout(qx, qy, 0.1, 5, torch.ops.quantized.add(qx, qy, 0.1, 5))
9001
9002        # non contiguous case fast setup (dense, non-overlapping, same shape and strides)
9003        x = xraw.permute(0, 2, 3, 1)
9004        y = yraw.permute(0, 2, 3, 1)
9005        test_memory_layout(x, y, None, None, x + y)
9006        qx = qxraw.permute(0, 2, 3, 1)
9007        qy = qyraw.permute(0, 2, 3, 1)
9008        test_memory_layout(qx, qy, 0.1, 5, torch.ops.quantized.add(qx, qy, 0.1, 5))
9009
9010        # non contiguous case fast setup (dense, non-overlapping)
9011        # input tensors have same shape and strides
9012        # output tensor have same shape as input tensors but different stride
9013        # output tensor should preserve its strides in this case
9014        x = xraw.permute(0, 2, 3, 1)
9015        y = yraw.permute(0, 2, 3, 1)
9016        out = torch.empty_like(xraw)
9017        out = out.permute(0, 3, 2, 1)
9018        expected_stride = out.stride()
9019        test_memory_layout(x, y, None, None, torch.add(x, y, out=out))
9020        self.assertEqual(expected_stride, out.stride())
9021
9022        # non contiguous case non fast setup
9023        x = xraw.permute(0, 2, 3, 1)
9024        y = yraw.permute(0, 3, 2, 1)
9025        test_memory_layout(x, y, None, None, x + y)
9026        qx = qxraw.permute(0, 2, 3, 1)
9027        qy = qyraw.permute(0, 3, 2, 1)
9028        test_memory_layout(qx, qy, 0.1, 5, torch.ops.quantized.add(qx, qy, 0.1, 5))
9029
9030    # Tests to make sure we still handle .data properly until it is removed
9031    def test_dot_data_use(self):
9032        # .data allows to change the Tensors types inplace, check that we still
9033        # raise a nice error.
9034        with self.assertRaisesRegex(
9035                RuntimeError,
9036                # message includes both Double and ComplexFloat
9037                '(?=.*Double)(?=.*ComplexFloat)'):
9038
9039            # Calls model with a LongTensor input but DoubleTensor weights
9040            input = torch.randn(1, 1, 1, 6, dtype=torch.double)
9041            weight = torch.zeros(1, 1, 1, 3, dtype=torch.complex64)
9042            model = torch.nn.Conv2d(1, 1, (1, 3), stride=1, padding=0, bias=False)
9043            model.weight.data = weight
9044            out = model(input)
9045
9046    def test_empty_storage_view(self):
9047        # we should be able to "modify" slices of a 0-element
9048        # array without an error being raised due to
9049        # trying to resize its storage
9050        t = torch.from_numpy(np.empty((0, 4)))
9051        t[:, 1::2] *= 1
9052
9053    def test_has_storage(self):
9054        self.assertIsNotNone(torch.tensor([]).storage())
9055        self.assertIsNotNone(torch.empty(0).storage())
9056        self.assertIsNotNone(torch.tensor([]).clone().storage())
9057        self.assertIsNotNone(torch.tensor([0, 0, 0]).nonzero().storage())
9058        self.assertIsNotNone(torch.tensor([]).new().storage())
9059
9060    # FIXME: Extend this test and put in a TensorProperties test class
9061    def test_numel(self):
9062        b = torch.ByteTensor(3, 100, 100)
9063        self.assertEqual(b.nelement(), 3 * 100 * 100)
9064        self.assertEqual(b.numel(), 3 * 100 * 100)
9065
9066    # Verifies that (deep)copies of dtypes are the same objects
9067    def test_copy_dtypes(self):
9068        for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
9069            copied_dtype = copy.deepcopy(dtype)
9070            self.assertIs(dtype, copied_dtype)
9071
9072    def test_dtype_is_signed(self):
9073        for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.half):
9074            self.assertEqual(dtype.is_signed, torch.is_signed(torch.tensor(0, dtype=dtype)))
9075
9076        self.assertRaisesRegex(RuntimeError, 'not supported for quantized', lambda: torch.quint8.is_signed)
9077        self.assertRaisesRegex(RuntimeError, 'not supported for quantized', lambda: torch.qint8.is_signed)
9078        self.assertRaisesRegex(RuntimeError, 'not supported for quantized', lambda: torch.qint32.is_signed)
9079
9080    # FIXME: Put the following random tests into their own test class or test suite
9081    @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098")
9082    def test_RNGState(self):
9083        state = torch.get_rng_state()
9084        stateCloned = state.clone()
9085        before = torch.rand(1000)
9086
9087        self.assertEqual(state.ne(stateCloned).long().sum(), 0, atol=0, rtol=0)
9088
9089        torch.set_rng_state(state)
9090        after = torch.rand(1000)
9091        self.assertEqual(before, after, atol=0, rtol=0)
9092
9093    @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098")
9094    def test_RNGStateAliasing(self):
9095        # Fork the random number stream at this point
9096        gen = torch.Generator()
9097        gen.set_state(torch.get_rng_state())
9098        self.assertEqual(gen.get_state(), torch.get_rng_state())
9099
9100        target_value = torch.rand(1000)
9101        # Dramatically alter the internal state of the main generator
9102        _ = torch.rand(100000)
9103        forked_value = torch.rand(1000, generator=gen)
9104        self.assertEqual(target_value, forked_value, atol=0, rtol=0, msg="RNG has not forked correctly.")
9105
9106    @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098")
9107    def test_RNG_after_pickle(self):
9108        torch.random.manual_seed(100)
9109        before = torch.rand(10)
9110
9111        torch.random.manual_seed(100)
9112        buf = io.BytesIO()
9113        tensor = torch.tensor([1, 2, 3])
9114        ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(tensor)
9115        after = torch.rand(10)
9116
9117        self.assertEqual(before, after, atol=0, rtol=0)
9118
9119    @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098")
9120    def test_boxMullerState(self):
9121        torch.manual_seed(123)
9122        odd_number = 101
9123        seeded = torch.randn(odd_number)
9124        state = torch.get_rng_state()
9125        midstream = torch.randn(odd_number)
9126        torch.set_rng_state(state)
9127        repeat_midstream = torch.randn(odd_number)
9128        torch.manual_seed(123)
9129        reseeded = torch.randn(odd_number)
9130        self.assertEqual(midstream, repeat_midstream, atol=0, rtol=0,
9131                         msg='get_rng_state/set_rng_state not generating same sequence of normally distributed numbers')
9132        self.assertEqual(seeded, reseeded, atol=0, rtol=0,
9133                         msg='repeated calls to manual_seed not generating same sequence of normally distributed numbers')
9134
9135    @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098")
9136    def test_manual_seed(self):
9137        rng_state = torch.get_rng_state()
9138        torch.manual_seed(2)
9139        x = torch.randn(100)
9140        self.assertEqual(torch.initial_seed(), 2)
9141        torch.manual_seed(2)
9142        y = torch.randn(100)
9143        self.assertEqual(x, y)
9144
9145        max_int64 = 0x7fff_ffff_ffff_ffff
9146        min_int64 = -max_int64 - 1
9147        max_uint64 = 0xffff_ffff_ffff_ffff
9148        # Check all boundary cases of valid seed value inputs
9149        test_cases = [
9150            # (seed, expected_initial_seed)
9151            # Positive seeds should be unchanged
9152            (max_int64, max_int64),
9153            (max_int64 + 1, max_int64 + 1),
9154            (max_uint64, max_uint64),
9155            (0, 0),
9156            # Negative seeds wrap around starting from the largest seed value
9157            (-1, max_uint64),
9158            (min_int64, max_int64 + 1)
9159        ]
9160        for seed, expected_initial_seed in test_cases:
9161            torch.manual_seed(seed)
9162            actual_initial_seed = torch.initial_seed()
9163            msg = (f"expected initial_seed() = {expected_initial_seed:x} "
9164                   f"after calling manual_seed({seed:x}), but got {actual_initial_seed:x} instead")
9165            self.assertEqual(expected_initial_seed, actual_initial_seed, msg=msg)
9166        for invalid_seed in [min_int64 - 1, max_uint64 + 1]:
9167            with self.assertRaisesRegex(RuntimeError, r'Overflow when unpacking long'):
9168                torch.manual_seed(invalid_seed)
9169
9170        torch.set_rng_state(rng_state)
9171
9172    # FIXME: Describe this test and port to the generic device framework in a more
9173    #   appropriate test suite for the copy operation
9174    def test_copy_transpose(self):
9175        x = torch.arange(100 * 100, dtype=torch.float).reshape(100, 100).t()
9176        y = torch.empty(100, 100, dtype=torch.float)
9177        y.copy_(x)
9178        self.assertEqual(y[:, 0], range(100))
9179        self.assertEqual(y[:, 40], range(4000, 4100))
9180
9181        y = torch.empty(100, 100, dtype=torch.double)
9182        y.copy_(x)
9183        self.assertEqual(y[:, 0], range(100))
9184        self.assertEqual(y[:, 40], range(4000, 4100))
9185
9186        # Validates regression reported in https://github.com/pytorch/pytorch/issues/45269
9187        x = torch.arange(100 * 100).reshape(100, 100).to(dtype=torch.cfloat).t()
9188        y = torch.empty(100, 100, dtype=torch.cfloat)
9189        y.copy_(x)
9190        self.assertEqual(y[:, 0], range(100))
9191        self.assertEqual(y[:, 40], range(4000, 4100))
9192
9193        x = torch.arange(100 * 100).reshape(100, 100).to(dtype=torch.complex32).t()
9194        y = torch.empty(100, 100, dtype=torch.complex32)
9195        y.copy_(x)
9196        self.assertEqual(y[:, 0], range(100))
9197        self.assertEqual(y[:, 40], range(4000, 4100))
9198
9199    # FIXME: Port to a more appropriate test suite
9200    def test_copy_broadcast(self):
9201        torch.zeros(5, 6).copy_(torch.zeros(6))
9202        self.assertRaises(RuntimeError, lambda: torch.zeros(5, 6).copy_(torch.zeros(30)))
9203
9204    # FIXME: Port to a more appropriate test suite
9205    # Fails with inductor (and aot_eager) because functionalization replaces copy_ with copy,
9206    # which doesn't properly error on bad inputs.
9207    def test_copy_many_to_one(self):
9208        # Testing in-place copy where it attempt to write from many memory
9209        # storage to a single storage would cause RuntimeError to be thrown
9210        self.assertRaises(RuntimeError, lambda: torch.zeros(1, 6).expand(5, 6).copy_(torch.zeros(5, 6)))
9211
9212    def test_copy_float16(self):
9213        # Check that fbgemm code no longer reads memory out of bounds, see
9214        # copy_impl and fbgemm::Float16ToFloat_ref.
9215        # https://github.com/pytorch/pytorch/issues/88543
9216
9217        # Types to test different code paths in copy_impl.
9218        dtypes = (
9219            # out_dtype, src_dtype
9220            (torch.float32, torch.float16),  # fbgemm
9221            (torch.float16, torch.float32),  # fbgemm
9222            (torch.float32, torch.float32),  # TensorIterator
9223        )
9224
9225        cases = (
9226            # out_shape, src_shape, is_ok
9227            # These cases used to crash with fbgemm, make sure these also raise
9228            # exceptions with TensorIterator.
9229            ((1, 2, 3), (0, 2, 3), False),  # same strides, not allowed by TI
9230            ((1, 5, 6), (4, 5, 6), False),  # same strides, not allowed by TI
9231            (1, (0, 2, 3), False),  # different strides
9232            ((4, 5, 6), (0, 2, 3), False),  # different strides
9233            ((4, 5, 6), (1, 2, 3), False),  # different strides
9234            ((4, 5, 6), (6, 5, 4), False),  # same numel
9235
9236            # These cases should pass with fbgemm and TensorIterator.
9237            ((4, 5, 6), (1, 5, 6), True),  # same strides
9238            ((4, 5, 6), (4, 5, 6), True),  # same strides
9239            ((0, 2, 3), 1, True),  # different strides, allowed by TI
9240            ((4, 5, 6), (4, 5, 1), True),  # different strides, allowed by TI
9241        )
9242
9243        for (out_shape, src_shape, is_ok), (out_dtype, src_dtype) in itertools.product(cases, dtypes):
9244            out = torch.zeros(out_shape, dtype=out_dtype, device=torch.device('cpu'))
9245            src = torch.ones(src_shape, dtype=src_dtype, device=torch.device('cpu'))
9246            if is_ok:
9247                if torch.cuda.is_available():
9248                    out_cuda = out.cuda()
9249                    src_cuda = src.cuda()
9250                res = out.copy_(src)
9251                if torch.cuda.is_available():
9252                    res_cuda = out_cuda.copy_(src_cuda)
9253                    self.assertEqual(res, res_cuda)
9254            else:
9255                self.assertRaises(RuntimeError, lambda: out.copy_(src))
9256
9257    # FIXME: Port to a more appropriate test suite
9258    def _test_to_with_layout(self, layout):
9259        def test_copy_behavior(t, non_blocking=False):
9260            self.assertIs(t, t.to(t, non_blocking=non_blocking))
9261            self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking))
9262            self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking))
9263            self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True))
9264            self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True))
9265            self.assertIsNot(t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True))
9266
9267            devices = [t.device]
9268            if t.device.type == 'cuda':
9269                if t.device.index == -1:
9270                    devices.append(f'cuda:{torch.cuda.current_device()}')
9271                elif t.device.index == torch.cuda.current_device():
9272                    devices.append('cuda')
9273            for device in devices:
9274                self.assertIs(t, t.to(device, non_blocking=non_blocking))
9275                self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking))
9276                self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True))
9277                self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True))
9278
9279        a = torch.tensor(5)
9280        if layout == torch.sparse_csr:
9281            a = torch.tensor([[0, 1, 2], [2, 0, 3]]).to_sparse_csr()
9282        test_copy_behavior(a)
9283        self.assertEqual(a.device, a.to('cpu').device)
9284        self.assertEqual(a.device, a.to('cpu', dtype=torch.float32).device)
9285        self.assertIs(torch.float32, a.to('cpu', dtype=torch.float32).dtype)
9286        self.assertEqual(a.device, a.to(torch.float32).device)
9287        self.assertIs(torch.float32, a.to(dtype=torch.float32).dtype)
9288
9289        def test_data_ptr(getter):
9290            self.assertEqual(getter(a), getter(a.to('cpu')))
9291            self.assertEqual(getter(a), getter(a.to(dtype=a.dtype, device=a.device, copy=False)))
9292            self.assertEqual(getter(a), getter(a.to('cpu', copy=False)))
9293            self.assertNotEqual(getter(a), getter(a.to('cpu', copy=True)))
9294        if layout == torch.sparse_csr:
9295            # TODO: compressed sparse tensors currently don't support data_ptr.
9296            # Exercising failure will allow us to widen coverage of this test once it does.
9297            with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer of Tensor that doesn't have storage"):
9298                a.data_ptr()
9299            # While compressed sparse tensors don't have a concept of data_ptr
9300            # the underlying tensors do. The implementation of to appropriately forwards
9301            # the call to the components, which is what we're test here.
9302            test_data_ptr(lambda a: a.values().data_ptr())
9303            test_data_ptr(lambda a: a.crow_indices().data_ptr())
9304            test_data_ptr(lambda a: a.col_indices().data_ptr())
9305        else:
9306            test_data_ptr(lambda a: a.data_ptr())
9307
9308        if torch.cuda.is_available():
9309            for non_blocking in [True, False]:
9310                for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
9311                    b = torch.tensor(5., device=cuda)
9312                    test_copy_behavior(b, non_blocking)
9313                    self.assertEqual(b.device, b.to(cuda, non_blocking=non_blocking).device)
9314                    self.assertEqual(a.device, b.to('cpu', non_blocking=non_blocking).device)
9315                    self.assertEqual(b.device, a.to(cuda, non_blocking=non_blocking).device)
9316                    self.assertIs(torch.int32, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).dtype)
9317                    self.assertEqual(a.device, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).device)
9318                    self.assertIs(torch.int32, b.to(dtype=torch.int32).dtype)
9319                    self.assertEqual(b.device, b.to(dtype=torch.int32).device)
9320
9321    def test_to(self):
9322        self._test_to_with_layout(torch.strided)
9323        is_cuda10_2_or_higher = (
9324            (torch.version.cuda is not None)
9325            and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2]))
9326        if is_cuda10_2_or_higher:  # in cuda10_1 sparse_csr is beta
9327            self._test_to_with_layout(torch.sparse_csr)
9328
9329    # FIXME: describe this test
9330    def test_as_subclass(self):
9331        class SubTensor(torch.Tensor):
9332            member_var = object()
9333
9334        t0 = torch.tensor(0)
9335        t1 = torch.tensor([1, 2])
9336        t2 = torch.tensor([[3, 4], [5, 6]])
9337
9338        s0 = t0.as_subclass(SubTensor)
9339        s1 = t1.as_subclass(SubTensor)
9340        s2 = t2.as_subclass(SubTensor)
9341
9342        # Check that the correct type is returned.
9343        self.assertTrue(type(s0) is SubTensor)
9344        self.assertTrue(type(s1) is SubTensor)
9345        self.assertTrue(type(s2) is SubTensor)
9346
9347        # Check that the data is equal.
9348        self.assertEqual(t0, s0)
9349        self.assertEqual(t1, s1)
9350        self.assertEqual(t2, s2)
9351
9352        t0[()] = 1
9353        t1[1] = 3
9354        t2[1, 1] = 7
9355
9356        # Check that the data is equal even after modification.
9357        self.assertEqual(t0, s0)
9358        self.assertEqual(t1, s1)
9359        self.assertEqual(t2, s2)
9360
9361        # Check that member variables are passed through.
9362        self.assertTrue(s0.member_var is SubTensor.member_var)
9363        self.assertTrue(s1.member_var is SubTensor.member_var)
9364        self.assertTrue(s2.member_var is SubTensor.member_var)
9365
9366        # Test that autograd is propagated.
9367        t = torch.tensor(5, dtype=torch.float32, requires_grad=True)
9368
9369        # Run a calculation on the tensor.
9370        exp_t = torch.exp(t)
9371
9372        # Cast exp_t to a subclass.
9373        exp_s = exp_t.as_subclass(SubTensor)
9374
9375        # Make sure that t.grad was initially None
9376        self.assertTrue(t.grad is None)
9377
9378        # Run the autograd calculation.
9379        exp_s.backward()
9380
9381        # Make sure autograd was propagated to the original tensor
9382        # declared with requires_grad.
9383        self.assertTrue(t.grad is not None)
9384
9385        # Make sure invalid subclasses raise nice errors
9386        class BadSubTensor:
9387            member_var = object()
9388
9389        err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor"
9390        with self.assertRaisesRegex(RuntimeError, err_msg):
9391            s0 = t0.as_subclass(BadSubTensor)
9392
9393    # FIXME: Port to a test suite that better fits slicing
9394    def test_slice(self):
9395        empty = torch.empty(0, 4)
9396        x = torch.arange(0., 16).view(4, 4)
9397        self.assertEqual(x[:], x)
9398        self.assertEqual(x[:4], x)
9399        # start and stop are clamped to the size of dim
9400        self.assertEqual(x[:5], x)
9401        # if start >= stop then the result is empty
9402        self.assertEqual(x[2:1], empty)
9403        self.assertEqual(x[2:2], empty)
9404        # out of bounds is also empty
9405        self.assertEqual(x[10:12], empty)
9406        # additional correctness checks
9407        self.assertEqual(x[:1].tolist(), [[0, 1, 2, 3]])
9408        self.assertEqual(x[:-3].tolist(), [[0, 1, 2, 3]])
9409        self.assertEqual(x[:, -2:3].tolist(), [[2], [6], [10], [14]])
9410        self.assertEqual(x[0:-1:2].tolist(), [[0, 1, 2, 3], [8, 9, 10, 11]])
9411
9412    def test_split_with_sizes_copy_out(self):
9413        device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
9414        shape = (30, 40, 50)
9415        x = torch.rand(*shape, device=device)
9416        cases = [
9417            (0, [3, 7, 8, 12]),
9418            (1, [3, 7, 10, 20]),
9419            (-2, [3, 7, 10, 20]),
9420            (2, [3, 7, 10, 12, 18]),
9421            (-1, [3, 7, 10, 12, 18]),
9422            (2, [3, 7, 10, 0, 30]),
9423        ]
9424        for dim, split_sizes in cases:
9425            views = x.split_with_sizes(split_sizes, dim=dim)
9426            expects = [v.clone() for v in views]
9427            out = [torch.zeros_like(v) for v in views]
9428            for expect, t in zip(expects, out):
9429                if expect.numel() != 0:
9430                    self.assertFalse(expect.eq(t).all().item())
9431
9432            torch.split_with_sizes_copy(x, split_sizes, dim=dim, out=out)
9433            for expect, t in zip(expects, out):
9434                self.assertTrue(expect.eq(t).all().item())
9435
9436            if not torch.cuda.is_available():
9437                continue
9438
9439            # Test with cuda graph
9440            out = [torch.zeros_like(v) for v in views]
9441            for expect, t in zip(expects, out):
9442                if expect.numel() != 0:
9443                    self.assertFalse(expect.eq(t).all().item())
9444
9445            g = torch.cuda.CUDAGraph()
9446            with torch.cuda.graph(g):
9447                torch.split_with_sizes_copy(x, split_sizes, dim=dim, out=out)
9448
9449            g.replay()
9450            for expect, t in zip(expects, out):
9451                self.assertTrue(expect.eq(t).all().item())
9452
9453    def test_type(self):
9454        x = torch.randn(3, 3).double()
9455        self.assertEqual(x.type('torch.FloatTensor').dtype, torch.float32)
9456        self.assertEqual(x.type(torch.FloatTensor).dtype, torch.float32)
9457        self.assertEqual(x.int().type(torch.Tensor).dtype, torch.get_default_dtype())
9458        self.assertEqual(x.type(torch.int32).dtype, torch.int32)
9459
9460    # FIXME: port to a quantization test suite
9461    def test_qengine(self):
9462        qengines = torch.backends.quantized.supported_engines
9463        original_qe = torch.backends.quantized.engine
9464        for qe in qengines:
9465            torch.backends.quantized.engine = qe
9466            assert torch.backends.quantized.engine == qe, 'qengine not set successfully'
9467        torch.backends.quantized.engine = original_qe
9468
9469    def test_terminate_handler_on_crash(self):
9470        cmd = [sys.executable, '-c', "import os; os.environ[\"TORCH_CUSTOM_TERMINATE\"] ='1'; \
9471               import torch; import torch._C; torch._C._abort()"]
9472        with self.assertRaises(subprocess.CalledProcessError) as cm:
9473            subprocess.check_output(cmd, shell=False)
9474        e = cm.exception
9475        output = e.stdout.decode("utf-8")
9476        self.assertNotEqual(e.returncode, 0)
9477        self.assertNotEqual(output, None)
9478        self.assertIn('Unhandled exception caught in c10/util/AbortHandler.h', output)
9479
9480    # FIXME: port to a distributed test suite -- also... how could this be OOMing on Windows CUDA?
9481    @slowTest
9482    @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
9483                        don't support multiprocessing with spawn start method")
9484    @unittest.skipIf(IS_WINDOWS, 'FIXME: CUDA OOM error on Windows')
9485    def test_multinomial_invalid_probs(self):
9486        def _spawn_method(self, method, arg):
9487            try:
9488                mp.set_start_method('spawn')
9489            except RuntimeError:
9490                pass
9491            with mp.Pool(1) as pool:
9492                out = pool.map(method, [arg])
9493                self.assertTrue(out[0])
9494
9495        def _test_multinomial_invalid_probs(probs):
9496            try:
9497                # n_sample = 1 is a special case, test n_sample=2 which is more general
9498                torch.multinomial(probs.to('cpu'), 2)
9499                return False  # Should not be reached
9500            except RuntimeError as e:
9501                return 'probability tensor contains either `inf`, `nan` or element < 0' in str(e)
9502
9503            _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., -1., 1.]))
9504            _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., inf, 1.]))
9505            _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., -inf, 1.]))
9506            _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., 1., nan]))
9507
9508    # FIXME: port to more appropriate test suite
9509    def test_to_with_tensor(self):
9510        a = torch.tensor(5)
9511        self.assertEqual(a.device, a.to(a).device)
9512
9513        if torch.cuda.is_available():
9514            for non_blocking in [True, False]:
9515                for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
9516                    b = torch.tensor(5., device=cuda)
9517                    self.assertEqual(b.device, b.to(b, non_blocking=non_blocking).device)
9518                    self.assertEqual(a.device, b.to(a, non_blocking=non_blocking).device)
9519                    self.assertEqual(b.device, a.to(b, non_blocking=non_blocking).device)
9520
9521    def test_device(self):
9522        cpu = torch.device('cpu')
9523        self.assertEqual('cpu', str(cpu))
9524        self.assertEqual('cpu', cpu.type)
9525        self.assertEqual(None, cpu.index)
9526
9527        cpu0 = torch.device('cpu:0')
9528        self.assertEqual('cpu:0', str(cpu0))
9529        self.assertEqual('cpu', cpu0.type)
9530        self.assertEqual(0, cpu0.index)
9531
9532        cpu0 = torch.device('cpu', 0)
9533        self.assertEqual('cpu:0', str(cpu0))
9534        self.assertEqual('cpu', cpu0.type)
9535        self.assertEqual(0, cpu0.index)
9536
9537        cuda = torch.device('cuda')
9538        self.assertEqual('cuda', str(cuda))
9539        self.assertEqual('cuda', cuda.type)
9540        self.assertEqual(None, cuda.index)
9541
9542        cuda1 = torch.device('cuda:1')
9543        self.assertEqual('cuda:1', str(cuda1))
9544        self.assertEqual('cuda', cuda1.type)
9545        self.assertEqual(1, cuda1.index)
9546
9547        cuda1 = torch.device('cuda', 1)
9548        self.assertEqual('cuda:1', str(cuda1))
9549        self.assertEqual('cuda', cuda1.type)
9550        self.assertEqual(1, cuda1.index)
9551
9552        cuda90 = torch.device('cuda', 90)
9553        self.assertEqual('cuda:90', str(cuda90))
9554        self.assertEqual('cuda', cuda90.type)
9555        self.assertEqual(90, cuda90.index)
9556
9557        self.assertRaises(RuntimeError, lambda: torch.device('cpu:-1'))
9558        self.assertRaises(RuntimeError, lambda: torch.device('cuda:-1'))
9559        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 '))
9560        self.assertRaises(RuntimeError, lambda: torch.device('cuda: 2'))
9561        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 2'))
9562        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2.'))
9563        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2?'))
9564        self.assertRaises(RuntimeError, lambda: torch.device('cuda:?2'))
9565        self.assertRaises(RuntimeError, lambda: torch.device('cuda:'))
9566        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2.232'))
9567        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 cuda:3'))
9568        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2+cuda:3'))
9569        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2cuda:3'))
9570        self.assertRaises(RuntimeError, lambda: torch.device(-1))
9571
9572        self.assertRaises(RuntimeError, lambda: torch.device('other'))
9573        self.assertRaises(RuntimeError, lambda: torch.device('other:0'))
9574
9575        device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'}
9576        device_hash_set = set()
9577        device_hash_set.update(hash(torch.device(device)) for device in device_set)
9578        self.assertEqual(len(device_set), len(device_hash_set))
9579
9580        def get_expected_device_repr(device):
9581            if device.index is not None:
9582                return f"device(type='{device.type}', index={device.index})"
9583
9584            return f"device(type='{device.type}')"
9585
9586        for device in device_set:
9587            dev = torch.device(device)
9588            self.assertEqual(repr(dev), get_expected_device_repr(dev))
9589
9590    # Tests that the use_deterministic_flag can be set as expected
9591    @wrapDeterministicFlagAPITest
9592    def test_deterministic_flag(self):
9593        for deterministic, warn_only in product([True, False], [True, False]):
9594            torch.use_deterministic_algorithms(deterministic, warn_only=warn_only)
9595            self.assertEqual(deterministic, torch.are_deterministic_algorithms_enabled())
9596            self.assertEqual(warn_only, torch.is_deterministic_algorithms_warn_only_enabled())
9597
9598            if deterministic:
9599                if warn_only:
9600                    debug_mode = 1
9601                else:
9602                    debug_mode = 2
9603            else:
9604                debug_mode = 0
9605
9606            self.assertEqual(debug_mode, torch.get_deterministic_debug_mode())
9607
9608        for debug_mode in [0, 1, 2]:
9609            torch.set_deterministic_debug_mode(debug_mode)
9610            self.assertEqual(debug_mode, torch.get_deterministic_debug_mode())
9611            deterministic = debug_mode in [1, 2]
9612            warn_only = debug_mode == 1
9613
9614            self.assertEqual(deterministic, torch.are_deterministic_algorithms_enabled())
9615            self.assertEqual(warn_only, torch.is_deterministic_algorithms_warn_only_enabled())
9616
9617        for debug_mode, debug_mode_str in [(0, 'default'), (1, 'warn'), (2, 'error')]:
9618            torch.set_deterministic_debug_mode(debug_mode_str)
9619            self.assertEqual(debug_mode, torch.get_deterministic_debug_mode())
9620
9621        with self.assertRaisesRegex(
9622                TypeError,
9623                r"_set_deterministic_algorithms\(\): argument 'mode' \(position 1\) must be bool, not int"):
9624            torch.use_deterministic_algorithms(1)
9625
9626        with self.assertRaisesRegex(
9627                TypeError,
9628                r"_set_deterministic_algorithms\(\): argument 'warn_only' must be bool, not int"):
9629            torch.use_deterministic_algorithms(False, warn_only=1)
9630
9631    # Tests that torch.utils.deterministic.fill_uninitialized_memory can be set as expected
9632    def test_deterministic_fill_uninitialized_memory(self):
9633        with DeterministicGuard(True, fill_uninitialized_memory=False):
9634            self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
9635            self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
9636
9637            with DeterministicGuard(True, fill_uninitialized_memory=True):
9638                self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
9639                self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
9640
9641            self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
9642            self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
9643
9644            torch.utils.deterministic.fill_uninitialized_memory = False
9645            self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
9646            self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
9647
9648            torch.utils.deterministic.fill_uninitialized_memory = True
9649            self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
9650            self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
9651
9652            torch._C._set_deterministic_fill_uninitialized_memory(False)
9653            self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
9654            self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
9655
9656            torch._C._set_deterministic_fill_uninitialized_memory(True)
9657            self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
9658            self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
9659
9660            with self.assertRaisesRegex(RuntimeError, r"expected a bool, but got int"):
9661                torch.utils.deterministic.fill_uninitialized_memory = 1
9662
9663    def test_type_conversion_via_dtype_name(self):
9664        x = torch.tensor([1])
9665        self.assertEqual(x.byte().dtype, torch.uint8)
9666        self.assertEqual(x.bool().dtype, torch.bool)
9667        self.assertEqual(x.char().dtype, torch.int8)
9668        self.assertEqual(x.double().dtype, torch.float64)
9669        self.assertEqual(x.float().dtype, torch.float32)
9670        self.assertEqual(x.half().dtype, torch.float16)
9671        self.assertEqual(x.int().dtype, torch.int32)
9672        self.assertEqual(x.bfloat16().dtype, torch.bfloat16)
9673        cfloat = x.cfloat()
9674        self.assertEqual(cfloat.dtype, torch.complex64)
9675        self.assertEqual(cfloat.real, x.float())
9676        self.assertEqual(cfloat.imag, torch.zeros_like(cfloat.imag))
9677        cdouble = x.cdouble()
9678        self.assertEqual(cdouble.dtype, torch.complex128)
9679        self.assertEqual(cdouble.real, x.double())
9680        self.assertEqual(cdouble.imag, torch.zeros_like(cdouble.imag))
9681        chalf = x.chalf()
9682        self.assertEqual(chalf.dtype, torch.complex32)
9683        self.assertEqual(chalf.real, x.half())
9684        self.assertEqual(chalf.imag, torch.zeros_like(chalf.imag))
9685
9686    def test_type_alias(self):
9687        type_alias_map = {torch.float64: torch.double,
9688                          torch.float32: torch.float,
9689                          torch.int32: torch.int,
9690                          torch.int64: torch.long,
9691                          torch.int16: torch.short,
9692                          torch.float16: torch.half,
9693                          torch.complex32: torch.chalf,
9694                          torch.complex64: torch.cfloat}
9695        for dtype, alias in type_alias_map.items():
9696            self.assertIs(alias, dtype)
9697
9698    def test_doc_template(self) -> None:
9699        """
9700        Test that all public API doc strings use the same standard template for
9701        all common arguments such as tensor or dim
9702        """
9703        from torch._torch_docs import __file__ as doc_file
9704        from torch._torch_docs import multi_dim_common, single_dim_common, factory_common_args, factory_like_common_args
9705
9706        with open(doc_file, encoding="utf-8") as f:
9707            doc_strs = f.read()
9708
9709        matches = re.findall(
9710            r'add_docstr\(([^,]+?),[^"\']*?(?:"""|\'\'\')(.*?)(?:"""|\'\'\')(?:\.|,?[^,\)]*?\))',
9711            doc_strs,
9712            re.MULTILINE | re.DOTALL,
9713        )
9714        self.assertTrue(matches)
9715
9716        for m in matches:
9717            func = m[0].strip()
9718            desc = m[1].strip()
9719
9720            for common_args in [multi_dim_common, single_dim_common, factory_common_args, factory_like_common_args]:
9721                for k, v in common_args.items():
9722                    self.assertNotIn(v, desc, f'The argument description "{v}" in {func} can be '
9723                                              f'replaced by {{{k}}}')
9724
9725    def test_doc(self):
9726        checked_types = (types.MethodType, types.FunctionType,
9727                         types.BuiltinFunctionType, types.BuiltinMethodType)
9728
9729        def _test_namespace(ns, *skips):
9730            if isinstance(ns, object):
9731                ns_name = ns.__class__.__name__
9732            else:
9733                ns_name = ns.__name__
9734            skip_regexes = []
9735            for r in skips:
9736                if isinstance(r, str):
9737                    skip_regexes.append(re.compile(f'^{re.escape(r)}$'))
9738                else:
9739                    skip_regexes.append(r)
9740
9741            for name in dir(ns):
9742                if name.startswith('_'):
9743                    continue
9744                if name in ['real', 'imag']:
9745                    y = torch.randn(1, dtype=torch.cfloat)
9746                    var = getattr(y, name)
9747                elif name in ["H", "mT", "mH"]:
9748                    y = torch.randn(1, 1)
9749                    var = getattr(y, name)
9750                else:
9751                    var = getattr(ns, name)
9752                if not isinstance(var, checked_types):
9753                    continue
9754                doc = var.__doc__
9755                has_doc = doc is not None and len(doc.strip()) > 0
9756                full_name = ns_name + '.' + name
9757                if any(r.match(name) for r in skip_regexes):
9758                    self.assertFalse(has_doc,
9759                                     f'New docs have been added for {full_name}, please remove '
9760                                     'it from the skipped list in TestTorch.test_doc')
9761                else:
9762                    self.assertTrue(has_doc, f'{full_name} is missing documentation')
9763
9764            # FIXME: All of the following should be marked as expected failures
9765            # so that it is easier to tell when missing has been added.
9766            # FIXME: fix all the skipped ones below!
9767            test_namespace(torch.randn(1),  # noqa: F821
9768                           'as_strided_',
9769                           re.compile('^clamp_(min|max)_?$'),
9770                           'is_distributed',
9771                           'is_nonzero',
9772                           'is_same_size',
9773                           'log_softmax',
9774                           'map2_',
9775                           'new',
9776                           'reinforce',
9777                           'relu',
9778                           'relu_',
9779                           'prelu',
9780                           'resize',
9781                           'resize_as',
9782                           'softmax',
9783                           'split_with_sizes',
9784                           'unsafe_split_with_sizes',
9785                           '_autocast_to_fp16',
9786                           '_autocast_to_fp32',
9787                           )
9788
9789            test_namespace(torch.nn)  # noqa: F821
9790            test_namespace(torch.nn.functional, 'assert_int_or_pair')  # noqa: F821
9791            # TODO: add torch.* tests when we have proper namespacing on ATen functions
9792            # test_namespace(torch)
9793
9794    # FIXME: deprecate torch.Tensor constructor
9795    def test_tensor_ctor_scalar(self):
9796        x = torch.Tensor(torch.tensor(1.0))
9797        self.assertEqual(x, torch.tensor(1.0))
9798
9799    def test_deepcopy_gradient(self):
9800        from copy import deepcopy
9801        a = torch.zeros(10)
9802        a.grad = torch.ones(10)
9803        self.assertEqual(a.grad, deepcopy(a).grad)
9804        s = torch.zeros(10).to_sparse()
9805        s.grad = torch.ones(10).to_sparse()
9806        self.assertEqual(s.grad, deepcopy(s).grad)
9807
9808        # ensure sharing is not broken
9809        c = deepcopy([a, a.grad])
9810        self.assertTrue(c[0].grad is c[1])
9811
9812    def test_tensor_base_init(self):
9813        # Direct construction not OK
9814        self.assertRaises(RuntimeError, lambda: torch._C.TensorBase())
9815
9816        # Subclassing it directly no OK
9817        with self.assertRaisesRegex(RuntimeError, "Cannot subclass"):
9818            class Tfail(torch._C.TensorBase):
9819                pass
9820
9821        # Doing so with Tensor is ok though
9822        class T(torch.Tensor):
9823            pass
9824
9825        T()
9826
9827    def test_storage_base_init(self):
9828        # Direct construction not OK
9829        self.assertRaises(RuntimeError, lambda: torch._C.StorageBase())
9830
9831        # But construction of subclass is OK
9832        class T(torch._C.StorageBase):
9833            pass
9834
9835        T()
9836
9837    def test_tensor_base_new(self):
9838
9839        # OK to call super().__new__, see
9840        # https://github.com/pytorch/pytorch/issues/57421
9841        class TestTensor(torch.Tensor):
9842            @staticmethod
9843            def __new__(cls, x, *args, **kwargs):
9844                return super().__new__(cls, x, *args, **kwargs)
9845
9846        x = torch.ones(5)
9847        test_tensor = TestTensor(x)
9848
9849    def test_storage_base_new(self):
9850
9851        # OK to call super().__new__, see
9852        # https://github.com/pytorch/pytorch/issues/57421
9853        class TestStorage(torch._C.StorageBase):
9854            @staticmethod
9855            def __new__(cls, x, *args, **kwargs):
9856                return super().__new__(cls, x, *args, **kwargs)
9857
9858        x = torch.UntypedStorage(5)
9859        test_storage = TestStorage(x)
9860
9861    def test_pyobj_preserved(self):
9862        x = torch.empty(2)
9863        x.foo = 2  # put something on __dict__
9864        y = torch.empty(2)
9865        y.grad = x
9866        del x  # x is dead in Python
9867        self.assertEqual(y.grad.foo, 2)
9868        z = y.grad  # it's live
9869        del z  # it's dead again
9870        self.assertEqual(y.grad.foo, 2)
9871
9872    def test_subclass_preserved(self):
9873        class MyTensor(torch.Tensor):
9874            pass
9875
9876        x = MyTensor(torch.empty(2))
9877        y = torch.empty(2)
9878        y.grad = x
9879        del x  # x is dead in Python
9880        self.assertEqual(type(y.grad), MyTensor)
9881        z = y.grad  # it's live
9882        del z  # it's dead again
9883        self.assertEqual(type(y.grad), MyTensor)
9884
9885    @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
9886    def test_storage_dealloc(self):
9887        m, t = Tracker.make()
9888        s0 = torch.UntypedStorage(10)
9889        s1 = s0
9890        s0._tracker = t
9891        del t
9892
9893        self.assertFalse(m[0])
9894        del s0
9895        self.assertFalse(m[0])
9896        del s1
9897        self.assertTrue(m[0])
9898
9899    @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
9900    def test_storage_from_tensor_dealloc(self):
9901        m, t = Tracker.make()
9902        a = torch.randn(10)
9903        s0 = a.untyped_storage()
9904        s0._tracker = t
9905        del t
9906
9907        s1 = a.untyped_storage()
9908        self.assertTrue(s0 is s1)
9909        self.assertTrue(hasattr(s1, '_tracker'))
9910
9911        del a
9912
9913        self.assertFalse(m[0])
9914        del s0
9915        self.assertFalse(m[0])
9916        del s1
9917        self.assertTrue(m[0])
9918
9919    @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
9920    def test_storage_from_tensor_dealloc_zombie(self):
9921        m, t = Tracker.make()
9922        a = torch.randn(10)
9923        s0 = a.untyped_storage()
9924        s0._tracker = t
9925        del t
9926
9927        s1 = a.untyped_storage()
9928        self.assertTrue(s0 is s1)
9929        self.assertTrue(hasattr(s1, '_tracker'))
9930
9931        self.assertFalse(m[0])
9932        del s0
9933        self.assertFalse(m[0])
9934        del s1
9935        self.assertFalse(m[0])
9936        del a
9937        self.assertTrue(m[0])
9938
9939    @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
9940    def test_storage_from_tensor_dealloc_resurrected(self):
9941        m, t = Tracker.make()
9942        a = torch.randn(10)
9943        s0 = a.untyped_storage()
9944        s0._tracker = t
9945        del t
9946
9947        s1 = a.untyped_storage()
9948        self.assertTrue(s0 is s1)
9949        self.assertTrue(hasattr(s1, '_tracker'))
9950
9951        self.assertFalse(m[0])
9952        del s0
9953        self.assertFalse(m[0])
9954        del s1
9955        self.assertFalse(m[0])
9956
9957        s0 = a.untyped_storage()
9958        self.assertTrue(isinstance(s0, torch.UntypedStorage))
9959
9960        del a
9961        self.assertFalse(m[0])
9962        del s0
9963        self.assertTrue(m[0])
9964
9965    @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
9966    def test_storage_dealloc_resurrected(self):
9967        m, t = Tracker.make()
9968        s = torch.UntypedStorage(10)
9969        s._tracker = t
9970        del t
9971
9972        a = torch.tensor(s)
9973        self.assertFalse(m[0])
9974        del s
9975
9976        self.assertFalse(m[0])
9977
9978        s = a.untyped_storage()
9979        self.assertTrue(isinstance(s, torch.UntypedStorage))
9980
9981        del a
9982        self.assertFalse(m[0])
9983        del s
9984        self.assertTrue(m[0])
9985
9986    @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
9987    def test_storage_dealloc_subclass_zombie(self):
9988        class MyStorage(torch.UntypedStorage):
9989            finalized_count = 0
9990
9991            def __del__(self):
9992                MyStorage.finalized_count += 1
9993
9994        m, t = Tracker.make()
9995        s = MyStorage(10)
9996        s._tracker = t
9997        del t
9998
9999        a = torch.tensor(s)
10000        self.assertFalse(m[0])
10001        del s
10002
10003        self.assertEqual(MyStorage.finalized_count, 0)
10004        self.assertFalse(m[0])
10005
10006        del a
10007        self.assertEqual(MyStorage.finalized_count, 1)
10008        self.assertTrue(m[0])
10009
10010    @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
10011    def test_storage_dealloc_subclass_resurrected(self):
10012        class MyStorage(torch.UntypedStorage):
10013            finalized_count = 0
10014
10015            def __del__(self):
10016                MyStorage.finalized_count += 1
10017
10018        m, t = Tracker.make()
10019        s = MyStorage(10)
10020        s._tracker = t
10021        del t
10022
10023        a = torch.tensor(s)
10024        self.assertFalse(m[0])
10025        del s
10026
10027        self.assertEqual(MyStorage.finalized_count, 0)
10028        self.assertFalse(m[0])
10029
10030        s = a.untyped_storage()
10031        del a
10032        self.assertFalse(m[0])
10033        self.assertEqual(MyStorage.finalized_count, 0)
10034        self.assertTrue(isinstance(s, MyStorage))
10035        del s
10036        self.assertEqual(MyStorage.finalized_count, 1)
10037        self.assertTrue(m[0])
10038
10039    def test_tensor_slot_dealloc(self):
10040
10041        class SlotTensor1(torch.Tensor):
10042            __slots__ = ['slot1']
10043
10044        class SlotTensor2(SlotTensor1):
10045            __slots__ = ['slot2']
10046
10047        m1, t1 = Tracker.make()
10048        m2, t2 = Tracker.make()
10049        slot_tensor = SlotTensor2(torch.empty(2))
10050        slot_tensor.slot1 = t1
10051        slot_tensor.slot2 = t2
10052        del t1
10053        del t2
10054        self.assertFalse(m1[0])
10055        self.assertFalse(m2[0])
10056        del slot_tensor
10057        self.assertTrue(m1[0])
10058        self.assertTrue(m2[0])
10059
10060    def test_storage_slot_dealloc(self):
10061
10062        class SlotStorage1(torch._C.StorageBase):
10063            __slots__ = ['slot1']
10064
10065        class SlotStorage2(SlotStorage1):
10066            __slots__ = ['slot2']
10067
10068        m1, t1 = Tracker.make()
10069        m2, t2 = Tracker.make()
10070        slot_storage = SlotStorage2(torch.UntypedStorage(2))
10071        slot_storage.slot1 = t1
10072        slot_storage.slot2 = t2
10073        del t1
10074        del t2
10075        self.assertFalse(m1[0])
10076        self.assertFalse(m2[0])
10077        del slot_storage
10078        self.assertTrue(m1[0])
10079        self.assertTrue(m2[0])
10080
10081    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
10082    def test_tensor_dict_dealloc(self):
10083        m, t = Tracker.make()
10084        x = torch.empty(2)
10085        x.arf = t
10086        del t
10087        self.assertFalse(m[0])
10088        del x
10089        self.assertTrue(m[0])
10090
10091    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
10092    def test_storage_dict_dealloc(self):
10093        m, t = Tracker.make()
10094        x = torch.UntypedStorage(2)
10095        x.arf = t
10096        del t
10097        self.assertFalse(m[0])
10098        del x
10099        self.assertTrue(m[0])
10100
10101    def test_tensor_finalizer_dealloc(self):
10102        m = [False]
10103
10104        class FinalizerTensor(torch.Tensor):
10105            def __del__(self):
10106                m[0] = True
10107
10108        fin_tensor = FinalizerTensor(torch.empty(2))
10109        self.assertFalse(m[0])
10110        del fin_tensor
10111        self.assertTrue(m[0])
10112
10113    def test_storage_finalizer_dealloc(self):
10114        m = [False]
10115
10116        class FinalizerStorage(torch._C.StorageBase):
10117            def __del__(self):
10118                m[0] = True
10119
10120        fin_storage = FinalizerStorage(torch.UntypedStorage(2))
10121        self.assertFalse(m[0])
10122        del fin_storage
10123        self.assertTrue(m[0])
10124
10125    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
10126    def test_tensor_weakref_dealloc(self):
10127        x = torch.empty(2)
10128        m = [False]
10129
10130        def cb(r):
10131            m[0] = True
10132
10133        wref = weakref.ref(x, cb)
10134        del x
10135        self.assertTrue(m[0])
10136        self.assertEqual(wref(), None)
10137
10138    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
10139    def test_storage_weakref_dealloc(self):
10140
10141        x = torch.UntypedStorage(2)
10142        m = [False]
10143
10144        def cb(r):
10145            m[0] = True
10146
10147        wref = weakref.ref(x, cb)
10148        del x
10149        self.assertTrue(m[0])
10150        self.assertEqual(wref(), None)
10151
10152    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
10153    def test_tensor_cycle_via_dict(self):
10154        m1, t1 = Tracker.make()
10155        x = torch.empty(2)
10156        x._tracker = t1
10157        del t1
10158
10159        m2, t2 = Tracker.make()
10160        y = torch.empty(2)
10161        y._tracker = t2
10162        del t2
10163
10164        x._loop = y
10165        y._loop = x
10166
10167        # C++ reference should keep the cycle live!
10168        # This exercise THPVariable_subtype_traverse
10169        # NB: Because z.grad is a reference done entirely in C++, cycles
10170        # involving it directly are NOT broken by Python GC; you've
10171        # set up a good old C++ reference cycle which we cannot safely
10172        # break (because C++ references are allowed to be accessed
10173        # multithreaded-ly) (TODO: except maybe if you can prove that
10174        # only Python has access to the C++ object, in which case you can
10175        # also prove that no multithreaded access occurs)
10176        z = torch.empty(2)
10177        z.grad = x
10178
10179        del x
10180        del y
10181
10182        gc.collect()
10183        self.assertFalse(m1[0])
10184        self.assertFalse(m2[0])
10185
10186        with disable_gc():
10187            del z
10188            self.assertFalse(m1[0])
10189            self.assertFalse(m2[0])
10190
10191        gc.collect()
10192        self.assertTrue(m1[0])
10193        self.assertTrue(m2[0])
10194
10195    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
10196    def test_storage_cycle_via_dict(self):
10197        m1, t1 = Tracker.make()
10198        x = torch.UntypedStorage(2)
10199        x._tracker = t1
10200        del t1
10201
10202        m2, t2 = Tracker.make()
10203        y = torch.UntypedStorage(2)
10204        y._tracker = t2
10205        del t2
10206
10207        x._loop = y
10208        y._loop = x
10209
10210        # C++ reference should keep the cycle live!
10211        # This exercise THPVariable_subtype_traverse
10212        # NB: Because z.grad is a reference done entirely in C++, cycles
10213        # involving it directly are NOT broken by Python GC; you've
10214        # set up a good old C++ reference cycle which we cannot safely
10215        # break (because C++ references are allowed to be accessed
10216        # multithreaded-ly) (TODO: except maybe if you can prove that
10217        # only Python has access to the C++ object, in which case you can
10218        # also prove that no multithreaded access occurs)
10219        z = torch.UntypedStorage(2)
10220        z.grad = x
10221
10222        del x
10223        del y
10224
10225        gc.collect()
10226        self.assertFalse(m1[0])
10227        self.assertFalse(m2[0])
10228
10229        with disable_gc():
10230            del z
10231            self.assertFalse(m1[0])
10232            self.assertFalse(m2[0])
10233
10234        gc.collect()
10235        self.assertTrue(m1[0])
10236        self.assertTrue(m2[0])
10237
10238    def test_tensor_cycle_via_slots(self):
10239        m1 = [False]
10240        m2 = [False]
10241
10242        class SlotTensor1(torch.Tensor):
10243            __slots__ = ['slot1']
10244
10245            def __del__(self):
10246                m1[0] = True
10247
10248        class SlotTensor2(SlotTensor1):
10249            __slots__ = ['slot2']
10250
10251            def __del__(self):
10252                m2[0] = True
10253
10254        x = SlotTensor1(torch.empty(2))
10255        y = SlotTensor2(torch.empty(2))
10256
10257        x.slot1 = y
10258        y.slot2 = x
10259
10260        del x
10261        with disable_gc():
10262            del y
10263            self.assertFalse(m1[0])
10264            self.assertFalse(m2[0])
10265
10266        gc.collect()
10267        self.assertTrue(m1[0])
10268        self.assertTrue(m2[0])
10269
10270    def test_storage_cycle_via_slots(self):
10271        m1 = [False]
10272        m2 = [False]
10273
10274        class SlotStorage1(torch._C.StorageBase):
10275            __slots__ = ['slot1']
10276
10277            def __del__(self):
10278                m1[0] = True
10279
10280        class SlotStorage2(SlotStorage1):
10281            __slots__ = ['slot2']
10282
10283            def __del__(self):
10284                m2[0] = True
10285
10286        x = SlotStorage1(torch.UntypedStorage(2))
10287        y = SlotStorage2(torch.UntypedStorage(2))
10288
10289        x.slot1 = y
10290        y.slot2 = x
10291
10292        del x
10293        with disable_gc():
10294            del y
10295            self.assertFalse(m1[0])
10296            self.assertFalse(m2[0])
10297
10298        gc.collect()
10299        self.assertTrue(m1[0])
10300        self.assertTrue(m2[0])
10301
10302    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
10303    def test_storage_preserve_nonhermetic_in_hermetic_context(self):
10304        from torch.library import Library, impl
10305        global _my_storage
10306
10307        my_lib = Library("my_lib", "DEF")  # noqa: TOR901
10308        my_lib.define('my_func() -> None')
10309
10310        a = torch.tensor([1.])
10311        _my_storage = a.untyped_storage()
10312
10313        m, t = Tracker.make()
10314        _my_storage._tracker = t
10315        del t
10316
10317        @impl(my_lib, 'my_func', '')
10318        def my_func():
10319            global _my_storage
10320            del _my_storage
10321
10322        self.assertFalse(m[0])
10323        torch.ops.my_lib.my_func()
10324        self.assertFalse(m[0])
10325
10326        s = a.untyped_storage()
10327        del a
10328        del s
10329        self.assertTrue(m[0])
10330
10331    # FIXME: move to test_autograd?
10332    @skipIfTorchDynamo("TorchDynamo does not work well with hooks")
10333    def test_backward_hooks_traverse(self):
10334        m1, t1 = Tracker.make()
10335        m2, t2 = Tracker.make()
10336        x = torch.empty(2, requires_grad=True)
10337        x._tracker = t1
10338        y = torch.empty(2, requires_grad=True)
10339        y._tracker = t2
10340        del t1
10341        del t2
10342
10343        # this hits a special setter, it's not just a __dict__ entry
10344        x._backward_hooks = y
10345        y._backward_hooks = x
10346
10347        del x
10348        with disable_gc():
10349            del y
10350            self.assertFalse(m1[0])
10351            self.assertFalse(m2[0])
10352
10353        gc.collect()
10354
10355        self.assertTrue(m1[0])
10356        self.assertTrue(m2[0])
10357
10358    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
10359    def test_tensor_dead_weak_ref(self):
10360        x = torch.empty(2)
10361        w_x = weakref.ref(x)
10362        y = torch.empty(2)
10363        y.grad = x
10364        del x
10365
10366        x = w_x()
10367        # Ideally, x would keep the tensor live.  But CPython doesn't
10368        # provide enough hooks to do this.  So it will go dead and x
10369        # will transmute into an undefined tensor.  Not great, but the
10370        # best we can do.
10371        del y
10372
10373        self.assertRaises(RuntimeError, lambda: x.sigmoid())
10374
10375    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
10376    def test_storage_dead_weak_ref(self):
10377        x = torch.UntypedStorage(2)
10378        w_x = weakref.ref(x)
10379        y = torch.tensor(x)
10380        del x
10381
10382        x = w_x()
10383        # Ideally, x would keep the storage live.  But CPython doesn't
10384        # provide enough hooks to do this.  So it will go dead and x
10385        # will transmute into storage with null StorageImpl. Not great, but the
10386        # best we can do.
10387        del y
10388
10389        self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x[0])
10390        self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x.float())
10391
10392    def test_tensor_resurrected_weak_ref(self):
10393        x = torch.empty(2)
10394        w_x = weakref.ref(x)
10395        y = torch.empty(2)
10396        y.grad = x
10397        del x
10398
10399        x = w_x()
10400        # Use this to manually fix weak references after dereferencing them
10401        x._fix_weakref()
10402        del y
10403        x.sigmoid()
10404
10405    def test_storage_resurrected_weak_ref(self):
10406        x = torch.UntypedStorage(2)
10407        w_x = weakref.ref(x)
10408        y = torch.tensor(x)
10409        del x
10410
10411        x = w_x()
10412        # Use this to manually fix weak reference after dereferencing them
10413        x._fix_weakref()
10414        del y
10415        x.float()
10416
10417    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
10418    def test_tensor_fix_weakref_no_leak(self):
10419        import weakref
10420
10421        called = False
10422
10423        a = torch.randn(1)
10424
10425        def callback(w):
10426            nonlocal called
10427            called = True
10428        wa = weakref.ref(a, callback)
10429        a._fix_weakref()
10430        del a
10431
10432        self.assertTrue(called)
10433
10434    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
10435    def test_storage_fix_weakref_no_leak(self):
10436        import weakref
10437
10438        called = False
10439
10440        a = torch.UntypedStorage(1)
10441
10442        def callback(w):
10443            nonlocal called
10444            called = True
10445        wa = weakref.ref(a, callback)
10446        a._fix_weakref()
10447        del a
10448
10449        self.assertTrue(called)
10450
10451    # FIXME: move to test_linalg
10452    @torch.inference_mode()
10453    def test_bmm_multithreaded(self):
10454        device = 'cpu'
10455        num_threads = torch.get_num_threads()
10456
10457        torch.set_num_threads(4)
10458        batch_sizes = [1, 10]
10459        M, N, O = 23, 8, 12
10460        dtype = torch.float32
10461        numpy_dtype = dtype
10462
10463        def invert_perm(p):
10464            d = {x: i for i, x in enumerate(p)}
10465            return (d[0], d[1], d[2])
10466
10467        def generate_inputs(num_batches):
10468            # transposed tensors
10469            for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2):
10470                b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
10471                b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
10472                b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
10473                b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
10474                yield b1, b2
10475            # broadcasting tensors
10476            for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6):
10477                shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1)
10478                shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1)
10479                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N)
10480                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O)
10481                yield b1, b2
10482            # zero-sized tensors
10483            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
10484                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
10485                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
10486                b1 = torch.randn(shape1, dtype=dtype, device=device)
10487                b2 = torch.randn(shape2, dtype=dtype, device=device)
10488                yield b1, b2
10489
10490        try:
10491            for num_batches in batch_sizes:
10492                for (b1, b2), perm3 in itertools.product(generate_inputs(num_batches), itertools.permutations((0, 1, 2))):
10493                    res1 = torch.bmm(b1, b2)
10494                    res2 = torch.full((num_batches, M, O), math.nan, dtype=dtype, device=device) \
10495                        .permute(perm3).contiguous().permute(invert_perm(perm3))
10496                    torch.bmm(b1, b2, out=res2)
10497                    expect = torch.from_numpy(
10498                        b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
10499                    self.assertEqual(expect, res1)
10500                    self.assertEqual(expect, res2)
10501        finally:
10502            torch.set_num_threads(num_threads)
10503
10504    def test_conj_neg_tolist(self):
10505        x = torch.randn(2, dtype=torch.cfloat)
10506        y1 = x.conj()
10507        y1_expect = x.conj_physical()
10508        y2 = y1.imag
10509        self.assertEqual(y1, y1_expect.tolist())
10510        self.assertEqual(y2, y1_expect.imag.tolist())
10511
10512    @unittest.skipIf(torch.backends.cuda.is_built(), "Skipped for cuda-enabled build")
10513    def test_no_cuda_monkeypatch(self):
10514        # Note that this is not in test_cuda.py as this whole file is skipped when cuda
10515        # is not available.
10516        with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class Stream"):
10517            torch.cuda.Stream()
10518
10519        with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class Event"):
10520            torch.cuda.Event()
10521
10522        with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class CUDAGraph"):
10523            torch.cuda.graphs.CUDAGraph()
10524
10525    def test_tensor_where_scalar(self):
10526
10527        a = torch.arange(4.0)
10528        not_zero = 0.001
10529
10530        # b is generated through torch.where function with not_zero being a scalar parameter
10531        b = torch.where(a != 0, a, not_zero)
10532        # c is generated through Tensor.where method with not_zero being a scalar parameter
10533        c = a.where(a != 0, not_zero)
10534
10535        self.assertEqual(b, c)
10536
10537    def test_data_ptr_of_empty_tensor_with_storage(self):
10538        t = torch.empty((2, 2))
10539        self.assertNotEqual(t.data_ptr(), 0)
10540        t.resize_((0, 2))
10541        self.assertEqual(t.data_ptr(), 0)
10542
10543    def test_data_ptr_of_empty_view_with_storage(self):
10544        t = torch.empty((2, 2))
10545        self.assertNotEqual(t.data_ptr(), 0)
10546        t2 = t[0:0].view(0, 1)
10547        self.assertEqual(t2.data_ptr(), 0)
10548
10549    def test_size_stride(self) -> None:
10550        t = torch.rand(2, 3, dtype=torch.float32)
10551        self.assertEqual(t.size(0), 2)
10552        self.assertEqual(t.size(dim=None), torch.Size([2, 3]))
10553        self.assertEqual(t.stride(dim=None), torch.Size([3, 1]))
10554        self.assertEqual(t.t().stride(), torch.Size([1, 3]))
10555
10556    def test_invalid_arg_error_handling(self) -> None:
10557        """ Tests that errors from old TH functions are propagated back """
10558        for invalid_val in [-1, 2**65]:
10559            self.assertRaises(RuntimeError, lambda: torch.set_num_threads(invalid_val))
10560            self.assertRaises(RuntimeError, lambda: torch.set_num_interop_threads(invalid_val))
10561
10562    def _get_tensor_prop(self, t):
10563        preserved = (
10564            id(t),
10565            # Refcount values get modified by Dynamo resume frames
10566            0 if TEST_WITH_TORCHDYNAMO else sys.getrefcount(t),
10567        )
10568        slotnames = copyreg._slotnames(t.__class__)
10569        moved = (
10570            slotnames,
10571            id(t.__dict__),
10572            tuple(t.__dict__.keys()),
10573            [getattr(t, name, None) for name in slotnames]
10574        )
10575        return preserved, moved
10576
10577    def _checked_swap(self, t1, t2):
10578        t1_pres, t1_moved = self._get_tensor_prop(t1)
10579        t2_pres, t2_moved = self._get_tensor_prop(t2)
10580
10581        torch.utils.swap_tensors(t1, t2)
10582
10583        new_t1_pres, new_t1_moved = self._get_tensor_prop(t1)
10584        new_t2_pres, new_t2_moved = self._get_tensor_prop(t2)
10585        self.assertEqual(t1_pres, new_t1_pres)
10586        self.assertEqual(t2_pres, new_t2_pres)
10587        self.assertEqual(t1_moved, new_t2_moved)
10588        self.assertEqual(t2_moved, new_t1_moved)
10589
10590        # tests that PyObject slots on TensorImpl are correctly swapped by
10591        # checking that when the function applied on a swapped tensor is
10592        # returns doesn't change the TensorImpl, the returned value (which is
10593        # given by returning the reference to the PyObject in the TensorImpl's
10594        # PyObjectSlot) is still correct
10595        self.assertEqual(id(t1.fill_(0.5)), id(t1))
10596        self.assertEqual(id(t2.fill_(0.5)), id(t2))
10597
10598    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo adds weakrefs")
10599    def test_swap_basic(self):
10600        ts = [
10601            torch.rand(2),
10602            torch.rand(3, 3),
10603            torch.empty(3, dtype=torch.int),
10604            TwoTensor(torch.rand(4), torch.rand(4))
10605        ]
10606
10607        for t1, t2 in itertools.combinations(ts, 2):
10608            t1 = t1.clone()
10609            t2 = t2.clone()
10610            t2.foo = "bar"
10611            holder = []
10612            holder.append(t1)
10613
10614            self._checked_swap(t1, t2)
10615
10616            self.assertIs(holder[0], t1)
10617            self.assertEqual(t1.foo, "bar")
10618
10619            if t1.is_floating_point():
10620                t3 = t1.clone().detach().requires_grad_(True)
10621                out = t3 * 2
10622                torch.utils.swap_tensors(t3, t2)
10623                with self.assertRaisesRegex(RuntimeError, "AccumulateGrad node that was poisoned by swap_tensors"):
10624                    out.sum().backward()
10625
10626            wr = weakref.ref(t1)
10627            with self.assertRaisesRegex(RuntimeError, "has weakref"):
10628                torch.utils.swap_tensors(t1, t2)
10629
10630
10631    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo adds weakrefs")
10632    def test_swap_fail_slots(self):
10633        class MyTwoTensor(TwoTensor):
10634            __slots__ = ("a", "b")
10635
10636        class MyTwoTensor2(TwoTensor):
10637            __slots__ = ("b", "a")
10638
10639        class MyTwoTensor3(TwoTensor):
10640            __slots__ = ("a", "b", "c", "d")
10641
10642        class MyTwoTensor4(TwoTensor):
10643            __slots__ = ("a", "c")
10644
10645
10646        t1 = torch.rand(4)
10647        t2 = TwoTensor(torch.rand(4), torch.rand(4))
10648        t3 = MyTwoTensor(torch.rand(4), torch.rand(4))
10649        t4 = MyTwoTensor(torch.rand(4), torch.rand(4))
10650        t5 = MyTwoTensor2(torch.rand(4), torch.rand(4))
10651        t6 = MyTwoTensor3(torch.rand(4), torch.rand(4))
10652        t7 = MyTwoTensor3(torch.rand(4), torch.rand(4))
10653        t8 = MyTwoTensor4(torch.rand(4), torch.rand(4))
10654
10655        self._checked_swap(t1, t2)
10656        with self.assertRaisesRegex(RuntimeError, "Cannot swap t1 and t2 if they have different slots"):
10657            torch.utils.swap_tensors(t1, t3)
10658        with self.assertRaisesRegex(RuntimeError, "Cannot swap t1 and t2 if they have different slots"):
10659            torch.utils.swap_tensors(t2, t3)
10660        with self.assertRaisesRegex(RuntimeError, "Cannot swap t1 and t2 if they have different slots"):
10661            torch.utils.swap_tensors(t2, t8)
10662        self._checked_swap(t3, t4)
10663        self._checked_swap(t3, t5)
10664        with self.assertRaisesRegex(RuntimeError, "Cannot swap t1 and t2 if they have different slots"):
10665            torch.utils.swap_tensors(t3, t6)
10666        t3.c = "foo"
10667        t4.d = "bar"
10668        self._checked_swap(t3, t4)
10669        self.assertEqual(t4.c, "foo")
10670        self.assertEqual(t3.d, "bar")
10671        t6.c = "cat"
10672        t7.d = "dog"
10673        self._checked_swap(t6, t7)
10674
10675    @unittest.skipIf(torch.cuda.is_available(), "Test specific for CPU")
10676    def test_bf16_supported_on_cpu(self):
10677        self.assertFalse(torch.cuda.is_bf16_supported())
10678
10679
10680# The following block extends TestTorch with negative dim wrapping tests
10681# FIXME: replace these with OpInfo sample inputs or systemic OpInfo tests
10682# Functions to test negative dimension wrapping
10683METHOD = 1
10684INPLACE_METHOD = 2
10685FUNCTIONAL = 4
10686DIM_ARG: None = None
10687
10688def make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim=0):
10689    def neg_dim_test(self):
10690        if isinstance(tensor_arg, list):
10691            assert METHOD not in types and INPLACE_METHOD not in types
10692            x = [torch.randn(arg) for arg in tensor_arg]
10693            ndim = len(tensor_arg[-1])
10694        else:
10695            x = torch.randn(*tensor_arg)
10696            ndim = len(tensor_arg)
10697        ndim += extra_dim
10698
10699        n_dim_to_test = sum(e is DIM_ARG for e in arg_constr())
10700
10701        for dims_val in combinations(range(ndim), n_dim_to_test):
10702            arg = arg_constr()
10703            arg_neg = copy.deepcopy(arg)
10704            idx = 0
10705            for i, v in enumerate(arg):
10706                if v is DIM_ARG:
10707                    arg[i] = dims_val[idx]
10708                    arg_neg[i] = dims_val[idx] - ndim
10709                    idx += 1
10710
10711            if METHOD in types:
10712                a = getattr(x, name)(*arg)
10713                b = getattr(x, name)(*arg_neg)
10714                self.assertEqual(a, b)
10715
10716            if INPLACE_METHOD in types:
10717                a = x.clone()
10718                getattr(a, name + '_')(*arg)
10719                b = x.clone()
10720                getattr(b, name + '_')(*arg_neg)
10721                self.assertEqual(a, b)
10722
10723            if FUNCTIONAL in types:
10724                a = getattr(torch, name)(x, *arg)
10725                b = getattr(torch, name)(x, *arg_neg)
10726                self.assertEqual(a, b)
10727
10728    return neg_dim_test
10729
10730def idx_tensor(size, max_val):
10731    return torch.LongTensor(*size).random_(0, max_val - 1)
10732
10733def add_neg_dim_tests():
10734    neg_dim_tests = [
10735        ('narrow', (10, 20, 30), lambda: [DIM_ARG, 0, 5], [METHOD]),
10736        ('transpose', (10, 20, 30), lambda: [DIM_ARG, DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
10737        ('size', (10, 20, 30), lambda: [DIM_ARG], [METHOD]),
10738        ('cat', [(2, 3, 4), (2, 3, 4)], lambda: [DIM_ARG], [FUNCTIONAL]),
10739        ('chunk', (10, 20, 30), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
10740        ('gather', (10, 20), lambda: [DIM_ARG, idx_tensor((10, 20), 10)], [METHOD, FUNCTIONAL]),
10741        ('index_select', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10)], [METHOD, FUNCTIONAL]),
10742        ('split', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
10743        ('squeeze', (10, 1, 20, 1), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
10744        ('unbind', (2, 3, 4), lambda: [DIM_ARG], [FUNCTIONAL]),
10745        ('unsqueeze', (10, 20), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL], 1),
10746        ('logcumsumexp', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10747        ('cumprod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10748        ('cumsum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10749        ('cummax', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10750        ('cummin', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10751        ('mean', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10752        ('median', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10753        ('nanmedian', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10754        ('mode', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10755        ('norm', (10, 20), lambda: [2, DIM_ARG], [METHOD, FUNCTIONAL]),
10756        ('prod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10757        ('std', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10758        ('sum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10759        ('var', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10760        ('kthvalue', (10, 20), lambda: [3, DIM_ARG], [METHOD, FUNCTIONAL]),
10761        ('max', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10762        ('min', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10763        ('sort', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10764        ('topk', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
10765        ('renorm', (10, 20), lambda: [2, DIM_ARG, 1], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
10766        ('index_add', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
10767        ('index_copy', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
10768        ('index_fill', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), 12], [INPLACE_METHOD]),
10769        ('scatter', (10, 10), lambda: [DIM_ARG, idx_tensor((10, 10), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
10770        ('select', (10, 20), lambda: [DIM_ARG, 3], [METHOD]),
10771        ('unfold', (10, 20), lambda: [DIM_ARG, 5, 2], [METHOD]),
10772    ]
10773
10774    for decl in neg_dim_tests:
10775        if len(decl) == 4:
10776            name, tensor_arg, arg_constr, types = decl
10777            extra_dim = 0
10778        elif len(decl) == 5:
10779            name, tensor_arg, arg_constr, types, extra_dim = decl
10780
10781        test_name = 'test_' + name + '_neg_dim'
10782
10783        assert not hasattr(TestTorch, test_name), "Duplicated test name: " + test_name
10784        setattr(TestTorch, test_name, make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim))
10785
10786# TODO: these empy classes are temporarily instantiated for XLA compatibility
10787#   once XLA updates their test suite it should be removed
10788class TestViewOps(TestCase):
10789    pass
10790
10791class TestTensorDeviceOps(TestCase):
10792    pass
10793
10794# Generates tests
10795# Note: test generation must be done at file scope, not within main, or
10796# pytest will fail.
10797add_neg_dim_tests()
10798instantiate_device_type_tests(TestViewOps, globals())
10799instantiate_device_type_tests(TestVitalSignsCuda, globals())
10800instantiate_device_type_tests(TestTensorDeviceOps, globals())
10801instantiate_device_type_tests(TestTorchDeviceType, globals())
10802instantiate_device_type_tests(TestDevicePrecision, globals(), except_for='cpu')
10803
10804if __name__ == '__main__':
10805    TestCase._default_dtype_check_enabled = True
10806    run_tests()
10807